diff options
| -rw-r--r-- | src/leap/soledad/client/_db/blobs.py | 64 | ||||
| -rw-r--r-- | testing/tests/blobs/test_blob_manager.py | 30 | 
2 files changed, 84 insertions, 10 deletions
diff --git a/src/leap/soledad/client/_db/blobs.py b/src/leap/soledad/client/_db/blobs.py index ada72475..db9ce00d 100644 --- a/src/leap/soledad/client/_db/blobs.py +++ b/src/leap/soledad/client/_db/blobs.py @@ -63,6 +63,14 @@ class InvalidFlagsError(SoledadError):      pass +class SyncStatus: +    SYNCED = 1 +    PENDING_UPLOAD = 2 +    PENDING_DOWNLOAD = 3 +    FAILED_UPLOAD = 4 +    FAILED_DOWNLOAD = 5 + +  class ConnectionPool(adbapi.ConnectionPool):      def insertAndGetLastRowid(self, *args, **kwargs): @@ -169,6 +177,7 @@ class BlobManager(object):      - If preamble + payload verifies correctly, mark the blob as usable      """ +    max_retries = 3      def __init__(              self, local_path, remote, key, secret, user, token=None, @@ -224,8 +233,8 @@ class BlobManager(object):          data = yield self._client.get(uri, params=params)          defer.returnValue((yield data.json())) -    def local_list(self, namespace=''): -        return self.local.list(namespace) +    def local_list(self, namespace='', sync_status=False): +        return self.local.list(namespace, sync_status)      @defer.inlineCallbacks      def send_missing(self, namespace=''): @@ -237,7 +246,16 @@ class BlobManager(object):          for blob_id in missing:              fd = yield self.local.get(blob_id, namespace)              logger.info("Upload local blob: %s" % blob_id) -            yield self._encrypt_and_upload(blob_id, fd) +            try: +                yield self._encrypt_and_upload(blob_id, fd) +                yield self.local.update_sync_status(blob_id, SyncStatus.SYNCED) +            except Exception, e: +                yield self.local.increment_retries(blob_id) +                _, retries = yield self.local.get_sync_status(blob_id) +                if retries > self.max_retries: +                    failed_upload = SyncStatus.FAILED_UPLOAD +                    yield self.local.update_sync_status(blob_id, failed_upload) +                raise e      @defer.inlineCallbacks      def fetch_missing(self, namespace=''): @@ -264,6 +282,7 @@ class BlobManager(object):          # handle gets forwarded into a write on the connection handle          fd = yield self.local.get(doc.blob_id, namespace)          yield self._encrypt_and_upload(doc.blob_id, fd, namespace=namespace) +        yield self.local.update_sync_status(doc.blob_id, SyncStatus.SYNCED)      @defer.inlineCallbacks      def set_flags(self, blob_id, flags, **params): @@ -435,11 +454,12 @@ class SQLiteBlobBackend(object):              pass      @defer.inlineCallbacks -    def put(self, blob_id, blob_fd, size=None, namespace=''): +    def put(self, blob_id, blob_fd, size=None, +            namespace='', status=SyncStatus.PENDING_UPLOAD):          logger.info("Saving blob in local database...") -        insert = 'INSERT INTO blobs (blob_id, namespace, payload) ' -        insert += 'VALUES (?, ?, zeroblob(?))' -        values = (blob_id, namespace, size) +        insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' +        insert += ' VALUES (?, ?, zeroblob(?), ?)' +        values = (blob_id, namespace, size, status)          irow = yield self.dbpool.insertAndGetLastRowid(insert, values)          handle = yield self.dbpool.blob('blobs', 'payload', irow, 1)          blob_fd.seek(0) @@ -463,14 +483,34 @@ class SQLiteBlobBackend(object):              defer.returnValue(BytesIO(str(result[0][0])))      @defer.inlineCallbacks -    def list(self, namespace=''): +    def get_sync_status(self, blob_id): +        select = 'SELECT sync_status, retries FROM blobs WHERE blob_id = ?' +        result = yield self.dbpool.runQuery(select, (blob_id,)) +        if result: +            defer.returnValue((result[0][0], result[0][1])) + +    @defer.inlineCallbacks +    def list(self, namespace='', sync_status=False):          query = 'select blob_id from blobs where namespace = ?' -        result = yield self.dbpool.runQuery(query, (namespace,)) +        values = (namespace,) +        if sync_status: +            query += ' and sync_status = ?' +            values += (sync_status,) +        result = yield self.dbpool.runQuery(query, values)          if result:              defer.returnValue([b_id[0] for b_id in result])          else:              defer.returnValue([]) +    def update_sync_status(self, blob_id, sync_status): +        query = 'update blobs set sync_status = ? where blob_id = ?' +        values = (sync_status, blob_id,) +        return self.dbpool.runQuery(query, values) + +    def increment_retries(self, blob_id): +        query = 'update blobs set retries = retries + 1 where blob_id = ?' +        return self.dbpool.runQuery(query, (blob_id,)) +      @defer.inlineCallbacks      def list_namespaces(self):          query = 'select namespace from blobs' @@ -501,8 +541,12 @@ def _init_blob_table(conn):      columns = [row[1] for row in conn.execute("pragma"                 " table_info(blobs)").fetchall()]      if 'namespace' not in columns: -        # migrate +        # namespace migration          conn.execute('ALTER TABLE blobs ADD COLUMN namespace TEXT') +    if 'sync_status' not in columns: +        # sync status migration +        conn.execute('ALTER TABLE blobs ADD COLUMN sync_status INT default 2') +        conn.execute('ALTER TABLE blobs ADD COLUMN retries INT default 0')  def _sqlcipherInitFactory(fun): diff --git a/testing/tests/blobs/test_blob_manager.py b/testing/tests/blobs/test_blob_manager.py index 087c17e6..dd57047d 100644 --- a/testing/tests/blobs/test_blob_manager.py +++ b/testing/tests/blobs/test_blob_manager.py @@ -19,8 +19,10 @@ Tests for BlobManager.  """  from twisted.trial import unittest  from twisted.internet import defer +from twisted.web.error import SchemeNotSupported  from leap.soledad.client._db.blobs import BlobManager, BlobDoc, FIXED_REV  from leap.soledad.client._db.blobs import BlobAlreadyExistsError +from leap.soledad.client._db.blobs import SyncStatus  from io import BytesIO  from mock import Mock  from uuid import uuid4 @@ -145,3 +147,31 @@ class BlobManagerTestCase(unittest.TestCase):          self.assertEquals(0, len(local_list))          params = {'namespace': ''}          self.manager._delete_from_remote.assert_called_with(blob_id, **params) + +    @defer.inlineCallbacks +    @pytest.mark.usefixtures("method_tmpdir") +    def test_local_sync_status_pending_upload(self): +        upload_failure = defer.fail(Exception()) +        self.manager._encrypt_and_upload = Mock(return_value=upload_failure) +        content, blob_id = "Blob content", uuid4().hex +        doc1 = BlobDoc(BytesIO(content), blob_id) +        with pytest.raises(Exception): +            yield self.manager.put(doc1, len(content)) +        pending_upload = SyncStatus.PENDING_UPLOAD +        local_list = yield self.manager.local_list(sync_status=pending_upload) +        self.assertIn(blob_id, local_list) + +    @defer.inlineCallbacks +    @pytest.mark.usefixtures("method_tmpdir") +    def test_upload_retry_limit(self): +        self.manager.remote_list = Mock(return_value=[]) +        content, blob_id = "Blob content", uuid4().hex +        doc1 = BlobDoc(BytesIO(content), blob_id) +        with pytest.raises(Exception): +            yield self.manager.put(doc1, len(content)) +        for _ in range(self.manager.max_retries + 1): +            with pytest.raises(SchemeNotSupported): +                yield self.manager.send_missing() +        failed_upload = SyncStatus.FAILED_UPLOAD +        local_list = yield self.manager.local_list(sync_status=failed_upload) +        self.assertIn(blob_id, local_list)  | 
