[feature] add ranges to blobs backend
authordrebs <drebs@leap.se>
Tue, 19 Dec 2017 20:00:23 +0000 (18:00 -0200)
committerdrebs <drebs@leap.se>
Tue, 26 Dec 2017 11:04:50 +0000 (09:04 -0200)
src/leap/soledad/server/_blobs/errors.py
src/leap/soledad/server/_blobs/fs_backend.py
src/leap/soledad/server/_blobs/resource.py
src/leap/soledad/server/interfaces.py
tests/blobs/test_fs_backend.py
tests/server/test_blobs_server.py
tests/server/test_incoming_server.py

index 8c1c532..0cd2059 100644 (file)
@@ -42,3 +42,10 @@ class ImproperlyConfiguredException(Exception):
     """
     Raised when there is a problem with the configuration of a backend.
     """
+
+
+class RangeNotSatisfiable(Exception):
+    """
+    Raised when the Range: HTTP header was sent but the server doesn't know how
+    to satisfy it.
+    """
index 6274a1c..e1c4991 100644 (file)
@@ -27,7 +27,8 @@ from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet import utils
-from twisted.web.client import FileBodyProducer
+from twisted.web.static import NoRangeStaticProducer
+from twisted.web.static import SingleRangeStaticProducer
 
 from leap.common.files import mkdir_p
 from leap.soledad.common.blobs import ACCEPTED_FLAGS
@@ -44,6 +45,43 @@ from .util import VALID_STRINGS
 logger = getLogger(__name__)
 
 
+class NoRangeProducer(NoRangeStaticProducer):
+    """
+    A static file producer that fires a deferred when it's finished.
+    """
+
+    def start(self):
+        NoRangeStaticProducer.start(self)
+        if self.request is None:
+            return defer.succeed(None)
+        self.deferred = defer.Deferred()
+        return self.deferred
+
+    def stopProducing(self):
+        NoRangeStaticProducer.stopProducing(self)
+        if hasattr(self, 'deferred'):
+            self.deferred.callback(None)
+
+
+class SingleRangeProducer(SingleRangeStaticProducer):
+    """
+    A static file producer of a single file range that fires a deferred when
+    it's finished.
+    """
+
+    def start(self):
+        SingleRangeStaticProducer.start(self)
+        if self.request is None:
+            return defer.succeed(None)
+        self.deferred = defer.Deferred()
+        return self.deferred
+
+    def stopProducing(self):
+        SingleRangeStaticProducer.stopProducing(self)
+        if hasattr(self, 'deferred'):
+            self.deferred.callback(None)
+
+
 @implementer(interfaces.IBlobsBackend)
 class FilesystemBlobsBackend(object):
 
@@ -63,13 +101,20 @@ class FilesystemBlobsBackend(object):
         open(path, 'a')
 
     @defer.inlineCallbacks
-    def read_blob(self, user, blob_id, consumer, namespace=''):
+    def read_blob(self, user, blob_id, consumer, namespace='', range=None):
         logger.info('reading blob: %s - %s@%s' % (user, blob_id, namespace))
         path = self._get_path(user, blob_id, namespace)
         logger.debug('blob path: %s' % path)
         with open(path) as fd:
-            producer = FileBodyProducer(fd)
-            yield producer.startProducing(consumer)
+            if range is None:
+                producer = NoRangeProducer(consumer, fd)
+            else:
+                start, end = range
+                offset = start
+                size = end - start
+                args = (consumer, fd, offset, size)
+                producer = SingleRangeProducer(*args)
+            yield producer.start()
 
     def get_flags(self, user, blob_id, namespace=''):
         try:
index a6c209f..dd9af86 100644 (file)
@@ -19,6 +19,8 @@ A Twisted Web resource for blobs.
 """
 import json
 
+from twisted.python.compat import intToBytes
+from twisted.python.compat import networkString
 from twisted.web import resource
 from twisted.web.client import FileBodyProducer
 from twisted.web.server import NOT_DONE_YET
@@ -31,6 +33,7 @@ from .errors import BlobNotFound
 from .errors import BlobExists
 from .errors import ImproperlyConfiguredException
 from .errors import QuotaExceeded
+from .errors import RangeNotSatisfiable
 from .util import VALID_STRINGS
 
 from leap.soledad.common.log import getLogger
@@ -44,7 +47,7 @@ def _catchBlobNotFound(failure, request, user, blob_id):
     logger.error("Error 404: Blob %s does not exist for user %s"
                  % (blob_id, user))
     request.setResponseCode(404)
-    request.write("Blob doesn't exists: %s" % blob_id)
+    request.write("Blob doesn't exist: %s" % blob_id)
     request.finish()
 
 
@@ -128,7 +131,7 @@ class BlobsResource(resource.Resource):
         d.addErrback(_catchAllErrors, request)
         return NOT_DONE_YET
 
-    def _get_blob(self, request, user, blob_id, namespace):
+    def _get_blob(self, request, user, blob_id, namespace, range):
 
         def _set_tag_header(tag):
             request.responseHeaders.setRawHeaders('Tag', [tag])
@@ -136,18 +139,32 @@ class BlobsResource(resource.Resource):
         def _read_blob(_):
             handler = self._handler
             consumer = request
-            d = handler.read_blob(user, blob_id, consumer, namespace=namespace)
+            d = handler.read_blob(
+                user, blob_id, consumer, namespace=namespace, range=range)
             return d
 
         d = self._handler.get_tag(user, blob_id, namespace)
         d.addCallback(_set_tag_header)
         d.addCallback(_read_blob)
-        d.addCallback(lambda _: request.finish())
         d.addErrback(_catchBlobNotFound, request, user, blob_id)
         d.addErrback(_catchAllErrors, request, finishRequest=True)
 
         return NOT_DONE_YET
 
+    def _parseRange(self, range):
+        if not range:
+            return None
+        try:
+            kind, value = range.split(b'=', 1)
+            if kind.strip() != b'bytes':
+                raise Exception('Unknown unit: %s' % kind)
+            start, end = value.split('-')
+            start = int(start) if start else None
+            end = int(end) if end else None
+            return start, end
+        except Exception as e:
+            raise RangeNotSatisfiable(e)
+
     def render_GET(self, request):
         logger.info("http get: %s" % request.path)
         user, blob_id, namespace = self._validate(request)
@@ -162,7 +179,34 @@ class BlobsResource(resource.Resource):
         if only_flags:
             return self._only_flags(request, user, blob_id, namespace)
 
-        return self._get_blob(request, user, blob_id, namespace)
+        def _handleRangeHeader(size):
+            try:
+                range = self._parseRange(request.getHeader('Range'))
+            except RangeNotSatisfiable:
+                content_range = 'bytes */%d' % size
+                content_range = networkString(content_range)
+                request.setResponseCode(416)
+                request.setHeader(b'content-range', content_range)
+                request.finish()
+                return
+
+            if not range:
+                start = end = None
+                request.setResponseCode(200)
+                request.setHeader(b'content-length', intToBytes(size))
+            else:
+                start, end = range
+                content_range = 'bytes %d-%d/%d' % (start, end, size)
+                content_range = networkString(content_range)
+                length = intToBytes(end - start)
+                request.setResponseCode(206)
+                request.setHeader(b'content-range', content_range)
+                request.setHeader(b'content-length', length)
+            return self._get_blob(request, user, blob_id, namespace, range)
+
+        d = self._handler.get_blob_size(user, blob_id, namespace=namespace)
+        d.addCallback(_handleRangeHeader)
+        return NOT_DONE_YET
 
     def render_DELETE(self, request):
         logger.info("http put: %s" % request.path)
index c2e7985..089111e 100644 (file)
@@ -25,7 +25,7 @@ class IBlobsBackend(Interface):
     An interface for a backend that can store blobs.
     """
 
-    def read_blob(user, blob_id, consumer, namespace=''):
+    def read_blob(user, blob_id, consumer, namespace='', range=None):
         """
         Read a blob from the backend storage.
 
@@ -37,6 +37,9 @@ class IBlobsBackend(Interface):
         :type consumer: twisted.internet.interfaces.IConsumer provider
         :param namespace: An optional namespace for the blob.
         :type namespace: str
+        :param range: An optional tuple indicating start and end position of
+            the blob to be produced.
+        :type range: (int, int)
 
         :return: A deferred that fires when the blob has been written to the
             consumer.
index 5c086a5..2485ccd 100644 (file)
@@ -20,6 +20,7 @@ Tests for blobs backend on server side.
 from twisted.trial import unittest
 from twisted.internet import defer
 from twisted.web.client import FileBodyProducer
+from twisted.web.test.requesthelper import DummyRequest
 from leap.common.files import mkdir_p
 from leap.soledad.server import _blobs
 from mock import Mock
@@ -61,18 +62,18 @@ class FilesystemBackendTestCase(unittest.TestCase):
         self.assertEquals(10, size)
 
     @pytest.mark.usefixtures("method_tmpdir")
-    @mock.patch('leap.soledad.server._blobs.fs_backend.open')
     @mock.patch('leap.soledad.server._blobs.fs_backend'
                 '.FilesystemBlobsBackend._get_path')
     @defer.inlineCallbacks
-    def test_read_blob(self, get_path, open):
-        get_path.return_value = 'path'
-        open.return_value = io.BytesIO('content')
+    def test_read_blob(self, get_path):
+        path = os.path.join(self.tempdir, 'blob')
+        with open(path, 'w') as f:
+            f.write('bl0b')
+        get_path.return_value = path
         backend = _blobs.FilesystemBlobsBackend(blobs_path=self.tempdir)
-        consumer = Mock()
+        consumer = DummyRequest([''])
         yield backend.read_blob('user', 'blob_id', consumer)
-        consumer.write.assert_called_with('content')
-        get_path.assert_called_once_with('user', 'blob_id', '')
+        self.assertEqual(['bl0b'], consumer.written)
 
     @pytest.mark.usefixtures("method_tmpdir")
     @mock.patch.object(os.path, 'isfile')
@@ -246,3 +247,13 @@ class FilesystemBackendTestCase(unittest.TestCase):
 
         count = yield backend.count('user', namespace='xfiles')
         self.assertEqual(2, count)
+
+    @pytest.mark.usefixtures("method_tmpdir")
+    @defer.inlineCallbacks
+    def test_read_range(self):
+        backend = _blobs.FilesystemBlobsBackend(blobs_path=self.tempdir)
+        producer = FileBodyProducer(io.BytesIO("0123456789"))
+        yield backend.write_blob('user', 'blob-id', producer)
+        consumer = DummyRequest([''])
+        yield backend.read_blob('user', 'blob-id', consumer, range=(1, 3))
+        self.assertEqual(['12'], consumer.written)
index bf92938..6fed6d6 100644 (file)
@@ -19,6 +19,8 @@ Integration tests for blobs server
 """
 import os
 import pytest
+import re
+import treq
 from urlparse import urljoin
 from uuid import uuid4
 from io import BytesIO
@@ -49,6 +51,11 @@ def sleep(x):
     return d
 
 
+def _get(*args, **kwargs):
+    kwargs.update({'persistent': False})
+    return treq.get(*args, **kwargs)
+
+
 class BlobServerTestCase(unittest.TestCase):
 
     def setUp(self):
@@ -455,3 +462,48 @@ class BlobServerTestCase(unittest.TestCase):
         self.addCleanup(manager.close)
         with pytest.raises(SoledadError):
             yield manager.delete('missing_id')
+
+    @defer.inlineCallbacks
+    @pytest.mark.usefixtures("method_tmpdir")
+    def test_get_range(self):
+        user_id = uuid4().hex
+        manager = BlobManager(self.tempdir, self.uri, self.secret,
+                              self.secret, user_id)
+        self.addCleanup(manager.close)
+        blob_id, content = 'blob_id', '0123456789'
+        doc = BlobDoc(BytesIO(content), blob_id)
+        yield manager.put(doc, len(content))
+        uri = urljoin(self.uri, '%s/%s' % (user_id, blob_id))
+        res = yield _get(uri, headers={'Range': 'bytes=10-20'})
+        text = yield res.text()
+        self.assertTrue(res.headers.hasHeader('content-range'))
+        content_range = res.headers.getRawHeaders('content-range').pop()
+        self.assertIsNotNone(re.match('^bytes 10-20/[0-9]+$', content_range))
+        self.assertEqual(10, len(text))
+
+    @defer.inlineCallbacks
+    @pytest.mark.usefixtures("method_tmpdir")
+    def test_get_range_not_satisfiable(self):
+        # put a blob in place
+        user_id = uuid4().hex
+        manager = BlobManager(self.tempdir, self.uri, self.secret,
+                              self.secret, user_id)
+        self.addCleanup(manager.close)
+        blob_id, content = uuid4().hex, 'content'
+        doc = BlobDoc(BytesIO(content), blob_id)
+        yield manager.put(doc, len(content))
+        # and check possible parsing errors
+        uri = urljoin(self.uri, '%s/%s' % (user_id, blob_id))
+        ranges = [
+            'bytes',
+            'bytes=',
+            'bytes=1',
+            'bytes=blah-100',
+            'potatoes=10-100'
+            'blah'
+        ]
+        for range in ranges:
+            res = yield _get(uri, headers={'Range': range})
+            self.assertEqual(416, res.code)
+            content_range = res.headers.getRawHeaders('content-range').pop()
+            self.assertIsNotNone(re.match('^bytes \*/[0-9]+$', content_range))
index 23c0aa9..16d5d5e 100644 (file)
 Integration tests for incoming API
 """
 import pytest
-import mock
 import treq
 from io import BytesIO
 from uuid import uuid4
 from twisted.web.server import Site
 from twisted.internet import reactor
 from twisted.internet import defer
+from twisted.web.test.requesthelper import DummyRequest
 
 from leap.soledad.server._incoming import IncomingResource
 from leap.soledad.server._blobs import BlobsServerState
@@ -83,10 +83,10 @@ class IncomingOnCouchServerTestCase(CouchDBTestCase):
         yield treq.put(incoming_endpoint, BytesIO(content), persistent=False)
 
         db = self.state.open_database(user_id)
-        consumer = mock.Mock()
+        consumer = DummyRequest([''])
         yield db.read_blob(user_id, doc_id, consumer, namespace='MX')
         flags = yield db.get_flags(user_id, doc_id, namespace='MX')
-        data = consumer.write.call_args[0][0]
+        data = consumer.written.pop()
         expected_preamble = formatter.preamble(content, doc_id)
         expected_preamble = decode_preamble(expected_preamble, True)
         written_preamble, written_content = data.split()