diff options
| -rw-r--r-- | src/leap/soledad/client/_db/blobs/__init__.py | 72 | ||||
| -rw-r--r-- | src/leap/soledad/server/_streaming_resource.py | 9 | ||||
| -rw-r--r-- | tests/server/test_blobs_server.py | 20 | 
3 files changed, 98 insertions, 3 deletions
diff --git a/src/leap/soledad/client/_db/blobs/__init__.py b/src/leap/soledad/client/_db/blobs/__init__.py index ddd22b4b..3daf8d1a 100644 --- a/src/leap/soledad/client/_db/blobs/__init__.py +++ b/src/leap/soledad/client/_db/blobs/__init__.py @@ -23,6 +23,7 @@ from urlparse import urljoin  import os  import json  import base64 +import struct  from collections import defaultdict  from io import BytesIO @@ -104,6 +105,55 @@ class DecrypterBuffer(object):              return self.raw_data, self.raw_data.tell() +class StreamDecrypterBuffer(object): +    size_pack = struct.Struct('<I') + +    def __init__(self, secret, blobs_list, done_callback): +        self.blobs_list = blobs_list +        self.secret = secret +        self.done_callback = done_callback +        self.buf = b'' +        self.reset() + +    def reset(self): +        self.current_blob_size = False +        self.current_blob_id = None +        self.received = 0 + +    def write(self, data): +        if not self.current_blob_size: +            self.buf += data +            if ' ' in self.buf: +                marker, self.buf = self.buf.split(' ') +                assert(len(marker) == 20)  # 16 byte tag + 4 byte size +                size, tag = marker[:4], marker[4:] +                self.current_blob_size = self.size_pack.unpack(size)[0] +                self.received = len(self.buf) +                blob_id = self.blobs_list.pop(0) +                buf = DecrypterBuffer(blob_id, self.secret, tag) +                self.current_blob_id = blob_id +                buf.write(self.buf) +                self.buf = buf +        elif (self.received + len(data)) < self.current_blob_size: +            self.buf.write(data) +            self.received += len(data) +        else: +            missing = self.current_blob_size - self.received +            self.buf.write(data[:missing]) +            blob_id = self.current_blob_id +            fd, size = self.buf.close() +            self.done_callback(blob_id, fd, size) +            self.buf = data[missing:] +            self.reset() + +    def close(self): +        if self.received != 0: +            missing = self.current_blob_size - self.received +            raise Exception("Incomplete download! missing: %s" % missing) +        if self.blobs_list: +            raise Exception("Missing from stream: %s" % self.blobs_list) + +  class BlobManager(BlobsSynchronizer):      """      The BlobManager can list, put, get, set flags and synchronize blobs stored @@ -115,7 +165,7 @@ class BlobManager(BlobsSynchronizer):      def __init__(              self, local_path, remote, key, secret, user, token=None, -            cert_file=None): +            cert_file=None, remote_stream=None):          """          Initialize the blob manager. @@ -131,12 +181,15 @@ class BlobManager(BlobsSynchronizer):          :type token: str          :param cert_file: The path to the CA certificate file.          :type cert_file: str +        :param cert_file: Remote storage stream URL, if supported. +        :type cert_file: str          """          super(BlobsSynchronizer, self).__init__()          if local_path:              mkdir_p(os.path.dirname(local_path))              self.local = SQLiteBlobBackend(local_path, key=key, user=user)          self.remote = remote +        self.remote_stream = remote_stream          self.secret = secret          self.user = user          self._client = HTTPClient(user, token, cert_file) @@ -424,6 +477,23 @@ class BlobManager(BlobsSynchronizer):          logger.info("Finished upload: %s" % (blob_id,))      @defer.inlineCallbacks +    def _downstream(self, blobs_id_list, namespace=''): +        uri = urljoin(self.remote_stream, self.user) +        params = {'namespace': namespace} if namespace else None +        data = BytesIO(json.dumps(list(blobs_id_list))) +        response = yield self._client.post(uri, params=params, data=data) +        deferreds = [] + +        def done_cb(blob_id, blobfd, size): +            d = self.local.put(blob_id, blobfd, size=size, namespace=namespace) +            deferreds.append(d) +        buf = StreamDecrypterBuffer(self.secret, blobs_id_list, done_cb) + +        yield treq.collect(response, buf.write) +        yield defer.gatherResults(deferreds, consumeErrors=True) +        buf.close() + +    @defer.inlineCallbacks      def _download_and_decrypt(self, blob_id, namespace=''):          logger.info("Staring download of blob: %s" % blob_id)          # TODO this needs to be connected in a tube diff --git a/src/leap/soledad/server/_streaming_resource.py b/src/leap/soledad/server/_streaming_resource.py index 18e67401..05f2bab6 100644 --- a/src/leap/soledad/server/_streaming_resource.py +++ b/src/leap/soledad/server/_streaming_resource.py @@ -35,7 +35,7 @@ __all__ = ['StreamingResource']  logger = getLogger(__name__) -SIZE_PACKER = struct.Struct("I") +SIZE_PACKER = struct.Struct('<I')  class StreamingResource(Resource): @@ -64,7 +64,12 @@ class StreamingResource(Resource):              size = db.get_blob_size(user, blob_id, namespace)              request.write(SIZE_PACKER.pack(size))              with open(path, 'rb') as blob_fd: -                request.content.write(blob_fd.read()) +                # TODO: use a producer +                blob_fd.seek(-16, 2) +                request.write(blob_fd.read())  # sends tag +                blob_fd.seek(0) +                request.write(' ') +                request.write(blob_fd.read())          request.finish()          return NOT_DONE_YET diff --git a/tests/server/test_blobs_server.py b/tests/server/test_blobs_server.py index eabf3ee7..5a895ddc 100644 --- a/tests/server/test_blobs_server.py +++ b/tests/server/test_blobs_server.py @@ -238,6 +238,26 @@ class BlobServerTestCase(unittest.TestCase):      @defer.inlineCallbacks      @pytest.mark.usefixtures("method_tmpdir") +    def test_downstream_from_namespace(self): +        manager = BlobManager(self.tempdir, self.uri, self.secret, +                              self.secret, uuid4().hex, +                              remote_stream=self.stream_uri) +        self.addCleanup(manager.close) +        namespace, blob_id, content = 'incoming', 'blob_id1', 'test' +        yield manager._encrypt_and_upload(blob_id, BytesIO(content), +                                          namespace=namespace) +        blob_id2, content2 = 'blob_id2', 'second test' +        yield manager._encrypt_and_upload(blob_id2, BytesIO(content2), +                                          namespace=namespace) +        blobs_list = [blob_id, blob_id2] +        yield manager._downstream(blobs_list, namespace) +        result = yield manager.local.get(blob_id, namespace) +        self.assertEquals(content, result.getvalue()) +        result = yield manager.local.get(blob_id2, namespace) +        self.assertEquals(content2, result.getvalue()) + +    @defer.inlineCallbacks +    @pytest.mark.usefixtures("method_tmpdir")      def test_download_from_namespace(self):          manager = BlobManager('', self.uri, self.secret,                                self.secret, uuid4().hex)  | 
