summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Shyba <victor1984@riseup.net>2017-11-27 13:05:33 -0300
committerVictor Shyba <victor1984@riseup.net>2017-12-01 01:38:08 -0300
commitc3d079de4675b0fceca130ed3c6b8890ec28d873 (patch)
tree93bafca777d257943e0fd10bedbaa0392e5f3ac4
parentd574e734e19d5350992bc1aeb00014c41a444add (diff)
[feature] adds a stream downloader
First version, still missing consumer/producer model and some tweaks, but working. -- Related: #8809
-rw-r--r--src/leap/soledad/client/_db/blobs/__init__.py72
-rw-r--r--src/leap/soledad/server/_streaming_resource.py9
-rw-r--r--tests/server/test_blobs_server.py20
3 files changed, 98 insertions, 3 deletions
diff --git a/src/leap/soledad/client/_db/blobs/__init__.py b/src/leap/soledad/client/_db/blobs/__init__.py
index ddd22b4b..3daf8d1a 100644
--- a/src/leap/soledad/client/_db/blobs/__init__.py
+++ b/src/leap/soledad/client/_db/blobs/__init__.py
@@ -23,6 +23,7 @@ from urlparse import urljoin
import os
import json
import base64
+import struct
from collections import defaultdict
from io import BytesIO
@@ -104,6 +105,55 @@ class DecrypterBuffer(object):
return self.raw_data, self.raw_data.tell()
+class StreamDecrypterBuffer(object):
+ size_pack = struct.Struct('<I')
+
+ def __init__(self, secret, blobs_list, done_callback):
+ self.blobs_list = blobs_list
+ self.secret = secret
+ self.done_callback = done_callback
+ self.buf = b''
+ self.reset()
+
+ def reset(self):
+ self.current_blob_size = False
+ self.current_blob_id = None
+ self.received = 0
+
+ def write(self, data):
+ if not self.current_blob_size:
+ self.buf += data
+ if ' ' in self.buf:
+ marker, self.buf = self.buf.split(' ')
+ assert(len(marker) == 20) # 16 byte tag + 4 byte size
+ size, tag = marker[:4], marker[4:]
+ self.current_blob_size = self.size_pack.unpack(size)[0]
+ self.received = len(self.buf)
+ blob_id = self.blobs_list.pop(0)
+ buf = DecrypterBuffer(blob_id, self.secret, tag)
+ self.current_blob_id = blob_id
+ buf.write(self.buf)
+ self.buf = buf
+ elif (self.received + len(data)) < self.current_blob_size:
+ self.buf.write(data)
+ self.received += len(data)
+ else:
+ missing = self.current_blob_size - self.received
+ self.buf.write(data[:missing])
+ blob_id = self.current_blob_id
+ fd, size = self.buf.close()
+ self.done_callback(blob_id, fd, size)
+ self.buf = data[missing:]
+ self.reset()
+
+ def close(self):
+ if self.received != 0:
+ missing = self.current_blob_size - self.received
+ raise Exception("Incomplete download! missing: %s" % missing)
+ if self.blobs_list:
+ raise Exception("Missing from stream: %s" % self.blobs_list)
+
+
class BlobManager(BlobsSynchronizer):
"""
The BlobManager can list, put, get, set flags and synchronize blobs stored
@@ -115,7 +165,7 @@ class BlobManager(BlobsSynchronizer):
def __init__(
self, local_path, remote, key, secret, user, token=None,
- cert_file=None):
+ cert_file=None, remote_stream=None):
"""
Initialize the blob manager.
@@ -131,12 +181,15 @@ class BlobManager(BlobsSynchronizer):
:type token: str
:param cert_file: The path to the CA certificate file.
:type cert_file: str
+ :param cert_file: Remote storage stream URL, if supported.
+ :type cert_file: str
"""
super(BlobsSynchronizer, self).__init__()
if local_path:
mkdir_p(os.path.dirname(local_path))
self.local = SQLiteBlobBackend(local_path, key=key, user=user)
self.remote = remote
+ self.remote_stream = remote_stream
self.secret = secret
self.user = user
self._client = HTTPClient(user, token, cert_file)
@@ -424,6 +477,23 @@ class BlobManager(BlobsSynchronizer):
logger.info("Finished upload: %s" % (blob_id,))
@defer.inlineCallbacks
+ def _downstream(self, blobs_id_list, namespace=''):
+ uri = urljoin(self.remote_stream, self.user)
+ params = {'namespace': namespace} if namespace else None
+ data = BytesIO(json.dumps(list(blobs_id_list)))
+ response = yield self._client.post(uri, params=params, data=data)
+ deferreds = []
+
+ def done_cb(blob_id, blobfd, size):
+ d = self.local.put(blob_id, blobfd, size=size, namespace=namespace)
+ deferreds.append(d)
+ buf = StreamDecrypterBuffer(self.secret, blobs_id_list, done_cb)
+
+ yield treq.collect(response, buf.write)
+ yield defer.gatherResults(deferreds, consumeErrors=True)
+ buf.close()
+
+ @defer.inlineCallbacks
def _download_and_decrypt(self, blob_id, namespace=''):
logger.info("Staring download of blob: %s" % blob_id)
# TODO this needs to be connected in a tube
diff --git a/src/leap/soledad/server/_streaming_resource.py b/src/leap/soledad/server/_streaming_resource.py
index 18e67401..05f2bab6 100644
--- a/src/leap/soledad/server/_streaming_resource.py
+++ b/src/leap/soledad/server/_streaming_resource.py
@@ -35,7 +35,7 @@ __all__ = ['StreamingResource']
logger = getLogger(__name__)
-SIZE_PACKER = struct.Struct("I")
+SIZE_PACKER = struct.Struct('<I')
class StreamingResource(Resource):
@@ -64,7 +64,12 @@ class StreamingResource(Resource):
size = db.get_blob_size(user, blob_id, namespace)
request.write(SIZE_PACKER.pack(size))
with open(path, 'rb') as blob_fd:
- request.content.write(blob_fd.read())
+ # TODO: use a producer
+ blob_fd.seek(-16, 2)
+ request.write(blob_fd.read()) # sends tag
+ blob_fd.seek(0)
+ request.write(' ')
+ request.write(blob_fd.read())
request.finish()
return NOT_DONE_YET
diff --git a/tests/server/test_blobs_server.py b/tests/server/test_blobs_server.py
index eabf3ee7..5a895ddc 100644
--- a/tests/server/test_blobs_server.py
+++ b/tests/server/test_blobs_server.py
@@ -238,6 +238,26 @@ class BlobServerTestCase(unittest.TestCase):
@defer.inlineCallbacks
@pytest.mark.usefixtures("method_tmpdir")
+ def test_downstream_from_namespace(self):
+ manager = BlobManager(self.tempdir, self.uri, self.secret,
+ self.secret, uuid4().hex,
+ remote_stream=self.stream_uri)
+ self.addCleanup(manager.close)
+ namespace, blob_id, content = 'incoming', 'blob_id1', 'test'
+ yield manager._encrypt_and_upload(blob_id, BytesIO(content),
+ namespace=namespace)
+ blob_id2, content2 = 'blob_id2', 'second test'
+ yield manager._encrypt_and_upload(blob_id2, BytesIO(content2),
+ namespace=namespace)
+ blobs_list = [blob_id, blob_id2]
+ yield manager._downstream(blobs_list, namespace)
+ result = yield manager.local.get(blob_id, namespace)
+ self.assertEquals(content, result.getvalue())
+ result = yield manager.local.get(blob_id2, namespace)
+ self.assertEquals(content2, result.getvalue())
+
+ @defer.inlineCallbacks
+ @pytest.mark.usefixtures("method_tmpdir")
def test_download_from_namespace(self):
manager = BlobManager('', self.uri, self.secret,
self.secret, uuid4().hex)