From 2d82d01a95f6da0c3206bf5de083a3aa465eb084 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Sat, 26 Nov 2016 21:26:23 -0300 Subject: [refactor] simplify _crypto After adding the streaming decrypt, some classes were doing almost the same thing. Unified them. Also fixed some module level variables to upper case and some class name to camel case. --- client/src/leap/soledad/client/_crypto.py | 228 ++++++++++++------------------ testing/tests/client/test_crypto.py | 5 +- 2 files changed, 91 insertions(+), 142 deletions(-) diff --git a/client/src/leap/soledad/client/_crypto.py b/client/src/leap/soledad/client/_crypto.py index 163c9e4e..22335f9d 100644 --- a/client/src/leap/soledad/client/_crypto.py +++ b/client/src/leap/soledad/client/_crypto.py @@ -36,7 +36,6 @@ import six from twisted.internet import defer from twisted.internet import interfaces -from twisted.logger import Logger from twisted.web.client import FileBodyProducer from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -47,21 +46,16 @@ from cryptography.hazmat.backends.openssl.backend \ from zope.interface import implements -log = Logger() - MAC_KEY_LENGTH = 64 -crypto_backend = MultiBackend([OpenSSLBackend()]) +CRYPTO_BACKEND = MultiBackend([OpenSSLBackend()]) PACMAN = struct.Struct('cQbb16s255p255p') -class ENC_SCHEME: - symkey = 1 - - -class ENC_METHOD: - aes_256_ctr = 1 +ENC_SCHEME = namedtuple('SCHEME', 'symkey')(1) +ENC_METHOD = namedtuple('METHOD', 'aes_256_ctr')(1) +DocInfo = namedtuple('DocInfo', 'doc_id rev') class EncryptionDecryptionError(Exception): @@ -72,9 +66,6 @@ class InvalidBlob(Exception): pass -docinfo = namedtuple('docinfo', 'doc_id rev') - - class SoledadCrypto(object): """ This class provides convenient methods for document encryption and @@ -107,7 +98,7 @@ class SoledadCrypto(object): content = BytesIO() content.write(str(doc.get_json())) - info = docinfo(doc.doc_id, doc.rev) + info = DocInfo(doc.doc_id, doc.rev) del doc encryptor = BlobEncryptor(info, content, secret=self.secret) d = encryptor.encrypt() @@ -124,7 +115,7 @@ class SoledadCrypto(object): :return: The decrypted cleartext content of the document. :rtype: str """ - info = docinfo(doc.doc_id, doc.rev) + info = DocInfo(doc.doc_id, doc.rev) ciphertext = BytesIO() payload = doc.content['raw'] del doc @@ -146,10 +137,10 @@ def encrypt_sym(data, key): encoded as base64. :rtype: (str, str) """ - encryptor = AESEncryptor(key) + encryptor = AESConsumer(key) encryptor.write(data) encryptor.end() - ciphertext = encryptor.fd.getvalue() + ciphertext = encryptor.buffer.getvalue() return base64.b64encode(encryptor.iv), ciphertext @@ -169,10 +160,10 @@ def decrypt_sym(data, key, iv): :rtype: str """ _iv = base64.b64decode(str(iv)) - decryptor = AESDecryptor(key, _iv) + decryptor = AESConsumer(key, _iv, operation=AESConsumer.decrypt) decryptor.write(data) decryptor.end() - plaintext = decryptor.fd.getvalue() + plaintext = decryptor.buffer.getvalue() return plaintext @@ -205,15 +196,16 @@ class BlobEncryptor(object): mac_key = _get_mac_key_for_doc(doc_info.doc_id, secret) self._aes_fd = BytesIO() - self._aes = AESEncryptor(sym_key, self._aes_fd) - self._hmac = HMACWriter(mac_key) + _aes = AESConsumer(sym_key, _buffer=self._aes_fd) + self.__iv = _aes.iv + self._hmac_writer = HMACWriter(mac_key) self._write_preamble() - self._crypter = VerifiedEncrypter(self._aes, self._hmac) + self._crypter = VerifiedEncrypter(_aes, self._hmac_writer) @property def iv(self): - return self._aes.iv + return self.__iv def encrypt(self): """ @@ -224,26 +216,14 @@ class BlobEncryptor(object): :rtype: twisted.internet.defer.Deferred """ d = self._producer.startProducing(self._crypter) - d.addCallback(self._end_crypto_stream) + d.addCallback(lambda _: self._end_crypto_stream()) return d - def encrypt_whole(self): - """ - Encrypts the input data at once and returns the resulting ciphertext - wrapped into a JSON string under the "raw" key. - - :return: The resulting ciphertext JSON string. - :rtype: str - """ - self._crypter.write(self._content_fd.getvalue()) - self._end_crypto_stream(None) - return '{"raw":"' + self.result.getvalue() + '"}' - def _write_preamble(self): def write(data): self._preamble.write(data) - self._hmac.write(data) + self._hmac_writer.write(data) current_time = int(time.time()) @@ -256,23 +236,16 @@ class BlobEncryptor(object): str(self.doc_id), str(self.rev))) - def _end_crypto_stream(self, ignored): - self._aes.end() - self._hmac.end() - self._content_fd.close() + def _end_crypto_stream(self): + encrypted, content_hmac = self._crypter.end() preamble = self._preamble.getvalue() - encrypted = self._aes_fd.getvalue() - hmac = self._hmac.result.getvalue() self.result.write( base64.urlsafe_b64encode(preamble)) self.result.write(' ') self.result.write( - base64.urlsafe_b64encode(encrypted + hmac)) - self._preamble.close() - self._aes_fd.close() - self._hmac.result.close() + base64.urlsafe_b64encode(encrypted + content_hmac)) self.result.seek(0) return defer.succeed(self.result) @@ -289,62 +262,65 @@ class BlobDecryptor(object): secret=None): if not secret: raise EncryptionDecryptionError('no secret given') - ciphertext_fd.seek(0) self.doc_id = doc_info.doc_id self.rev = doc_info.rev - self.sym_key = _get_sym_key_for_doc(doc_info.doc_id, secret) - self.mac_key = _get_mac_key_for_doc(doc_info.doc_id, secret) - - self._read_preamble(ciphertext_fd) - - self._producer = FileBodyProducer(self.ciphertext, readSize=2**16) - self._content_fd = self.ciphertext + ciphertext_fd, preamble, iv = self._consume_preamble(ciphertext_fd) + mac_key = _get_mac_key_for_doc(doc_info.doc_id, secret) + self._current_hmac = BytesIO() + _hmac_writer = HMACWriter(mac_key, self._current_hmac) + _hmac_writer.write(preamble) self.result = result or BytesIO() + sym_key = _get_sym_key_for_doc(doc_info.doc_id, secret) + _aes = AESConsumer(sym_key, iv, self.result, + operation=AESConsumer.decrypt) + self._decrypter = VerifiedDecrypter(_aes, _hmac_writer) - self._aes_fd = BytesIO() - self._aes = AESDecryptor(self.sym_key, self.iv, self.result) - self._hmac = HMACWriter(self.mac_key) - self._hmac.write(self.preamble) - - self._decrypter = VerifiedDecrypter(self._aes, self._hmac) + self._producer = FileBodyProducer(ciphertext_fd, readSize=2**16) - def _read_preamble(self, ciphertext): + def _consume_preamble(self, ciphertext_fd): + ciphertext_fd.seek(0) try: - self.preamble, ciphertext = _split(ciphertext.getvalue()) - self.doc_hmac, self.ciphertext = ciphertext[-64:], ciphertext[:-64] + preamble, ciphertext = _split(ciphertext_fd.getvalue()) + self.doc_hmac, ciphertext = ciphertext[-64:], ciphertext[:-64] except (TypeError, binascii.Error): raise InvalidBlob - self.ciphertext = BytesIO(self.ciphertext) + ciphertext_fd.close() - if len(self.preamble) != PACMAN.size: + if len(preamble) != PACMAN.size: raise InvalidBlob try: - unpacked_data = PACMAN.unpack(self.preamble) + unpacked_data = PACMAN.unpack(preamble) pad, ts, sch, meth, iv, doc_id, rev = unpacked_data - self.iv = iv except struct.error: raise InvalidBlob + if pad != '\x80': raise InvalidBlob - # TODO check timestamp if sch != ENC_SCHEME.symkey: raise InvalidBlob('invalid scheme') # TODO should adapt the assymetric-gpg too, rigth? if meth != ENC_METHOD.aes_256_ctr: raise InvalidBlob('invalid encryption scheme') - if rev != self.rev: raise InvalidBlob('invalid revision') + if doc_id != self.doc_id: + raise InvalidBlob('invalid revision') + return BytesIO(ciphertext), preamble, iv def _check_hmac(self): - if self._hmac._hmac.digest() != self.doc_hmac: + if self._current_hmac.getvalue() != self.doc_hmac: raise InvalidBlob('HMAC could not be verifed') + def _end_stream(self): + self._decrypter.end() + self._check_hmac() + return self.result.getvalue() + def decrypt(self): """ Starts producing encrypted data from the cleartext data. @@ -354,50 +330,9 @@ class BlobDecryptor(object): :rtype: twisted.internet.defer.Deferred """ d = self._producer.startProducing(self._decrypter) - d.addCallback(lambda _: self._check_hmac()) - d.addCallback(lambda _: self.result.getvalue()) + d.addCallback(lambda _: self._end_stream()) return d - def decrypt_whole(self): - ciphertext = self.ciphertext.getvalue() - self.hmac_obj.update(ciphertext) - self._check_hmac() - decryptor = _get_aes_ctr_cipher(self.sym_key, self.iv).decryptor() - - self.result.write(decryptor.update(ciphertext)) - self.result.write(decryptor.finalize()) - return self.result - - -class AESEncryptor(object): - """ - A Twisted's Consumer implementation that takes an input file descriptor and - applies AES-256 cipher in CTR mode. - """ - implements(interfaces.IConsumer) - - def __init__(self, key, fd=None): - if len(key) != 32: - raise EncryptionDecryptionError('key is not 256 bits') - self.iv = os.urandom(16) - - cipher = _get_aes_ctr_cipher(key, self.iv) - self.encryptor = cipher.encryptor() - - self.fd = fd or BytesIO() - - self.done = False - - def write(self, data): - encrypted = self.encryptor.update(data) - self.fd.write(encrypted) - return encrypted - - def end(self): - if not self.done: - self.fd.write(self.encryptor.finalize()) - self.done = True - class HMACWriter(object): """ @@ -407,15 +342,16 @@ class HMACWriter(object): implements(interfaces.IConsumer) hashtype = 'sha512' - def __init__(self, key): + def __init__(self, key, result=None): self._hmac = hmac.new(key, '', getattr(hashlib, self.hashtype)) - self.result = BytesIO('') + self.result = result or BytesIO('') def write(self, data): self._hmac.update(data) def end(self): self.result.write(self._hmac.digest()) + return self.result.getvalue() class VerifiedEncrypter(object): @@ -425,13 +361,18 @@ class VerifiedEncrypter(object): """ implements(interfaces.IConsumer) - def __init__(self, crypter, hmac): + def __init__(self, crypter, hmac_writer): self.crypter = crypter - self.hmac = hmac + self.hmac_writer = hmac_writer def write(self, data): enc_chunk = self.crypter.write(data) - self.hmac.write(enc_chunk) + self.hmac_writer.write(enc_chunk) + + def end(self): + ciphertext = self.crypter.end() + content_hmac = self.hmac_writer.end() + return ciphertext, content_hmac class VerifiedDecrypter(object): @@ -442,46 +383,53 @@ class VerifiedDecrypter(object): """ implements(interfaces.IConsumer) - def __init__(self, decrypter, hmac): + def __init__(self, decrypter, hmac_writer): self.decrypter = decrypter - self.hmac = hmac + self.hmac_writer = hmac_writer def write(self, enc_chunk): - self.hmac.write(enc_chunk) + self.hmac_writer.write(enc_chunk) self.decrypter.write(enc_chunk) + def end(self): + self.decrypter.end() + self.hmac_writer.end() + -class AESDecryptor(object): +class AESConsumer(object): """ - A Twisted's Consumer implementation that consumes data encrypted with - AES-256 in CTR mode from a file descriptor and generates decrypted data. + A Twisted's Consumer implementation that takes an input file descriptor and + applies AES-256 cipher in CTR mode. """ implements(interfaces.IConsumer) + encrypt = 1 + decrypt = 2 - def __init__(self, key, iv, fd=None): - iv = iv or os.urandom(16) + def __init__(self, key, iv=None, _buffer=None, operation=encrypt): if len(key) != 32: raise EncryptionDecryptionError('key is not 256 bits') - if len(iv) != 16: - raise EncryptionDecryptionError('iv is not 128 bits') - - cipher = _get_aes_ctr_cipher(key, iv) - self.decryptor = cipher.decryptor() - - self.fd = fd or BytesIO() - self.done = False + self.iv = iv or os.urandom(16) + self.buffer = _buffer or BytesIO() self.deferred = defer.Deferred() + self.done = False + + cipher = _get_aes_ctr_cipher(key, self.iv) + if operation == self.encrypt: + self.operator = cipher.encryptor() + else: + self.operator = cipher.decryptor() def write(self, data): - decrypted = self.decryptor.update(data) - self.fd.write(decrypted) - return decrypted + consumed = self.operator.update(data) + self.buffer.write(consumed) + return consumed def end(self): if not self.done: - self.decryptor.finalize() - self.deferred.callback(self.fd) + self.buffer.write(self.operator.finalize()) + self.deferred.callback(self.buffer) self.done = True + return self.buffer.getvalue() def is_symmetrically_encrypted(doc): @@ -525,7 +473,7 @@ def _get_sym_key_for_doc(doc_id, secret): def _get_aes_ctr_cipher(key, iv): - return Cipher(algorithms.AES(key), modes.CTR(iv), backend=crypto_backend) + return Cipher(algorithms.AES(key), modes.CTR(iv), backend=CRYPTO_BACKEND) def _split(base64_raw_payload): diff --git a/testing/tests/client/test_crypto.py b/testing/tests/client/test_crypto.py index 863873f7..7643f75d 100644 --- a/testing/tests/client/test_crypto.py +++ b/testing/tests/client/test_crypto.py @@ -52,7 +52,7 @@ class AESTest(unittest.TestCase): key = 'A' * 32 fd = BytesIO() - aes = _crypto.AESEncryptor(key, fd) + aes = _crypto.AESConsumer(key, _buffer=fd) iv = aes.iv data = snowden1 @@ -78,7 +78,8 @@ class AESTest(unittest.TestCase): ciphertext = _aes_encrypt(key, iv, data) fd = BytesIO() - aes = _crypto.AESDecryptor(key, iv, fd) + operation = _crypto.AESConsumer.decrypt + aes = _crypto.AESConsumer(key, iv, fd, operation) for i in range(len(ciphertext) / block): chunk = ciphertext[i * block:(i + 1) * block] -- cgit v1.2.3