diff options
Diffstat (limited to 'src/leap/soledad/client/_db')
-rw-r--r-- | src/leap/soledad/client/_db/blobs/__init__.py | 93 | ||||
-rw-r--r-- | src/leap/soledad/client/_db/blobs/sql.py | 68 | ||||
-rw-r--r-- | src/leap/soledad/client/_db/blobs/sync.py | 106 |
3 files changed, 204 insertions, 63 deletions
diff --git a/src/leap/soledad/client/_db/blobs/__init__.py b/src/leap/soledad/client/_db/blobs/__init__.py index 4699e0a0..df801850 100644 --- a/src/leap/soledad/client/_db/blobs/__init__.py +++ b/src/leap/soledad/client/_db/blobs/__init__.py @@ -24,6 +24,7 @@ import os import json import base64 +from collections import defaultdict from io import BytesIO from twisted.logger import Logger @@ -44,6 +45,7 @@ from leap.soledad.client._pipes import TruncatedTailPipe from leap.soledad.client._pipes import PreamblePipe from .sql import SyncStatus +from .sql import Priority from .sql import SQLiteBlobBackend from .sync import BlobsSynchronizer from .errors import ( @@ -139,6 +141,7 @@ class BlobManager(BlobsSynchronizer): self.user = user self._client = HTTPClient(user, token, cert_file) self.semaphore = defer.DeferredSemaphore(self.concurrent_writes_limit) + self.locks = defaultdict(defer.DeferredLock) def close(self): if hasattr(self, 'local') and self.local: @@ -207,7 +210,8 @@ class BlobManager(BlobsSynchronizer): def local_list_status(self, status, namespace=''): return self.local.list_status(status, namespace) - def put(self, doc, size, namespace='', local_only=False): + def put(self, doc, size, namespace='', local_only=False, + priority=Priority.DEFAULT): """ Put a blob in local storage and upload it to server. @@ -220,12 +224,15 @@ class BlobManager(BlobsSynchronizer): :param namespace: Optional parameter to restrict operation to a given namespace. :type namespace: str + + :return: A deferred that fires when the blob has been put. + :rtype: twisted.internet.defer.Deferred """ return self.semaphore.run( - self._put, doc, size, namespace, local_only=local_only) + self._put, doc, size, namespace, local_only, priority) @defer.inlineCallbacks - def _put(self, doc, size, namespace, local_only=False): + def _put(self, doc, size, namespace, local_only, priority): if (yield self.local.exists(doc.blob_id, namespace=namespace)): error_message = "Blob already exists: %s" % doc.blob_id raise BlobAlreadyExistsError(error_message) @@ -233,17 +240,31 @@ class BlobManager(BlobsSynchronizer): # TODO this is a tee really, but ok... could do db and upload # concurrently. not sure if we'd gain something. yield self.local.put(doc.blob_id, fd, size=size, namespace=namespace) + if local_only: yield self.local.update_sync_status( - doc.blob_id, SyncStatus.LOCAL_ONLY) + doc.blob_id, SyncStatus.LOCAL_ONLY, namespace=namespace) defer.returnValue(None) + yield self.local.update_sync_status( - doc.blob_id, SyncStatus.PENDING_UPLOAD) + doc.blob_id, SyncStatus.PENDING_UPLOAD, namespace=namespace, + priority=priority) + yield self._send(doc.blob_id, namespace, 1, 1) + + def _send(self, blob_id, namespace, i, total): + lock = self.locks[blob_id] + d = lock.run(self.__send, blob_id, namespace, i, total) + return d + + @defer.inlineCallbacks + def __send(self, blob_id, namespace, i, total): + logger.info("Sending blob to server (%d/%d): %s" + % (i, total, blob_id)) # In fact, some kind of pipe is needed here, where each write on db # handle gets forwarded into a write on the connection handle - fd = yield self.local.get(doc.blob_id, namespace=namespace) - yield self._encrypt_and_upload(doc.blob_id, fd, namespace=namespace) - yield self.local.update_sync_status(doc.blob_id, SyncStatus.SYNCED) + fd = yield self.local.get(blob_id, namespace=namespace) + yield self._encrypt_and_upload(blob_id, fd) + yield self.local.update_sync_status(blob_id, SyncStatus.SYNCED) def set_flags(self, blob_id, flags, namespace=''): """ @@ -294,7 +315,7 @@ class BlobManager(BlobsSynchronizer): defer.returnValue((yield response.json())) @defer.inlineCallbacks - def get(self, blob_id, namespace=''): + def get(self, blob_id, namespace='', priority=Priority.DEFAULT): """ Get the blob from local storage or, if not available, from the server. @@ -304,6 +325,10 @@ class BlobManager(BlobsSynchronizer): :param namespace: Optional parameter to restrict operation to a given namespace. :type namespace: str + + :return: A deferred that fires with the file descriptor for the + contents of the blob. + :rtype: twisted.internet.defer.Deferred """ local_blob = yield self.local.get(blob_id, namespace=namespace) if local_blob: @@ -311,8 +336,19 @@ class BlobManager(BlobsSynchronizer): defer.returnValue(local_blob) yield self.local.update_sync_status( - blob_id, SyncStatus.PENDING_DOWNLOAD, namespace=namespace) + blob_id, SyncStatus.PENDING_DOWNLOAD, namespace=namespace, + priority=priority) + + fd = yield self._fetch(blob_id, namespace) + defer.returnValue(fd) + + def _fetch(self, blob_id, namespace): + lock = self.locks[blob_id] + d = lock.run(self.__fetch, blob_id, namespace) + return d + @defer.inlineCallbacks + def __fetch(self, blob_id, namespace): try: result = yield self._download_and_decrypt(blob_id, namespace) except Exception as e: @@ -347,6 +383,8 @@ class BlobManager(BlobsSynchronizer): logger.info("Got decrypted blob of type: %s" % type(blob)) blob.seek(0) yield self.local.put(blob_id, blob, size=size, namespace=namespace) + yield self.local.update_sync_status(blob_id, SyncStatus.SYNCED, + namespace=namespace) local_blob = yield self.local.get(blob_id, namespace=namespace) defer.returnValue(local_blob) else: @@ -435,3 +473,38 @@ class BlobManager(BlobsSynchronizer): response = yield self._client.delete(uri, params=params) check_http_status(response.code, blob_id=blob_id) defer.returnValue(response) + + def set_priority(self, blob_id, priority, namespace=''): + """ + Set the transfer priority for a certain blob. + + :param blob_id: Unique identifier of a blob. + :type blob_id: str + :param priority: The priority to be set. + :type priority: int + :param namespace: Optional parameter to restrict operation to a given + namespace. + :type namespace: str + + :return: A deferred that fires after the priority has been set. + :rtype: twisted.internet.defer.Deferred + """ + d = self.local.update_priority(blob_id, priority, namespace=namespace) + return d + + def get_priority(self, blob_id, namespace=''): + """ + Get the transfer priority for a certain blob. + + :param blob_id: Unique identifier of a blob. + :type blob_id: str + :param namespace: Optional parameter to restrict operation to a given + namespace. + :type namespace: str + + :return: A deferred that fires with the current transfer priority of + the blob. + :rtype: twisted.internet.defer.Deferred + """ + d = self.local.get_priority(blob_id, namespace=namespace) + return d diff --git a/src/leap/soledad/client/_db/blobs/sql.py b/src/leap/soledad/client/_db/blobs/sql.py index ebd6c095..a89802d8 100644 --- a/src/leap/soledad/client/_db/blobs/sql.py +++ b/src/leap/soledad/client/_db/blobs/sql.py @@ -46,6 +46,14 @@ class SyncStatus: UNAVAILABLE_STATUSES = (3, 5) +class Priority: + LOW = 1 + MEDIUM = 2 + HIGH = 3 + URGENT = 4 + DEFAULT = 2 + + class SQLiteBlobBackend(object): concurrency_limit = 10 @@ -130,7 +138,7 @@ class SQLiteBlobBackend(object): @defer.inlineCallbacks def list_status(self, sync_status, namespace=''): query = 'select blob_id from sync_state where sync_status = ?' - query += 'AND namespace = ?' + query += 'AND namespace = ? ORDER BY priority DESC' values = (sync_status, namespace,) result = yield self.dbpool.runQuery(query, values) if result: @@ -139,21 +147,42 @@ class SQLiteBlobBackend(object): defer.returnValue([]) @defer.inlineCallbacks - def update_sync_status(self, blob_id, sync_status, namespace=""): - query = 'SELECT sync_status FROM sync_state WHERE blob_id = ?' - result = yield self.dbpool.runQuery(query, (blob_id,)) + def update_sync_status(self, blob_id, sync_status, namespace="", + priority=None): + retries = '(SELECT retries from sync_state' \ + ' WHERE blob_id="%s" and namespace="%s")' \ + % (blob_id, namespace) + if not priority: + priority = '(SELECT priority FROM sync_state' \ + ' WHERE blob_id="%s" AND namespace="%s")' \ + % (blob_id, namespace) + fields = 'blob_id, namespace, sync_status, retries, priority' + markers = '?, ?, ?, %s, %s' % (retries, priority) + values = [blob_id, namespace, sync_status] + insert = 'INSERT or REPLACE INTO sync_state (%s) VALUES (%s)' \ + % (fields, markers) + yield self.dbpool.runOperation(insert, tuple(values)) + @defer.inlineCallbacks + def get_priority(self, blob_id, namespace=""): + query = 'SELECT priority FROM sync_state WHERE blob_id = ?' + result = yield self.dbpool.runQuery(query, (blob_id,)) if not result: - insert = 'INSERT INTO sync_state' - insert += ' (blob_id, namespace, sync_status)' - insert += ' VALUES (?, ?, ?)' - values = (blob_id, namespace, sync_status) - yield self.dbpool.runOperation(insert, values) - return + defer.returnValue(None) + priority = result.pop()[0] + defer.returnValue(priority) - update = 'UPDATE sync_state SET sync_status = ? WHERE blob_id = ?' - values = (sync_status, blob_id,) - result = yield self.dbpool.runOperation(update, values) + @defer.inlineCallbacks + def update_priority(self, blob_id, priority, namespace=""): + old_priority = self.get_priority(blob_id, namespace=namespace) + if not old_priority: + logger.error("Can't update priority of %s: no sync status entry.") + return + if old_priority == priority: + return + update = 'UPDATE sync_state SET priority = ? WHERE blob_id = ?' + values = (priority, blob_id,) + yield self.dbpool.runOperation(update, values) def update_batch_sync_status(self, blob_id_list, sync_status, namespace=''): @@ -161,11 +190,12 @@ class SQLiteBlobBackend(object): return insert = 'INSERT or REPLACE INTO sync_state' first_blob_id, blob_id_list = blob_id_list[0], blob_id_list[1:] - insert += ' (blob_id, namespace, sync_status) VALUES (?, ?, ?)' - values = (first_blob_id, namespace, sync_status) + insert += ' (blob_id, namespace, sync_status, priority)' + insert += ' VALUES (?, ?, ?, ?)' + values = (first_blob_id, namespace, sync_status, Priority.DEFAULT) for blob_id in blob_id_list: - insert += ', (?, ?, ?)' - values += (blob_id, namespace, sync_status) + insert += ', (?, ?, ?, ?)' + values += (blob_id, namespace, sync_status, Priority.DEFAULT) return self.dbpool.runQuery(insert, values) def increment_retries(self, blob_id): @@ -212,9 +242,11 @@ def _init_sync_table(conn): blob_id PRIMARY KEY, namespace TEXT, sync_status INT default %s, + priority INT default %d, retries INT default 0)""" default_status = SyncStatus.PENDING_UPLOAD - maybe_create %= default_status + default_priority = Priority.DEFAULT + maybe_create %= (default_status, default_priority) conn.execute(maybe_create) diff --git a/src/leap/soledad/client/_db/blobs/sync.py b/src/leap/soledad/client/_db/blobs/sync.py index 3ee60305..e6397ede 100644 --- a/src/leap/soledad/client/_db/blobs/sync.py +++ b/src/leap/soledad/client/_db/blobs/sync.py @@ -17,12 +17,15 @@ """ Synchronization between blobs client/server """ +from collections import defaultdict from twisted.internet import defer from twisted.internet import reactor from twisted.logger import Logger from twisted.internet import error from .sql import SyncStatus from .errors import RetriableTransferError + + logger = Logger() @@ -56,6 +59,9 @@ def with_retry(func, *args, **kwargs): class BlobsSynchronizer(object): + def __init__(self): + self.locks = defaultdict(defer.DeferredLock) + @defer.inlineCallbacks def refresh_sync_status_from_server(self, namespace=''): d1 = self.remote_list(namespace=namespace) @@ -82,7 +88,6 @@ class BlobsSynchronizer(object): SyncStatus.SYNCED, namespace=namespace) - @defer.inlineCallbacks def send_missing(self, namespace=''): """ Compare local and remote blobs and send what's missing in server. @@ -90,30 +95,41 @@ class BlobsSynchronizer(object): :param namespace: Optional parameter to restrict operation to a given namespace. :type namespace: str + + :return: A deferred that fires when all local blobs were sent to + server. + :rtype: twisted.internet.defer.Deferred """ - missing = yield self.local.list_status( - SyncStatus.PENDING_UPLOAD, namespace) - total = len(missing) - logger.info("Will send %d blobs to server." % total) - deferreds = [] - semaphore = defer.DeferredSemaphore(self.concurrent_transfers_limit) - - for i in xrange(total): - blob_id = missing.pop() - d = semaphore.run( - with_retry, self.__send_one, blob_id, namespace, i, total) - deferreds.append(d) - yield defer.gatherResults(deferreds, consumeErrors=True) + lock = self.locks['send_missing'] + d = lock.run(self._send_missing, namespace) + return d @defer.inlineCallbacks - def __send_one(self, blob_id, namespace, i, total): - logger.info("Sending blob to server (%d/%d): %s" - % (i, total, blob_id)) - fd = yield self.local.get(blob_id, namespace=namespace) - yield self._encrypt_and_upload(blob_id, fd) - yield self.local.update_sync_status(blob_id, SyncStatus.SYNCED) + def _send_missing(self, namespace): + max_transfers = self.concurrent_transfers_limit + semaphore = defer.DeferredSemaphore(max_transfers) + # the list of blobs should be refreshed often, so we run as many + # concurrent transfers as we can and then refresh the list + while True: + d = self.local_list_status(SyncStatus.PENDING_UPLOAD, namespace) + missing = yield d + + if not missing: + break + + total = len(missing) + now = min(total, max_transfers) + logger.info("There are %d pending blob uploads." % total) + logger.info("Will send %d blobs to server now." % now) + missing = missing[:now] + deferreds = [] + for i in xrange(now): + blob_id = missing.pop(0) + d = semaphore.run( + with_retry, self._send, blob_id, namespace, i, total) + deferreds.append(d) + yield defer.gatherResults(deferreds, consumeErrors=True) - @defer.inlineCallbacks def fetch_missing(self, namespace=''): """ Compare local and remote blobs and fetch what's missing in local @@ -122,21 +138,41 @@ class BlobsSynchronizer(object): :param namespace: Optional parameter to restrict operation to a given namespace. :type namespace: str + + :return: A deferred that fires when all remote blobs were received from + server. + :rtype: twisted.internet.defer.Deferred """ - # TODO: Use something to prioritize user requests over general new docs - d = self.local_list_status(SyncStatus.PENDING_DOWNLOAD, namespace) - docs_we_want = yield d - total = len(docs_we_want) - logger.info("Will fetch %d blobs from server." % total) - deferreds = [] - semaphore = defer.DeferredSemaphore(self.concurrent_transfers_limit) - - for i in xrange(len(docs_we_want)): - blob_id = docs_we_want.pop() - logger.info("Fetching blob (%d/%d): %s" % (i, total, blob_id)) - d = semaphore.run(with_retry, self.get, blob_id, namespace) - deferreds.append(d) - yield defer.gatherResults(deferreds, consumeErrors=True) + lock = self.locks['fetch_missing'] + d = lock.run(self._fetch_missing, namespace) + return d + + @defer.inlineCallbacks + def _fetch_missing(self, namespace=''): + max_transfers = self.concurrent_transfers_limit + semaphore = defer.DeferredSemaphore(max_transfers) + # in order to make sure that transfer priorities will be met, the list + # of blobs to transfer should be refreshed often. What we do is run as + # many concurrent transfers as we can and then refresh the list + while True: + d = self.local_list_status(SyncStatus.PENDING_DOWNLOAD, namespace) + docs_we_want = yield d + + if not docs_we_want: + break + + total = len(docs_we_want) + now = min(total, max_transfers) + logger.info("There are %d pending blob downloads." % total) + logger.info("Will fetch %d blobs from server now." % now) + docs_we_want = docs_we_want[:now] + deferreds = [] + for i in xrange(now): + blob_id = docs_we_want.pop(0) + logger.info("Fetching blob (%d/%d): %s" % (i, now, blob_id)) + d = semaphore.run(with_retry, self._fetch, blob_id, namespace) + deferreds.append(d) + yield defer.gatherResults(deferreds, consumeErrors=True) @defer.inlineCallbacks def sync(self, namespace=''): |