summaryrefslogtreecommitdiff
path: root/tests/test_scrypt.py
blob: 4c5a207db33c68ec1d8b086bdd166ae7c8f8ece8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# -*- coding: utf-8 -*-

from os import urandom
from os.path import dirname, abspath, sep
from sys import version_info
from csv import reader
from binascii import a2b_hex, b2a_hex
import base64
import json

if ((version_info > (3, 2, 0, 'final', 0)) or
    (version_info > (2, 7, 0, 'final', 0) and version_info < (3, 0, 0, 'final', 0))):
    import unittest as testm
else:
    try:
        import unittest2 as testm
    except ImportError:
        print("Please install unittest2 to run the test suite")
        exit(-1)

import scrypt


class TestScrypt(testm.TestCase):

    def setUp(self):
        self.input = "message"
        self.password = "password"
        self.longinput = str(urandom(100000))
        self.five_minutes = 300.0
        self.five_seconds = 5.0
        self.one_byte = 1  # in Bytes
        self.one_megabyte = 1024 * 1024  # in Bytes
        self.ten_megabytes = 10 * self.one_megabyte
        base_dir = dirname(abspath(__file__)) + sep
        cvf = open(base_dir + "ciphertexts.csv", "r")
        ciphertxt_reader = reader(cvf, dialect="excel")
        self.ciphertexts = []
        for row in ciphertxt_reader:
            self.ciphertexts.append(row)
        cvf.close()
        self.ciphertext = a2b_hex(bytes(self.ciphertexts[1][5].encode('ascii')))

    def test_encrypt_decrypt(self):
        """Test encrypt for simple encryption and decryption"""
        s = scrypt.encrypt(self.input, self.password, 0.1)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt(self):
        """Test encrypt takes input and password strings as
        positional arguments and produces ciphertext"""
        s = scrypt.encrypt(self.input, self.password)
        self.assertEqual(len(s), 128 + len(self.input))

    def test_encrypt_input_and_password_as_keywords(self):
        """Test encrypt for input and password accepted as keywords"""
        s = scrypt.encrypt(password=self.password, input=self.input)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_missing_input_keyword_argument(self):
        """Test encrypt raises TypeError if keyword argument missing input"""
        self.assertRaises(TypeError, lambda: scrypt.encrypt(password=self.password))

    def test_encrypt_missing_password_positional_argument(self):
        """Test encrypt raises TypeError if second positional argument missing
        (password)"""
        self.assertRaises(TypeError, lambda: scrypt.encrypt(self.input))

    def test_encrypt_missing_both_required_positional_arguments(self):
        """Test encrypt raises TypeError if both positional arguments missing
        (input and password)"""
        self.assertRaises(TypeError, lambda: scrypt.encrypt())

    def test_encrypt_maxtime_positional(self):
        """Test encrypt maxtime accepts maxtime at position 3"""
        s = scrypt.encrypt(self.input, self.password, 0.01)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxtime_key(self):
        """Test encrypt maxtime accepts maxtime as keyword argument"""
        s = scrypt.encrypt(self.input, self.password, maxtime=0.01)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxmem_positional(self):
        """Test encrypt maxmem accepts 4th positional argument and exactly
        (1 megabyte) of storage to use for V array"""
        s = scrypt.encrypt(self.input, self.password, 0.01, self.one_megabyte)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxmem_undersized(self):
        """Test encrypt maxmem accepts (< 1 megabyte) of storage to use for V array"""
        s = scrypt.encrypt(self.input, self.password, 0.01, self.one_byte)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxmem_in_normal_range(self):
        """Test encrypt maxmem accepts (> 1 megabyte) of storage to use for V array"""
        s = scrypt.encrypt(self.input,
                           self.password,
                           0.01,
                           self.ten_megabytes)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxmem_keyword_argument(self):
        """Test encrypt maxmem accepts exactly (1 megabyte) of storage to use for
        V array"""
        s = scrypt.encrypt(self.input,
                           self.password,
                           maxmem=self.one_megabyte,
                           maxtime=0.01)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxmemfrac_positional(self):
        """Test encrypt maxmemfrac accepts 5th positional argument of 1/16 total
        memory for V array"""
        s = scrypt.encrypt(self.input, self.password, 0.01, 0, 0.0625)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_maxmemfrac_keyword_argument(self):
        """Test encrypt maxmemfrac accepts keyword argument of 1/16 total memory for
        V array"""
        s = scrypt.encrypt(self.input, self.password, maxmemfrac=0.0625,
                           maxtime=0.01)
        m = scrypt.decrypt(s, self.password)
        self.assertEqual(m, self.input)

    def test_encrypt_long_input(self):
        """Test encrypt accepts long input for encryption"""
        s = scrypt.encrypt(self.longinput, self.password, 0.1)
        self.assertEqual(len(s), 128 + len(self.longinput))

    def test_encrypt_raises_error_on_invalid_keyword(self):
        """Test encrypt raises TypeError if invalid keyword used in argument"""
        self.assertRaises(TypeError, lambda: scrypt.encrypt(self.input,
            self.password, nonsense="Raise error"))

    def test_decrypt_from_csv_ciphertexts(self):
        """Test decrypt function with precalculated combinations"""
        for row in self.ciphertexts[1:]:
            h = scrypt.decrypt(a2b_hex(bytes(row[5].encode('ascii'))), row[1])
            self.assertEqual(bytes(h.encode("ascii")), row[0].encode("ascii"))

    def test_decrypt_maxtime_positional(self):
        """Test decrypt function accepts third positional argument"""
        m = scrypt.decrypt(self.ciphertext, self.password, self.five_seconds)
        self.assertEqual(m, self.input)

    def test_decrypt_maxtime_keyword_argument(self):
        """Test decrypt function accepts maxtime keyword argument"""
        m = scrypt.decrypt(maxtime=1.0, input=self.ciphertext, password=self.password)
        self.assertEqual(m, self.input)

    def test_decrypt_maxmem_positional(self):
        """Test decrypt function accepts fourth positional argument"""
        m = scrypt.decrypt(self.ciphertext, self.password, self.five_minutes, self.ten_megabytes)
        self.assertEqual(m, self.input)

    def test_decrypt_maxmem_keyword_argument(self):
        """Test decrypt function accepts maxmem keyword argument"""
        m = scrypt.decrypt(maxmem=self.ten_megabytes, input=self.ciphertext, password=self.password)
        self.assertEqual(m, self.input)

    def test_decrypt_maxmemfrac_positional(self):
        """Test decrypt function accepts maxmem keyword argument"""
        m = scrypt.decrypt(self.ciphertext, self.password, self.five_minutes, self.one_megabyte, 0.0625)
        self.assertEqual(m, self.input)

    def test_decrypt_maxmemfrac_keyword_argument(self):
        """Test decrypt function accepts maxmem keyword argument"""
        m = scrypt.decrypt(maxmemfrac=0.625, input=self.ciphertext, password=self.password)
        self.assertEqual(m, self.input)

    def test_decrypt_raises_error_on_too_little_time(self):
        """Test decrypt function raises scrypt.error raised if insufficient time allowed for
        ciphertext decryption"""
        s = scrypt.encrypt(self.input, self.password, 0.1)
        self.assertRaises(scrypt.error,
                          lambda: scrypt.decrypt(s, self.password, .01))


class TestScryptHash(testm.TestCase):

    def setUp(self):
        self.input = "message"
        self.password = "password"
        self.salt = "NaCl"
        self.hashes = []
        base_dir = dirname(abspath(__file__)) + sep
        hvf = open(base_dir + "hashvectors.csv", "r")
        hash_reader = reader(hvf, dialect="excel")
        for row in hash_reader:
            self.hashes.append(row)
        hvf.close()

    def test_hash_vectors_from_csv(self):
        """Test hash function with precalculated combinations"""
        for row in self.hashes[1:]:
            h = scrypt.hash(row[0], row[1], int(row[2]), int(row[3]), int(row[4]))
            hhex = b2a_hex(h)
            self.assertEqual(hhex, bytes(row[5].encode("utf-8")))

    def test_hash_buflen_keyword(self):
        """Test hash takes keyword valid buflen"""
        h64 = scrypt.hash(self.input, self.salt, buflen=64)
        h128 = scrypt.hash(self.input, self.salt, buflen=128)
        self.assertEqual(len(h64), 64)
        self.assertEqual(len(h128), 128)

    def test_hash_n_positional(self):
        """Test hash accepts valid N in position 3"""
        h = scrypt.hash(self.input, self.salt, 256)
        self.assertEqual(len(h), 64)

    def test_hash_n_keyword(self):
        """Test hash takes keyword valid N"""
        h = scrypt.hash(N=256, password=self.input, salt=self.salt)
        self.assertEqual(len(h), 64)

    def test_hash_r_positional(self):
        """Test hash accepts valid r in position 4"""
        h = scrypt.hash(self.input, self.salt, 256, 16)
        self.assertEqual(len(h), 64)

    def test_hash_r_keyword(self):
        """Test hash takes keyword valid r"""
        h = scrypt.hash(r=16, password=self.input, salt=self.salt)
        self.assertEqual(len(h), 64)

    def test_hash_p_positional(self):
        """Test hash accepts valid p in position 5"""
        h = scrypt.hash(self.input, self.salt, 256, 8, 2)
        self.assertEqual(len(h), 64)

    def test_hash_p_keyword(self):
        """Test hash takes keyword valid p"""
        h = scrypt.hash(p=4, password=self.input, salt=self.salt)
        self.assertEqual(len(h), 64)

    def test_hash_raises_error_on_p_equals_zero(self):
        """Test hash raises scrypt error on illegal parameter value (p = 0)"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, p=0))

    def test_hash_raises_error_on_negative_p(self):
        """Test hash raises scrypt error on illegal parameter value (p < 0)"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, p=-1))

    def test_hash_raises_error_on_r_equals_zero(self):
        """Test hash raises scrypt error on illegal parameter value (r = 0)"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, r=0))

    def test_hash_raises_error_on_negative_r(self):
        """Test hash raises scrypt error on illegal parameter value (r < 1)"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, r=-1))

    def test_hash_raises_error_r_p_over_limit(self):
        """Test hash raises scrypt error when parameters r multiplied by p over limit 2**30"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, r=2, p=2 ** 29))

    def test_hash_raises_error_n_not_power_of_two(self):
        """Test hash raises scrypt error when parameter N is not a power of two {2, 4, 8, 16, etc}"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, N=3))

    def test_hash_raises_error_n_under_limit(self):
        """Test hash raises scrypt error when parameter N under limit of 1"""
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, N=1))
        self.assertRaises(scrypt.error,
                          lambda: scrypt.hash(self.input, self.salt, N=-1))

if __name__ == "__main__":
    testm.main()