diff options
author | Victor Shyba <victor1984@riseup.net> | 2017-10-25 12:49:06 -0300 |
---|---|---|
committer | Victor Shyba <victor1984@riseup.net> | 2017-10-27 16:01:38 -0300 |
commit | 83e4e5e65aca6307923d92a7915f3851f1206501 (patch) | |
tree | 27bc63ae5382966ea9dec05912ce11c085a92cd0 /src/leap/soledad/client/_db/blobs | |
parent | 4a852d9d78f715c7f51327f13275cd72e77e9907 (diff) |
[refactor] split blobs into modules
So we can have manager, sync, sql and errors in its own places.
--Related: #8970
Diffstat (limited to 'src/leap/soledad/client/_db/blobs')
-rw-r--r-- | src/leap/soledad/client/_db/blobs/__init__.py | 414 | ||||
-rw-r--r-- | src/leap/soledad/client/_db/blobs/__main__.py | 195 | ||||
-rw-r--r-- | src/leap/soledad/client/_db/blobs/errors.py | 40 | ||||
-rw-r--r-- | src/leap/soledad/client/_db/blobs/sql.py | 266 | ||||
-rw-r--r-- | src/leap/soledad/client/_db/blobs/sync.py | 150 |
5 files changed, 1065 insertions, 0 deletions
diff --git a/src/leap/soledad/client/_db/blobs/__init__.py b/src/leap/soledad/client/_db/blobs/__init__.py new file mode 100644 index 00000000..0120b222 --- /dev/null +++ b/src/leap/soledad/client/_db/blobs/__init__.py @@ -0,0 +1,414 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Clientside BlobBackend Storage. +""" + +from urlparse import urljoin + +import os +import json +import base64 + +from io import BytesIO + +from twisted.logger import Logger +from twisted.internet import defer + +import treq + +from leap.soledad.common.errors import SoledadError +from leap.common.files import mkdir_p + +from leap.soledad.client._crypto import DocInfo +from leap.soledad.client._crypto import InvalidBlob +from leap.soledad.client._crypto import BlobEncryptor +from leap.soledad.client._crypto import BlobDecryptor +from leap.soledad.client._crypto import EncryptionSchemeNotImplementedException +from leap.soledad.client._http import HTTPClient +from leap.soledad.client._pipes import TruncatedTailPipe +from leap.soledad.client._pipes import PreamblePipe + +from .sql import SyncStatus +from .sql import SQLiteBlobBackend +from .sync import BlobsSynchronizer +from .errors import ( + BlobAlreadyExistsError, MaximumRetriesError, + RetriableTransferError, BlobNotFoundError, InvalidFlagsError) + + +logger = Logger() +FIXED_REV = 'ImmutableRevision' # Blob content is immutable + + +def check_http_status(code, blob_id=None, flags=None): + if code == 404: + raise BlobNotFoundError(blob_id) + if code == 409: + raise BlobAlreadyExistsError(blob_id) + elif code == 406: + raise InvalidFlagsError((blob_id, flags)) + elif code != 200: + raise SoledadError("Server Error: %s" % code) + + +class DecrypterBuffer(object): + + def __init__(self, blob_id, secret, tag): + self.doc_info = DocInfo(blob_id, FIXED_REV) + self.secret = secret + self.tag = tag + self.preamble_pipe = PreamblePipe(self._make_decryptor) + self.decrypter = None + + def _make_decryptor(self, preamble): + try: + self.decrypter = BlobDecryptor( + self.doc_info, preamble, + secret=self.secret, + armor=False, + start_stream=False, + tag=self.tag) + return TruncatedTailPipe(self.decrypter, tail_size=len(self.tag)) + except EncryptionSchemeNotImplementedException: + # If we do not support the provided encryption scheme, than that's + # something for the application using soledad to handle. This is + # the case on asymmetrically encrypted documents on IncomingBox. + self.raw_data = BytesIO() + return self.raw_data + + def write(self, data): + self.preamble_pipe.write(data) + + def close(self): + if self.decrypter: + real_size = self.decrypter.decrypted_content_size + return self.decrypter.endStream(), real_size + else: + return self.raw_data, self.raw_data.tell() + + +class BlobManager(BlobsSynchronizer): + """ + The BlobManager can list, put, get, set flags and synchronize blobs stored + in local and remote storages. + """ + + def __init__( + self, local_path, remote, key, secret, user, token=None, + cert_file=None): + """ + Initialize the blob manager. + + :param local_path: The path for the local blobs database. + :type local_path: str + :param remote: The URL of the remote storage. + :type remote: str + :param secret: The secret used to encrypt/decrypt blobs. + :type secret: str + :param user: The uuid of the user. + :type user: str + :param token: The access token for interacting with remote storage. + :type token: str + :param cert_file: The path to the CA certificate file. + :type cert_file: str + """ + super(BlobsSynchronizer, self).__init__() + self.max_retries = 3 + self.concurrent_transfers_limit = 3 + self.concurrent_writes_limit = 100 + if local_path: + mkdir_p(os.path.dirname(local_path)) + self.local = SQLiteBlobBackend(local_path, key=key, user=user) + self.remote = remote + self.secret = secret + self.user = user + self._client = HTTPClient(user, token, cert_file) + self.semaphore = defer.DeferredSemaphore(self.concurrent_writes_limit) + + def close(self): + if hasattr(self, 'local') and self.local: + return self.local.close() + + def count(self, namespace=''): + """ + Count the number of blobs. + + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + :return: A deferred that fires with a dict parsed from the JSON + response, which `count` key has the number of blobs as value. + Eg.: {"count": 42} + :rtype: twisted.internet.defer.Deferred + """ + return self.remote_list(namespace=namespace, only_count=True) + + @defer.inlineCallbacks + def remote_list(self, namespace='', order_by=None, + filter_flag=False, only_count=False): + """ + List blobs from server, with filtering and ordering capabilities. + + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + :param order_by: + Optional parameter to order results. Possible values are: + date or +date - Ascending order (older first) + -date - Descending order (newer first) + :type order_by: str + :param filter_flag: + Optional parameter to filter listing to results containing the + specified tag. + :type filter_flag: leap.soledad.common.blobs.Flags + :param only_count: + Optional paramter to return only the number of blobs found. + :type only_count: bool + :return: A deferred that fires with a list parsed from the JSON + response, holding the requested list of blobs. + Eg.: ['blob_id1', 'blob_id2'] + :rtype: twisted.internet.defer.Deferred + """ + uri = urljoin(self.remote, self.user + '/') + params = {'namespace': namespace} if namespace else {} + if order_by: + params['order_by'] = order_by + if filter_flag: + params['filter_flag'] = filter_flag + if only_count: + params['only_count'] = only_count + response = yield self._client.get(uri, params=params) + check_http_status(response.code) + defer.returnValue((yield response.json())) + + def local_list(self, namespace='', sync_status=None): + return self.local.list(namespace, sync_status) + + def put(self, doc, size, namespace=''): + """ + Put a blob in local storage and upload it to server. + + :param doc: A BlobDoc representing the blob. + :type doc: leap.soledad.client._document.BlobDoc + :param size: The size of the blob. + :type size: int + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + """ + return self.semaphore.run(self._put, doc, size, namespace) + + @defer.inlineCallbacks + def _put(self, doc, size, namespace): + if (yield self.local.exists(doc.blob_id, namespace=namespace)): + error_message = "Blob already exists: %s" % doc.blob_id + raise BlobAlreadyExistsError(error_message) + fd = doc.blob_fd + # TODO this is a tee really, but ok... could do db and upload + # concurrently. not sure if we'd gain something. + yield self.local.put(doc.blob_id, fd, size=size, namespace=namespace) + # In fact, some kind of pipe is needed here, where each write on db + # handle gets forwarded into a write on the connection handle + fd = yield self.local.get(doc.blob_id, namespace=namespace) + yield self._encrypt_and_upload(doc.blob_id, fd, namespace=namespace) + yield self.local.update_sync_status(doc.blob_id, SyncStatus.SYNCED) + + def set_flags(self, blob_id, flags, namespace=''): + """ + Set flags for a given blob_id. + + :param blob_id: + Unique identifier of a blob. + :type blob_id: str + :param flags: + List of flags to be set. + :type flags: [leap.soledad.common.blobs.Flags] + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + :return: A deferred that fires when the operation finishes. + :rtype: twisted.internet.defer.Deferred + """ + return self.semaphore.run(self._set_flags, blob_id, flags, namespace) + + @defer.inlineCallbacks + def _set_flags(self, blob_id, flags, namespace): + params = {'namespace': namespace} if namespace else None + flagsfd = BytesIO(json.dumps(flags)) + uri = urljoin(self.remote, self.user + "/" + blob_id) + response = yield self._client.post(uri, data=flagsfd, params=params) + check_http_status(response.code, blob_id=blob_id, flags=flags) + + @defer.inlineCallbacks + def get_flags(self, blob_id, namespace=''): + """ + Get flags from a given blob_id. + + :param blob_id: + Unique identifier of a blob. + :type blob_id: str + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + :return: A deferred that fires with a list parsed from JSON response. + Eg.: [Flags.PENDING] + :rtype: twisted.internet.defer.Deferred + """ + uri = urljoin(self.remote, self.user + "/" + blob_id) + params = {'namespace': namespace} if namespace else {} + params['only_flags'] = True + response = yield self._client.get(uri, params=params) + check_http_status(response.code, blob_id=blob_id) + defer.returnValue((yield response.json())) + + @defer.inlineCallbacks + def get(self, blob_id, namespace=''): + """ + Get the blob from local storage or, if not available, from the server. + + :param blob_id: + Unique identifier of a blob. + :type blob_id: str + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + """ + local_blob = yield self.local.get(blob_id, namespace=namespace) + if local_blob: + logger.info("Found blob in local database: %s" % blob_id) + defer.returnValue(local_blob) + + yield self.local.update_sync_status( + blob_id, SyncStatus.PENDING_DOWNLOAD, namespace=namespace) + + try: + result = yield self._download_and_decrypt(blob_id, namespace) + except Exception as e: + _, retries = yield self.local.get_sync_status(blob_id) + + if isinstance(e, InvalidBlob): + message = "Corrupted blob received from server! ID: %s\n" + message += "Error: %r\n" + message += "Retries: %s - Attempts left: %s\n" + message += "This is either a bug or the contents of the " + message += "blob have been tampered with. Please, report to " + message += "your provider's sysadmin and submit a bug report." + message %= (blob_id, e, retries, (self.max_retries - retries)) + logger.error(message) + + yield self.local.increment_retries(blob_id) + + if (retries + 1) >= self.max_retries: + failed_download = SyncStatus.FAILED_DOWNLOAD + yield self.local.update_sync_status( + blob_id, failed_download, namespace=namespace) + raise MaximumRetriesError(e) + + raise RetriableTransferError(e) + + if not result: + defer.returnValue(None) + blob, size = result + + if blob: + logger.info("Got decrypted blob of type: %s" % type(blob)) + blob.seek(0) + yield self.local.put(blob_id, blob, size=size, namespace=namespace) + local_blob = yield self.local.get(blob_id, namespace=namespace) + defer.returnValue(local_blob) + else: + # XXX we shouldn't get here, but we will... + # lots of ugly error handling possible: + # 1. retry, might be network error + # 2. try later, maybe didn't finished streaming + # 3.. resignation, might be error while verifying + logger.error('sorry, dunno what happened') + + @defer.inlineCallbacks + def _encrypt_and_upload(self, blob_id, fd, namespace=''): + # TODO ------------------------------------------ + # this is wrong, is doing 2 stages. + # the crypto producer can be passed to + # the uploader and react as data is written. + # try to rewrite as a tube: pass the fd to aes and let aes writer + # produce data to the treq request fd. + # ------------------------------------------------ + logger.info("Staring upload of blob: %s" % blob_id) + doc_info = DocInfo(blob_id, FIXED_REV) + uri = urljoin(self.remote, self.user + "/" + blob_id) + crypter = BlobEncryptor(doc_info, fd, secret=self.secret, + armor=False) + fd = yield crypter.encrypt() + params = {'namespace': namespace} if namespace else None + response = yield self._client.put(uri, data=fd, params=params) + check_http_status(response.code, blob_id) + logger.info("Finished upload: %s" % (blob_id,)) + + @defer.inlineCallbacks + def _download_and_decrypt(self, blob_id, namespace=''): + logger.info("Staring download of blob: %s" % blob_id) + # TODO this needs to be connected in a tube + uri = urljoin(self.remote, self.user + '/' + blob_id) + params = {'namespace': namespace} if namespace else None + response = yield self._client.get(uri, params=params) + check_http_status(response.code, blob_id=blob_id) + + if not response.headers.hasHeader('Tag'): + msg = "Server didn't send a tag header for: %s" % blob_id + logger.error(msg) + raise SoledadError(msg) + tag = response.headers.getRawHeaders('Tag')[0] + tag = base64.urlsafe_b64decode(tag) + buf = DecrypterBuffer(blob_id, self.secret, tag) + + # incrementally collect the body of the response + yield treq.collect(response, buf.write) + fd, size = buf.close() + logger.info("Finished download: (%s, %d)" % (blob_id, size)) + defer.returnValue((fd, size)) + + def delete(self, blob_id, namespace=''): + """ + Delete a blob from local and remote storages. + + :param blob_id: + Unique identifier of a blob. + :type blob_id: str + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + :return: A deferred that fires when the operation finishes. + :rtype: twisted.internet.defer.Deferred + """ + return self.semaphore.run(self._delete, blob_id, namespace) + + @defer.inlineCallbacks + def _delete(self, blob_id, namespace): + logger.info("Staring deletion of blob: %s" % blob_id) + yield self._delete_from_remote(blob_id, namespace=namespace) + if (yield self.local.exists(blob_id, namespace=namespace)): + yield self.local.delete(blob_id, namespace=namespace) + + @defer.inlineCallbacks + def _delete_from_remote(self, blob_id, namespace=''): + # TODO this needs to be connected in a tube + uri = urljoin(self.remote, self.user + '/' + blob_id) + params = {'namespace': namespace} if namespace else None + response = yield self._client.delete(uri, params=params) + check_http_status(response.code, blob_id=blob_id) + defer.returnValue(response) diff --git a/src/leap/soledad/client/_db/blobs/__main__.py b/src/leap/soledad/client/_db/blobs/__main__.py new file mode 100644 index 00000000..ddaa8351 --- /dev/null +++ b/src/leap/soledad/client/_db/blobs/__main__.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# __main__.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import os +from leap.common.files import mkdir_p +from twisted.internet import defer +from . import BlobManager +from leap.soledad.client._document import BlobDoc +from twisted.logger import Logger + +logger = Logger() + + +# +# testing facilities +# +@defer.inlineCallbacks +def testit(reactor): + # configure logging to stdout + from twisted.python import log + import sys + log.startLogging(sys.stdout) + + # parse command line arguments + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--url', default='http://localhost:9000/') + parser.add_argument('--path', default='/tmp/blobs') + parser.add_argument('--secret', default='secret') + parser.add_argument('--uuid', default='user') + parser.add_argument('--token', default=None) + parser.add_argument('--cert-file', default='') + + subparsers = parser.add_subparsers(help='sub-command help', dest='action') + + # parse upload command + parser_upload = subparsers.add_parser( + 'upload', help='upload blob and bypass local db') + parser_upload.add_argument('payload') + parser_upload.add_argument('blob_id') + + # parse download command + parser_download = subparsers.add_parser( + 'download', help='download blob and bypass local db') + parser_download.add_argument('blob_id') + parser_download.add_argument('--output-file', default='/tmp/incoming-file') + + # parse put command + parser_put = subparsers.add_parser( + 'put', help='put blob in local db and upload') + parser_put.add_argument('payload') + parser_put.add_argument('blob_id') + + # parse get command + parser_get = subparsers.add_parser( + 'get', help='get blob from local db, get if needed') + parser_get.add_argument('blob_id') + + # parse delete command + parser_get = subparsers.add_parser( + 'delete', help='delete blob from local and remote db') + parser_get.add_argument('blob_id') + + # parse list command + parser_get = subparsers.add_parser( + 'list', help='list local and remote blob ids') + + # parse send_missing command + parser_get = subparsers.add_parser( + 'send_missing', help='send all pending upload blobs') + + # parse send_missing command + parser_get = subparsers.add_parser( + 'fetch_missing', help='fetch all new server blobs') + + # parse arguments + args = parser.parse_args() + + # TODO convert these into proper unittests + + def _manager(): + mkdir_p(os.path.dirname(args.path)) + manager = BlobManager( + args.path, args.url, + 'A' * 32, args.secret, + args.uuid, args.token, args.cert_file) + return manager + + @defer.inlineCallbacks + def _upload(blob_id, payload): + logger.info(":: Starting upload only: %s" % str((blob_id, payload))) + manager = _manager() + with open(payload, 'r') as fd: + yield manager._encrypt_and_upload(blob_id, fd) + logger.info(":: Finished upload only: %s" % str((blob_id, payload))) + + @defer.inlineCallbacks + def _download(blob_id): + logger.info(":: Starting download only: %s" % blob_id) + manager = _manager() + result = yield manager._download_and_decrypt(blob_id) + logger.info(":: Result of download: %s" % str(result)) + if result: + fd, _ = result + with open(args.output_file, 'w') as f: + logger.info(":: Writing data to %s" % args.output_file) + f.write(fd.read()) + logger.info(":: Finished download only: %s" % blob_id) + + @defer.inlineCallbacks + def _put(blob_id, payload): + logger.info(":: Starting full put: %s" % blob_id) + manager = _manager() + size = os.path.getsize(payload) + with open(payload) as fd: + doc = BlobDoc(fd, blob_id) + result = yield manager.put(doc, size=size) + logger.info(":: Result of put: %s" % str(result)) + logger.info(":: Finished full put: %s" % blob_id) + + @defer.inlineCallbacks + def _get(blob_id): + logger.info(":: Starting full get: %s" % blob_id) + manager = _manager() + fd = yield manager.get(blob_id) + if fd: + logger.info(":: Result of get: " + fd.getvalue()) + logger.info(":: Finished full get: %s" % blob_id) + + @defer.inlineCallbacks + def _delete(blob_id): + logger.info(":: Starting deletion of: %s" % blob_id) + manager = _manager() + yield manager.delete(blob_id) + logger.info(":: Finished deletion of: %s" % blob_id) + + @defer.inlineCallbacks + def _list(): + logger.info(":: Listing local blobs") + manager = _manager() + local_list = yield manager.local_list() + logger.info(":: Local list: %s" % local_list) + logger.info(":: Listing remote blobs") + remote_list = yield manager.remote_list() + logger.info(":: Remote list: %s" % remote_list) + + @defer.inlineCallbacks + def _send_missing(): + logger.info(":: Sending local pending upload docs") + manager = _manager() + yield manager.send_missing() + logger.info(":: Finished sending missing docs") + + @defer.inlineCallbacks + def _fetch_missing(): + logger.info(":: Fetching remote new docs") + manager = _manager() + yield manager.fetch_missing() + logger.info(":: Finished fetching new docs") + + if args.action == 'upload': + yield _upload(args.blob_id, args.payload) + elif args.action == 'download': + yield _download(args.blob_id) + elif args.action == 'put': + yield _put(args.blob_id, args.payload) + elif args.action == 'get': + yield _get(args.blob_id) + elif args.action == 'delete': + yield _delete(args.blob_id) + elif args.action == 'list': + yield _list() + elif args.action == 'send_missing': + yield _send_missing() + elif args.action == 'fetch_missing': + yield _fetch_missing() + + +if __name__ == '__main__': + from twisted.internet.task import react + react(testit) diff --git a/src/leap/soledad/client/_db/blobs/errors.py b/src/leap/soledad/client/_db/blobs/errors.py new file mode 100644 index 00000000..90de9553 --- /dev/null +++ b/src/leap/soledad/client/_db/blobs/errors.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# errors.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Blobs exceptions +""" +from leap.soledad.common.errors import SoledadError + + +class BlobAlreadyExistsError(SoledadError): + pass + + +class BlobNotFoundError(SoledadError): + pass + + +class InvalidFlagsError(SoledadError): + pass + + +class RetriableTransferError(Exception): + pass + + +class MaximumRetriesError(Exception): + pass diff --git a/src/leap/soledad/client/_db/blobs/sql.py b/src/leap/soledad/client/_db/blobs/sql.py new file mode 100644 index 00000000..dcca1671 --- /dev/null +++ b/src/leap/soledad/client/_db/blobs/sql.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +# _blobs.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Local blobs backend on SQLCipher +""" +import os +import binascii +from functools import partial +from twisted.internet import defer +from twisted.logger import Logger +from twisted.enterprise import adbapi +from leap.common.files import mkdir_p +from .. import sqlcipher +from .. import pragmas +from io import BytesIO +logger = Logger() + + +class SyncStatus: + SYNCED = 1 + PENDING_UPLOAD = 2 + PENDING_DOWNLOAD = 3 + FAILED_UPLOAD = 4 + FAILED_DOWNLOAD = 5 + UNAVAILABLE_STATUSES = (3, 5) + + +class SQLiteBlobBackend(object): + + concurrency_limit = 10 + + def __init__(self, path, key=None, user=None): + dbname = '%s_blobs.db' % (user or 'soledad') + self.path = os.path.abspath( + os.path.join(path, dbname)) + mkdir_p(os.path.dirname(self.path)) + if not key: + raise ValueError('key cannot be None') + backend = 'pysqlcipher.dbapi2' + opts = sqlcipher.SQLCipherOptions( + '/tmp/ignored', binascii.b2a_hex(key), + is_raw_key=True, create=True) + openfun = partial(pragmas.set_init_pragmas, opts=opts, + schema_func=_init_blob_table) + + self.dbpool = ConnectionPool( + backend, self.path, check_same_thread=False, timeout=5, + cp_openfun=openfun, cp_min=2, cp_max=2, cp_name='blob_pool') + + def close(self): + from twisted._threads import AlreadyQuit + try: + self.dbpool.close() + except AlreadyQuit: + pass + + @defer.inlineCallbacks + def put(self, blob_id, blob_fd, size=None, + namespace='', status=SyncStatus.PENDING_UPLOAD): + previous_state = yield self.get_sync_status(blob_id) + unavailable = SyncStatus.UNAVAILABLE_STATUSES + if previous_state and previous_state[0] in unavailable: + yield self.delete(blob_id, namespace=namespace) + status = SyncStatus.SYNCED + logger.info("Saving blob in local database...") + insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' + insert += ' VALUES (?, ?, zeroblob(?), ?)' + values = (blob_id, namespace, size, status) + irow = yield self.dbpool.insertAndGetLastRowid(insert, values) + yield self.dbpool.write_blob('blobs', 'payload', irow, blob_fd) + logger.info("Finished saving blob in local database.") + + @defer.inlineCallbacks + def get(self, blob_id, namespace=''): + # TODO we can also stream the blob value using sqlite + # incremental interface for blobs - and just return the raw fd instead + select = 'SELECT payload FROM blobs WHERE blob_id = ? AND namespace= ?' + values = (blob_id, namespace,) + avoid_values = SyncStatus.UNAVAILABLE_STATUSES + select += ' AND sync_status NOT IN (%s)' + select %= ','.join(['?' for _ in avoid_values]) + values += avoid_values + result = yield self.dbpool.runQuery(select, values) + if result: + defer.returnValue(BytesIO(str(result[0][0]))) + + @defer.inlineCallbacks + def get_sync_status(self, blob_id): + select = 'SELECT sync_status, retries FROM blobs WHERE blob_id = ?' + result = yield self.dbpool.runQuery(select, (blob_id,)) + if result: + defer.returnValue((result[0][0], result[0][1])) + + @defer.inlineCallbacks + def list(self, namespace='', sync_status=False): + query = 'select blob_id from blobs where namespace = ?' + values = (namespace,) + if sync_status: + query += ' and sync_status = ?' + values += (sync_status,) + else: + avoid_values = SyncStatus.UNAVAILABLE_STATUSES + query += ' AND sync_status NOT IN (%s)' + query %= ','.join(['?' for _ in avoid_values]) + values += avoid_values + result = yield self.dbpool.runQuery(query, values) + if result: + defer.returnValue([b_id[0] for b_id in result]) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + def update_sync_status(self, blob_id, sync_status, namespace=""): + query = 'SELECT sync_status FROM blobs WHERE blob_id = ?' + result = yield self.dbpool.runQuery(query, (blob_id,)) + + if not result: + insert = 'INSERT INTO blobs' + insert += ' (blob_id, namespace, payload, sync_status)' + insert += ' VALUES (?, ?, zeroblob(0), ?)' + values = (blob_id, namespace, sync_status) + yield self.dbpool.runOperation(insert, values) + return + + update = 'UPDATE blobs SET sync_status = ? WHERE blob_id = ?' + values = (sync_status, blob_id,) + result = yield self.dbpool.runOperation(update, values) + + def update_batch_sync_status(self, blob_id_list, sync_status, + namespace=''): + insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' + first_blob_id, blob_id_list = blob_id_list[0], blob_id_list[1:] + insert += ' VALUES (?, ?, zeroblob(0), ?)' + values = (first_blob_id, namespace, sync_status) + for blob_id in blob_id_list: + insert += ', (?, ?, zeroblob(0), ?)' + values += (blob_id, namespace, sync_status) + return self.dbpool.runQuery(insert, values) + + def increment_retries(self, blob_id): + query = 'update blobs set retries = retries + 1 where blob_id = ?' + return self.dbpool.runQuery(query, (blob_id,)) + + @defer.inlineCallbacks + def list_namespaces(self): + query = 'select namespace from blobs' + result = yield self.dbpool.runQuery(query) + if result: + defer.returnValue([namespace[0] for namespace in result]) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + def exists(self, blob_id, namespace=''): + query = 'SELECT blob_id from blobs WHERE blob_id = ? AND namespace= ?' + result = yield self.dbpool.runQuery(query, (blob_id, namespace,)) + defer.returnValue(bool(len(result))) + + def delete(self, blob_id, namespace=''): + query = 'DELETE FROM blobs WHERE blob_id = ? AND namespace = ?' + return self.dbpool.runQuery(query, (blob_id, namespace,)) + + +def _init_sync_table(conn): + maybe_create = """ + CREATE TABLE IF NOT EXISTS + sync_state ( + blob_id PRIMARY KEY, + sync_status INT default %s, + retries INT default 0)""" + default_status = SyncStatus.PENDING_UPLOAD + maybe_create %= default_status + conn.execute(maybe_create) + + +def _init_blob_table(conn): + maybe_create = ( + "CREATE TABLE IF NOT EXISTS " + "blobs (" + "blob_id PRIMARY KEY, " + "payload BLOB)") + conn.execute(maybe_create) + columns = [row[1] for row in conn.execute("pragma" + " table_info(blobs)").fetchall()] + if 'namespace' not in columns: + # namespace migration + conn.execute('ALTER TABLE blobs ADD COLUMN namespace TEXT') + if 'sync_status' not in columns: + # sync status migration + default_status = SyncStatus.PENDING_UPLOAD + sync_column = 'ALTER TABLE blobs ADD COLUMN sync_status INT default %s' + sync_column %= default_status + conn.execute(sync_column) + conn.execute('ALTER TABLE blobs ADD COLUMN retries INT default 0') + + +class ConnectionPool(adbapi.ConnectionPool): + + def insertAndGetLastRowid(self, *args, **kwargs): + """ + Execute an SQL query and return the last rowid. + + See: https://sqlite.org/c3ref/last_insert_rowid.html + """ + return self.runInteraction( + self._insertAndGetLastRowid, *args, **kwargs) + + def _insertAndGetLastRowid(self, trans, *args, **kw): + trans.execute(*args, **kw) + return trans.lastrowid + + def blob(self, table, column, irow, flags): + """ + Open a BLOB for incremental I/O. + + Return a handle to the BLOB that would be selected by: + + SELECT column FROM table WHERE rowid = irow; + + See: https://sqlite.org/c3ref/blob_open.html + + :param table: The table in which to lookup the blob. + :type table: str + :param column: The column where the BLOB is located. + :type column: str + :param rowid: The rowid of the BLOB. + :type rowid: int + :param flags: If zero, BLOB is opened for read-only. If non-zero, + BLOB is opened for RW. + :type flags: int + + :return: A BLOB handle. + :rtype: pysqlcipher.dbapi.Blob + """ + return self.runInteraction(self._blob, table, column, irow, flags) + + def write_blob(self, table, column, irow, blob_fd): + return self.runInteraction(self._write_blob, table, column, irow, + blob_fd) + + def _write_blob(self, trans, table, column, irow, blob_fd): + blob_fd.seek(0) + with trans._connection.blob(table, column, irow, 1) as handle: + data = blob_fd.read(2**12) + while data: + handle.write(data) + data = blob_fd.read(2**12) + + def _blob(self, trans, table, column, irow, flags): + # TODO: should not use transaction private variable here + handle = trans._connection.blob(table, column, irow, flags) + return handle diff --git a/src/leap/soledad/client/_db/blobs/sync.py b/src/leap/soledad/client/_db/blobs/sync.py new file mode 100644 index 00000000..67df1d7f --- /dev/null +++ b/src/leap/soledad/client/_db/blobs/sync.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# _blobs.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Synchronization between blobs client/server +""" +from twisted.internet import defer +from twisted.internet import reactor +from twisted.logger import Logger +from twisted.internet import error +from .sql import SyncStatus +from .errors import MaximumRetriesError, RetriableTransferError +logger = Logger() + + +def sleep(seconds): + d = defer.Deferred() + reactor.callLater(seconds, d.callback, None) + return d + + +MAX_WAIT = 60 # In seconds. Max time between retries + + +@defer.inlineCallbacks +def with_retry(func, *args, **kwargs): + retry_wait = 1 + retriable_errors = (error.ConnectError, error.ConnectionClosed, + RetriableTransferError,) + while True: + try: + yield func(*args, **kwargs) + break + except retriable_errors: + yield sleep(retry_wait) + retry_wait = min(retry_wait + 10, MAX_WAIT) + + +class BlobsSynchronizer(object): + + @defer.inlineCallbacks + def refresh_sync_status_from_server(self, namespace=''): + d1 = self.remote_list(namespace=namespace) + d2 = self.local_list(namespace=namespace) + remote_list, local_list = yield defer.gatherResults([d1, d2]) + pending_download_ids = tuple(set(remote_list) - set(local_list)) + yield self.local.update_batch_sync_status( + pending_download_ids, + SyncStatus.PENDING_DOWNLOAD, + namespace=namespace) + + @defer.inlineCallbacks + def send_missing(self, namespace=''): + """ + Compare local and remote blobs and send what's missing in server. + + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + """ + missing = yield self.local.list(namespace, SyncStatus.PENDING_UPLOAD) + total = len(missing) + logger.info("Will send %d blobs to server." % total) + deferreds = [] + semaphore = defer.DeferredSemaphore(self.concurrent_transfers_limit) + + def release(result): + print result + semaphore.release() + return result + + for i in xrange(total): + yield semaphore.acquire() + blob_id = missing.pop() + d = with_retry(self.__send_one, blob_id, namespace, i, total) + d.addCallbacks(release, release) + deferreds.append(d) + yield defer.gatherResults(deferreds) + + @defer.inlineCallbacks + def __send_one(self, blob_id, namespace, i, total): + logger.info("Sending blob to server (%d/%d): %s" + % (i, total, blob_id)) + fd = yield self.local.get(blob_id, namespace=namespace) + try: + yield self._encrypt_and_upload(blob_id, fd) + yield self.local.update_sync_status(blob_id, SyncStatus.SYNCED) + except Exception as e: + yield self.local.increment_retries(blob_id) + res = yield self.local.get_sync_status(blob_id) + _, retries = res + if (retries + 1) > self.max_retries: + failed_upload = SyncStatus.FAILED_UPLOAD + yield self.local.update_sync_status(blob_id, failed_upload) + raise MaximumRetriesError(e) + raise e + + @defer.inlineCallbacks + def fetch_missing(self, namespace=''): + """ + Compare local and remote blobs and fetch what's missing in local + storage. + + :param namespace: + Optional parameter to restrict operation to a given namespace. + :type namespace: str + """ + # TODO: Use something to prioritize user requests over general new docs + d = self.local_list(namespace=namespace, + sync_status=SyncStatus.PENDING_DOWNLOAD) + docs_we_want = yield d + total = len(docs_we_want) + logger.info("Will fetch %d blobs from server." % total) + deferreds = [] + semaphore = defer.DeferredSemaphore(self.concurrent_transfers_limit) + + def release(result): + semaphore.release() + return result + + for i in xrange(len(docs_we_want)): + yield semaphore.acquire() + blob_id = docs_we_want.pop() + logger.info("Fetching blob (%d/%d): %s" % (i, total, blob_id)) + d = with_retry(self.get, blob_id, namespace) + d.addCallbacks(release, release) + deferreds.append(d) + yield defer.gatherResults(deferreds) + + @defer.inlineCallbacks + def sync(self, namespace=''): + try: + yield self.refresh_sync_status_from_server(namespace) + yield self.fetch_missing(namespace) + yield self.send_missing(namespace) + except defer.FirstError as e: + e.subFailure.raiseException() |