diff options
-rw-r--r-- | src/leap/soledad/client/_db/blobs.py | 12 | ||||
-rw-r--r-- | tests/blobs/test_blob_manager.py | 19 |
2 files changed, 30 insertions, 1 deletions
diff --git a/src/leap/soledad/client/_db/blobs.py b/src/leap/soledad/client/_db/blobs.py index e23b1cf9..75f196c7 100644 --- a/src/leap/soledad/client/_db/blobs.py +++ b/src/leap/soledad/client/_db/blobs.py @@ -552,6 +552,11 @@ class SQLiteBlobBackend(object): @defer.inlineCallbacks def put(self, blob_id, blob_fd, size=None, namespace='', status=SyncStatus.PENDING_UPLOAD): + previous_state = yield self.get_sync_status(blob_id) + unavailable = SyncStatus.UNAVAILABLE_STATUSES + if previous_state and previous_state[0] in unavailable: + yield self.delete(blob_id, namespace) + status = SyncStatus.SYNCED logger.info("Saving blob in local database...") insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' insert += ' VALUES (?, ?, zeroblob(?), ?)' @@ -565,7 +570,12 @@ class SQLiteBlobBackend(object): # TODO we can also stream the blob value using sqlite # incremental interface for blobs - and just return the raw fd instead select = 'SELECT payload FROM blobs WHERE blob_id = ? AND namespace= ?' - result = yield self.dbpool.runQuery(select, (blob_id, namespace,)) + values = (blob_id, namespace,) + avoid_values = SyncStatus.UNAVAILABLE_STATUSES + select += ' AND sync_status NOT IN (%s)' + select %= ','.join(['?' for _ in avoid_values]) + values += avoid_values + result = yield self.dbpool.runQuery(select, values) if result: defer.returnValue(BytesIO(str(result[0][0]))) diff --git a/tests/blobs/test_blob_manager.py b/tests/blobs/test_blob_manager.py index 1fe47864..81379c73 100644 --- a/tests/blobs/test_blob_manager.py +++ b/tests/blobs/test_blob_manager.py @@ -196,6 +196,25 @@ class BlobManagerTestCase(unittest.TestCase): @defer.inlineCallbacks @pytest.mark.usefixtures("method_tmpdir") + def test_get_doesnt_include_unavailable_blobs(self): + local = self.manager.local + unavailable_ids, deferreds = [], [] + for unavailable_status in SyncStatus.UNAVAILABLE_STATUSES: + current_blob_id = uuid4().hex + deferreds.append(local.put(current_blob_id, BytesIO(''), 0, + status=unavailable_status)) + unavailable_ids.append(current_blob_id) + available_blob_id = uuid4().hex + content, length = self.cleartext, len(self.cleartext.getvalue()) + deferreds.append(local.put(available_blob_id, content, length)) + yield defer.gatherResults(deferreds) + message = 'Unavailable blob showing up on GET!' + for blob_id in unavailable_ids: + blob = yield local.get(blob_id) + self.assertFalse(blob, message) + + @defer.inlineCallbacks + @pytest.mark.usefixtures("method_tmpdir") def test_persist_sync_statuses_listing_from_server(self): local = self.manager.local remote_ids = [uuid4().hex for _ in range(10)] |