diff options
Diffstat (limited to 'client/src/leap/soledad/client/encdecpool.py')
-rw-r--r-- | client/src/leap/soledad/client/encdecpool.py | 384 |
1 files changed, 110 insertions, 274 deletions
diff --git a/client/src/leap/soledad/client/encdecpool.py b/client/src/leap/soledad/client/encdecpool.py index 34667a1e..a6d49b21 100644 --- a/client/src/leap/soledad/client/encdecpool.py +++ b/client/src/leap/soledad/client/encdecpool.py @@ -22,12 +22,11 @@ during synchronization. """ -import multiprocessing -import Queue import json import logging +from uuid import uuid4 -from twisted.internet import reactor +from twisted.internet.task import LoopingCall from twisted.internet import threads from twisted.internet import defer from twisted.python import log @@ -51,9 +50,6 @@ class SyncEncryptDecryptPool(object): Base class for encrypter/decrypter pools. """ - # TODO implement throttling to reduce cpu usage?? - WORKERS = multiprocessing.cpu_count() - def __init__(self, crypto, sync_db): """ Initialize the pool of encryption-workers. @@ -66,21 +62,14 @@ class SyncEncryptDecryptPool(object): """ self._crypto = crypto self._sync_db = sync_db - self._pool = None self._delayed_call = None self._started = False def start(self): - if self.running: - return - self._create_pool() self._started = True def stop(self): - if not self.running: - return self._started = False - self._destroy_pool() # maybe cancel the next delayed call if self._delayed_call \ and not self._delayed_call.called: @@ -90,27 +79,6 @@ class SyncEncryptDecryptPool(object): def running(self): return self._started - def _create_pool(self): - self._pool = multiprocessing.Pool(self.WORKERS) - - def _destroy_pool(self): - """ - Cleanly close the pool of workers. - """ - logger.debug("Closing %s" % (self.__class__.__name__,)) - self._pool.close() - try: - self._pool.join() - except Exception: - pass - - def terminate(self): - """ - Terminate the pool of workers. - """ - logger.debug("Terminating %s" % (self.__class__.__name__,)) - self._pool.terminate() - def _runOperation(self, query, *args): """ Run an operation on the sync db. @@ -180,7 +148,6 @@ class SyncEncrypterPool(SyncEncryptDecryptPool): Initialize the sync encrypter pool. """ SyncEncryptDecryptPool.__init__(self, *args, **kwargs) - self._encr_queue = defer.DeferredQueue() # TODO delete already synced files from database def start(self): @@ -189,73 +156,33 @@ class SyncEncrypterPool(SyncEncryptDecryptPool): """ SyncEncryptDecryptPool.start(self) logger.debug("Starting the encryption loop...") - reactor.callWhenRunning(self._maybe_encrypt_and_recurse) def stop(self): """ Stop the encrypter pool. """ - # close the sync queue - if self._encr_queue: - q = self._encr_queue - for d in q.pending: - d.cancel() - del q - self._encr_queue = None SyncEncryptDecryptPool.stop(self) - def enqueue_doc_for_encryption(self, doc): + def encrypt_doc(self, doc): """ - Enqueue a document for encryption. + Encrypt document asynchronously then insert it on + local staging database. :param doc: The document to be encrypted. :type doc: SoledadDocument """ - try: - self._encr_queue.put(doc) - except Queue.Full: - # do not asynchronously encrypt this file if the queue is full - pass - - @defer.inlineCallbacks - def _maybe_encrypt_and_recurse(self): - """ - Process one document from the encryption queue. - - Asynchronously encrypt a document that will then be stored in the sync - db. Processed documents will be read by the SoledadSyncTarget during - the sync_exchange. - """ - try: - while self.running: - doc = yield self._encr_queue.get() - self._encrypt_doc(doc) - except defer.QueueUnderflow: - self._delayed_call = reactor.callLater( - self.ENCRYPT_LOOP_PERIOD, - self._maybe_encrypt_and_recurse) - - def _encrypt_doc(self, doc): - """ - Symmetrically encrypt a document. - - :param doc: The document with contents to be encrypted. - :type doc: SoledadDocument - - :param workers: Whether to defer the decryption to the multiprocess - pool of workers. Useful for debugging purposes. - :type workers: bool - """ soledad_assert(self._crypto is not None, "need a crypto object") docstr = doc.get_json() key = self._crypto.doc_passphrase(doc.doc_id) secret = self._crypto.secret args = doc.doc_id, doc.rev, docstr, key, secret # encrypt asynchronously - self._pool.apply_async( - encrypt_doc_task, args, - callback=self._encrypt_doc_cb) + # TODO use dedicated threadpool / move to ampoule + d = threads.deferToThread( + encrypt_doc_task, *args) + d.addCallback(self._encrypt_doc_cb) + return d def _encrypt_doc_cb(self, result): """ @@ -336,8 +263,8 @@ def decrypt_doc_task(doc_id, doc_rev, content, gen, trans_id, key, secret, :type doc_id: str :param doc_rev: The document revision. :type doc_rev: str - :param content: The encrypted content of the document. - :type content: str + :param content: The encrypted content of the document as JSON dict. + :type content: dict :param gen: The generation corresponding to the modification of that document. :type gen: int @@ -384,7 +311,7 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): """ TABLE_NAME = "docs_received" FIELD_NAMES = "doc_id PRIMARY KEY, rev, content, gen, " \ - "trans_id, encrypted, idx" + "trans_id, encrypted, idx, sync_id" """ Period of recurrence of the periodic decrypting task, in seconds. @@ -414,46 +341,8 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): self._docs_to_process = None self._processed_docs = 0 self._last_inserted_idx = 0 - self._decrypting_docs = [] - - # a list that holds the asynchronous decryption results so they can be - # collected when they are ready - self._async_results = [] - # initialize db and make sure any database operation happens after - # db initialization - self._deferred_init = self._init_db() - self._wait_init_db('_runOperation', '_runQuery') - - def _wait_init_db(self, *methods): - """ - Methods that need to wait for db initialization. - - :param methods: methods that need to wait for initialization - :type methods: tuple(str) - """ - self._waiting = [] - self._stored = {} - - def _restore(_): - for method in self._stored: - setattr(self, method, self._stored[method]) - for d in self._waiting: - d.callback(None) - - def _makeWrapper(method): - def wrapper(*args, **kw): - d = defer.Deferred() - d.addCallback(lambda _: self._stored[method](*args, **kw)) - self._waiting.append(d) - return d - return wrapper - - for method in methods: - self._stored[method] = getattr(self, method) - setattr(self, method, _makeWrapper(method)) - - self._deferred_init.addCallback(_restore) + self._loop = LoopingCall(self._decrypt_and_recurse) def start(self, docs_to_process): """ @@ -466,13 +355,39 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): :type docs_to_process: int """ SyncEncryptDecryptPool.start(self) + self._decrypted_docs_indexes = set() + self._sync_id = uuid4().hex self._docs_to_process = docs_to_process self._deferred = defer.Deferred() - reactor.callWhenRunning(self._launch_decrypt_and_recurse) + d = self._init_db() + d.addCallback(lambda _: self._loop.start(self.DECRYPT_LOOP_PERIOD)) + return d + + def stop(self): + if self._loop.running: + self._loop.stop() + self._finish() + SyncEncryptDecryptPool.stop(self) + + def _init_db(self): + """ + Ensure sync_id column is present then + Empty the received docs table of the sync database. - def _launch_decrypt_and_recurse(self): - d = self._decrypt_and_recurse() - d.addErrback(self._errback) + :return: A deferred that will fire when the operation in the database + has finished. + :rtype: twisted.internet.defer.Deferred + """ + ensure_sync_id_column = ("ALTER TABLE %s ADD COLUMN sync_id" % + self.TABLE_NAME) + d = self._runQuery(ensure_sync_id_column) + + def empty_received_docs(_): + query = "DELETE FROM %s WHERE sync_id <> ?" % (self.TABLE_NAME,) + return self._runOperation(query, (self._sync_id,)) + + d.addCallbacks(empty_received_docs, empty_received_docs) + return d def _errback(self, failure): log.err(failure) @@ -491,8 +406,8 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): def insert_encrypted_received_doc( self, doc_id, doc_rev, content, gen, trans_id, idx): """ - Insert a received message with encrypted content, to be decrypted later - on. + Decrypt and insert a received document into local staging area to be + processed later on. :param doc_id: The document ID. :type doc_id: str @@ -507,15 +422,22 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): :param idx: The index of this document in the current sync process. :type idx: int - :return: A deferred that will fire when the operation in the database - has finished. + :return: A deferred that will fire after the decrypted document has + been inserted in the sync db. :rtype: twisted.internet.defer.Deferred """ - docstr = json.dumps(content) - query = "INSERT OR REPLACE INTO '%s' VALUES (?, ?, ?, ?, ?, ?, ?)" \ - % self.TABLE_NAME - return self._runOperation( - query, (doc_id, doc_rev, docstr, gen, trans_id, 1, idx)) + soledad_assert(self._crypto is not None, "need a crypto object") + + key = self._crypto.doc_passphrase(doc_id) + secret = self._crypto.secret + args = doc_id, doc_rev, content, gen, trans_id, key, secret, idx + # decrypt asynchronously + # TODO use dedicated threadpool / move to ampoule + d = threads.deferToThread( + decrypt_doc_task, *args) + # callback will insert it for later processing + d.addCallback(self._decrypt_doc_cb) + return d def insert_received_doc( self, doc_id, doc_rev, content, gen, trans_id, idx): @@ -543,56 +465,29 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): """ if not isinstance(content, str): content = json.dumps(content) - query = "INSERT OR REPLACE INTO '%s' VALUES (?, ?, ?, ?, ?, ?, ?)" \ + query = "INSERT OR REPLACE INTO '%s' VALUES (?, ?, ?, ?, ?, ?, ?, ?)" \ % self.TABLE_NAME - return self._runOperation( - query, (doc_id, doc_rev, content, gen, trans_id, 0, idx)) + d = self._runOperation( + query, (doc_id, doc_rev, content, gen, trans_id, 0, + idx, self._sync_id)) + d.addCallback(lambda _: self._decrypted_docs_indexes.add(idx)) + return d - def _delete_received_doc(self, doc_id): + def _delete_received_docs(self, doc_ids): """ - Delete a received doc after it was inserted into the local db. + Delete a list of received docs after get them inserted into the db. - :param doc_id: Document ID. - :type doc_id: str + :param doc_id: Document ID list. + :type doc_id: list :return: A deferred that will fire when the operation in the database has finished. :rtype: twisted.internet.defer.Deferred """ - query = "DELETE FROM '%s' WHERE doc_id=?" \ - % self.TABLE_NAME - return self._runOperation(query, (doc_id,)) - - def _async_decrypt_doc(self, doc_id, rev, content, gen, trans_id, idx): - """ - Dispatch an asynchronous document decrypting routine and save the - result object. - - :param doc_id: The ID for the document with contents to be encrypted. - :type doc: str - :param rev: The revision of the document. - :type rev: str - :param content: The serialized content of the document. - :type content: str - :param gen: The generation corresponding to the modification of that - document. - :type gen: int - :param trans_id: The transaction id corresponding to the modification - of that document. - :type trans_id: str - :param idx: The index of this document in the current sync process. - :type idx: int - """ - soledad_assert(self._crypto is not None, "need a crypto object") - - content = json.loads(content) - key = self._crypto.doc_passphrase(doc_id) - secret = self._crypto.secret - args = doc_id, rev, content, gen, trans_id, key, secret, idx - # decrypt asynchronously - self._async_results.append( - self._pool.apply_async( - decrypt_doc_task, args)) + placeholders = ', '.join('?' for _ in doc_ids) + query = "DELETE FROM '%s' WHERE doc_id in (%s)" \ + % (self.TABLE_NAME, placeholders) + return self._runOperation(query, (doc_ids)) def _decrypt_doc_cb(self, result): """ @@ -610,11 +505,10 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): doc_id, rev, content, gen, trans_id, idx = result logger.debug("Sync decrypter pool: decrypted doc %s: %s %s %s" % (doc_id, rev, gen, trans_id)) - self._decrypting_docs.remove((doc_id, rev)) return self.insert_received_doc( doc_id, rev, content, gen, trans_id, idx) - def _get_docs(self, encrypted=None, order_by='idx', order='ASC'): + def _get_docs(self, encrypted=None, sequence=None): """ Get documents from the received docs table in the sync db. @@ -622,9 +516,6 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): field equal to given parameter. :type encrypted: bool or None :param order_by: The name of the field to order results. - :type order_by: str - :param order: Whether the order should be ASC or DESC. - :type order: str :return: A deferred that will fire with the results of the database query. @@ -632,10 +523,18 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): """ query = "SELECT doc_id, rev, content, gen, trans_id, encrypted, " \ "idx FROM %s" % self.TABLE_NAME - if encrypted is not None: - query += " WHERE encrypted = %d" % int(encrypted) - query += " ORDER BY %s %s" % (order_by, order) - return self._runQuery(query) + parameters = [] + if encrypted or sequence: + query += " WHERE sync_id = ? and" + parameters += [self._sync_id] + if encrypted: + query += " encrypted = ?" + parameters += [int(encrypted)] + if sequence: + query += " idx in (" + ', '.join('?' * len(sequence)) + ")" + parameters += [int(i) for i in sequence] + query += " ORDER BY idx ASC" + return self._runQuery(query, parameters) @defer.inlineCallbacks def _get_insertable_docs(self): @@ -646,35 +545,19 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): documents. :rtype: twisted.internet.defer.Deferred """ - # here, we fetch the list of decrypted documents and compare with the - # index of the last succesfully processed document. - decrypted_docs = yield self._get_docs(encrypted=False) - insertable = [] - last_idx = self._last_inserted_idx - for doc_id, rev, content, gen, trans_id, encrypted, idx in \ - decrypted_docs: - if (idx != last_idx + 1): - break - insertable.append((doc_id, rev, content, gen, trans_id, idx)) - last_idx += 1 - defer.returnValue(insertable) - - @defer.inlineCallbacks - def _async_decrypt_received_docs(self): - """ - Get all the encrypted documents from the sync database and dispatch a - decrypt worker to decrypt each one of them. - - :return: A deferred that will fire after all documents have been - decrypted and inserted back in the sync db. - :rtype: twisted.internet.defer.Deferred - """ - docs = yield self._get_docs(encrypted=True) - for doc_id, rev, content, gen, trans_id, _, idx in docs: - if (doc_id, rev) not in self._decrypting_docs: - self._decrypting_docs.append((doc_id, rev)) - self._async_decrypt_doc( - doc_id, rev, content, gen, trans_id, idx) + # Here, check in memory what are the insertable indexes that can + # form a sequence starting from the last inserted index + sequence = [] + insertable_docs = [] + next_index = self._last_inserted_idx + 1 + while next_index in self._decrypted_docs_indexes: + sequence.append(str(next_index)) + next_index += 1 + # Then fetch all the ones ready for insertion. + if sequence: + insertable_docs = yield self._get_docs(encrypted=False, + sequence=sequence) + defer.returnValue(insertable_docs) @defer.inlineCallbacks def _process_decrypted_docs(self): @@ -687,36 +570,18 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): :rtype: twisted.internet.defer.Deferred """ insertable = yield self._get_insertable_docs() + processed_docs_ids = [] for doc_fields in insertable: method = self._insert_decrypted_local_doc # FIXME: This is used only because SQLCipherU1DBSync is synchronous # When adbapi is used there is no need for an external thread # Without this the reactor can freeze and fail docs download yield threads.deferToThread(method, *doc_fields) - defer.returnValue(insertable) - - def _delete_processed_docs(self, inserted): - """ - Delete from the sync db documents that have been processed. - - :param inserted: List of documents inserted in the previous process - step. - :type inserted: list - - :return: A list of deferreds that will fire when each operation in the - database has finished. - :rtype: twisted.internet.defer.DeferredList - """ - deferreds = [] - for doc_id, doc_rev, _, _, _, _ in inserted: - deferreds.append( - self._delete_received_doc(doc_id)) - if not deferreds: - return defer.succeed(None) - return defer.gatherResults(deferreds) + processed_docs_ids.append(doc_fields[0]) + yield self._delete_received_docs(processed_docs_ids) def _insert_decrypted_local_doc(self, doc_id, doc_rev, content, - gen, trans_id, idx): + gen, trans_id, encrypted, idx): """ Insert the decrypted document into the local replica. @@ -751,32 +616,6 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): self._last_inserted_idx = idx self._processed_docs += 1 - def _init_db(self): - """ - Empty the received docs table of the sync database. - - :return: A deferred that will fire when the operation in the database - has finished. - :rtype: twisted.internet.defer.Deferred - """ - query = "DELETE FROM %s WHERE 1" % (self.TABLE_NAME,) - return self._runOperation(query) - - @defer.inlineCallbacks - def _collect_async_decryption_results(self): - """ - Collect the results of the asynchronous doc decryptions and re-raise - any exception raised by a multiprocessing async decryption call. - - :raise Exception: Raised if an async call has raised an exception. - """ - async_results = self._async_results[:] - for res in async_results: - if res.ready(): - # XXX: might raise an exception! - yield self._decrypt_doc_cb(res.get()) - self._async_results.remove(res) - @defer.inlineCallbacks def _decrypt_and_recurse(self): """ @@ -792,22 +631,19 @@ class SyncDecrypterPool(SyncEncryptDecryptPool): delete operations have been executed. :rtype: twisted.internet.defer.Deferred """ + if not self.running: + defer.returnValue(None) processed = self._processed_docs pending = self._docs_to_process if processed < pending: - yield self._async_decrypt_received_docs() - yield self._collect_async_decryption_results() - docs = yield self._process_decrypted_docs() - yield self._delete_processed_docs(docs) - # recurse - self._delayed_call = reactor.callLater( - self.DECRYPT_LOOP_PERIOD, - self._launch_decrypt_and_recurse) + yield self._process_decrypted_docs() else: self._finish() def _finish(self): self._processed_docs = 0 self._last_inserted_idx = 0 - self._deferred.callback(None) + self._decrypted_docs_indexes = set() + if not self._deferred.called: + self._deferred.callback(None) |