diff options
| -rw-r--r-- | src/leap/soledad/client/_db/blobs.py | 101 | ||||
| -rw-r--r-- | testing/tests/blobs/test_blob_manager.py | 5 | ||||
| -rw-r--r-- | testing/tests/blobs/test_sqlcipher_client_backend.py | 2 | 
3 files changed, 66 insertions, 42 deletions
diff --git a/src/leap/soledad/client/_db/blobs.py b/src/leap/soledad/client/_db/blobs.py index 9e97fe8c..1b4d9114 100644 --- a/src/leap/soledad/client/_db/blobs.py +++ b/src/leap/soledad/client/_db/blobs.py @@ -224,46 +224,50 @@ class BlobManager(object):          data = yield self._client.get(uri, params=params)          defer.returnValue((yield data.json())) -    def local_list(self): -        return self.local.list() +    def local_list(self, namespace='default'): +        assert namespace +        return self.local.list(namespace)      @defer.inlineCallbacks -    def send_missing(self): -        our_blobs = yield self.local_list() -        server_blobs = yield self.remote_list() +    def send_missing(self, namespace='default'): +        assert namespace +        our_blobs = yield self.local_list(namespace) +        server_blobs = yield self.remote_list(namespace=namespace)          missing = [b_id for b_id in our_blobs if b_id not in server_blobs]          logger.info("Amount of documents missing on server: %s" % len(missing))          # TODO: Send concurrently when we are able to stream directly from db          for blob_id in missing: -            fd = yield self.local.get(blob_id) +            fd = yield self.local.get(blob_id, namespace)              logger.info("Upload local blob: %s" % blob_id)              yield self._encrypt_and_upload(blob_id, fd)      @defer.inlineCallbacks -    def fetch_missing(self): +    def fetch_missing(self, namespace='default'): +        assert namespace          # TODO: Use something to prioritize user requests over general new docs -        our_blobs = yield self.local_list() -        server_blobs = yield self.remote_list() +        our_blobs = yield self.local_list(namespace) +        server_blobs = yield self.remote_list(namespace=namespace)          docs_we_want = [b_id for b_id in server_blobs if b_id not in our_blobs]          logger.info("Fetching new docs from server: %s" % len(docs_we_want))          # TODO: Fetch concurrently when we are able to stream directly into db          for blob_id in docs_we_want:              logger.info("Fetching new doc: %s" % blob_id) -            yield self.get(blob_id) +            yield self.get(blob_id, namespace)      @defer.inlineCallbacks -    def put(self, doc, size): -        if (yield self.local.exists(doc.blob_id)): +    def put(self, doc, size, namespace='default'): +        assert namespace +        if (yield self.local.exists(doc.blob_id, namespace)):              error_message = "Blob already exists: %s" % doc.blob_id              raise BlobAlreadyExistsError(error_message)          fd = doc.blob_fd          # TODO this is a tee really, but ok... could do db and upload          # concurrently. not sure if we'd gain something. -        yield self.local.put(doc.blob_id, fd, size=size) +        yield self.local.put(doc.blob_id, fd, size=size, namespace=namespace)          # In fact, some kind of pipe is needed here, where each write on db          # handle gets forwarded into a write on the connection handle -        fd = yield self.local.get(doc.blob_id) -        yield self._encrypt_and_upload(doc.blob_id, fd) +        fd = yield self.local.get(doc.blob_id, namespace) +        yield self._encrypt_and_upload(doc.blob_id, fd, namespace=namespace)      @defer.inlineCallbacks      def set_flags(self, blob_id, flags, **params): @@ -314,8 +318,9 @@ class BlobManager(object):          defer.returnValue((yield response.json()))      @defer.inlineCallbacks -    def get(self, blob_id, namespace=None): -        local_blob = yield self.local.get(blob_id) +    def get(self, blob_id, namespace='default'): +        assert namespace +        local_blob = yield self.local.get(blob_id, namespace=namespace)          if local_blob:              logger.info("Found blob in local database: %s" % blob_id)              defer.returnValue(local_blob) @@ -329,8 +334,8 @@ class BlobManager(object):          if blob:              logger.info("Got decrypted blob of type: %s" % type(blob))              blob.seek(0) -            yield self.local.put(blob_id, blob, size=size) -            defer.returnValue((yield self.local.get(blob_id))) +            yield self.local.put(blob_id, blob, size=size, namespace=namespace) +            defer.returnValue((yield self.local.get(blob_id, namespace)))          else:              # XXX we shouldn't get here, but we will...              # lots of ugly error handling possible: @@ -359,12 +364,13 @@ class BlobManager(object):          logger.info("Finished upload: %s" % (blob_id,))      @defer.inlineCallbacks -    def _download_and_decrypt(self, blob_id, namespace=None): +    def _download_and_decrypt(self, blob_id, namespace='default'): +        assert namespace          logger.info("Staring download of blob: %s" % blob_id)          # TODO this needs to be connected in a tube          uri = urljoin(self.remote, self.user + '/' + blob_id)          params = {'namespace': namespace} if namespace else None -        data = yield self._client.get(uri, params=params) +        data = yield self._client.get(uri, params=params, namespace=namespace)          if data.code == 404:              logger.warn("Blob not found in server: %s" % blob_id) @@ -383,7 +389,7 @@ class BlobManager(object):          defer.returnValue((fd, size))      @defer.inlineCallbacks -    def delete(self, blob_id, **params): +    def delete(self, blob_id, namespace='default', **params):          """          Deletes a blob from local and remote storages.          :param blob_id: @@ -395,10 +401,12 @@ class BlobManager(object):          :return: A deferred that fires when the operation finishes.          :rtype: twisted.internet.defer.Deferred          """ +        assert namespace +        params['namespace'] = namespace          logger.info("Staring deletion of blob: %s" % blob_id)          yield self._delete_from_remote(blob_id, **params) -        if (yield self.local.exists(blob_id)): -            yield self.local.delete(blob_id) +        if (yield self.local.exists(blob_id, namespace)): +            yield self.local.delete(blob_id, namespace)      def _delete_from_remote(self, blob_id, **params):          # TODO this needs to be connected in a tube @@ -433,10 +441,13 @@ class SQLiteBlobBackend(object):              pass      @defer.inlineCallbacks -    def put(self, blob_id, blob_fd, size=None): +    def put(self, blob_id, blob_fd, size=None, namespace='default'): +        assert namespace          logger.info("Saving blob in local database...") -        insert = 'INSERT INTO blobs (blob_id, payload) VALUES (?, zeroblob(?))' -        irow = yield self.dbpool.insertAndGetLastRowid(insert, (blob_id, size)) +        insert = 'INSERT INTO blobs (blob_id, namespace, payload) ' +        insert += 'VALUES (?, ?, zeroblob(?))' +        values = (blob_id, namespace, size) +        irow = yield self.dbpool.insertAndGetLastRowid(insert, values)          handle = yield self.dbpool.blob('blobs', 'payload', irow, 1)          blob_fd.seek(0)          # XXX I have to copy the buffer here so that I'm able to @@ -450,32 +461,44 @@ class SQLiteBlobBackend(object):          defer.returnValue(done)      @defer.inlineCallbacks -    def get(self, blob_id): +    def get(self, blob_id, namespace='default'):          # TODO we can also stream the blob value using sqlite          # incremental interface for blobs - and just return the raw fd instead -        select = 'SELECT payload FROM blobs WHERE blob_id = ?' -        result = yield self.dbpool.runQuery(select, (blob_id,)) +        assert namespace +        select = 'SELECT payload FROM blobs WHERE blob_id = ? AND namespace= ?' +        result = yield self.dbpool.runQuery(select, (blob_id, namespace,))          if result:              defer.returnValue(BytesIO(str(result[0][0])))      @defer.inlineCallbacks -    def list(self): -        query = 'select blob_id from blobs' -        result = yield self.dbpool.runQuery(query) +    def list(self, namespace='default'): +        assert namespace +        query = 'select blob_id from blobs where namespace = ?' +        result = yield self.dbpool.runQuery(query, (namespace,))          if result:              defer.returnValue([b_id[0] for b_id in result])          else:              defer.returnValue([])      @defer.inlineCallbacks -    def exists(self, blob_id): -        query = 'SELECT blob_id from blobs WHERE blob_id = ?' -        result = yield self.dbpool.runQuery(query, (blob_id,)) +    def list_namespaces(self): +        query = 'select namespace from blobs' +        result = yield self.dbpool.runQuery(query) +        if result: +            defer.returnValue([namespace[0] for namespace in result]) +        else: +            defer.returnValue([]) + +    @defer.inlineCallbacks +    def exists(self, blob_id, namespace='default'): +        query = 'SELECT blob_id from blobs WHERE blob_id = ? AND namespace= ?' +        result = yield self.dbpool.runQuery(query, (blob_id, namespace,))          defer.returnValue(bool(len(result))) -    def delete(self, blob_id): -        query = 'DELETE FROM blobs WHERE blob_id = ?' -        return self.dbpool.runQuery(query, (blob_id,)) +    def delete(self, blob_id, namespace='default'): +        assert namespace +        query = 'DELETE FROM blobs WHERE blob_id = ? AND namespace = ?' +        return self.dbpool.runQuery(query, (blob_id, namespace,))  def _init_blob_table(conn): diff --git a/testing/tests/blobs/test_blob_manager.py b/testing/tests/blobs/test_blob_manager.py index 56bea87a..3f0bf8c4 100644 --- a/testing/tests/blobs/test_blob_manager.py +++ b/testing/tests/blobs/test_blob_manager.py @@ -50,7 +50,7 @@ class BlobManagerTestCase(unittest.TestCase):          bad_blob_id = 'inexsistent_id'          result = yield self.manager.get(bad_blob_id)          self.assertIsNone(result) -        args = bad_blob_id, None +        args = bad_blob_id, 'default'          self.manager._download_and_decrypt.assert_called_once_with(*args)      @defer.inlineCallbacks @@ -143,4 +143,5 @@ class BlobManagerTestCase(unittest.TestCase):          yield self.manager.delete(blob_id)          local_list = yield self.manager.local_list()          self.assertEquals(0, len(local_list)) -        self.manager._delete_from_remote.assert_called_with(blob_id) +        params = {'namespace': 'default'} +        self.manager._delete_from_remote.assert_called_with(blob_id, **params) diff --git a/testing/tests/blobs/test_sqlcipher_client_backend.py b/testing/tests/blobs/test_sqlcipher_client_backend.py index b67215e8..6193b486 100644 --- a/testing/tests/blobs/test_sqlcipher_client_backend.py +++ b/testing/tests/blobs/test_sqlcipher_client_backend.py @@ -71,4 +71,4 @@ class SQLBackendTestCase(unittest.TestCase):                               len(content)))          yield defer.gatherResults(deferreds)          result = yield self.local.list() -        self.assertEquals(blob_ids, result) +        self.assertEquals(set(blob_ids), set(result))  | 
