diff options
Diffstat (limited to 'src/leap')
-rw-r--r-- | src/leap/soledad/client/_db/blobs.py | 64 |
1 files changed, 54 insertions, 10 deletions
diff --git a/src/leap/soledad/client/_db/blobs.py b/src/leap/soledad/client/_db/blobs.py index ada72475..db9ce00d 100644 --- a/src/leap/soledad/client/_db/blobs.py +++ b/src/leap/soledad/client/_db/blobs.py @@ -63,6 +63,14 @@ class InvalidFlagsError(SoledadError): pass +class SyncStatus: + SYNCED = 1 + PENDING_UPLOAD = 2 + PENDING_DOWNLOAD = 3 + FAILED_UPLOAD = 4 + FAILED_DOWNLOAD = 5 + + class ConnectionPool(adbapi.ConnectionPool): def insertAndGetLastRowid(self, *args, **kwargs): @@ -169,6 +177,7 @@ class BlobManager(object): - If preamble + payload verifies correctly, mark the blob as usable """ + max_retries = 3 def __init__( self, local_path, remote, key, secret, user, token=None, @@ -224,8 +233,8 @@ class BlobManager(object): data = yield self._client.get(uri, params=params) defer.returnValue((yield data.json())) - def local_list(self, namespace=''): - return self.local.list(namespace) + def local_list(self, namespace='', sync_status=False): + return self.local.list(namespace, sync_status) @defer.inlineCallbacks def send_missing(self, namespace=''): @@ -237,7 +246,16 @@ class BlobManager(object): for blob_id in missing: fd = yield self.local.get(blob_id, namespace) logger.info("Upload local blob: %s" % blob_id) - yield self._encrypt_and_upload(blob_id, fd) + try: + yield self._encrypt_and_upload(blob_id, fd) + yield self.local.update_sync_status(blob_id, SyncStatus.SYNCED) + except Exception, e: + yield self.local.increment_retries(blob_id) + _, retries = yield self.local.get_sync_status(blob_id) + if retries > self.max_retries: + failed_upload = SyncStatus.FAILED_UPLOAD + yield self.local.update_sync_status(blob_id, failed_upload) + raise e @defer.inlineCallbacks def fetch_missing(self, namespace=''): @@ -264,6 +282,7 @@ class BlobManager(object): # handle gets forwarded into a write on the connection handle fd = yield self.local.get(doc.blob_id, namespace) yield self._encrypt_and_upload(doc.blob_id, fd, namespace=namespace) + yield self.local.update_sync_status(doc.blob_id, SyncStatus.SYNCED) @defer.inlineCallbacks def set_flags(self, blob_id, flags, **params): @@ -435,11 +454,12 @@ class SQLiteBlobBackend(object): pass @defer.inlineCallbacks - def put(self, blob_id, blob_fd, size=None, namespace=''): + def put(self, blob_id, blob_fd, size=None, + namespace='', status=SyncStatus.PENDING_UPLOAD): logger.info("Saving blob in local database...") - insert = 'INSERT INTO blobs (blob_id, namespace, payload) ' - insert += 'VALUES (?, ?, zeroblob(?))' - values = (blob_id, namespace, size) + insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' + insert += ' VALUES (?, ?, zeroblob(?), ?)' + values = (blob_id, namespace, size, status) irow = yield self.dbpool.insertAndGetLastRowid(insert, values) handle = yield self.dbpool.blob('blobs', 'payload', irow, 1) blob_fd.seek(0) @@ -463,14 +483,34 @@ class SQLiteBlobBackend(object): defer.returnValue(BytesIO(str(result[0][0]))) @defer.inlineCallbacks - def list(self, namespace=''): + def get_sync_status(self, blob_id): + select = 'SELECT sync_status, retries FROM blobs WHERE blob_id = ?' + result = yield self.dbpool.runQuery(select, (blob_id,)) + if result: + defer.returnValue((result[0][0], result[0][1])) + + @defer.inlineCallbacks + def list(self, namespace='', sync_status=False): query = 'select blob_id from blobs where namespace = ?' - result = yield self.dbpool.runQuery(query, (namespace,)) + values = (namespace,) + if sync_status: + query += ' and sync_status = ?' + values += (sync_status,) + result = yield self.dbpool.runQuery(query, values) if result: defer.returnValue([b_id[0] for b_id in result]) else: defer.returnValue([]) + def update_sync_status(self, blob_id, sync_status): + query = 'update blobs set sync_status = ? where blob_id = ?' + values = (sync_status, blob_id,) + return self.dbpool.runQuery(query, values) + + def increment_retries(self, blob_id): + query = 'update blobs set retries = retries + 1 where blob_id = ?' + return self.dbpool.runQuery(query, (blob_id,)) + @defer.inlineCallbacks def list_namespaces(self): query = 'select namespace from blobs' @@ -501,8 +541,12 @@ def _init_blob_table(conn): columns = [row[1] for row in conn.execute("pragma" " table_info(blobs)").fetchall()] if 'namespace' not in columns: - # migrate + # namespace migration conn.execute('ALTER TABLE blobs ADD COLUMN namespace TEXT') + if 'sync_status' not in columns: + # sync status migration + conn.execute('ALTER TABLE blobs ADD COLUMN sync_status INT default 2') + conn.execute('ALTER TABLE blobs ADD COLUMN retries INT default 0') def _sqlcipherInitFactory(fun): |