From a2a41830f72ae014d3b4dba591c6bfc85ffeb062 Mon Sep 17 00:00:00 2001 From: drebs Date: Sat, 29 Apr 2017 16:40:55 +0200 Subject: [refactor] create client _database module --- client/src/leap/soledad/client/_blobs.py | 533 ------------ .../src/leap/soledad/client/_database/__init__.py | 0 client/src/leap/soledad/client/_database/adbapi.py | 298 +++++++ client/src/leap/soledad/client/_database/blobs.py | 522 ++++++++++++ .../src/leap/soledad/client/_database/dbschema.sql | 42 + .../src/leap/soledad/client/_database/pragmas.py | 379 +++++++++ .../src/leap/soledad/client/_database/sqlcipher.py | 633 ++++++++++++++ client/src/leap/soledad/client/_database/sqlite.py | 930 +++++++++++++++++++++ client/src/leap/soledad/client/adbapi.py | 297 ------- client/src/leap/soledad/client/api.py | 20 +- .../examples/benchmarks/measure_index_times.py | 2 +- .../benchmarks/measure_index_times_custom_docid.py | 2 +- .../src/leap/soledad/client/examples/use_adbapi.py | 2 +- .../src/leap/soledad/client/http_target/fetch.py | 2 +- client/src/leap/soledad/client/pragmas.py | 379 --------- client/src/leap/soledad/client/sqlcipher.py | 632 -------------- common/src/leap/soledad/common/l2db/__init__.py | 4 +- .../leap/soledad/common/l2db/backends/dbschema.sql | 42 - .../soledad/common/l2db/backends/sqlite_backend.py | 930 --------------------- .../soledad/common/l2db/remote/server_state.py | 14 +- testing/test_soledad/u1db_tests/__init__.py | 7 +- testing/test_soledad/u1db_tests/test_open.py | 11 +- testing/test_soledad/util.py | 4 +- testing/tests/blobs/test_blobs.py | 10 +- testing/tests/blobs/test_local_backend.py | 2 +- testing/tests/client/test_aux_methods.py | 2 +- testing/tests/server/test_blobs_server.py | 4 +- testing/tests/sqlcipher/test_async.py | 6 +- testing/tests/sqlcipher/test_backend.py | 9 +- testing/tests/sync/test_sync_target.py | 6 +- 30 files changed, 2860 insertions(+), 2864 deletions(-) delete mode 100644 client/src/leap/soledad/client/_blobs.py create mode 100644 client/src/leap/soledad/client/_database/__init__.py create mode 100644 client/src/leap/soledad/client/_database/adbapi.py create mode 100644 client/src/leap/soledad/client/_database/blobs.py create mode 100644 client/src/leap/soledad/client/_database/dbschema.sql create mode 100644 client/src/leap/soledad/client/_database/pragmas.py create mode 100644 client/src/leap/soledad/client/_database/sqlcipher.py create mode 100644 client/src/leap/soledad/client/_database/sqlite.py delete mode 100644 client/src/leap/soledad/client/adbapi.py delete mode 100644 client/src/leap/soledad/client/pragmas.py delete mode 100644 client/src/leap/soledad/client/sqlcipher.py delete mode 100644 common/src/leap/soledad/common/l2db/backends/dbschema.sql delete mode 100644 common/src/leap/soledad/common/l2db/backends/sqlite_backend.py diff --git a/client/src/leap/soledad/client/_blobs.py b/client/src/leap/soledad/client/_blobs.py deleted file mode 100644 index 90007427..00000000 --- a/client/src/leap/soledad/client/_blobs.py +++ /dev/null @@ -1,533 +0,0 @@ -# -*- coding: utf-8 -*- -# _blobs.py -# Copyright (C) 2017 LEAP -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -""" -Clientside BlobBackend Storage. -""" - -from urlparse import urljoin - -import errno -import os -import uuid -import base64 - -from io import BytesIO -from functools import partial - -from twisted.logger import Logger -from twisted.enterprise import adbapi -from twisted.internet import defer -from twisted.web.client import FileBodyProducer - -import treq - -from leap.soledad.client.sqlcipher import SQLCipherOptions -from leap.soledad.client import pragmas -from leap.soledad.client._pipes import TruncatedTailPipe, PreamblePipe -from leap.soledad.common.errors import SoledadError - -from _crypto import DocInfo, BlobEncryptor, BlobDecryptor -from _http import HTTPClient - - -logger = Logger() -FIXED_REV = 'ImmutableRevision' # Blob content is immutable - - -class BlobAlreadyExistsError(SoledadError): - pass - - -class ConnectionPool(adbapi.ConnectionPool): - - def insertAndGetLastRowid(self, *args, **kwargs): - """ - Execute an SQL query and return the last rowid. - - See: https://sqlite.org/c3ref/last_insert_rowid.html - """ - return self.runInteraction( - self._insertAndGetLastRowid, *args, **kwargs) - - def _insertAndGetLastRowid(self, trans, *args, **kw): - trans.execute(*args, **kw) - return trans.lastrowid - - def blob(self, table, column, irow, flags): - """ - Open a BLOB for incremental I/O. - - Return a handle to the BLOB that would be selected by: - - SELECT column FROM table WHERE rowid = irow; - - See: https://sqlite.org/c3ref/blob_open.html - - :param table: The table in which to lookup the blob. - :type table: str - :param column: The column where the BLOB is located. - :type column: str - :param rowid: The rowid of the BLOB. - :type rowid: int - :param flags: If zero, BLOB is opened for read-only. If non-zero, - BLOB is opened for RW. - :type flags: int - - :return: A BLOB handle. - :rtype: pysqlcipher.dbapi.Blob - """ - return self.runInteraction(self._blob, table, column, irow, flags) - - def _blob(self, trans, table, column, irow, flags): - # TODO: should not use transaction private variable here - handle = trans._connection.blob(table, column, irow, flags) - return handle - - -def check_http_status(code): - if code == 409: - raise BlobAlreadyExistsError() - elif code != 200: - raise SoledadError("Server Error") - - -class DecrypterBuffer(object): - - def __init__(self, blob_id, secret, tag): - self.doc_info = DocInfo(blob_id, FIXED_REV) - self.secret = secret - self.tag = tag - self.preamble_pipe = PreamblePipe(self._make_decryptor) - - def _make_decryptor(self, preamble): - self.decrypter = BlobDecryptor( - self.doc_info, preamble, - secret=self.secret, - armor=False, - start_stream=False, - tag=self.tag) - return TruncatedTailPipe(self.decrypter, tail_size=len(self.tag)) - - def write(self, data): - self.preamble_pipe.write(data) - - def close(self): - real_size = self.decrypter.decrypted_content_size - return self.decrypter._end_stream(), real_size - - -def mkdir_p(path): - try: - os.makedirs(path) - except OSError as exc: - if exc.errno == errno.EEXIST and os.path.isdir(path): - pass - else: - raise - - -class BlobManager(object): - """ - Ideally, the decrypting flow goes like this: - - - GET a blob from remote server. - - Decrypt the preamble - - Allocate a zeroblob in the sqlcipher sink - - Mark the blob as unusable (ie, not verified) - - Decrypt the payload incrementally, and write chunks to sqlcipher - ** Is it possible to use a small buffer for the aes writer w/o - ** allocating all the memory in openssl? - - Finalize the AES decryption - - If preamble + payload verifies correctly, mark the blob as usable - - """ - - def __init__( - self, local_path, remote, key, secret, user, token=None, - cert_file=None): - if local_path: - mkdir_p(os.path.dirname(local_path)) - self.local = SQLiteBlobBackend(local_path, key) - self.remote = remote - self.secret = secret - self.user = user - self._client = HTTPClient(user, token, cert_file) - - def close(self): - if hasattr(self, 'local') and self.local: - return self.local.close() - - @defer.inlineCallbacks - def remote_list(self): - uri = urljoin(self.remote, self.user + '/') - data = yield self._client.get(uri) - defer.returnValue((yield data.json())) - - def local_list(self): - return self.local.list() - - @defer.inlineCallbacks - def send_missing(self): - our_blobs = yield self.local_list() - server_blobs = yield self.remote_list() - missing = [b_id for b_id in our_blobs if b_id not in server_blobs] - logger.info("Amount of documents missing on server: %s" % len(missing)) - # TODO: Send concurrently when we are able to stream directly from db - for blob_id in missing: - fd = yield self.local.get(blob_id) - logger.info("Upload local blob: %s" % blob_id) - yield self._encrypt_and_upload(blob_id, fd) - - @defer.inlineCallbacks - def fetch_missing(self): - # TODO: Use something to prioritize user requests over general new docs - our_blobs = yield self.local_list() - server_blobs = yield self.remote_list() - docs_we_want = [b_id for b_id in server_blobs if b_id not in our_blobs] - logger.info("Fetching new docs from server: %s" % len(docs_we_want)) - # TODO: Fetch concurrently when we are able to stream directly into db - for blob_id in docs_we_want: - logger.info("Fetching new doc: %s" % blob_id) - yield self.get(blob_id) - - @defer.inlineCallbacks - def put(self, doc, size): - fd = doc.blob_fd - # TODO this is a tee really, but ok... could do db and upload - # concurrently. not sure if we'd gain something. - yield self.local.put(doc.blob_id, fd, size=size) - # In fact, some kind of pipe is needed here, where each write on db - # handle gets forwarded into a write on the connection handle - fd = yield self.local.get(doc.blob_id) - yield self._encrypt_and_upload(doc.blob_id, fd) - - @defer.inlineCallbacks - def get(self, blob_id): - local_blob = yield self.local.get(blob_id) - if local_blob: - logger.info("Found blob in local database: %s" % blob_id) - defer.returnValue(local_blob) - - result = yield self._download_and_decrypt(blob_id) - - if not result: - defer.returnValue(None) - blob, size = result - - if blob: - logger.info("Got decrypted blob of type: %s" % type(blob)) - blob.seek(0) - yield self.local.put(blob_id, blob, size=size) - defer.returnValue((yield self.local.get(blob_id))) - else: - # XXX we shouldn't get here, but we will... - # lots of ugly error handling possible: - # 1. retry, might be network error - # 2. try later, maybe didn't finished streaming - # 3.. resignation, might be error while verifying - logger.error('sorry, dunno what happened') - - @defer.inlineCallbacks - def _encrypt_and_upload(self, blob_id, fd): - # TODO ------------------------------------------ - # this is wrong, is doing 2 stages. - # the crypto producer can be passed to - # the uploader and react as data is written. - # try to rewrite as a tube: pass the fd to aes and let aes writer - # produce data to the treq request fd. - # ------------------------------------------------ - logger.info("Staring upload of blob: %s" % blob_id) - doc_info = DocInfo(blob_id, FIXED_REV) - uri = urljoin(self.remote, self.user + "/" + blob_id) - crypter = BlobEncryptor(doc_info, fd, secret=self.secret, - armor=False) - fd = yield crypter.encrypt() - response = yield self._client.put(uri, data=fd) - check_http_status(response.code) - logger.info("Finished upload: %s" % (blob_id,)) - - @defer.inlineCallbacks - def _download_and_decrypt(self, blob_id): - logger.info("Staring download of blob: %s" % blob_id) - # TODO this needs to be connected in a tube - uri = urljoin(self.remote, self.user + '/' + blob_id) - data = yield self._client.get(uri) - - if data.code == 404: - logger.warn("Blob not found in server: %s" % blob_id) - defer.returnValue(None) - elif not data.headers.hasHeader('Tag'): - logger.error("Server didn't send a tag header for: %s" % blob_id) - defer.returnValue(None) - tag = data.headers.getRawHeaders('Tag')[0] - tag = base64.urlsafe_b64decode(tag) - buf = DecrypterBuffer(blob_id, self.secret, tag) - - # incrementally collect the body of the response - yield treq.collect(data, buf.write) - fd, size = buf.close() - logger.info("Finished download: (%s, %d)" % (blob_id, size)) - defer.returnValue((fd, size)) - - -class SQLiteBlobBackend(object): - - def __init__(self, path, key=None): - self.path = os.path.abspath( - os.path.join(path, 'soledad_blob.db')) - mkdir_p(os.path.dirname(self.path)) - if not key: - raise ValueError('key cannot be None') - backend = 'pysqlcipher.dbapi2' - opts = SQLCipherOptions('/tmp/ignored', key) - pragmafun = partial(pragmas.set_init_pragmas, opts=opts) - openfun = _sqlcipherInitFactory(pragmafun) - - self.dbpool = ConnectionPool( - backend, self.path, check_same_thread=False, timeout=5, - cp_openfun=openfun, cp_min=1, cp_max=2, cp_name='blob_pool') - - def close(self): - from twisted._threads import AlreadyQuit - try: - self.dbpool.close() - except AlreadyQuit: - pass - - @defer.inlineCallbacks - def put(self, blob_id, blob_fd, size=None): - logger.info("Saving blob in local database...") - insert = 'INSERT INTO blobs (blob_id, payload) VALUES (?, zeroblob(?))' - irow = yield self.dbpool.insertAndGetLastRowid(insert, (blob_id, size)) - handle = yield self.dbpool.blob('blobs', 'payload', irow, 1) - blob_fd.seek(0) - # XXX I have to copy the buffer here so that I'm able to - # return a non-closed file to the caller (blobmanager.get) - # FIXME should remove this duplication! - # have a look at how treq does cope with closing the handle - # for uploading a file - producer = FileBodyProducer(blob_fd) - done = yield producer.startProducing(handle) - logger.info("Finished saving blob in local database.") - defer.returnValue(done) - - @defer.inlineCallbacks - def get(self, blob_id): - # TODO we can also stream the blob value using sqlite - # incremental interface for blobs - and just return the raw fd instead - select = 'SELECT payload FROM blobs WHERE blob_id = ?' - result = yield self.dbpool.runQuery(select, (blob_id,)) - if result: - defer.returnValue(BytesIO(str(result[0][0]))) - - @defer.inlineCallbacks - def list(self): - query = 'select blob_id from blobs' - result = yield self.dbpool.runQuery(query) - if result: - defer.returnValue([b_id[0] for b_id in result]) - else: - defer.returnValue([]) - - -def _init_blob_table(conn): - maybe_create = ( - "CREATE TABLE IF NOT EXISTS " - "blobs (" - "blob_id PRIMARY KEY, " - "payload BLOB)") - conn.execute(maybe_create) - - -def _sqlcipherInitFactory(fun): - def _initialize(conn): - fun(conn) - _init_blob_table(conn) - return _initialize - - -class BlobDoc(object): - - # TODO probably not needed, but convenient for testing for now. - - def __init__(self, content, blob_id): - - self.blob_id = blob_id - self.is_blob = True - self.blob_fd = content - if blob_id is None: - blob_id = uuid.uuid4().get_hex() - self.blob_id = blob_id - - -# -# testing facilities -# - -@defer.inlineCallbacks -def testit(reactor): - # configure logging to stdout - from twisted.python import log - import sys - log.startLogging(sys.stdout) - - # parse command line arguments - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument('--url', default='http://localhost:9000/') - parser.add_argument('--path', default='/tmp/blobs') - parser.add_argument('--secret', default='secret') - parser.add_argument('--uuid', default='user') - parser.add_argument('--token', default=None) - parser.add_argument('--cert-file', default='') - - subparsers = parser.add_subparsers(help='sub-command help', dest='action') - - # parse upload command - parser_upload = subparsers.add_parser( - 'upload', help='upload blob and bypass local db') - parser_upload.add_argument('payload') - parser_upload.add_argument('blob_id') - - # parse download command - parser_download = subparsers.add_parser( - 'download', help='download blob and bypass local db') - parser_download.add_argument('blob_id') - parser_download.add_argument('--output-file', default='/tmp/incoming-file') - - # parse put command - parser_put = subparsers.add_parser( - 'put', help='put blob in local db and upload') - parser_put.add_argument('payload') - parser_put.add_argument('blob_id') - - # parse get command - parser_get = subparsers.add_parser( - 'get', help='get blob from local db, get if needed') - parser_get.add_argument('blob_id') - - # parse list command - parser_get = subparsers.add_parser( - 'list', help='list local and remote blob ids') - - # parse send_missing command - parser_get = subparsers.add_parser( - 'send_missing', help='send all pending upload blobs') - - # parse send_missing command - parser_get = subparsers.add_parser( - 'fetch_missing', help='fetch all new server blobs') - - # parse arguments - args = parser.parse_args() - - # TODO convert these into proper unittests - - def _manager(): - if not os.path.isdir(args.path): - mkdir_p(args.path) - manager = BlobManager( - args.path, args.url, - 'A' * 32, args.secret, - args.uuid, args.token, args.cert_file) - return manager - - @defer.inlineCallbacks - def _upload(blob_id, payload): - logger.info(":: Starting upload only: %s" % str((blob_id, payload))) - manager = _manager() - with open(payload, 'r') as fd: - yield manager._encrypt_and_upload(blob_id, fd) - logger.info(":: Finished upload only: %s" % str((blob_id, payload))) - - @defer.inlineCallbacks - def _download(blob_id): - logger.info(":: Starting download only: %s" % blob_id) - manager = _manager() - result = yield manager._download_and_decrypt(blob_id) - logger.info(":: Result of download: %s" % str(result)) - if result: - fd, _ = result - with open(args.output_file, 'w') as f: - logger.info(":: Writing data to %s" % args.output_file) - f.write(fd.read()) - logger.info(":: Finished download only: %s" % blob_id) - - @defer.inlineCallbacks - def _put(blob_id, payload): - logger.info(":: Starting full put: %s" % blob_id) - manager = _manager() - size = os.path.getsize(payload) - with open(payload) as fd: - doc = BlobDoc(fd, blob_id) - result = yield manager.put(doc, size=size) - logger.info(":: Result of put: %s" % str(result)) - logger.info(":: Finished full put: %s" % blob_id) - - @defer.inlineCallbacks - def _get(blob_id): - logger.info(":: Starting full get: %s" % blob_id) - manager = _manager() - fd = yield manager.get(blob_id) - if fd: - logger.info(":: Result of get: " + fd.getvalue()) - logger.info(":: Finished full get: %s" % blob_id) - - @defer.inlineCallbacks - def _list(): - logger.info(":: Listing local blobs") - manager = _manager() - local_list = yield manager.local_list() - logger.info(":: Local list: %s" % local_list) - logger.info(":: Listing remote blobs") - remote_list = yield manager.remote_list() - logger.info(":: Remote list: %s" % remote_list) - - @defer.inlineCallbacks - def _send_missing(): - logger.info(":: Sending local pending upload docs") - manager = _manager() - yield manager.send_missing() - logger.info(":: Finished sending missing docs") - - @defer.inlineCallbacks - def _fetch_missing(): - logger.info(":: Fetching remote new docs") - manager = _manager() - yield manager.fetch_missing() - logger.info(":: Finished fetching new docs") - - if args.action == 'upload': - yield _upload(args.blob_id, args.payload) - elif args.action == 'download': - yield _download(args.blob_id) - elif args.action == 'put': - yield _put(args.blob_id, args.payload) - elif args.action == 'get': - yield _get(args.blob_id) - elif args.action == 'list': - yield _list() - elif args.action == 'send_missing': - yield _send_missing() - elif args.action == 'fetch_missing': - yield _fetch_missing() - - -if __name__ == '__main__': - from twisted.internet.task import react - react(testit) diff --git a/client/src/leap/soledad/client/_database/__init__.py b/client/src/leap/soledad/client/_database/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/client/src/leap/soledad/client/_database/adbapi.py b/client/src/leap/soledad/client/_database/adbapi.py new file mode 100644 index 00000000..5c28d108 --- /dev/null +++ b/client/src/leap/soledad/client/_database/adbapi.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# adbapi.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +An asyncrhonous interface to soledad using sqlcipher backend. +It uses twisted.enterprise.adbapi. +""" +import re +import sys + +from functools import partial + +from twisted.enterprise import adbapi +from twisted.internet.defer import DeferredSemaphore +from twisted.python import compat +from zope.proxy import ProxyBase, setProxiedObject + +from leap.soledad.common.log import getLogger +from leap.soledad.common.errors import DatabaseAccessError + +from . import sqlcipher +from . import pragmas + +if sys.version_info[0] < 3: + from pysqlcipher import dbapi2 +else: + from pysqlcipher3 import dbapi2 + + +logger = getLogger(__name__) + + +""" +How long the SQLCipher connection should wait for the lock to go away until +raising an exception. +""" +SQLCIPHER_CONNECTION_TIMEOUT = 10 + +""" +How many times a SQLCipher query should be retried in case of timeout. +""" +SQLCIPHER_MAX_RETRIES = 20 + + +def getConnectionPool(opts, openfun=None, driver="pysqlcipher"): + """ + Return a connection pool. + + :param opts: + Options for the SQLCipher connection. + :type opts: SQLCipherOptions + :param openfun: + Callback invoked after every connect() on the underlying DB-API + object. + :type openfun: callable + :param driver: + The connection driver. + :type driver: str + + :return: A U1DB connection pool. + :rtype: U1DBConnectionPool + """ + if openfun is None and driver == "pysqlcipher": + openfun = partial(pragmas.set_init_pragmas, opts=opts) + return U1DBConnectionPool( + opts, + # the following params are relayed "as is" to twisted's + # ConnectionPool. + "%s.dbapi2" % driver, opts.path, timeout=SQLCIPHER_CONNECTION_TIMEOUT, + check_same_thread=False, cp_openfun=openfun) + + +class U1DBConnection(adbapi.Connection): + """ + A wrapper for a U1DB connection instance. + """ + + u1db_wrapper = sqlcipher.SoledadSQLCipherWrapper + """ + The U1DB wrapper to use. + """ + + def __init__(self, pool, init_u1db=False): + """ + :param pool: The pool of connections to that owns this connection. + :type pool: adbapi.ConnectionPool + :param init_u1db: Wether the u1db database should be initialized. + :type init_u1db: bool + """ + self.init_u1db = init_u1db + try: + adbapi.Connection.__init__(self, pool) + except dbapi2.DatabaseError as e: + raise DatabaseAccessError( + 'Error initializing connection to sqlcipher database: %s' + % str(e)) + + def reconnect(self): + """ + Reconnect to the U1DB database. + """ + if self._connection is not None: + self._pool.disconnect(self._connection) + self._connection = self._pool.connect() + + if self.init_u1db: + self._u1db = self.u1db_wrapper( + self._connection, + self._pool.opts) + + def __getattr__(self, name): + """ + Route the requested attribute either to the U1DB wrapper or to the + connection. + + :param name: The name of the attribute. + :type name: str + """ + if name.startswith('u1db_'): + attr = re.sub('^u1db_', '', name) + return getattr(self._u1db, attr) + else: + return getattr(self._connection, name) + + +class U1DBTransaction(adbapi.Transaction): + """ + A wrapper for a U1DB 'cursor' object. + """ + + def __getattr__(self, name): + """ + Route the requested attribute either to the U1DB wrapper of the + connection or to the actual connection cursor. + + :param name: The name of the attribute. + :type name: str + """ + if name.startswith('u1db_'): + attr = re.sub('^u1db_', '', name) + return getattr(self._connection._u1db, attr) + else: + return getattr(self._cursor, name) + + +class U1DBConnectionPool(adbapi.ConnectionPool): + """ + Represent a pool of connections to an U1DB database. + """ + + connectionFactory = U1DBConnection + transactionFactory = U1DBTransaction + + def __init__(self, opts, *args, **kwargs): + """ + Initialize the connection pool. + """ + self.opts = opts + try: + adbapi.ConnectionPool.__init__(self, *args, **kwargs) + except dbapi2.DatabaseError as e: + raise DatabaseAccessError( + 'Error initializing u1db connection pool: %s' % str(e)) + + # all u1db connections, hashed by thread-id + self._u1dbconnections = {} + + # The replica uid, primed by the connections on init. + self.replica_uid = ProxyBase(None) + + try: + conn = self.connectionFactory( + self, init_u1db=True) + replica_uid = conn._u1db._real_replica_uid + setProxiedObject(self.replica_uid, replica_uid) + except DatabaseAccessError as e: + self.threadpool.stop() + raise DatabaseAccessError( + "Error initializing connection factory: %s" % str(e)) + + def runU1DBQuery(self, meth, *args, **kw): + """ + Execute a U1DB query in a thread, using a pooled connection. + + Concurrent threads trying to update the same database may timeout + because of other threads holding the database lock. Because of this, + we will retry SQLCIPHER_MAX_RETRIES times and fail after that. + + :param meth: The U1DB wrapper method name. + :type meth: str + + :return: a Deferred which will fire the return value of + 'self._runU1DBQuery(Transaction(...), *args, **kw)', or a Failure. + :rtype: twisted.internet.defer.Deferred + """ + meth = "u1db_%s" % meth + semaphore = DeferredSemaphore(SQLCIPHER_MAX_RETRIES) + + def _run_interaction(): + return self.runInteraction( + self._runU1DBQuery, meth, *args, **kw) + + def _errback(failure): + failure.trap(dbapi2.OperationalError) + if failure.getErrorMessage() == "database is locked": + logger.warn("database operation timed out") + should_retry = semaphore.acquire() + if should_retry: + logger.warn("trying again...") + return _run_interaction() + logger.warn("giving up!") + return failure + + d = _run_interaction() + d.addErrback(_errback) + return d + + def _runU1DBQuery(self, trans, meth, *args, **kw): + """ + Execute a U1DB query. + + :param trans: An U1DB transaction. + :type trans: adbapi.Transaction + :param meth: the U1DB wrapper method name. + :type meth: str + """ + meth = getattr(trans, meth) + return meth(*args, **kw) + # XXX should return a fetchall? + + # XXX add _runOperation too + + def _runInteraction(self, interaction, *args, **kw): + """ + Interact with the database and return the result. + + :param interaction: + A callable object whose first argument is an + L{adbapi.Transaction}. + :type interaction: callable + :return: a Deferred which will fire the return value of + 'interaction(Transaction(...), *args, **kw)', or a Failure. + :rtype: twisted.internet.defer.Deferred + """ + tid = self.threadID() + u1db = self._u1dbconnections.get(tid) + conn = self.connectionFactory( + self, init_u1db=not bool(u1db)) + + if self.replica_uid is None: + replica_uid = conn._u1db._real_replica_uid + setProxiedObject(self.replica_uid, replica_uid) + + if u1db is None: + self._u1dbconnections[tid] = conn._u1db + else: + conn._u1db = u1db + + trans = self.transactionFactory(self, conn) + try: + result = interaction(trans, *args, **kw) + trans.close() + conn.commit() + return result + except: + excType, excValue, excTraceback = sys.exc_info() + try: + conn.rollback() + except: + logger.error(None, "Rollback failed") + compat.reraise(excValue, excTraceback) + + def finalClose(self): + """ + A final close, only called by the shutdown trigger. + """ + self.shutdownID = None + if self.threadpool.started: + self.threadpool.stop() + self.running = False + for conn in self.connections.values(): + self._close(conn) + for u1db in self._u1dbconnections.values(): + self._close(u1db) + self.connections.clear() diff --git a/client/src/leap/soledad/client/_database/blobs.py b/client/src/leap/soledad/client/_database/blobs.py new file mode 100644 index 00000000..91914380 --- /dev/null +++ b/client/src/leap/soledad/client/_database/blobs.py @@ -0,0 +1,522 @@ +# -*- coding: utf-8 -*- +# _blobs.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Clientside BlobBackend Storage. +""" + +from urlparse import urljoin + +import errno +import os +import base64 + +from io import BytesIO +from functools import partial + +from twisted.logger import Logger +from twisted.enterprise import adbapi +from twisted.internet import defer +from twisted.web.client import FileBodyProducer + +import treq + +from leap.soledad.common.errors import SoledadError + +from .._document import BlobDoc +from .._crypto import DocInfo +from .._crypto import BlobEncryptor +from .._crypto import BlobDecryptor +from .._http import HTTPClient +from .._pipes import TruncatedTailPipe +from .._pipes import PreamblePipe + +from . import pragmas +from . import sqlcipher + + +logger = Logger() +FIXED_REV = 'ImmutableRevision' # Blob content is immutable + + +class BlobAlreadyExistsError(SoledadError): + pass + + +class ConnectionPool(adbapi.ConnectionPool): + + def insertAndGetLastRowid(self, *args, **kwargs): + """ + Execute an SQL query and return the last rowid. + + See: https://sqlite.org/c3ref/last_insert_rowid.html + """ + return self.runInteraction( + self._insertAndGetLastRowid, *args, **kwargs) + + def _insertAndGetLastRowid(self, trans, *args, **kw): + trans.execute(*args, **kw) + return trans.lastrowid + + def blob(self, table, column, irow, flags): + """ + Open a BLOB for incremental I/O. + + Return a handle to the BLOB that would be selected by: + + SELECT column FROM table WHERE rowid = irow; + + See: https://sqlite.org/c3ref/blob_open.html + + :param table: The table in which to lookup the blob. + :type table: str + :param column: The column where the BLOB is located. + :type column: str + :param rowid: The rowid of the BLOB. + :type rowid: int + :param flags: If zero, BLOB is opened for read-only. If non-zero, + BLOB is opened for RW. + :type flags: int + + :return: A BLOB handle. + :rtype: pysqlcipher.dbapi.Blob + """ + return self.runInteraction(self._blob, table, column, irow, flags) + + def _blob(self, trans, table, column, irow, flags): + # TODO: should not use transaction private variable here + handle = trans._connection.blob(table, column, irow, flags) + return handle + + +def check_http_status(code): + if code == 409: + raise BlobAlreadyExistsError() + elif code != 200: + raise SoledadError("Server Error") + + +class DecrypterBuffer(object): + + def __init__(self, blob_id, secret, tag): + self.doc_info = DocInfo(blob_id, FIXED_REV) + self.secret = secret + self.tag = tag + self.preamble_pipe = PreamblePipe(self._make_decryptor) + + def _make_decryptor(self, preamble): + self.decrypter = BlobDecryptor( + self.doc_info, preamble, + secret=self.secret, + armor=False, + start_stream=False, + tag=self.tag) + return TruncatedTailPipe(self.decrypter, tail_size=len(self.tag)) + + def write(self, data): + self.preamble_pipe.write(data) + + def close(self): + real_size = self.decrypter.decrypted_content_size + return self.decrypter._end_stream(), real_size + + +def mkdir_p(path): + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + + +class BlobManager(object): + """ + Ideally, the decrypting flow goes like this: + + - GET a blob from remote server. + - Decrypt the preamble + - Allocate a zeroblob in the sqlcipher sink + - Mark the blob as unusable (ie, not verified) + - Decrypt the payload incrementally, and write chunks to sqlcipher + ** Is it possible to use a small buffer for the aes writer w/o + ** allocating all the memory in openssl? + - Finalize the AES decryption + - If preamble + payload verifies correctly, mark the blob as usable + + """ + + def __init__( + self, local_path, remote, key, secret, user, token=None, + cert_file=None): + if local_path: + mkdir_p(os.path.dirname(local_path)) + self.local = SQLiteBlobBackend(local_path, key) + self.remote = remote + self.secret = secret + self.user = user + self._client = HTTPClient(user, token, cert_file) + + def close(self): + if hasattr(self, 'local') and self.local: + return self.local.close() + + @defer.inlineCallbacks + def remote_list(self): + uri = urljoin(self.remote, self.user + '/') + data = yield self._client.get(uri) + defer.returnValue((yield data.json())) + + def local_list(self): + return self.local.list() + + @defer.inlineCallbacks + def send_missing(self): + our_blobs = yield self.local_list() + server_blobs = yield self.remote_list() + missing = [b_id for b_id in our_blobs if b_id not in server_blobs] + logger.info("Amount of documents missing on server: %s" % len(missing)) + # TODO: Send concurrently when we are able to stream directly from db + for blob_id in missing: + fd = yield self.local.get(blob_id) + logger.info("Upload local blob: %s" % blob_id) + yield self._encrypt_and_upload(blob_id, fd) + + @defer.inlineCallbacks + def fetch_missing(self): + # TODO: Use something to prioritize user requests over general new docs + our_blobs = yield self.local_list() + server_blobs = yield self.remote_list() + docs_we_want = [b_id for b_id in server_blobs if b_id not in our_blobs] + logger.info("Fetching new docs from server: %s" % len(docs_we_want)) + # TODO: Fetch concurrently when we are able to stream directly into db + for blob_id in docs_we_want: + logger.info("Fetching new doc: %s" % blob_id) + yield self.get(blob_id) + + @defer.inlineCallbacks + def put(self, doc, size): + fd = doc.blob_fd + # TODO this is a tee really, but ok... could do db and upload + # concurrently. not sure if we'd gain something. + yield self.local.put(doc.blob_id, fd, size=size) + # In fact, some kind of pipe is needed here, where each write on db + # handle gets forwarded into a write on the connection handle + fd = yield self.local.get(doc.blob_id) + yield self._encrypt_and_upload(doc.blob_id, fd) + + @defer.inlineCallbacks + def get(self, blob_id): + local_blob = yield self.local.get(blob_id) + if local_blob: + logger.info("Found blob in local database: %s" % blob_id) + defer.returnValue(local_blob) + + result = yield self._download_and_decrypt(blob_id) + + if not result: + defer.returnValue(None) + blob, size = result + + if blob: + logger.info("Got decrypted blob of type: %s" % type(blob)) + blob.seek(0) + yield self.local.put(blob_id, blob, size=size) + defer.returnValue((yield self.local.get(blob_id))) + else: + # XXX we shouldn't get here, but we will... + # lots of ugly error handling possible: + # 1. retry, might be network error + # 2. try later, maybe didn't finished streaming + # 3.. resignation, might be error while verifying + logger.error('sorry, dunno what happened') + + @defer.inlineCallbacks + def _encrypt_and_upload(self, blob_id, fd): + # TODO ------------------------------------------ + # this is wrong, is doing 2 stages. + # the crypto producer can be passed to + # the uploader and react as data is written. + # try to rewrite as a tube: pass the fd to aes and let aes writer + # produce data to the treq request fd. + # ------------------------------------------------ + logger.info("Staring upload of blob: %s" % blob_id) + doc_info = DocInfo(blob_id, FIXED_REV) + uri = urljoin(self.remote, self.user + "/" + blob_id) + crypter = BlobEncryptor(doc_info, fd, secret=self.secret, + armor=False) + fd = yield crypter.encrypt() + response = yield self._client.put(uri, data=fd) + check_http_status(response.code) + logger.info("Finished upload: %s" % (blob_id,)) + + @defer.inlineCallbacks + def _download_and_decrypt(self, blob_id): + logger.info("Staring download of blob: %s" % blob_id) + # TODO this needs to be connected in a tube + uri = urljoin(self.remote, self.user + '/' + blob_id) + data = yield self._client.get(uri) + + if data.code == 404: + logger.warn("Blob not found in server: %s" % blob_id) + defer.returnValue(None) + elif not data.headers.hasHeader('Tag'): + logger.error("Server didn't send a tag header for: %s" % blob_id) + defer.returnValue(None) + tag = data.headers.getRawHeaders('Tag')[0] + tag = base64.urlsafe_b64decode(tag) + buf = DecrypterBuffer(blob_id, self.secret, tag) + + # incrementally collect the body of the response + yield treq.collect(data, buf.write) + fd, size = buf.close() + logger.info("Finished download: (%s, %d)" % (blob_id, size)) + defer.returnValue((fd, size)) + + +class SQLiteBlobBackend(object): + + def __init__(self, path, key=None): + self.path = os.path.abspath( + os.path.join(path, 'soledad_blob.db')) + mkdir_p(os.path.dirname(self.path)) + if not key: + raise ValueError('key cannot be None') + backend = 'pysqlcipher.dbapi2' + opts = sqlcipher.SQLCipherOptions('/tmp/ignored', key) + pragmafun = partial(pragmas.set_init_pragmas, opts=opts) + openfun = _sqlcipherInitFactory(pragmafun) + + self.dbpool = ConnectionPool( + backend, self.path, check_same_thread=False, timeout=5, + cp_openfun=openfun, cp_min=1, cp_max=2, cp_name='blob_pool') + + def close(self): + from twisted._threads import AlreadyQuit + try: + self.dbpool.close() + except AlreadyQuit: + pass + + @defer.inlineCallbacks + def put(self, blob_id, blob_fd, size=None): + logger.info("Saving blob in local database...") + insert = 'INSERT INTO blobs (blob_id, payload) VALUES (?, zeroblob(?))' + irow = yield self.dbpool.insertAndGetLastRowid(insert, (blob_id, size)) + handle = yield self.dbpool.blob('blobs', 'payload', irow, 1) + blob_fd.seek(0) + # XXX I have to copy the buffer here so that I'm able to + # return a non-closed file to the caller (blobmanager.get) + # FIXME should remove this duplication! + # have a look at how treq does cope with closing the handle + # for uploading a file + producer = FileBodyProducer(blob_fd) + done = yield producer.startProducing(handle) + logger.info("Finished saving blob in local database.") + defer.returnValue(done) + + @defer.inlineCallbacks + def get(self, blob_id): + # TODO we can also stream the blob value using sqlite + # incremental interface for blobs - and just return the raw fd instead + select = 'SELECT payload FROM blobs WHERE blob_id = ?' + result = yield self.dbpool.runQuery(select, (blob_id,)) + if result: + defer.returnValue(BytesIO(str(result[0][0]))) + + @defer.inlineCallbacks + def list(self): + query = 'select blob_id from blobs' + result = yield self.dbpool.runQuery(query) + if result: + defer.returnValue([b_id[0] for b_id in result]) + else: + defer.returnValue([]) + + +def _init_blob_table(conn): + maybe_create = ( + "CREATE TABLE IF NOT EXISTS " + "blobs (" + "blob_id PRIMARY KEY, " + "payload BLOB)") + conn.execute(maybe_create) + + +def _sqlcipherInitFactory(fun): + def _initialize(conn): + fun(conn) + _init_blob_table(conn) + return _initialize + + +# +# testing facilities +# + +@defer.inlineCallbacks +def testit(reactor): + # configure logging to stdout + from twisted.python import log + import sys + log.startLogging(sys.stdout) + + # parse command line arguments + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--url', default='http://localhost:9000/') + parser.add_argument('--path', default='/tmp/blobs') + parser.add_argument('--secret', default='secret') + parser.add_argument('--uuid', default='user') + parser.add_argument('--token', default=None) + parser.add_argument('--cert-file', default='') + + subparsers = parser.add_subparsers(help='sub-command help', dest='action') + + # parse upload command + parser_upload = subparsers.add_parser( + 'upload', help='upload blob and bypass local db') + parser_upload.add_argument('payload') + parser_upload.add_argument('blob_id') + + # parse download command + parser_download = subparsers.add_parser( + 'download', help='download blob and bypass local db') + parser_download.add_argument('blob_id') + parser_download.add_argument('--output-file', default='/tmp/incoming-file') + + # parse put command + parser_put = subparsers.add_parser( + 'put', help='put blob in local db and upload') + parser_put.add_argument('payload') + parser_put.add_argument('blob_id') + + # parse get command + parser_get = subparsers.add_parser( + 'get', help='get blob from local db, get if needed') + parser_get.add_argument('blob_id') + + # parse list command + parser_get = subparsers.add_parser( + 'list', help='list local and remote blob ids') + + # parse send_missing command + parser_get = subparsers.add_parser( + 'send_missing', help='send all pending upload blobs') + + # parse send_missing command + parser_get = subparsers.add_parser( + 'fetch_missing', help='fetch all new server blobs') + + # parse arguments + args = parser.parse_args() + + # TODO convert these into proper unittests + + def _manager(): + mkdir_p(os.path.dirname(args.path)) + manager = BlobManager( + args.path, args.url, + 'A' * 32, args.secret, + args.uuid, args.token, args.cert_file) + return manager + + @defer.inlineCallbacks + def _upload(blob_id, payload): + logger.info(":: Starting upload only: %s" % str((blob_id, payload))) + manager = _manager() + with open(payload, 'r') as fd: + yield manager._encrypt_and_upload(blob_id, fd) + logger.info(":: Finished upload only: %s" % str((blob_id, payload))) + + @defer.inlineCallbacks + def _download(blob_id): + logger.info(":: Starting download only: %s" % blob_id) + manager = _manager() + result = yield manager._download_and_decrypt(blob_id) + logger.info(":: Result of download: %s" % str(result)) + if result: + fd, _ = result + with open(args.output_file, 'w') as f: + logger.info(":: Writing data to %s" % args.output_file) + f.write(fd.read()) + logger.info(":: Finished download only: %s" % blob_id) + + @defer.inlineCallbacks + def _put(blob_id, payload): + logger.info(":: Starting full put: %s" % blob_id) + manager = _manager() + size = os.path.getsize(payload) + with open(payload) as fd: + doc = BlobDoc(fd, blob_id) + result = yield manager.put(doc, size=size) + logger.info(":: Result of put: %s" % str(result)) + logger.info(":: Finished full put: %s" % blob_id) + + @defer.inlineCallbacks + def _get(blob_id): + logger.info(":: Starting full get: %s" % blob_id) + manager = _manager() + fd = yield manager.get(blob_id) + if fd: + logger.info(":: Result of get: " + fd.getvalue()) + logger.info(":: Finished full get: %s" % blob_id) + + @defer.inlineCallbacks + def _list(): + logger.info(":: Listing local blobs") + manager = _manager() + local_list = yield manager.local_list() + logger.info(":: Local list: %s" % local_list) + logger.info(":: Listing remote blobs") + remote_list = yield manager.remote_list() + logger.info(":: Remote list: %s" % remote_list) + + @defer.inlineCallbacks + def _send_missing(): + logger.info(":: Sending local pending upload docs") + manager = _manager() + yield manager.send_missing() + logger.info(":: Finished sending missing docs") + + @defer.inlineCallbacks + def _fetch_missing(): + logger.info(":: Fetching remote new docs") + manager = _manager() + yield manager.fetch_missing() + logger.info(":: Finished fetching new docs") + + if args.action == 'upload': + yield _upload(args.blob_id, args.payload) + elif args.action == 'download': + yield _download(args.blob_id) + elif args.action == 'put': + yield _put(args.blob_id, args.payload) + elif args.action == 'get': + yield _get(args.blob_id) + elif args.action == 'list': + yield _list() + elif args.action == 'send_missing': + yield _send_missing() + elif args.action == 'fetch_missing': + yield _fetch_missing() + + +if __name__ == '__main__': + from twisted.internet.task import react + react(testit) diff --git a/client/src/leap/soledad/client/_database/dbschema.sql b/client/src/leap/soledad/client/_database/dbschema.sql new file mode 100644 index 00000000..ae027fc5 --- /dev/null +++ b/client/src/leap/soledad/client/_database/dbschema.sql @@ -0,0 +1,42 @@ +-- Database schema +CREATE TABLE transaction_log ( + generation INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT NOT NULL, + transaction_id TEXT NOT NULL +); +CREATE TABLE document ( + doc_id TEXT PRIMARY KEY, + doc_rev TEXT NOT NULL, + content TEXT +); +CREATE TABLE document_fields ( + doc_id TEXT NOT NULL, + field_name TEXT NOT NULL, + value TEXT +); +CREATE INDEX document_fields_field_value_doc_idx + ON document_fields(field_name, value, doc_id); + +CREATE TABLE sync_log ( + replica_uid TEXT PRIMARY KEY, + known_generation INTEGER, + known_transaction_id TEXT +); +CREATE TABLE conflicts ( + doc_id TEXT, + doc_rev TEXT, + content TEXT, + CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) +); +CREATE TABLE index_definitions ( + name TEXT, + offset INT, + field TEXT, + CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) +); +create index index_definitions_field on index_definitions(field); +CREATE TABLE u1db_config ( + name TEXT PRIMARY KEY, + value TEXT +); +INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/client/src/leap/soledad/client/_database/pragmas.py b/client/src/leap/soledad/client/_database/pragmas.py new file mode 100644 index 00000000..870ed63e --- /dev/null +++ b/client/src/leap/soledad/client/_database/pragmas.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- +# pragmas.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Different pragmas used in the initialization of the SQLCipher database. +""" +import string +import threading +import os + +from leap.soledad.common import soledad_assert +from leap.soledad.common.log import getLogger + + +logger = getLogger(__name__) + + +_db_init_lock = threading.Lock() + + +def set_init_pragmas(conn, opts=None, extra_queries=None): + """ + Set the initialization pragmas. + + This includes the crypto pragmas, and any other options that must + be passed early to sqlcipher db. + """ + soledad_assert(opts is not None) + extra_queries = [] if extra_queries is None else extra_queries + with _db_init_lock: + # only one execution path should initialize the db + _set_init_pragmas(conn, opts, extra_queries) + + +def _set_init_pragmas(conn, opts, extra_queries): + + sync_off = os.environ.get('LEAP_SQLITE_NOSYNC') + memstore = os.environ.get('LEAP_SQLITE_MEMSTORE') + nowal = os.environ.get('LEAP_SQLITE_NOWAL') + + set_crypto_pragmas(conn, opts) + + if not nowal: + set_write_ahead_logging(conn) + if sync_off: + set_synchronous_off(conn) + else: + set_synchronous_normal(conn) + if memstore: + set_mem_temp_store(conn) + + for query in extra_queries: + conn.cursor().execute(query) + + +def set_crypto_pragmas(db_handle, sqlcipher_opts): + """ + Set cryptographic params (key, cipher, KDF number of iterations and + cipher page size). + + :param db_handle: + :type db_handle: + :param sqlcipher_opts: options for the SQLCipherDatabase + :type sqlcipher_opts: SQLCipherOpts instance + """ + # XXX assert CryptoOptions + opts = sqlcipher_opts + _set_key(db_handle, opts.key, opts.is_raw_key) + _set_cipher(db_handle, opts.cipher) + _set_kdf_iter(db_handle, opts.kdf_iter) + _set_cipher_page_size(db_handle, opts.cipher_page_size) + + +def _set_key(db_handle, key, is_raw_key): + """ + Set the ``key`` for use with the database. + + The process of creating a new, encrypted database is called 'keying' + the database. SQLCipher uses just-in-time key derivation at the point + it is first needed for an operation. This means that the key (and any + options) must be set before the first operation on the database. As + soon as the database is touched (e.g. SELECT, CREATE TABLE, UPDATE, + etc.) and pages need to be read or written, the key is prepared for + use. + + Implementation Notes: + + * PRAGMA key should generally be called as the first operation on a + database. + + :param key: The key for use with the database. + :type key: str + :param is_raw_key: + Whether C{key} is a raw 64-char hex string or a passphrase that should + be hashed to obtain the encyrption key. + :type is_raw_key: bool + """ + if is_raw_key: + _set_key_raw(db_handle, key) + else: + _set_key_passphrase(db_handle, key) + + +def _set_key_passphrase(db_handle, passphrase): + """ + Set a passphrase for encryption key derivation. + + The key itself can be a passphrase, which is converted to a key using + PBKDF2 key derivation. The result is used as the encryption key for + the database. By using this method, there is no way to alter the KDF; + if you want to do so you should use a raw key instead and derive the + key using your own KDF. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param passphrase: The passphrase used to derive the encryption key. + :type passphrase: str + """ + db_handle.cursor().execute("PRAGMA key = '%s'" % passphrase) + + +def _set_key_raw(db_handle, key): + """ + Set a raw hexadecimal encryption key. + + It is possible to specify an exact byte sequence using a blob literal. + With this method, it is the calling application's responsibility to + ensure that the data provided is a 64 character hex string, which will + be converted directly to 32 bytes (256 bits) of key data. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param key: A 64 character hex string. + :type key: str + """ + if not all(c in string.hexdigits for c in key): + raise NotAnHexString(key) + db_handle.cursor().execute('PRAGMA key = "x\'%s"' % key) + + +def _set_cipher(db_handle, cipher='aes-256-cbc'): + """ + Set the cipher and mode to use for symmetric encryption. + + SQLCipher uses aes-256-cbc as the default cipher and mode of + operation. It is possible to change this, though not generally + recommended, using PRAGMA cipher. + + SQLCipher makes direct use of libssl, so all cipher options available + to libssl are also available for use with SQLCipher. See `man enc` for + OpenSSL's supported ciphers. + + Implementation Notes: + + * PRAGMA cipher must be called after PRAGMA key and before the first + actual database operation or it will have no effect. + + * If a non-default value is used PRAGMA cipher to create a database, + it must also be called every time that database is opened. + + * SQLCipher does not implement its own encryption. Instead it uses the + widely available and peer-reviewed OpenSSL libcrypto for all + cryptographic functions. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param cipher: The cipher and mode to use. + :type cipher: str + """ + db_handle.cursor().execute("PRAGMA cipher = '%s'" % cipher) + + +def _set_kdf_iter(db_handle, kdf_iter=4000): + """ + Set the number of iterations for the key derivation function. + + SQLCipher uses PBKDF2 key derivation to strengthen the key and make it + resistent to brute force and dictionary attacks. The default + configuration uses 4000 PBKDF2 iterations (effectively 16,000 SHA1 + operations). PRAGMA kdf_iter can be used to increase or decrease the + number of iterations used. + + Implementation Notes: + + * PRAGMA kdf_iter must be called after PRAGMA key and before the first + actual database operation or it will have no effect. + + * If a non-default value is used PRAGMA kdf_iter to create a database, + it must also be called every time that database is opened. + + * It is not recommended to reduce the number of iterations if a + passphrase is in use. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param kdf_iter: The number of iterations to use. + :type kdf_iter: int + """ + db_handle.cursor().execute("PRAGMA kdf_iter = '%d'" % kdf_iter) + + +def _set_cipher_page_size(db_handle, cipher_page_size=1024): + """ + Set the page size of the encrypted database. + + SQLCipher 2 introduced the new PRAGMA cipher_page_size that can be + used to adjust the page size for the encrypted database. The default + page size is 1024 bytes, but it can be desirable for some applications + to use a larger page size for increased performance. For instance, + some recent testing shows that increasing the page size can noticeably + improve performance (5-30%) for certain queries that manipulate a + large number of pages (e.g. selects without an index, large inserts in + a transaction, big deletes). + + To adjust the page size, call the pragma immediately after setting the + key for the first time and each subsequent time that you open the + database. + + Implementation Notes: + + * PRAGMA cipher_page_size must be called after PRAGMA key and before + the first actual database operation or it will have no effect. + + * If a non-default value is used PRAGMA cipher_page_size to create a + database, it must also be called every time that database is opened. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param cipher_page_size: The page size. + :type cipher_page_size: int + """ + db_handle.cursor().execute( + "PRAGMA cipher_page_size = '%d'" % cipher_page_size) + + +# XXX UNUSED ? +def set_rekey(db_handle, new_key, is_raw_key): + """ + Change the key of an existing encrypted database. + + To change the key on an existing encrypted database, it must first be + unlocked with the current encryption key. Once the database is + readable and writeable, PRAGMA rekey can be used to re-encrypt every + page in the database with a new key. + + * PRAGMA rekey must be called after PRAGMA key. It can be called at any + time once the database is readable. + + * PRAGMA rekey can not be used to encrypted a standard SQLite + database! It is only useful for changing the key on an existing + database. + + * Previous versions of SQLCipher provided a PRAGMA rekey_cipher and + code>PRAGMA rekey_kdf_iter. These are deprecated and should not be + used. Instead, use sqlcipher_export(). + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param new_key: The new key. + :type new_key: str + :param is_raw_key: Whether C{password} is a raw 64-char hex string or a + passphrase that should be hashed to obtain the encyrption + key. + :type is_raw_key: bool + """ + if is_raw_key: + _set_rekey_raw(db_handle, new_key) + else: + _set_rekey_passphrase(db_handle, new_key) + + +def _set_rekey_passphrase(db_handle, passphrase): + """ + Change the passphrase for encryption key derivation. + + The key itself can be a passphrase, which is converted to a key using + PBKDF2 key derivation. The result is used as the encryption key for + the database. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param passphrase: The passphrase used to derive the encryption key. + :type passphrase: str + """ + db_handle.cursor().execute("PRAGMA rekey = '%s'" % passphrase) + + +def _set_rekey_raw(db_handle, key): + """ + Change the raw hexadecimal encryption key. + + It is possible to specify an exact byte sequence using a blob literal. + With this method, it is the calling application's responsibility to + ensure that the data provided is a 64 character hex string, which will + be converted directly to 32 bytes (256 bits) of key data. + + :param db_handle: A handle to the SQLCipher database. + :type db_handle: pysqlcipher.Connection + :param key: A 64 character hex string. + :type key: str + """ + if not all(c in string.hexdigits for c in key): + raise NotAnHexString(key) + db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % key) + + +def set_synchronous_off(db_handle): + """ + Change the setting of the "synchronous" flag to OFF. + """ + logger.debug("sqlcipher: setting synchronous off") + db_handle.cursor().execute('PRAGMA synchronous=OFF') + + +def set_synchronous_normal(db_handle): + """ + Change the setting of the "synchronous" flag to NORMAL. + """ + logger.debug("sqlcipher: setting synchronous normal") + db_handle.cursor().execute('PRAGMA synchronous=NORMAL') + + +def set_mem_temp_store(db_handle): + """ + Use a in-memory store for temporary tables. + """ + logger.debug("sqlcipher: setting temp_store memory") + db_handle.cursor().execute('PRAGMA temp_store=MEMORY') + + +def set_write_ahead_logging(db_handle): + """ + Enable write-ahead logging, and set the autocheckpoint to 50 pages. + + Setting the autocheckpoint to a small value, we make the reads not + suffer too much performance degradation. + + From the sqlite docs: + + "There is a tradeoff between average read performance and average write + performance. To maximize the read performance, one wants to keep the + WAL as small as possible and hence run checkpoints frequently, perhaps + as often as every COMMIT. To maximize write performance, one wants to + amortize the cost of each checkpoint over as many writes as possible, + meaning that one wants to run checkpoints infrequently and let the WAL + grow as large as possible before each checkpoint. The decision of how + often to run checkpoints may therefore vary from one application to + another depending on the relative read and write performance + requirements of the application. The default strategy is to run a + checkpoint once the WAL reaches 1000 pages" + """ + logger.debug("sqlcipher: setting write-ahead logging") + db_handle.cursor().execute('PRAGMA journal_mode=WAL') + + # The optimum value can still use a little bit of tuning, but we favor + # small sizes of the WAL file to get fast reads, since we assume that + # the writes will be quick enough to not block too much. + + db_handle.cursor().execute('PRAGMA wal_autocheckpoint=50') + + +class NotAnHexString(Exception): + """ + Raised when trying to (raw) key the database with a non-hex string. + """ + pass diff --git a/client/src/leap/soledad/client/_database/sqlcipher.py b/client/src/leap/soledad/client/_database/sqlcipher.py new file mode 100644 index 00000000..d22017bd --- /dev/null +++ b/client/src/leap/soledad/client/_database/sqlcipher.py @@ -0,0 +1,633 @@ +# -*- coding: utf-8 -*- +# sqlcipher.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +A U1DB backend that uses SQLCipher as its persistence layer. + +The SQLCipher API (http://sqlcipher.net/sqlcipher-api/) is fully implemented, +with the exception of the following statements: + + * PRAGMA cipher_use_hmac + * PRAGMA cipher_default_use_mac + +SQLCipher 2.0 introduced a per-page HMAC to validate that the page data has +not be tampered with. By default, when creating or opening a database using +SQLCipher 2, SQLCipher will attempt to use an HMAC check. This change in +database format means that SQLCipher 2 can't operate on version 1.1.x +databases by default. Thus, in order to provide backward compatibility with +SQLCipher 1.1.x, PRAGMA cipher_use_hmac can be used to disable the HMAC +functionality on specific databases. + +In some very specific cases, it is not possible to call PRAGMA cipher_use_hmac +as one of the first operations on a database. An example of this is when +trying to ATTACH a 1.1.x database to the main database. In these cases PRAGMA +cipher_default_use_hmac can be used to globally alter the default use of HMAC +when opening a database. + +So, as the statements above were introduced for backwards compatibility with +SQLCipher 1.1 databases, we do not implement them as all SQLCipher databases +handled by Soledad should be created by SQLCipher >= 2.0. +""" +import os +import sys + +from functools import partial + +from twisted.internet import reactor +from twisted.internet import defer +from twisted.enterprise import adbapi + +from leap.soledad.common.log import getLogger +from leap.soledad.common.l2db import errors as u1db_errors +from leap.soledad.common.errors import DatabaseAccessError + +from leap.soledad.client.http_target import SoledadHTTPSyncTarget +from leap.soledad.client.sync import SoledadSynchronizer + +from .._document import Document +from . import sqlite +from . import pragmas + +if sys.version_info[0] < 3: + from pysqlcipher import dbapi2 as sqlcipher_dbapi2 +else: + from pysqlcipher3 import dbapi2 as sqlcipher_dbapi2 + +logger = getLogger(__name__) + + +# Monkey-patch u1db.backends.sqlite with pysqlcipher.dbapi2 +sqlite.dbapi2 = sqlcipher_dbapi2 + + +# we may want to collect statistics from the sync process +DO_STATS = False +if os.environ.get('SOLEDAD_STATS'): + DO_STATS = True + + +def initialize_sqlcipher_db(opts, on_init=None, check_same_thread=True): + """ + Initialize a SQLCipher database. + + :param opts: + :type opts: SQLCipherOptions + :param on_init: a tuple of queries to be executed on initialization + :type on_init: tuple + :return: pysqlcipher.dbapi2.Connection + """ + # Note: There seemed to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + # Removing from here now, look at the pysqlite implementation if the + # bug shows up in windows. + + if not os.path.isfile(opts.path) and not opts.create: + raise u1db_errors.DatabaseDoesNotExist() + + conn = sqlcipher_dbapi2.connect( + opts.path, check_same_thread=check_same_thread) + pragmas.set_init_pragmas(conn, opts, extra_queries=on_init) + return conn + + +def initialize_sqlcipher_adbapi_db(opts, extra_queries=None): + from leap.soledad.client import sqlcipher_adbapi + return sqlcipher_adbapi.getConnectionPool( + opts, extra_queries=extra_queries) + + +class SQLCipherOptions(object): + """ + A container with options for the initialization of an SQLCipher database. + """ + + @classmethod + def copy(cls, source, path=None, key=None, create=None, + is_raw_key=None, cipher=None, kdf_iter=None, + cipher_page_size=None, sync_db_key=None): + """ + Return a copy of C{source} with parameters different than None + replaced by new values. + """ + local_vars = locals() + args = [] + kwargs = {} + + for name in ["path", "key"]: + val = local_vars[name] + if val is not None: + args.append(val) + else: + args.append(getattr(source, name)) + + for name in ["create", "is_raw_key", "cipher", "kdf_iter", + "cipher_page_size", "sync_db_key"]: + val = local_vars[name] + if val is not None: + kwargs[name] = val + else: + kwargs[name] = getattr(source, name) + + return SQLCipherOptions(*args, **kwargs) + + def __init__(self, path, key, create=True, is_raw_key=False, + cipher='aes-256-cbc', kdf_iter=4000, cipher_page_size=1024, + sync_db_key=None): + """ + :param path: The filesystem path for the database to open. + :type path: str + :param create: + True/False, should the database be created if it doesn't + already exist? + :param create: bool + :param is_raw_key: + Whether ``password`` is a raw 64-char hex string or a passphrase + that should be hashed to obtain the encyrption key. + :type raw_key: bool + :param cipher: The cipher and mode to use. + :type cipher: str + :param kdf_iter: The number of iterations to use. + :type kdf_iter: int + :param cipher_page_size: The page size. + :type cipher_page_size: int + """ + self.path = path + self.key = key + self.is_raw_key = is_raw_key + self.create = create + self.cipher = cipher + self.kdf_iter = kdf_iter + self.cipher_page_size = cipher_page_size + self.sync_db_key = sync_db_key + + def __str__(self): + """ + Return string representation of options, for easy debugging. + + :return: String representation of options. + :rtype: str + """ + attr_names = filter(lambda a: not a.startswith('_'), dir(self)) + attr_str = [] + for a in attr_names: + attr_str.append(a + "=" + str(getattr(self, a))) + name = self.__class__.__name__ + return "%s(%s)" % (name, ', '.join(attr_str)) + + +# +# The SQLCipher database +# + +class SQLCipherDatabase(sqlite.SQLitePartialExpandDatabase): + """ + A U1DB implementation that uses SQLCipher as its persistence layer. + """ + + # The attribute _index_storage_value will be used as the lookup key for the + # implementation of the SQLCipher storage backend. + _index_storage_value = 'expand referenced encrypted' + + def __init__(self, opts): + """ + Connect to an existing SQLCipher database, creating a new sqlcipher + database file if needed. + + *** IMPORTANT *** + + Don't forget to close the database after use by calling the close() + method otherwise some resources might not be freed and you may + experience several kinds of leakages. + + *** IMPORTANT *** + + :param opts: options for initialization of the SQLCipher database. + :type opts: SQLCipherOptions + """ + # ensure the db is encrypted if the file already exists + if os.path.isfile(opts.path): + self._db_handle = _assert_db_is_encrypted(opts) + else: + # connect to the sqlcipher database + self._db_handle = initialize_sqlcipher_db(opts) + + # TODO --------------------------------------------------- + # Everything else in this initialization has to be factored + # out, so it can be used from SoledadSQLCipherWrapper.__init__ + # too. + # --------------------------------------------------------- + + self._ensure_schema() + self.set_document_factory(doc_factory) + self._prime_replica_uid() + + def _prime_replica_uid(self): + """ + In the u1db implementation, _replica_uid is a property + that returns the value in _real_replica_uid, and does + a db query if no value found. + Here we prime the replica uid during initialization so + that we don't have to wait for the query afterwards. + """ + self._real_replica_uid = None + self._get_replica_uid() + + def _extra_schema_init(self, c): + """ + Add any extra fields, etc to the basic table definitions. + + This method is called by u1db.backends.sqlite_backend._initialize() + method, which is executed when the database schema is created. Here, + we use it to include the "syncable" property for LeapDocuments. + + :param c: The cursor for querying the database. + :type c: dbapi2.cursor + """ + c.execute( + 'ALTER TABLE document ' + 'ADD COLUMN syncable BOOL NOT NULL DEFAULT TRUE') + + # + # SQLCipher API methods + # + + # Extra query methods: extensions to the base u1db sqlite implmentation. + + def get_count_from_index(self, index_name, *key_values): + """ + Return the count for a given combination of index_name + and key values. + + Extension method made from similar methods in u1db version 13.09 + + :param index_name: The index to query + :type index_name: str + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :type key_values: tuple + :return: count. + :rtype: int + """ + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + + if len(key_values) != len(definition): + raise u1db_errors.InvalidValueForIndex() + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + exact_where = [novalue_where[i] + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + where.append(exact_where[idx]) + args.append(value) + + tables = ["document_fields d%d" % i for i in range(len(definition))] + statement = ( + "SELECT COUNT(*) FROM document d, %s WHERE %s " % ( + ', '.join(tables), + ' AND '.join(where), + )) + try: + c.execute(statement, tuple(args)) + except sqlcipher_dbapi2.OperationalError as e: + raise sqlcipher_dbapi2.OperationalError( + str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + return res[0][0] + + def close(self): + """ + Close db connections. + """ + # TODO should be handled by adbapi instead + # TODO syncdb should be stopped first + + if logger is not None: # logger might be none if called from __del__ + logger.debug("SQLCipher backend: closing") + + # close the actual database + if getattr(self, '_db_handle', False): + self._db_handle.close() + self._db_handle = None + + # indexes + + def _put_and_update_indexes(self, old_doc, doc): + """ + Update a document and all indexes related to it. + + :param old_doc: The old version of the document. + :type old_doc: u1db.Document + :param doc: The new version of the document. + :type doc: u1db.Document + """ + sqlite.SQLitePartialExpandDatabase._put_and_update_indexes( + self, old_doc, doc) + c = self._db_handle.cursor() + c.execute('UPDATE document SET syncable=? WHERE doc_id=?', + (doc.syncable, doc.doc_id)) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """ + Get just the document content, without fancy handling. + + :param doc_id: The unique document identifier + :type doc_id: str + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise asking for a deleted + document will return None. + :type include_deleted: bool + + :return: a Document object. + :type: u1db.Document + """ + doc = sqlite.SQLitePartialExpandDatabase._get_doc( + self, doc_id, check_for_conflicts) + if doc: + c = self._db_handle.cursor() + c.execute('SELECT syncable FROM document WHERE doc_id=?', + (doc.doc_id,)) + result = c.fetchone() + doc.syncable = bool(result[0]) + return doc + + def __del__(self): + """ + Free resources when deleting or garbage collecting the database. + + This is only here to minimze problems if someone ever forgets to call + the close() method after using the database; you should not rely on + garbage collecting to free up the database resources. + """ + self.close() + + +class SQLCipherU1DBSync(SQLCipherDatabase): + """ + Soledad syncer implementation. + """ + + """ + The name of the local symmetrically encrypted documents to + sync database file. + """ + LOCAL_SYMMETRIC_SYNC_FILE_NAME = 'sync.u1db' + + """ + Period or recurrence of the Looping Call that will do the encryption to the + syncdb (in seconds). + """ + ENCRYPT_LOOP_PERIOD = 1 + + def __init__(self, opts, soledad_crypto, replica_uid, cert_file): + self._opts = opts + self._path = opts.path + self._crypto = soledad_crypto + self.__replica_uid = replica_uid + self._cert_file = cert_file + + # storage for the documents received during a sync + self.received_docs = [] + + self.running = False + self._db_handle = None + + # initialize the main db before scheduling a start + self._initialize_main_db() + self._reactor = reactor + self._reactor.callWhenRunning(self._start) + + if DO_STATS: + self.sync_phase = None + + def commit(self): + self._db_handle.commit() + + @property + def _replica_uid(self): + return str(self.__replica_uid) + + def _start(self): + if not self.running: + self.running = True + + def _initialize_main_db(self): + try: + self._db_handle = initialize_sqlcipher_db( + self._opts, check_same_thread=False) + self._real_replica_uid = None + self._ensure_schema() + self.set_document_factory(doc_factory) + except sqlcipher_dbapi2.DatabaseError as e: + raise DatabaseAccessError(str(e)) + + @defer.inlineCallbacks + def sync(self, url, creds=None): + """ + Synchronize documents with remote replica exposed at url. + + It is not safe to initiate more than one sync process and let them run + concurrently. It is responsibility of the caller to ensure that there + are no concurrent sync processes running. This is currently controlled + by the main Soledad object because it may also run post-sync hooks, + which should be run while the lock is locked. + + :param url: The url of the target replica to sync with. + :type url: str + :param creds: optional dictionary giving credentials to authorize the + operation with the server. + :type creds: dict + + :return: + A Deferred, that will fire with the local generation (type `int`) + before the synchronisation was performed. + :rtype: Deferred + """ + syncer = self._get_syncer(url, creds=creds) + if DO_STATS: + self.sync_phase = syncer.sync_phase + self.syncer = syncer + self.sync_exchange_phase = syncer.sync_exchange_phase + local_gen_before_sync = yield syncer.sync() + self.received_docs = syncer.received_docs + defer.returnValue(local_gen_before_sync) + + def _get_syncer(self, url, creds=None): + """ + Get a synchronizer for ``url`` using ``creds``. + + :param url: The url of the target replica to sync with. + :type url: str + :param creds: optional dictionary giving credentials. + to authorize the operation with the server. + :type creds: dict + + :return: A synchronizer. + :rtype: Synchronizer + """ + return SoledadSynchronizer( + self, + SoledadHTTPSyncTarget( + url, + # XXX is the replica_uid ready? + self._replica_uid, + creds=creds, + crypto=self._crypto, + cert_file=self._cert_file)) + + # + # Symmetric encryption of syncing docs + # + + def get_generation(self): + # FIXME + # XXX this SHOULD BE a callback + return self._get_generation() + + +class U1DBSQLiteBackend(sqlite.SQLitePartialExpandDatabase): + """ + A very simple wrapper for u1db around sqlcipher backend. + + Instead of initializing the database on the fly, it just uses an existing + connection that is passed to it in the initializer. + + It can be used in tests and debug runs to initialize the adbapi with plain + sqlite connections, decoupled from the sqlcipher layer. + """ + + def __init__(self, conn): + self._db_handle = conn + self._real_replica_uid = None + self._ensure_schema() + self._factory = Document + + +class SoledadSQLCipherWrapper(SQLCipherDatabase): + """ + A wrapper for u1db that uses the Soledad-extended sqlcipher backend. + + Instead of initializing the database on the fly, it just uses an existing + connection that is passed to it in the initializer. + + It can be used from adbapi to initialize a soledad database after + getting a regular connection to a sqlcipher database. + """ + def __init__(self, conn, opts): + self._db_handle = conn + self._real_replica_uid = None + self._ensure_schema() + self.set_document_factory(doc_factory) + self._prime_replica_uid() + + +def _assert_db_is_encrypted(opts): + """ + Assert that the sqlcipher file contains an encrypted database. + + When opening an existing database, PRAGMA key will not immediately + throw an error if the key provided is incorrect. To test that the + database can be successfully opened with the provided key, it is + necessary to perform some operation on the database (i.e. read from + it) and confirm it is success. + + The easiest way to do this is select off the sqlite_master table, + which will attempt to read the first page of the database and will + parse the schema. + + :param opts: + """ + # We try to open an encrypted database with the regular u1db + # backend should raise a DatabaseError exception. + # If the regular backend succeeds, then we need to stop because + # the database was not properly initialized. + try: + sqlite.SQLitePartialExpandDatabase(opts.path) + except sqlcipher_dbapi2.DatabaseError: + # assert that we can access it using SQLCipher with the given + # key + dummy_query = ('SELECT count(*) FROM sqlite_master',) + return initialize_sqlcipher_db(opts, on_init=dummy_query) + else: + raise DatabaseIsNotEncrypted() + +# +# Exceptions +# + + +class DatabaseIsNotEncrypted(Exception): + """ + Exception raised when trying to open non-encrypted databases. + """ + pass + + +def doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False, + syncable=True): + """ + Return a default Soledad Document. + Used in the initialization for SQLCipherDatabase + """ + return Document(doc_id=doc_id, rev=rev, json=json, + has_conflicts=has_conflicts, syncable=syncable) + + +sqlite.SQLiteDatabase.register_implementation(SQLCipherDatabase) + + +# +# twisted.enterprise.adbapi SQLCipher implementation +# + +SQLCIPHER_CONNECTION_TIMEOUT = 10 + + +def getConnectionPool(opts, extra_queries=None): + openfun = partial( + pragmas.set_init_pragmas, + opts=opts, + extra_queries=extra_queries) + return SQLCipherConnectionPool( + database=opts.path, + check_same_thread=False, + cp_openfun=openfun, + timeout=SQLCIPHER_CONNECTION_TIMEOUT) + + +class SQLCipherConnection(adbapi.Connection): + pass + + +class SQLCipherTransaction(adbapi.Transaction): + pass + + +class SQLCipherConnectionPool(adbapi.ConnectionPool): + + connectionFactory = SQLCipherConnection + transactionFactory = SQLCipherTransaction + + def __init__(self, *args, **kwargs): + adbapi.ConnectionPool.__init__( + self, "pysqlcipher.dbapi2", *args, **kwargs) diff --git a/client/src/leap/soledad/client/_database/sqlite.py b/client/src/leap/soledad/client/_database/sqlite.py new file mode 100644 index 00000000..4f7b1259 --- /dev/null +++ b/client/src/leap/soledad/client/_database/sqlite.py @@ -0,0 +1,930 @@ +# Copyright 2011 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +""" +A L2DB implementation that uses SQLite as its persistence layer. +""" + +import errno +import os +import json +import sys +import time +import uuid +import pkg_resources + +from sqlite3 import dbapi2 + +from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget +from leap.soledad.common.l2db import ( + Document, errors, + query_parser, vectorclock) + + +class SQLiteDatabase(CommonBackend): + """A U1DB implementation that uses SQLite as its persistence layer.""" + + _sqlite_registry = {} + + def __init__(self, sqlite_file, document_factory=None): + """Create a new sqlite file.""" + self._db_handle = dbapi2.connect(sqlite_file) + self._real_replica_uid = None + self._ensure_schema() + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return SQLiteSyncTarget(self) + + @classmethod + def _which_index_storage(cls, c): + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'index_storage'") + except dbapi2.OperationalError as e: + # The table does not exist yet + return None, e + else: + return c.fetchone()[0], None + + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + + @classmethod + def _open_database(cls, sqlite_file, document_factory=None): + if not os.path.isfile(sqlite_file): + raise errors.DatabaseDoesNotExist() + tries = 2 + while True: + # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + db_handle = dbapi2.connect(sqlite_file) + c = db_handle.cursor() + v, err = cls._which_index_storage(c) + db_handle.close() + if v is not None: + break + # possibly another process is initializing it, wait for it to be + # done + if tries == 0: + raise err # go for the richest error? + tries -= 1 + time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) + return SQLiteDatabase._sqlite_registry[v]( + sqlite_file, document_factory=document_factory) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None): + try: + return cls._open_database( + sqlite_file, document_factory=document_factory) + except errors.DatabaseDoesNotExist: + if not create: + raise + if backend_cls is None: + # default is SQLitePartialExpandDatabase + backend_cls = SQLitePartialExpandDatabase + return backend_cls(sqlite_file, document_factory=document_factory) + + @staticmethod + def delete_database(sqlite_file): + try: + os.unlink(sqlite_file) + except OSError as ex: + if ex.errno == errno.ENOENT: + raise errors.DatabaseDoesNotExist() + raise + + @staticmethod + def register_implementation(klass): + """Register that we implement an SQLiteDatabase. + + The attribute _index_storage_value will be used as the lookup key. + """ + SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass + + def _get_sqlite_handle(self): + """Get access to the underlying sqlite database. + + This should only be used by the test suite, etc, for examining the + state of the underlying database. + """ + return self._db_handle + + def _close_sqlite_handle(self): + """Release access to the underlying sqlite database.""" + self._db_handle.close() + + def close(self): + self._close_sqlite_handle() + + def _is_initialized(self, c): + """Check if this database has been initialized.""" + c.execute("PRAGMA case_sensitive_like=ON") + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'sql_schema'") + except dbapi2.OperationalError: + # The table does not exist yet + val = None + else: + val = c.fetchone() + if val is not None: + return True + return False + + def _initialize(self, c): + """Create the schema in the database.""" + # read the script with sql commands + # TODO: Change how we set up the dependency. Most likely use something + # like lp:dirspec to grab the file from a common resource + # directory. Doesn't specifically need to be handled until we get + # to the point of packaging this. + schema_content = pkg_resources.resource_string( + __name__, 'dbschema.sql') + # Note: We'd like to use c.executescript() here, but it seems that + # executescript always commits, even if you set + # isolation_level = None, so if we want to properly handle + # exclusive locking and rollbacks between processes, we need + # to execute it line-by-line + for line in schema_content.split(';'): + if not line: + continue + c.execute(line) + # add extra fields + self._extra_schema_init(c) + # A unique identifier should be set for this replica. Implementations + # don't have to strictly use uuid here, but we do want the uid to be + # unique amongst all databases that will sync with each other. + # We might extend this to using something with hostname for easier + # debugging. + self._set_replica_uid_in_transaction(uuid.uuid4().hex) + c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", + (self._index_storage_value,)) + + def _ensure_schema(self): + """Ensure that the database schema has been created.""" + old_isolation_level = self._db_handle.isolation_level + c = self._db_handle.cursor() + if self._is_initialized(c): + return + try: + # autocommit/own mgmt of transactions + self._db_handle.isolation_level = None + with self._db_handle: + # only one execution path should initialize the db + c.execute("begin exclusive") + if self._is_initialized(c): + return + self._initialize(c) + finally: + self._db_handle.isolation_level = old_isolation_level + + def _extra_schema_init(self, c): + """Add any extra fields, etc to the basic table definitions.""" + + def _parse_index_definition(self, index_field): + """Parse a field definition for an index, returning a Getter.""" + # Note: We may want to keep a Parser object around, and cache the + # Getter objects for a greater length of time. Specifically, if + # you create a bunch of indexes, and then insert 50k docs, you'll + # re-parse the indexes between puts. The time to insert the docs + # is still likely to dominate put_doc time, though. + parser = query_parser.Parser() + getter = parser.parse(index_field) + return getter + + def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): + """Update document_fields for a single document. + + :param doc_id: Identifier for this document + :param raw_doc: The python dict representation of the document. + :param getters: A list of [(field_name, Getter)]. Getter.get will be + called to evaluate the index definition for this document, and the + results will be inserted into the db. + :param db_cursor: An sqlite Cursor. + :return: None + """ + values = [] + for field_name, getter in getters: + for idx_value in getter.get(raw_doc): + values.append((doc_id, field_name, idx_value)) + if values: + db_cursor.executemany( + "INSERT INTO document_fields VALUES (?, ?, ?)", values) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + with self._db_handle: + self._set_replica_uid_in_transaction(replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO u1db_config" + " VALUES ('replica_uid', ?)", + (replica_uid,)) + self._real_replica_uid = replica_uid + + def _get_replica_uid(self): + if self._real_replica_uid is not None: + return self._real_replica_uid + c = self._db_handle.cursor() + c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") + val = c.fetchone() + if val is None: + return None + self._real_replica_uid = val[0] + return self._real_replica_uid + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + c = self._db_handle.cursor() + c.execute('SELECT max(generation) FROM transaction_log') + val = c.fetchone()[0] + if val is None: + return 0 + return val + + def _get_generation_info(self): + c = self._db_handle.cursor() + c.execute( + 'SELECT max(generation), transaction_id FROM transaction_log ') + val = c.fetchone() + if val[0] is None: + return(0, '') + return val + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + c = self._db_handle.cursor() + c.execute( + 'SELECT transaction_id FROM transaction_log WHERE generation = ?', + (generation,)) + val = c.fetchone() + if val is None: + raise errors.InvalidGeneration + return val[0] + + def _get_transaction_log(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, transaction_id FROM transaction_log" + " ORDER BY generation") + return c.fetchall() + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + c = self._db_handle.cursor() + if check_for_conflicts: + c.execute( + "SELECT document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " + "conflicts ON conflicts.doc_id = document.doc_id WHERE " + "document.doc_id = ? GROUP BY document.doc_id, " + "document.doc_rev, document.content;", (doc_id,)) + else: + c.execute( + "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", + (doc_id,)) + val = c.fetchone() + if val is None: + return None + doc_rev, content, conflicts = val + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + return doc + + def _has_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", + (doc_id,)) + val = c.fetchone() + if val is None: + return False + else: + return True + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + c = self._db_handle.cursor() + c.execute( + "SELECT document.doc_id, document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " + "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " + "document.doc_rev, document.content;") + rows = c.fetchall() + for doc_id, doc_rev, content, conflicts in rows: + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): + """Convert a dict representation into named fields. + + So something like: {'key1': 'val1', 'key2': 'val2'} + gets converted into: [(doc_id, 'key1', 'val1', 0) + (doc_id, 'key2', 'val2', 0)] + :param doc_id: Just added to every record. + :param base_field: if set, these are nested keys, so each field should + be appropriately prefixed. + :param raw_doc: The python dictionary. + """ + # TODO: Handle lists + values = [] + for field_name, value in raw_doc.iteritems(): + if value is None and not save_none: + continue + if base_field: + full_name = base_field + '.' + field_name + else: + full_name = field_name + if value is None or isinstance(value, (int, float, basestring)): + values.append((doc_id, full_name, value, len(values))) + else: + subvalues = self._expand_to_fields(doc_id, full_name, value, + save_none) + for _, subfield_name, val, _ in subvalues: + values.append((doc_id, subfield_name, val, len(values))) + return values + + def _put_and_update_indexes(self, old_doc, doc): + """Actually insert a document into the database. + + This both updates the existing documents content, and any indexes that + refer to this document. + """ + raise NotImplementedError(self._put_and_update_indexes) + + def whats_changed(self, old_generation=0): + c = self._db_handle.cursor() + c.execute("SELECT generation, doc_id, transaction_id" + " FROM transaction_log" + " WHERE generation > ? ORDER BY generation DESC", + (old_generation,)) + results = c.fetchall() + cur_gen = old_generation + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + c.execute("SELECT generation, transaction_id" + " FROM transaction_log ORDER BY generation DESC LIMIT 1") + results = c.fetchone() + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, newest_trans_id = results + + return cur_gen, newest_trans_id, changes + + def delete_doc(self, doc): + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _get_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", + (doc_id,)) + return [self._factory(doc_id, doc_rev, content) + for doc_rev, content in c.fetchall()] + + def get_doc_conflicts(self, doc_id): + with self._db_handle: + conflict_docs = self._get_conflicts(doc_id) + if not conflict_docs: + return [] + this_doc = self._get_doc(doc_id) + this_doc.has_conflicts = True + return [this_doc] + conflict_docs + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + c = self._db_handle.cursor() + c.execute("SELECT known_generation, known_transaction_id FROM sync_log" + " WHERE replica_uid = ?", + (other_replica_uid,)) + val = c.fetchone() + if val is None: + other_gen = 0 + trans_id = '' + else: + other_gen = val[0] + trans_id = val[1] + return other_gen, trans_id + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + with self._db_handle: + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", + (other_replica_uid, other_generation, + other_transaction_id)) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + return super(SQLiteDatabase, self)._put_doc_if_newer( + doc, + save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + + def _add_conflict(self, c, doc_id, my_doc_rev, my_content): + c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", + (doc_id, my_doc_rev, my_content)) + + def _delete_conflicts(self, c, doc, conflict_revs): + deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] + c.executemany("DELETE FROM conflicts" + " WHERE doc_id=? AND doc_rev=?", deleting) + doc.has_conflicts = self._has_conflicts(doc.doc_id) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + c_revs_to_prune = [] + for c_doc in self._get_conflicts(doc.doc_id): + c_vcr = vectorclock.VectorClockRev(c_doc.rev) + if doc_vcr.is_newer(c_vcr): + c_revs_to_prune.append(c_doc.rev) + elif doc.same_content_as(c_doc): + c_revs_to_prune.append(c_doc.rev) + doc_vcr.maximize(c_vcr) + autoresolved = True + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + c = self._db_handle.cursor() + self._delete_conflicts(c, doc, c_revs_to_prune) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + c = self._db_handle.cursor() + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + def resolve_doc(self, doc, conflicted_doc_revs): + with self._db_handle: + cur_doc = self._get_doc(doc.doc_id) + # TODO: https://bugs.launchpad.net/u1db/+bug/928274 + # I think we have a logic bug in resolve_doc + # Specifically, cur_doc.rev is always in the final vector + # clock of revisions that we supersede, even if it wasn't in + # conflicted_doc_revs. We still add it as a conflict, but the + # fact that _put_doc_if_newer propagates resolutions means I + # think that conflict could accidentally be resolved. We need + # to add a test for this case first. (create a rev, create a + # conflict, create another conflict, resolve the first rev + # and first conflict, then make sure that the resolved + # rev doesn't supersede the second conflict rev.) It *might* + # not matter, because the superseding rev is in as a + # conflict, but it does seem incorrect + new_rev = self._ensure_maximal_rev(cur_doc.rev, + conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + c = self._db_handle.cursor() + doc.rev = new_rev + if cur_doc.rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) + # TODO: Is there some way that we could construct a rev that would + # end up in superseded_revs, such that we add a conflict, and + # then immediately delete it? + self._delete_conflicts(c, doc, superseded_revs) + + def list_indexes(self): + """Return the list of indexes and their definitions.""" + c = self._db_handle.cursor() + # TODO: How do we test the ordering? + c.execute("SELECT name, field FROM index_definitions" + " ORDER BY name, offset") + definitions = [] + cur_name = None + for name, field in c.fetchall(): + if cur_name != name: + definitions.append((name, [])) + cur_name = name + definitions[-1][-1].append(field) + return definitions + + def _get_index_definition(self, index_name): + """Return the stored definition for a given index_name.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions" + " WHERE name = ? ORDER BY offset", (index_name,)) + fields = [x[0] for x in c.fetchall()] + if not fields: + raise errors.IndexDoesNotExist + return fields + + @staticmethod + def _strip_glob(value): + """Remove the trailing * from a value.""" + assert value[-1] == '*' + return value[:-1] + + def _format_query(self, definition, key_values): + # First, build the definition. We join the document_fields table + # against itself, as many times as the 'width' of our definition. + # We then do a query for each key_value, one-at-a-time. + # Note: All of these strings are static, we could cache them, etc. + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + wildcard_where = [novalue_where[i] + + (" AND d%d.value NOT NULL" % (i,)) + for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + like_where = [novalue_where[i] + + (" AND d%d.value GLOB ?" % (i,)) + for i in range(len(definition))] + is_wildcard = False + # Merge the lists together, so that: + # [field1, field2, field3], [val1, val2, val3] + # Becomes: + # (field1, val1, field2, val2, field3, val3) + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(exact_where[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_from_index(self, index_name, *key_values): + definition = self._get_index_definition(index_name) + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + statement, args = self._format_query(definition, key_values) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError as e: + raise dbapi2.OperationalError( + str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def _format_range_query(self, definition, start_value, end_value): + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + wildcard_where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + like_where = [ + novalue_where[i] + ( + " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in + range(len(definition))] + range_where_lower = [ + novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in + range(len(definition))] + range_where_upper = [ + novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in + range(len(definition))] + args = [] + where = [] + if start_value: + if isinstance(start_value, basestring): + start_value = (start_value,) + if len(start_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, start_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(self._strip_glob(value)) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(value) + if end_value: + if isinstance(end_value, basestring): + end_value = (end_value,) + if len(end_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, end_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(self._strip_glob(value)) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_upper[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + definition = self._get_index_definition(index_name) + statement, args = self._format_range_query( + definition, start_value, end_value) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError as e: + raise dbapi2.OperationalError( + str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def get_index_keys(self, index_name): + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + value_fields = ', '.join([ + 'd%d.value' % i for i in range(len(definition))]) + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + statement = ( + "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( + value_fields, ', '.join(tables), ' AND '.join(where), + value_fields)) + try: + c.execute(statement, tuple(definition)) + except dbapi2.OperationalError as e: + raise dbapi2.OperationalError( + str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) + return c.fetchall() + + def delete_index(self, index_name): + with self._db_handle: + c = self._db_handle.cursor() + c.execute("DELETE FROM index_definitions WHERE name = ?", + (index_name,)) + c.execute( + "DELETE FROM document_fields WHERE document_fields.field_name " + " NOT IN (SELECT field from index_definitions)") + + +class SQLiteSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SQLitePartialExpandDatabase(SQLiteDatabase): + """An SQLite Backend that expands documents into a document_field table. + + It stores the original document text in document.doc. For fields that are + indexed, the data goes into document_fields. + """ + + _index_storage_value = 'expand referenced' + + def _get_indexed_fields(self): + """Determine what fields are indexed.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions") + return set([x[0] for x in c.fetchall()]) + + def _evaluate_index(self, raw_doc, field): + parser = query_parser.Parser() + getter = parser.parse(field) + return getter.get(raw_doc) + + def _put_and_update_indexes(self, old_doc, doc): + c = self._db_handle.cursor() + if doc and not doc.is_tombstone(): + raw_doc = json.loads(doc.get_json()) + else: + raw_doc = {} + if old_doc is not None: + c.execute("UPDATE document SET doc_rev=?, content=?" + " WHERE doc_id = ?", + (doc.rev, doc.get_json(), doc.doc_id)) + c.execute("DELETE FROM document_fields WHERE doc_id = ?", + (doc.doc_id,)) + else: + c.execute("INSERT INTO document (doc_id, doc_rev, content)" + " VALUES (?, ?, ?)", + (doc.doc_id, doc.rev, doc.get_json())) + indexed_fields = self._get_indexed_fields() + if indexed_fields: + # It is expected that len(indexed_fields) is shorter than + # len(raw_doc) + getters = [(field, self._parse_index_definition(field)) + for field in indexed_fields] + self._update_indexes(doc.doc_id, raw_doc, getters, c) + trans_id = self._allocate_transaction_id() + c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" + " VALUES (?, ?)", (doc.doc_id, trans_id)) + + def create_index(self, index_name, *index_expressions): + with self._db_handle: + c = self._db_handle.cursor() + cur_fields = self._get_indexed_fields() + definition = [(index_name, idx, field) + for idx, field in enumerate(index_expressions)] + try: + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + definition) + except dbapi2.IntegrityError as e: + stored_def = self._get_index_definition(index_name) + if stored_def == [x[-1] for x in definition]: + return + raise errors.IndexNameTakenError( + str(e) + + str(sys.exc_info()[2]) + ) + new_fields = set( + [f for f in index_expressions if f not in cur_fields]) + if new_fields: + self._update_all_indexes(new_fields) + + def _iter_all_docs(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, content FROM document") + while True: + next_rows = c.fetchmany() + if not next_rows: + break + for row in next_rows: + yield row + + def _update_all_indexes(self, new_fields): + """Iterate all the documents, and add content to document_fields. + + :param new_fields: The index definitions that need to be added. + """ + getters = [(field, self._parse_index_definition(field)) + for field in new_fields] + c = self._db_handle.cursor() + for doc_id, doc in self._iter_all_docs(): + if doc is None: + continue + raw_doc = json.loads(doc) + self._update_indexes(doc_id, raw_doc, getters, c) + + +SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/client/src/leap/soledad/client/adbapi.py b/client/src/leap/soledad/client/adbapi.py deleted file mode 100644 index b002055e..00000000 --- a/client/src/leap/soledad/client/adbapi.py +++ /dev/null @@ -1,297 +0,0 @@ -# -*- coding: utf-8 -*- -# adbapi.py -# Copyright (C) 2013, 2014 LEAP -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -""" -An asyncrhonous interface to soledad using sqlcipher backend. -It uses twisted.enterprise.adbapi. -""" -import re -import sys - -from functools import partial - -from twisted.enterprise import adbapi -from twisted.internet.defer import DeferredSemaphore -from twisted.python import compat -from zope.proxy import ProxyBase, setProxiedObject - -from leap.soledad.common.log import getLogger -from leap.soledad.common.errors import DatabaseAccessError -from leap.soledad.client import sqlcipher as soledad_sqlcipher -from leap.soledad.client.pragmas import set_init_pragmas - -if sys.version_info[0] < 3: - from pysqlcipher import dbapi2 -else: - from pysqlcipher3 import dbapi2 - - -logger = getLogger(__name__) - - -""" -How long the SQLCipher connection should wait for the lock to go away until -raising an exception. -""" -SQLCIPHER_CONNECTION_TIMEOUT = 10 - -""" -How many times a SQLCipher query should be retried in case of timeout. -""" -SQLCIPHER_MAX_RETRIES = 20 - - -def getConnectionPool(opts, openfun=None, driver="pysqlcipher"): - """ - Return a connection pool. - - :param opts: - Options for the SQLCipher connection. - :type opts: SQLCipherOptions - :param openfun: - Callback invoked after every connect() on the underlying DB-API - object. - :type openfun: callable - :param driver: - The connection driver. - :type driver: str - - :return: A U1DB connection pool. - :rtype: U1DBConnectionPool - """ - if openfun is None and driver == "pysqlcipher": - openfun = partial(set_init_pragmas, opts=opts) - return U1DBConnectionPool( - opts, - # the following params are relayed "as is" to twisted's - # ConnectionPool. - "%s.dbapi2" % driver, opts.path, timeout=SQLCIPHER_CONNECTION_TIMEOUT, - check_same_thread=False, cp_openfun=openfun) - - -class U1DBConnection(adbapi.Connection): - """ - A wrapper for a U1DB connection instance. - """ - - u1db_wrapper = soledad_sqlcipher.SoledadSQLCipherWrapper - """ - The U1DB wrapper to use. - """ - - def __init__(self, pool, init_u1db=False): - """ - :param pool: The pool of connections to that owns this connection. - :type pool: adbapi.ConnectionPool - :param init_u1db: Wether the u1db database should be initialized. - :type init_u1db: bool - """ - self.init_u1db = init_u1db - try: - adbapi.Connection.__init__(self, pool) - except dbapi2.DatabaseError as e: - raise DatabaseAccessError( - 'Error initializing connection to sqlcipher database: %s' - % str(e)) - - def reconnect(self): - """ - Reconnect to the U1DB database. - """ - if self._connection is not None: - self._pool.disconnect(self._connection) - self._connection = self._pool.connect() - - if self.init_u1db: - self._u1db = self.u1db_wrapper( - self._connection, - self._pool.opts) - - def __getattr__(self, name): - """ - Route the requested attribute either to the U1DB wrapper or to the - connection. - - :param name: The name of the attribute. - :type name: str - """ - if name.startswith('u1db_'): - attr = re.sub('^u1db_', '', name) - return getattr(self._u1db, attr) - else: - return getattr(self._connection, name) - - -class U1DBTransaction(adbapi.Transaction): - """ - A wrapper for a U1DB 'cursor' object. - """ - - def __getattr__(self, name): - """ - Route the requested attribute either to the U1DB wrapper of the - connection or to the actual connection cursor. - - :param name: The name of the attribute. - :type name: str - """ - if name.startswith('u1db_'): - attr = re.sub('^u1db_', '', name) - return getattr(self._connection._u1db, attr) - else: - return getattr(self._cursor, name) - - -class U1DBConnectionPool(adbapi.ConnectionPool): - """ - Represent a pool of connections to an U1DB database. - """ - - connectionFactory = U1DBConnection - transactionFactory = U1DBTransaction - - def __init__(self, opts, *args, **kwargs): - """ - Initialize the connection pool. - """ - self.opts = opts - try: - adbapi.ConnectionPool.__init__(self, *args, **kwargs) - except dbapi2.DatabaseError as e: - raise DatabaseAccessError( - 'Error initializing u1db connection pool: %s' % str(e)) - - # all u1db connections, hashed by thread-id - self._u1dbconnections = {} - - # The replica uid, primed by the connections on init. - self.replica_uid = ProxyBase(None) - - try: - conn = self.connectionFactory( - self, init_u1db=True) - replica_uid = conn._u1db._real_replica_uid - setProxiedObject(self.replica_uid, replica_uid) - except DatabaseAccessError as e: - self.threadpool.stop() - raise DatabaseAccessError( - "Error initializing connection factory: %s" % str(e)) - - def runU1DBQuery(self, meth, *args, **kw): - """ - Execute a U1DB query in a thread, using a pooled connection. - - Concurrent threads trying to update the same database may timeout - because of other threads holding the database lock. Because of this, - we will retry SQLCIPHER_MAX_RETRIES times and fail after that. - - :param meth: The U1DB wrapper method name. - :type meth: str - - :return: a Deferred which will fire the return value of - 'self._runU1DBQuery(Transaction(...), *args, **kw)', or a Failure. - :rtype: twisted.internet.defer.Deferred - """ - meth = "u1db_%s" % meth - semaphore = DeferredSemaphore(SQLCIPHER_MAX_RETRIES) - - def _run_interaction(): - return self.runInteraction( - self._runU1DBQuery, meth, *args, **kw) - - def _errback(failure): - failure.trap(dbapi2.OperationalError) - if failure.getErrorMessage() == "database is locked": - logger.warn("database operation timed out") - should_retry = semaphore.acquire() - if should_retry: - logger.warn("trying again...") - return _run_interaction() - logger.warn("giving up!") - return failure - - d = _run_interaction() - d.addErrback(_errback) - return d - - def _runU1DBQuery(self, trans, meth, *args, **kw): - """ - Execute a U1DB query. - - :param trans: An U1DB transaction. - :type trans: adbapi.Transaction - :param meth: the U1DB wrapper method name. - :type meth: str - """ - meth = getattr(trans, meth) - return meth(*args, **kw) - # XXX should return a fetchall? - - # XXX add _runOperation too - - def _runInteraction(self, interaction, *args, **kw): - """ - Interact with the database and return the result. - - :param interaction: - A callable object whose first argument is an - L{adbapi.Transaction}. - :type interaction: callable - :return: a Deferred which will fire the return value of - 'interaction(Transaction(...), *args, **kw)', or a Failure. - :rtype: twisted.internet.defer.Deferred - """ - tid = self.threadID() - u1db = self._u1dbconnections.get(tid) - conn = self.connectionFactory( - self, init_u1db=not bool(u1db)) - - if self.replica_uid is None: - replica_uid = conn._u1db._real_replica_uid - setProxiedObject(self.replica_uid, replica_uid) - - if u1db is None: - self._u1dbconnections[tid] = conn._u1db - else: - conn._u1db = u1db - - trans = self.transactionFactory(self, conn) - try: - result = interaction(trans, *args, **kw) - trans.close() - conn.commit() - return result - except: - excType, excValue, excTraceback = sys.exc_info() - try: - conn.rollback() - except: - logger.error(None, "Rollback failed") - compat.reraise(excValue, excTraceback) - - def finalClose(self): - """ - A final close, only called by the shutdown trigger. - """ - self.shutdownID = None - if self.threadpool.started: - self.threadpool.stop() - self.running = False - for conn in self.connections.values(): - self._close(conn) - for u1db in self._u1dbconnections.values(): - self._close(u1db) - self.connections.clear() diff --git a/client/src/leap/soledad/client/api.py b/client/src/leap/soledad/client/api.py index fee9e160..3dd99227 100644 --- a/client/src/leap/soledad/client/api.py +++ b/client/src/leap/soledad/client/api.py @@ -51,14 +51,14 @@ from leap.soledad.common.l2db.remote import http_client from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname from leap.soledad.common.errors import DatabaseAccessError -from leap.soledad.client import adbapi -from leap.soledad.client import events as soledad_events -from leap.soledad.client import interfaces as soledad_interfaces -from leap.soledad.client import sqlcipher -from leap.soledad.client._recovery_code import RecoveryCode -from leap.soledad.client._secrets import Secrets -from leap.soledad.client._crypto import SoledadCrypto -from ._blobs import BlobManager +from . import events as soledad_events +from . import interfaces as soledad_interfaces +from ._crypto import SoledadCrypto +from ._database import adbapi +from ._database import blobs +from ._database import sqlcipher +from ._recovery_code import RecoveryCode +from ._secrets import Secrets logger = getLogger(__name__) @@ -268,8 +268,8 @@ class Soledad(object): path = os.path.join(os.path.dirname(self._local_db_path), 'blobs') url = urlparse.urljoin(self.server_url, 'blobs/%s' % uuid) key = self._secrets.local_key - self.blobmanager = BlobManager(path, url, key, self.uuid, self.token, - SOLEDAD_CERT) + self.blobmanager = blobs.BlobManager(path, url, key, self.uuid, + self.token, SOLEDAD_CERT) # # Closing methods diff --git a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py index 92bc85d6..276b1200 100644 --- a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py +++ b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py @@ -28,7 +28,7 @@ from twisted.internet import defer, reactor from leap.soledad.common import l2db from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherOptions folder = os.environ.get("TMPDIR", "tmp") diff --git a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py index 429566c7..d288582b 100644 --- a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py +++ b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py @@ -27,7 +27,7 @@ import sys from twisted.internet import defer, reactor from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherOptions from leap.soledad.common import l2db diff --git a/client/src/leap/soledad/client/examples/use_adbapi.py b/client/src/leap/soledad/client/examples/use_adbapi.py index 39301b41..25b32307 100644 --- a/client/src/leap/soledad/client/examples/use_adbapi.py +++ b/client/src/leap/soledad/client/examples/use_adbapi.py @@ -24,7 +24,7 @@ import os from twisted.internet import defer, reactor from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherOptions from leap.soledad.common import l2db diff --git a/client/src/leap/soledad/client/http_target/fetch.py b/client/src/leap/soledad/client/http_target/fetch.py index c85c5bab..9d456830 100644 --- a/client/src/leap/soledad/client/http_target/fetch.py +++ b/client/src/leap/soledad/client/http_target/fetch.py @@ -23,10 +23,10 @@ from leap.soledad.client.events import emit_async from leap.soledad.client.http_target.support import RequestBody from leap.soledad.common.log import getLogger from leap.soledad.client._crypto import is_symmetrically_encrypted -from leap.soledad.common.document import Document from leap.soledad.common.l2db import errors from leap.soledad.client import crypto as old_crypto +from .._document import Document from . import fetch_protocol logger = getLogger(__name__) diff --git a/client/src/leap/soledad/client/pragmas.py b/client/src/leap/soledad/client/pragmas.py deleted file mode 100644 index 870ed63e..00000000 --- a/client/src/leap/soledad/client/pragmas.py +++ /dev/null @@ -1,379 +0,0 @@ -# -*- coding: utf-8 -*- -# pragmas.py -# Copyright (C) 2013, 2014 LEAP -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -""" -Different pragmas used in the initialization of the SQLCipher database. -""" -import string -import threading -import os - -from leap.soledad.common import soledad_assert -from leap.soledad.common.log import getLogger - - -logger = getLogger(__name__) - - -_db_init_lock = threading.Lock() - - -def set_init_pragmas(conn, opts=None, extra_queries=None): - """ - Set the initialization pragmas. - - This includes the crypto pragmas, and any other options that must - be passed early to sqlcipher db. - """ - soledad_assert(opts is not None) - extra_queries = [] if extra_queries is None else extra_queries - with _db_init_lock: - # only one execution path should initialize the db - _set_init_pragmas(conn, opts, extra_queries) - - -def _set_init_pragmas(conn, opts, extra_queries): - - sync_off = os.environ.get('LEAP_SQLITE_NOSYNC') - memstore = os.environ.get('LEAP_SQLITE_MEMSTORE') - nowal = os.environ.get('LEAP_SQLITE_NOWAL') - - set_crypto_pragmas(conn, opts) - - if not nowal: - set_write_ahead_logging(conn) - if sync_off: - set_synchronous_off(conn) - else: - set_synchronous_normal(conn) - if memstore: - set_mem_temp_store(conn) - - for query in extra_queries: - conn.cursor().execute(query) - - -def set_crypto_pragmas(db_handle, sqlcipher_opts): - """ - Set cryptographic params (key, cipher, KDF number of iterations and - cipher page size). - - :param db_handle: - :type db_handle: - :param sqlcipher_opts: options for the SQLCipherDatabase - :type sqlcipher_opts: SQLCipherOpts instance - """ - # XXX assert CryptoOptions - opts = sqlcipher_opts - _set_key(db_handle, opts.key, opts.is_raw_key) - _set_cipher(db_handle, opts.cipher) - _set_kdf_iter(db_handle, opts.kdf_iter) - _set_cipher_page_size(db_handle, opts.cipher_page_size) - - -def _set_key(db_handle, key, is_raw_key): - """ - Set the ``key`` for use with the database. - - The process of creating a new, encrypted database is called 'keying' - the database. SQLCipher uses just-in-time key derivation at the point - it is first needed for an operation. This means that the key (and any - options) must be set before the first operation on the database. As - soon as the database is touched (e.g. SELECT, CREATE TABLE, UPDATE, - etc.) and pages need to be read or written, the key is prepared for - use. - - Implementation Notes: - - * PRAGMA key should generally be called as the first operation on a - database. - - :param key: The key for use with the database. - :type key: str - :param is_raw_key: - Whether C{key} is a raw 64-char hex string or a passphrase that should - be hashed to obtain the encyrption key. - :type is_raw_key: bool - """ - if is_raw_key: - _set_key_raw(db_handle, key) - else: - _set_key_passphrase(db_handle, key) - - -def _set_key_passphrase(db_handle, passphrase): - """ - Set a passphrase for encryption key derivation. - - The key itself can be a passphrase, which is converted to a key using - PBKDF2 key derivation. The result is used as the encryption key for - the database. By using this method, there is no way to alter the KDF; - if you want to do so you should use a raw key instead and derive the - key using your own KDF. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param passphrase: The passphrase used to derive the encryption key. - :type passphrase: str - """ - db_handle.cursor().execute("PRAGMA key = '%s'" % passphrase) - - -def _set_key_raw(db_handle, key): - """ - Set a raw hexadecimal encryption key. - - It is possible to specify an exact byte sequence using a blob literal. - With this method, it is the calling application's responsibility to - ensure that the data provided is a 64 character hex string, which will - be converted directly to 32 bytes (256 bits) of key data. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param key: A 64 character hex string. - :type key: str - """ - if not all(c in string.hexdigits for c in key): - raise NotAnHexString(key) - db_handle.cursor().execute('PRAGMA key = "x\'%s"' % key) - - -def _set_cipher(db_handle, cipher='aes-256-cbc'): - """ - Set the cipher and mode to use for symmetric encryption. - - SQLCipher uses aes-256-cbc as the default cipher and mode of - operation. It is possible to change this, though not generally - recommended, using PRAGMA cipher. - - SQLCipher makes direct use of libssl, so all cipher options available - to libssl are also available for use with SQLCipher. See `man enc` for - OpenSSL's supported ciphers. - - Implementation Notes: - - * PRAGMA cipher must be called after PRAGMA key and before the first - actual database operation or it will have no effect. - - * If a non-default value is used PRAGMA cipher to create a database, - it must also be called every time that database is opened. - - * SQLCipher does not implement its own encryption. Instead it uses the - widely available and peer-reviewed OpenSSL libcrypto for all - cryptographic functions. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param cipher: The cipher and mode to use. - :type cipher: str - """ - db_handle.cursor().execute("PRAGMA cipher = '%s'" % cipher) - - -def _set_kdf_iter(db_handle, kdf_iter=4000): - """ - Set the number of iterations for the key derivation function. - - SQLCipher uses PBKDF2 key derivation to strengthen the key and make it - resistent to brute force and dictionary attacks. The default - configuration uses 4000 PBKDF2 iterations (effectively 16,000 SHA1 - operations). PRAGMA kdf_iter can be used to increase or decrease the - number of iterations used. - - Implementation Notes: - - * PRAGMA kdf_iter must be called after PRAGMA key and before the first - actual database operation or it will have no effect. - - * If a non-default value is used PRAGMA kdf_iter to create a database, - it must also be called every time that database is opened. - - * It is not recommended to reduce the number of iterations if a - passphrase is in use. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param kdf_iter: The number of iterations to use. - :type kdf_iter: int - """ - db_handle.cursor().execute("PRAGMA kdf_iter = '%d'" % kdf_iter) - - -def _set_cipher_page_size(db_handle, cipher_page_size=1024): - """ - Set the page size of the encrypted database. - - SQLCipher 2 introduced the new PRAGMA cipher_page_size that can be - used to adjust the page size for the encrypted database. The default - page size is 1024 bytes, but it can be desirable for some applications - to use a larger page size for increased performance. For instance, - some recent testing shows that increasing the page size can noticeably - improve performance (5-30%) for certain queries that manipulate a - large number of pages (e.g. selects without an index, large inserts in - a transaction, big deletes). - - To adjust the page size, call the pragma immediately after setting the - key for the first time and each subsequent time that you open the - database. - - Implementation Notes: - - * PRAGMA cipher_page_size must be called after PRAGMA key and before - the first actual database operation or it will have no effect. - - * If a non-default value is used PRAGMA cipher_page_size to create a - database, it must also be called every time that database is opened. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param cipher_page_size: The page size. - :type cipher_page_size: int - """ - db_handle.cursor().execute( - "PRAGMA cipher_page_size = '%d'" % cipher_page_size) - - -# XXX UNUSED ? -def set_rekey(db_handle, new_key, is_raw_key): - """ - Change the key of an existing encrypted database. - - To change the key on an existing encrypted database, it must first be - unlocked with the current encryption key. Once the database is - readable and writeable, PRAGMA rekey can be used to re-encrypt every - page in the database with a new key. - - * PRAGMA rekey must be called after PRAGMA key. It can be called at any - time once the database is readable. - - * PRAGMA rekey can not be used to encrypted a standard SQLite - database! It is only useful for changing the key on an existing - database. - - * Previous versions of SQLCipher provided a PRAGMA rekey_cipher and - code>PRAGMA rekey_kdf_iter. These are deprecated and should not be - used. Instead, use sqlcipher_export(). - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param new_key: The new key. - :type new_key: str - :param is_raw_key: Whether C{password} is a raw 64-char hex string or a - passphrase that should be hashed to obtain the encyrption - key. - :type is_raw_key: bool - """ - if is_raw_key: - _set_rekey_raw(db_handle, new_key) - else: - _set_rekey_passphrase(db_handle, new_key) - - -def _set_rekey_passphrase(db_handle, passphrase): - """ - Change the passphrase for encryption key derivation. - - The key itself can be a passphrase, which is converted to a key using - PBKDF2 key derivation. The result is used as the encryption key for - the database. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param passphrase: The passphrase used to derive the encryption key. - :type passphrase: str - """ - db_handle.cursor().execute("PRAGMA rekey = '%s'" % passphrase) - - -def _set_rekey_raw(db_handle, key): - """ - Change the raw hexadecimal encryption key. - - It is possible to specify an exact byte sequence using a blob literal. - With this method, it is the calling application's responsibility to - ensure that the data provided is a 64 character hex string, which will - be converted directly to 32 bytes (256 bits) of key data. - - :param db_handle: A handle to the SQLCipher database. - :type db_handle: pysqlcipher.Connection - :param key: A 64 character hex string. - :type key: str - """ - if not all(c in string.hexdigits for c in key): - raise NotAnHexString(key) - db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % key) - - -def set_synchronous_off(db_handle): - """ - Change the setting of the "synchronous" flag to OFF. - """ - logger.debug("sqlcipher: setting synchronous off") - db_handle.cursor().execute('PRAGMA synchronous=OFF') - - -def set_synchronous_normal(db_handle): - """ - Change the setting of the "synchronous" flag to NORMAL. - """ - logger.debug("sqlcipher: setting synchronous normal") - db_handle.cursor().execute('PRAGMA synchronous=NORMAL') - - -def set_mem_temp_store(db_handle): - """ - Use a in-memory store for temporary tables. - """ - logger.debug("sqlcipher: setting temp_store memory") - db_handle.cursor().execute('PRAGMA temp_store=MEMORY') - - -def set_write_ahead_logging(db_handle): - """ - Enable write-ahead logging, and set the autocheckpoint to 50 pages. - - Setting the autocheckpoint to a small value, we make the reads not - suffer too much performance degradation. - - From the sqlite docs: - - "There is a tradeoff between average read performance and average write - performance. To maximize the read performance, one wants to keep the - WAL as small as possible and hence run checkpoints frequently, perhaps - as often as every COMMIT. To maximize write performance, one wants to - amortize the cost of each checkpoint over as many writes as possible, - meaning that one wants to run checkpoints infrequently and let the WAL - grow as large as possible before each checkpoint. The decision of how - often to run checkpoints may therefore vary from one application to - another depending on the relative read and write performance - requirements of the application. The default strategy is to run a - checkpoint once the WAL reaches 1000 pages" - """ - logger.debug("sqlcipher: setting write-ahead logging") - db_handle.cursor().execute('PRAGMA journal_mode=WAL') - - # The optimum value can still use a little bit of tuning, but we favor - # small sizes of the WAL file to get fast reads, since we assume that - # the writes will be quick enough to not block too much. - - db_handle.cursor().execute('PRAGMA wal_autocheckpoint=50') - - -class NotAnHexString(Exception): - """ - Raised when trying to (raw) key the database with a non-hex string. - """ - pass diff --git a/client/src/leap/soledad/client/sqlcipher.py b/client/src/leap/soledad/client/sqlcipher.py deleted file mode 100644 index 14fecd3b..00000000 --- a/client/src/leap/soledad/client/sqlcipher.py +++ /dev/null @@ -1,632 +0,0 @@ -# -*- coding: utf-8 -*- -# sqlcipher.py -# Copyright (C) 2013, 2014 LEAP -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -""" -A U1DB backend that uses SQLCipher as its persistence layer. - -The SQLCipher API (http://sqlcipher.net/sqlcipher-api/) is fully implemented, -with the exception of the following statements: - - * PRAGMA cipher_use_hmac - * PRAGMA cipher_default_use_mac - -SQLCipher 2.0 introduced a per-page HMAC to validate that the page data has -not be tampered with. By default, when creating or opening a database using -SQLCipher 2, SQLCipher will attempt to use an HMAC check. This change in -database format means that SQLCipher 2 can't operate on version 1.1.x -databases by default. Thus, in order to provide backward compatibility with -SQLCipher 1.1.x, PRAGMA cipher_use_hmac can be used to disable the HMAC -functionality on specific databases. - -In some very specific cases, it is not possible to call PRAGMA cipher_use_hmac -as one of the first operations on a database. An example of this is when -trying to ATTACH a 1.1.x database to the main database. In these cases PRAGMA -cipher_default_use_hmac can be used to globally alter the default use of HMAC -when opening a database. - -So, as the statements above were introduced for backwards compatibility with -SQLCipher 1.1 databases, we do not implement them as all SQLCipher databases -handled by Soledad should be created by SQLCipher >= 2.0. -""" -import os -import sys - -from functools import partial - -from twisted.internet import reactor -from twisted.internet import defer -from twisted.enterprise import adbapi - -from leap.soledad.common.log import getLogger -from leap.soledad.common.l2db import errors as u1db_errors -from leap.soledad.common.l2db.backends import sqlite_backend -from leap.soledad.common.errors import DatabaseAccessError - -from leap.soledad.client.http_target import SoledadHTTPSyncTarget -from leap.soledad.client.sync import SoledadSynchronizer -from leap.soledad.client import pragmas -from leap.soledad.client._document import Document - -if sys.version_info[0] < 3: - from pysqlcipher import dbapi2 as sqlcipher_dbapi2 -else: - from pysqlcipher3 import dbapi2 as sqlcipher_dbapi2 - -logger = getLogger(__name__) - - -# Monkey-patch u1db.backends.sqlite_backend with pysqlcipher.dbapi2 -sqlite_backend.dbapi2 = sqlcipher_dbapi2 - - -# we may want to collect statistics from the sync process -DO_STATS = False -if os.environ.get('SOLEDAD_STATS'): - DO_STATS = True - - -def initialize_sqlcipher_db(opts, on_init=None, check_same_thread=True): - """ - Initialize a SQLCipher database. - - :param opts: - :type opts: SQLCipherOptions - :param on_init: a tuple of queries to be executed on initialization - :type on_init: tuple - :return: pysqlcipher.dbapi2.Connection - """ - # Note: There seemed to be a bug in sqlite 3.5.9 (with python2.6) - # where without re-opening the database on Windows, it - # doesn't see the transaction that was just committed - # Removing from here now, look at the pysqlite implementation if the - # bug shows up in windows. - - if not os.path.isfile(opts.path) and not opts.create: - raise u1db_errors.DatabaseDoesNotExist() - - conn = sqlcipher_dbapi2.connect( - opts.path, check_same_thread=check_same_thread) - pragmas.set_init_pragmas(conn, opts, extra_queries=on_init) - return conn - - -def initialize_sqlcipher_adbapi_db(opts, extra_queries=None): - from leap.soledad.client import sqlcipher_adbapi - return sqlcipher_adbapi.getConnectionPool( - opts, extra_queries=extra_queries) - - -class SQLCipherOptions(object): - """ - A container with options for the initialization of an SQLCipher database. - """ - - @classmethod - def copy(cls, source, path=None, key=None, create=None, - is_raw_key=None, cipher=None, kdf_iter=None, - cipher_page_size=None, sync_db_key=None): - """ - Return a copy of C{source} with parameters different than None - replaced by new values. - """ - local_vars = locals() - args = [] - kwargs = {} - - for name in ["path", "key"]: - val = local_vars[name] - if val is not None: - args.append(val) - else: - args.append(getattr(source, name)) - - for name in ["create", "is_raw_key", "cipher", "kdf_iter", - "cipher_page_size", "sync_db_key"]: - val = local_vars[name] - if val is not None: - kwargs[name] = val - else: - kwargs[name] = getattr(source, name) - - return SQLCipherOptions(*args, **kwargs) - - def __init__(self, path, key, create=True, is_raw_key=False, - cipher='aes-256-cbc', kdf_iter=4000, cipher_page_size=1024, - sync_db_key=None): - """ - :param path: The filesystem path for the database to open. - :type path: str - :param create: - True/False, should the database be created if it doesn't - already exist? - :param create: bool - :param is_raw_key: - Whether ``password`` is a raw 64-char hex string or a passphrase - that should be hashed to obtain the encyrption key. - :type raw_key: bool - :param cipher: The cipher and mode to use. - :type cipher: str - :param kdf_iter: The number of iterations to use. - :type kdf_iter: int - :param cipher_page_size: The page size. - :type cipher_page_size: int - """ - self.path = path - self.key = key - self.is_raw_key = is_raw_key - self.create = create - self.cipher = cipher - self.kdf_iter = kdf_iter - self.cipher_page_size = cipher_page_size - self.sync_db_key = sync_db_key - - def __str__(self): - """ - Return string representation of options, for easy debugging. - - :return: String representation of options. - :rtype: str - """ - attr_names = filter(lambda a: not a.startswith('_'), dir(self)) - attr_str = [] - for a in attr_names: - attr_str.append(a + "=" + str(getattr(self, a))) - name = self.__class__.__name__ - return "%s(%s)" % (name, ', '.join(attr_str)) - - -# -# The SQLCipher database -# - -class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase): - """ - A U1DB implementation that uses SQLCipher as its persistence layer. - """ - - # The attribute _index_storage_value will be used as the lookup key for the - # implementation of the SQLCipher storage backend. - _index_storage_value = 'expand referenced encrypted' - - def __init__(self, opts): - """ - Connect to an existing SQLCipher database, creating a new sqlcipher - database file if needed. - - *** IMPORTANT *** - - Don't forget to close the database after use by calling the close() - method otherwise some resources might not be freed and you may - experience several kinds of leakages. - - *** IMPORTANT *** - - :param opts: options for initialization of the SQLCipher database. - :type opts: SQLCipherOptions - """ - # ensure the db is encrypted if the file already exists - if os.path.isfile(opts.path): - self._db_handle = _assert_db_is_encrypted(opts) - else: - # connect to the sqlcipher database - self._db_handle = initialize_sqlcipher_db(opts) - - # TODO --------------------------------------------------- - # Everything else in this initialization has to be factored - # out, so it can be used from SoledadSQLCipherWrapper.__init__ - # too. - # --------------------------------------------------------- - - self._ensure_schema() - self.set_document_factory(doc_factory) - self._prime_replica_uid() - - def _prime_replica_uid(self): - """ - In the u1db implementation, _replica_uid is a property - that returns the value in _real_replica_uid, and does - a db query if no value found. - Here we prime the replica uid during initialization so - that we don't have to wait for the query afterwards. - """ - self._real_replica_uid = None - self._get_replica_uid() - - def _extra_schema_init(self, c): - """ - Add any extra fields, etc to the basic table definitions. - - This method is called by u1db.backends.sqlite_backend._initialize() - method, which is executed when the database schema is created. Here, - we use it to include the "syncable" property for LeapDocuments. - - :param c: The cursor for querying the database. - :type c: dbapi2.cursor - """ - c.execute( - 'ALTER TABLE document ' - 'ADD COLUMN syncable BOOL NOT NULL DEFAULT TRUE') - - # - # SQLCipher API methods - # - - # Extra query methods: extensions to the base u1db sqlite implmentation. - - def get_count_from_index(self, index_name, *key_values): - """ - Return the count for a given combination of index_name - and key values. - - Extension method made from similar methods in u1db version 13.09 - - :param index_name: The index to query - :type index_name: str - :param key_values: values to match. eg, if you have - an index with 3 fields then you would have: - get_from_index(index_name, val1, val2, val3) - :type key_values: tuple - :return: count. - :rtype: int - """ - c = self._db_handle.cursor() - definition = self._get_index_definition(index_name) - - if len(key_values) != len(definition): - raise u1db_errors.InvalidValueForIndex() - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = ["d.doc_id = d%d.doc_id" - " AND d%d.field_name = ?" - % (i, i) for i in range(len(definition))] - exact_where = [novalue_where[i] + (" AND d%d.value = ?" % (i,)) - for i in range(len(definition))] - args = [] - where = [] - for idx, (field, value) in enumerate(zip(definition, key_values)): - args.append(field) - where.append(exact_where[idx]) - args.append(value) - - tables = ["document_fields d%d" % i for i in range(len(definition))] - statement = ( - "SELECT COUNT(*) FROM document d, %s WHERE %s " % ( - ', '.join(tables), - ' AND '.join(where), - )) - try: - c.execute(statement, tuple(args)) - except sqlcipher_dbapi2.OperationalError as e: - raise sqlcipher_dbapi2.OperationalError( - str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - return res[0][0] - - def close(self): - """ - Close db connections. - """ - # TODO should be handled by adbapi instead - # TODO syncdb should be stopped first - - if logger is not None: # logger might be none if called from __del__ - logger.debug("SQLCipher backend: closing") - - # close the actual database - if getattr(self, '_db_handle', False): - self._db_handle.close() - self._db_handle = None - - # indexes - - def _put_and_update_indexes(self, old_doc, doc): - """ - Update a document and all indexes related to it. - - :param old_doc: The old version of the document. - :type old_doc: u1db.Document - :param doc: The new version of the document. - :type doc: u1db.Document - """ - sqlite_backend.SQLitePartialExpandDatabase._put_and_update_indexes( - self, old_doc, doc) - c = self._db_handle.cursor() - c.execute('UPDATE document SET syncable=? WHERE doc_id=?', - (doc.syncable, doc.doc_id)) - - def _get_doc(self, doc_id, check_for_conflicts=False): - """ - Get just the document content, without fancy handling. - - :param doc_id: The unique document identifier - :type doc_id: str - :param include_deleted: If set to True, deleted documents will be - returned with empty content. Otherwise asking for a deleted - document will return None. - :type include_deleted: bool - - :return: a Document object. - :type: u1db.Document - """ - doc = sqlite_backend.SQLitePartialExpandDatabase._get_doc( - self, doc_id, check_for_conflicts) - if doc: - c = self._db_handle.cursor() - c.execute('SELECT syncable FROM document WHERE doc_id=?', - (doc.doc_id,)) - result = c.fetchone() - doc.syncable = bool(result[0]) - return doc - - def __del__(self): - """ - Free resources when deleting or garbage collecting the database. - - This is only here to minimze problems if someone ever forgets to call - the close() method after using the database; you should not rely on - garbage collecting to free up the database resources. - """ - self.close() - - -class SQLCipherU1DBSync(SQLCipherDatabase): - """ - Soledad syncer implementation. - """ - - """ - The name of the local symmetrically encrypted documents to - sync database file. - """ - LOCAL_SYMMETRIC_SYNC_FILE_NAME = 'sync.u1db' - - """ - Period or recurrence of the Looping Call that will do the encryption to the - syncdb (in seconds). - """ - ENCRYPT_LOOP_PERIOD = 1 - - def __init__(self, opts, soledad_crypto, replica_uid, cert_file): - self._opts = opts - self._path = opts.path - self._crypto = soledad_crypto - self.__replica_uid = replica_uid - self._cert_file = cert_file - - # storage for the documents received during a sync - self.received_docs = [] - - self.running = False - self._db_handle = None - - # initialize the main db before scheduling a start - self._initialize_main_db() - self._reactor = reactor - self._reactor.callWhenRunning(self._start) - - if DO_STATS: - self.sync_phase = None - - def commit(self): - self._db_handle.commit() - - @property - def _replica_uid(self): - return str(self.__replica_uid) - - def _start(self): - if not self.running: - self.running = True - - def _initialize_main_db(self): - try: - self._db_handle = initialize_sqlcipher_db( - self._opts, check_same_thread=False) - self._real_replica_uid = None - self._ensure_schema() - self.set_document_factory(doc_factory) - except sqlcipher_dbapi2.DatabaseError as e: - raise DatabaseAccessError(str(e)) - - @defer.inlineCallbacks - def sync(self, url, creds=None): - """ - Synchronize documents with remote replica exposed at url. - - It is not safe to initiate more than one sync process and let them run - concurrently. It is responsibility of the caller to ensure that there - are no concurrent sync processes running. This is currently controlled - by the main Soledad object because it may also run post-sync hooks, - which should be run while the lock is locked. - - :param url: The url of the target replica to sync with. - :type url: str - :param creds: optional dictionary giving credentials to authorize the - operation with the server. - :type creds: dict - - :return: - A Deferred, that will fire with the local generation (type `int`) - before the synchronisation was performed. - :rtype: Deferred - """ - syncer = self._get_syncer(url, creds=creds) - if DO_STATS: - self.sync_phase = syncer.sync_phase - self.syncer = syncer - self.sync_exchange_phase = syncer.sync_exchange_phase - local_gen_before_sync = yield syncer.sync() - self.received_docs = syncer.received_docs - defer.returnValue(local_gen_before_sync) - - def _get_syncer(self, url, creds=None): - """ - Get a synchronizer for ``url`` using ``creds``. - - :param url: The url of the target replica to sync with. - :type url: str - :param creds: optional dictionary giving credentials. - to authorize the operation with the server. - :type creds: dict - - :return: A synchronizer. - :rtype: Synchronizer - """ - return SoledadSynchronizer( - self, - SoledadHTTPSyncTarget( - url, - # XXX is the replica_uid ready? - self._replica_uid, - creds=creds, - crypto=self._crypto, - cert_file=self._cert_file)) - - # - # Symmetric encryption of syncing docs - # - - def get_generation(self): - # FIXME - # XXX this SHOULD BE a callback - return self._get_generation() - - -class U1DBSQLiteBackend(sqlite_backend.SQLitePartialExpandDatabase): - """ - A very simple wrapper for u1db around sqlcipher backend. - - Instead of initializing the database on the fly, it just uses an existing - connection that is passed to it in the initializer. - - It can be used in tests and debug runs to initialize the adbapi with plain - sqlite connections, decoupled from the sqlcipher layer. - """ - - def __init__(self, conn): - self._db_handle = conn - self._real_replica_uid = None - self._ensure_schema() - self._factory = Document - - -class SoledadSQLCipherWrapper(SQLCipherDatabase): - """ - A wrapper for u1db that uses the Soledad-extended sqlcipher backend. - - Instead of initializing the database on the fly, it just uses an existing - connection that is passed to it in the initializer. - - It can be used from adbapi to initialize a soledad database after - getting a regular connection to a sqlcipher database. - """ - def __init__(self, conn, opts): - self._db_handle = conn - self._real_replica_uid = None - self._ensure_schema() - self.set_document_factory(doc_factory) - self._prime_replica_uid() - - -def _assert_db_is_encrypted(opts): - """ - Assert that the sqlcipher file contains an encrypted database. - - When opening an existing database, PRAGMA key will not immediately - throw an error if the key provided is incorrect. To test that the - database can be successfully opened with the provided key, it is - necessary to perform some operation on the database (i.e. read from - it) and confirm it is success. - - The easiest way to do this is select off the sqlite_master table, - which will attempt to read the first page of the database and will - parse the schema. - - :param opts: - """ - # We try to open an encrypted database with the regular u1db - # backend should raise a DatabaseError exception. - # If the regular backend succeeds, then we need to stop because - # the database was not properly initialized. - try: - sqlite_backend.SQLitePartialExpandDatabase(opts.path) - except sqlcipher_dbapi2.DatabaseError: - # assert that we can access it using SQLCipher with the given - # key - dummy_query = ('SELECT count(*) FROM sqlite_master',) - return initialize_sqlcipher_db(opts, on_init=dummy_query) - else: - raise DatabaseIsNotEncrypted() - -# -# Exceptions -# - - -class DatabaseIsNotEncrypted(Exception): - """ - Exception raised when trying to open non-encrypted databases. - """ - pass - - -def doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False, - syncable=True): - """ - Return a default Soledad Document. - Used in the initialization for SQLCipherDatabase - """ - return Document(doc_id=doc_id, rev=rev, json=json, - has_conflicts=has_conflicts, syncable=syncable) - - -sqlite_backend.SQLiteDatabase.register_implementation(SQLCipherDatabase) - - -# -# twisted.enterprise.adbapi SQLCipher implementation -# - -SQLCIPHER_CONNECTION_TIMEOUT = 10 - - -def getConnectionPool(opts, extra_queries=None): - openfun = partial( - pragmas.set_init_pragmas, - opts=opts, - extra_queries=extra_queries) - return SQLCipherConnectionPool( - database=opts.path, - check_same_thread=False, - cp_openfun=openfun, - timeout=SQLCIPHER_CONNECTION_TIMEOUT) - - -class SQLCipherConnection(adbapi.Connection): - pass - - -class SQLCipherTransaction(adbapi.Transaction): - pass - - -class SQLCipherConnectionPool(adbapi.ConnectionPool): - - connectionFactory = SQLCipherConnection - transactionFactory = SQLCipherTransaction - - def __init__(self, *args, **kwargs): - adbapi.ConnectionPool.__init__( - self, "pysqlcipher.dbapi2", *args, **kwargs) diff --git a/common/src/leap/soledad/common/l2db/__init__.py b/common/src/leap/soledad/common/l2db/__init__.py index 568897c4..ebbf546a 100644 --- a/common/src/leap/soledad/common/l2db/__init__.py +++ b/common/src/leap/soledad/common/l2db/__init__.py @@ -37,8 +37,8 @@ def open(path, create, document_factory=None): parameters as Document.__init__. :return: An instance of Database. """ - from leap.soledad.common.l2db.backends import sqlite_backend - return sqlite_backend.SQLiteDatabase.open_database( + from leap.soledad.client._database import sqlite + return sqlite.SQLiteDatabase.open_database( path, create=create, document_factory=document_factory) diff --git a/common/src/leap/soledad/common/l2db/backends/dbschema.sql b/common/src/leap/soledad/common/l2db/backends/dbschema.sql deleted file mode 100644 index ae027fc5..00000000 --- a/common/src/leap/soledad/common/l2db/backends/dbschema.sql +++ /dev/null @@ -1,42 +0,0 @@ --- Database schema -CREATE TABLE transaction_log ( - generation INTEGER PRIMARY KEY AUTOINCREMENT, - doc_id TEXT NOT NULL, - transaction_id TEXT NOT NULL -); -CREATE TABLE document ( - doc_id TEXT PRIMARY KEY, - doc_rev TEXT NOT NULL, - content TEXT -); -CREATE TABLE document_fields ( - doc_id TEXT NOT NULL, - field_name TEXT NOT NULL, - value TEXT -); -CREATE INDEX document_fields_field_value_doc_idx - ON document_fields(field_name, value, doc_id); - -CREATE TABLE sync_log ( - replica_uid TEXT PRIMARY KEY, - known_generation INTEGER, - known_transaction_id TEXT -); -CREATE TABLE conflicts ( - doc_id TEXT, - doc_rev TEXT, - content TEXT, - CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) -); -CREATE TABLE index_definitions ( - name TEXT, - offset INT, - field TEXT, - CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) -); -create index index_definitions_field on index_definitions(field); -CREATE TABLE u1db_config ( - name TEXT PRIMARY KEY, - value TEXT -); -INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/common/src/leap/soledad/common/l2db/backends/sqlite_backend.py b/common/src/leap/soledad/common/l2db/backends/sqlite_backend.py deleted file mode 100644 index 4f7b1259..00000000 --- a/common/src/leap/soledad/common/l2db/backends/sqlite_backend.py +++ /dev/null @@ -1,930 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# Copyright 2016 LEAP Encryption Access Project -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -""" -A L2DB implementation that uses SQLite as its persistence layer. -""" - -import errno -import os -import json -import sys -import time -import uuid -import pkg_resources - -from sqlite3 import dbapi2 - -from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget -from leap.soledad.common.l2db import ( - Document, errors, - query_parser, vectorclock) - - -class SQLiteDatabase(CommonBackend): - """A U1DB implementation that uses SQLite as its persistence layer.""" - - _sqlite_registry = {} - - def __init__(self, sqlite_file, document_factory=None): - """Create a new sqlite file.""" - self._db_handle = dbapi2.connect(sqlite_file) - self._real_replica_uid = None - self._ensure_schema() - self._factory = document_factory or Document - - def set_document_factory(self, factory): - self._factory = factory - - def get_sync_target(self): - return SQLiteSyncTarget(self) - - @classmethod - def _which_index_storage(cls, c): - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'index_storage'") - except dbapi2.OperationalError as e: - # The table does not exist yet - return None, e - else: - return c.fetchone()[0], None - - WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 - - @classmethod - def _open_database(cls, sqlite_file, document_factory=None): - if not os.path.isfile(sqlite_file): - raise errors.DatabaseDoesNotExist() - tries = 2 - while True: - # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) - # where without re-opening the database on Windows, it - # doesn't see the transaction that was just committed - db_handle = dbapi2.connect(sqlite_file) - c = db_handle.cursor() - v, err = cls._which_index_storage(c) - db_handle.close() - if v is not None: - break - # possibly another process is initializing it, wait for it to be - # done - if tries == 0: - raise err # go for the richest error? - tries -= 1 - time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) - return SQLiteDatabase._sqlite_registry[v]( - sqlite_file, document_factory=document_factory) - - @classmethod - def open_database(cls, sqlite_file, create, backend_cls=None, - document_factory=None): - try: - return cls._open_database( - sqlite_file, document_factory=document_factory) - except errors.DatabaseDoesNotExist: - if not create: - raise - if backend_cls is None: - # default is SQLitePartialExpandDatabase - backend_cls = SQLitePartialExpandDatabase - return backend_cls(sqlite_file, document_factory=document_factory) - - @staticmethod - def delete_database(sqlite_file): - try: - os.unlink(sqlite_file) - except OSError as ex: - if ex.errno == errno.ENOENT: - raise errors.DatabaseDoesNotExist() - raise - - @staticmethod - def register_implementation(klass): - """Register that we implement an SQLiteDatabase. - - The attribute _index_storage_value will be used as the lookup key. - """ - SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass - - def _get_sqlite_handle(self): - """Get access to the underlying sqlite database. - - This should only be used by the test suite, etc, for examining the - state of the underlying database. - """ - return self._db_handle - - def _close_sqlite_handle(self): - """Release access to the underlying sqlite database.""" - self._db_handle.close() - - def close(self): - self._close_sqlite_handle() - - def _is_initialized(self, c): - """Check if this database has been initialized.""" - c.execute("PRAGMA case_sensitive_like=ON") - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'sql_schema'") - except dbapi2.OperationalError: - # The table does not exist yet - val = None - else: - val = c.fetchone() - if val is not None: - return True - return False - - def _initialize(self, c): - """Create the schema in the database.""" - # read the script with sql commands - # TODO: Change how we set up the dependency. Most likely use something - # like lp:dirspec to grab the file from a common resource - # directory. Doesn't specifically need to be handled until we get - # to the point of packaging this. - schema_content = pkg_resources.resource_string( - __name__, 'dbschema.sql') - # Note: We'd like to use c.executescript() here, but it seems that - # executescript always commits, even if you set - # isolation_level = None, so if we want to properly handle - # exclusive locking and rollbacks between processes, we need - # to execute it line-by-line - for line in schema_content.split(';'): - if not line: - continue - c.execute(line) - # add extra fields - self._extra_schema_init(c) - # A unique identifier should be set for this replica. Implementations - # don't have to strictly use uuid here, but we do want the uid to be - # unique amongst all databases that will sync with each other. - # We might extend this to using something with hostname for easier - # debugging. - self._set_replica_uid_in_transaction(uuid.uuid4().hex) - c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", - (self._index_storage_value,)) - - def _ensure_schema(self): - """Ensure that the database schema has been created.""" - old_isolation_level = self._db_handle.isolation_level - c = self._db_handle.cursor() - if self._is_initialized(c): - return - try: - # autocommit/own mgmt of transactions - self._db_handle.isolation_level = None - with self._db_handle: - # only one execution path should initialize the db - c.execute("begin exclusive") - if self._is_initialized(c): - return - self._initialize(c) - finally: - self._db_handle.isolation_level = old_isolation_level - - def _extra_schema_init(self, c): - """Add any extra fields, etc to the basic table definitions.""" - - def _parse_index_definition(self, index_field): - """Parse a field definition for an index, returning a Getter.""" - # Note: We may want to keep a Parser object around, and cache the - # Getter objects for a greater length of time. Specifically, if - # you create a bunch of indexes, and then insert 50k docs, you'll - # re-parse the indexes between puts. The time to insert the docs - # is still likely to dominate put_doc time, though. - parser = query_parser.Parser() - getter = parser.parse(index_field) - return getter - - def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): - """Update document_fields for a single document. - - :param doc_id: Identifier for this document - :param raw_doc: The python dict representation of the document. - :param getters: A list of [(field_name, Getter)]. Getter.get will be - called to evaluate the index definition for this document, and the - results will be inserted into the db. - :param db_cursor: An sqlite Cursor. - :return: None - """ - values = [] - for field_name, getter in getters: - for idx_value in getter.get(raw_doc): - values.append((doc_id, field_name, idx_value)) - if values: - db_cursor.executemany( - "INSERT INTO document_fields VALUES (?, ?, ?)", values) - - def _set_replica_uid(self, replica_uid): - """Force the replica_uid to be set.""" - with self._db_handle: - self._set_replica_uid_in_transaction(replica_uid) - - def _set_replica_uid_in_transaction(self, replica_uid): - """Set the replica_uid. A transaction should already be held.""" - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO u1db_config" - " VALUES ('replica_uid', ?)", - (replica_uid,)) - self._real_replica_uid = replica_uid - - def _get_replica_uid(self): - if self._real_replica_uid is not None: - return self._real_replica_uid - c = self._db_handle.cursor() - c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") - val = c.fetchone() - if val is None: - return None - self._real_replica_uid = val[0] - return self._real_replica_uid - - _replica_uid = property(_get_replica_uid) - - def _get_generation(self): - c = self._db_handle.cursor() - c.execute('SELECT max(generation) FROM transaction_log') - val = c.fetchone()[0] - if val is None: - return 0 - return val - - def _get_generation_info(self): - c = self._db_handle.cursor() - c.execute( - 'SELECT max(generation), transaction_id FROM transaction_log ') - val = c.fetchone() - if val[0] is None: - return(0, '') - return val - - def _get_trans_id_for_gen(self, generation): - if generation == 0: - return '' - c = self._db_handle.cursor() - c.execute( - 'SELECT transaction_id FROM transaction_log WHERE generation = ?', - (generation,)) - val = c.fetchone() - if val is None: - raise errors.InvalidGeneration - return val[0] - - def _get_transaction_log(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, transaction_id FROM transaction_log" - " ORDER BY generation") - return c.fetchall() - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - c = self._db_handle.cursor() - if check_for_conflicts: - c.execute( - "SELECT document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " - "conflicts ON conflicts.doc_id = document.doc_id WHERE " - "document.doc_id = ? GROUP BY document.doc_id, " - "document.doc_rev, document.content;", (doc_id,)) - else: - c.execute( - "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", - (doc_id,)) - val = c.fetchone() - if val is None: - return None - doc_rev, content, conflicts = val - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - return doc - - def _has_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", - (doc_id,)) - val = c.fetchone() - if val is None: - return False - else: - return True - - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - c = self._db_handle.cursor() - c.execute( - "SELECT document.doc_id, document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " - "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " - "document.doc_rev, document.content;") - rows = c.fetchall() - for doc_id, doc_rev, content, conflicts in rows: - if content is None and not include_deleted: - continue - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - results.append(doc) - return (generation, results) - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc and old_doc.has_conflicts: - raise errors.ConflictedDoc() - if old_doc and doc.rev is None and old_doc.is_tombstone(): - new_rev = self._allocate_doc_rev(old_doc.rev) - else: - if old_doc is not None: - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - else: - if doc.rev is not None: - raise errors.RevisionConflict() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): - """Convert a dict representation into named fields. - - So something like: {'key1': 'val1', 'key2': 'val2'} - gets converted into: [(doc_id, 'key1', 'val1', 0) - (doc_id, 'key2', 'val2', 0)] - :param doc_id: Just added to every record. - :param base_field: if set, these are nested keys, so each field should - be appropriately prefixed. - :param raw_doc: The python dictionary. - """ - # TODO: Handle lists - values = [] - for field_name, value in raw_doc.iteritems(): - if value is None and not save_none: - continue - if base_field: - full_name = base_field + '.' + field_name - else: - full_name = field_name - if value is None or isinstance(value, (int, float, basestring)): - values.append((doc_id, full_name, value, len(values))) - else: - subvalues = self._expand_to_fields(doc_id, full_name, value, - save_none) - for _, subfield_name, val, _ in subvalues: - values.append((doc_id, subfield_name, val, len(values))) - return values - - def _put_and_update_indexes(self, old_doc, doc): - """Actually insert a document into the database. - - This both updates the existing documents content, and any indexes that - refer to this document. - """ - raise NotImplementedError(self._put_and_update_indexes) - - def whats_changed(self, old_generation=0): - c = self._db_handle.cursor() - c.execute("SELECT generation, doc_id, transaction_id" - " FROM transaction_log" - " WHERE generation > ? ORDER BY generation DESC", - (old_generation,)) - results = c.fetchall() - cur_gen = old_generation - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - c.execute("SELECT generation, transaction_id" - " FROM transaction_log ORDER BY generation DESC LIMIT 1") - results = c.fetchone() - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, newest_trans_id = results - - return cur_gen, newest_trans_id, changes - - def delete_doc(self, doc): - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _get_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", - (doc_id,)) - return [self._factory(doc_id, doc_rev, content) - for doc_rev, content in c.fetchall()] - - def get_doc_conflicts(self, doc_id): - with self._db_handle: - conflict_docs = self._get_conflicts(doc_id) - if not conflict_docs: - return [] - this_doc = self._get_doc(doc_id) - this_doc.has_conflicts = True - return [this_doc] + conflict_docs - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - c = self._db_handle.cursor() - c.execute("SELECT known_generation, known_transaction_id FROM sync_log" - " WHERE replica_uid = ?", - (other_replica_uid,)) - val = c.fetchone() - if val is None: - other_gen = 0 - trans_id = '' - else: - other_gen = val[0] - trans_id = val[1] - return other_gen, trans_id - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - with self._db_handle: - self._do_set_replica_gen_and_trans_id( - other_replica_uid, other_generation, other_transaction_id) - - def _do_set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, - other_transaction_id): - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", - (other_replica_uid, other_generation, - other_transaction_id)) - - def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, - replica_gen=None, replica_trans_id=None): - return super(SQLiteDatabase, self)._put_doc_if_newer( - doc, - save_conflict=save_conflict, - replica_uid=replica_uid, replica_gen=replica_gen, - replica_trans_id=replica_trans_id) - - def _add_conflict(self, c, doc_id, my_doc_rev, my_content): - c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", - (doc_id, my_doc_rev, my_content)) - - def _delete_conflicts(self, c, doc, conflict_revs): - deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] - c.executemany("DELETE FROM conflicts" - " WHERE doc_id=? AND doc_rev=?", deleting) - doc.has_conflicts = self._has_conflicts(doc.doc_id) - - def _prune_conflicts(self, doc, doc_vcr): - if self._has_conflicts(doc.doc_id): - autoresolved = False - c_revs_to_prune = [] - for c_doc in self._get_conflicts(doc.doc_id): - c_vcr = vectorclock.VectorClockRev(c_doc.rev) - if doc_vcr.is_newer(c_vcr): - c_revs_to_prune.append(c_doc.rev) - elif doc.same_content_as(c_doc): - c_revs_to_prune.append(c_doc.rev) - doc_vcr.maximize(c_vcr) - autoresolved = True - if autoresolved: - doc_vcr.increment(self._replica_uid) - doc.rev = doc_vcr.as_str() - c = self._db_handle.cursor() - self._delete_conflicts(c, doc, c_revs_to_prune) - - def _force_doc_sync_conflict(self, doc): - my_doc = self._get_doc(doc.doc_id) - c = self._db_handle.cursor() - self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) - self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) - doc.has_conflicts = True - self._put_and_update_indexes(my_doc, doc) - - def resolve_doc(self, doc, conflicted_doc_revs): - with self._db_handle: - cur_doc = self._get_doc(doc.doc_id) - # TODO: https://bugs.launchpad.net/u1db/+bug/928274 - # I think we have a logic bug in resolve_doc - # Specifically, cur_doc.rev is always in the final vector - # clock of revisions that we supersede, even if it wasn't in - # conflicted_doc_revs. We still add it as a conflict, but the - # fact that _put_doc_if_newer propagates resolutions means I - # think that conflict could accidentally be resolved. We need - # to add a test for this case first. (create a rev, create a - # conflict, create another conflict, resolve the first rev - # and first conflict, then make sure that the resolved - # rev doesn't supersede the second conflict rev.) It *might* - # not matter, because the superseding rev is in as a - # conflict, but it does seem incorrect - new_rev = self._ensure_maximal_rev(cur_doc.rev, - conflicted_doc_revs) - superseded_revs = set(conflicted_doc_revs) - c = self._db_handle.cursor() - doc.rev = new_rev - if cur_doc.rev in superseded_revs: - self._put_and_update_indexes(cur_doc, doc) - else: - self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) - # TODO: Is there some way that we could construct a rev that would - # end up in superseded_revs, such that we add a conflict, and - # then immediately delete it? - self._delete_conflicts(c, doc, superseded_revs) - - def list_indexes(self): - """Return the list of indexes and their definitions.""" - c = self._db_handle.cursor() - # TODO: How do we test the ordering? - c.execute("SELECT name, field FROM index_definitions" - " ORDER BY name, offset") - definitions = [] - cur_name = None - for name, field in c.fetchall(): - if cur_name != name: - definitions.append((name, [])) - cur_name = name - definitions[-1][-1].append(field) - return definitions - - def _get_index_definition(self, index_name): - """Return the stored definition for a given index_name.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions" - " WHERE name = ? ORDER BY offset", (index_name,)) - fields = [x[0] for x in c.fetchall()] - if not fields: - raise errors.IndexDoesNotExist - return fields - - @staticmethod - def _strip_glob(value): - """Remove the trailing * from a value.""" - assert value[-1] == '*' - return value[:-1] - - def _format_query(self, definition, key_values): - # First, build the definition. We join the document_fields table - # against itself, as many times as the 'width' of our definition. - # We then do a query for each key_value, one-at-a-time. - # Note: All of these strings are static, we could cache them, etc. - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = ["d.doc_id = d%d.doc_id" - " AND d%d.field_name = ?" - % (i, i) for i in range(len(definition))] - wildcard_where = [novalue_where[i] + - (" AND d%d.value NOT NULL" % (i,)) - for i in range(len(definition))] - exact_where = [novalue_where[i] + - (" AND d%d.value = ?" % (i,)) - for i in range(len(definition))] - like_where = [novalue_where[i] + - (" AND d%d.value GLOB ?" % (i,)) - for i in range(len(definition))] - is_wildcard = False - # Merge the lists together, so that: - # [field1, field2, field3], [val1, val2, val3] - # Becomes: - # (field1, val1, field2, val2, field3, val3) - args = [] - where = [] - for idx, (field, value) in enumerate(zip(definition, key_values)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(exact_where[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_from_index(self, index_name, *key_values): - definition = self._get_index_definition(index_name) - if len(key_values) != len(definition): - raise errors.InvalidValueForIndex() - statement, args = self._format_query(definition, key_values) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError as e: - raise dbapi2.OperationalError( - str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def _format_range_query(self, definition, start_value, end_value): - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - wildcard_where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - like_where = [ - novalue_where[i] + ( - " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in - range(len(definition))] - range_where_lower = [ - novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in - range(len(definition))] - range_where_upper = [ - novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in - range(len(definition))] - args = [] - where = [] - if start_value: - if isinstance(start_value, basestring): - start_value = (start_value,) - if len(start_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, start_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(self._strip_glob(value)) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(value) - if end_value: - if isinstance(end_value, basestring): - end_value = (end_value,) - if len(end_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, end_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(self._strip_glob(value)) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_upper[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - """Return all documents with key values in the specified range.""" - definition = self._get_index_definition(index_name) - statement, args = self._format_range_query( - definition, start_value, end_value) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError as e: - raise dbapi2.OperationalError( - str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def get_index_keys(self, index_name): - c = self._db_handle.cursor() - definition = self._get_index_definition(index_name) - value_fields = ', '.join([ - 'd%d.value' % i for i in range(len(definition))]) - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - statement = ( - "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( - value_fields, ', '.join(tables), ' AND '.join(where), - value_fields)) - try: - c.execute(statement, tuple(definition)) - except dbapi2.OperationalError as e: - raise dbapi2.OperationalError( - str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) - return c.fetchall() - - def delete_index(self, index_name): - with self._db_handle: - c = self._db_handle.cursor() - c.execute("DELETE FROM index_definitions WHERE name = ?", - (index_name,)) - c.execute( - "DELETE FROM document_fields WHERE document_fields.field_name " - " NOT IN (SELECT field from index_definitions)") - - -class SQLiteSyncTarget(CommonSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - -class SQLitePartialExpandDatabase(SQLiteDatabase): - """An SQLite Backend that expands documents into a document_field table. - - It stores the original document text in document.doc. For fields that are - indexed, the data goes into document_fields. - """ - - _index_storage_value = 'expand referenced' - - def _get_indexed_fields(self): - """Determine what fields are indexed.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions") - return set([x[0] for x in c.fetchall()]) - - def _evaluate_index(self, raw_doc, field): - parser = query_parser.Parser() - getter = parser.parse(field) - return getter.get(raw_doc) - - def _put_and_update_indexes(self, old_doc, doc): - c = self._db_handle.cursor() - if doc and not doc.is_tombstone(): - raw_doc = json.loads(doc.get_json()) - else: - raw_doc = {} - if old_doc is not None: - c.execute("UPDATE document SET doc_rev=?, content=?" - " WHERE doc_id = ?", - (doc.rev, doc.get_json(), doc.doc_id)) - c.execute("DELETE FROM document_fields WHERE doc_id = ?", - (doc.doc_id,)) - else: - c.execute("INSERT INTO document (doc_id, doc_rev, content)" - " VALUES (?, ?, ?)", - (doc.doc_id, doc.rev, doc.get_json())) - indexed_fields = self._get_indexed_fields() - if indexed_fields: - # It is expected that len(indexed_fields) is shorter than - # len(raw_doc) - getters = [(field, self._parse_index_definition(field)) - for field in indexed_fields] - self._update_indexes(doc.doc_id, raw_doc, getters, c) - trans_id = self._allocate_transaction_id() - c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" - " VALUES (?, ?)", (doc.doc_id, trans_id)) - - def create_index(self, index_name, *index_expressions): - with self._db_handle: - c = self._db_handle.cursor() - cur_fields = self._get_indexed_fields() - definition = [(index_name, idx, field) - for idx, field in enumerate(index_expressions)] - try: - c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", - definition) - except dbapi2.IntegrityError as e: - stored_def = self._get_index_definition(index_name) - if stored_def == [x[-1] for x in definition]: - return - raise errors.IndexNameTakenError( - str(e) + - str(sys.exc_info()[2]) - ) - new_fields = set( - [f for f in index_expressions if f not in cur_fields]) - if new_fields: - self._update_all_indexes(new_fields) - - def _iter_all_docs(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, content FROM document") - while True: - next_rows = c.fetchmany() - if not next_rows: - break - for row in next_rows: - yield row - - def _update_all_indexes(self, new_fields): - """Iterate all the documents, and add content to document_fields. - - :param new_fields: The index definitions that need to be added. - """ - getters = [(field, self._parse_index_definition(field)) - for field in new_fields] - c = self._db_handle.cursor() - for doc_id, doc in self._iter_all_docs(): - if doc is None: - continue - raw_doc = json.loads(doc) - self._update_indexes(doc_id, raw_doc, getters, c) - - -SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/common/src/leap/soledad/common/l2db/remote/server_state.py b/common/src/leap/soledad/common/l2db/remote/server_state.py index e20b4679..89ac5742 100644 --- a/common/src/leap/soledad/common/l2db/remote/server_state.py +++ b/common/src/leap/soledad/common/l2db/remote/server_state.py @@ -42,10 +42,9 @@ class ServerState(object): def open_database(self, path): """Open a database at the given location.""" - from u1db.backends import sqlite_backend + from leap.soledad.client._database import sqlite full_path = self._relpath(path) - return sqlite_backend.SQLiteDatabase.open_database(full_path, - create=False) + return sqlite.SQLiteDatabase.open_database(full_path, create=False) def check_database(self, path): """Check if the database at the given location exists. @@ -57,14 +56,13 @@ class ServerState(object): def ensure_database(self, path): """Ensure database at the given location.""" - from u1db.backends import sqlite_backend + from leap.soledad.client._database import sqlite full_path = self._relpath(path) - db = sqlite_backend.SQLiteDatabase.open_database(full_path, - create=True) + db = sqlite.SQLiteDatabase.open_database(full_path, create=True) return db, db._replica_uid def delete_database(self, path): """Delete database at the given location.""" - from u1db.backends import sqlite_backend + from leap.soledad.client._database import sqlite full_path = self._relpath(path) - sqlite_backend.SQLiteDatabase.delete_database(full_path) + sqlite.SQLiteDatabase.delete_database(full_path) diff --git a/testing/test_soledad/u1db_tests/__init__.py b/testing/test_soledad/u1db_tests/__init__.py index 1575b859..ccbb6ca6 100644 --- a/testing/test_soledad/u1db_tests/__init__.py +++ b/testing/test_soledad/u1db_tests/__init__.py @@ -37,11 +37,12 @@ from twisted.internet import reactor from leap.soledad.common.l2db import errors from leap.soledad.common.l2db import Document from leap.soledad.common.l2db.backends import inmemory -from leap.soledad.common.l2db.backends import sqlite_backend from leap.soledad.common.l2db.remote import server_state from leap.soledad.common.l2db.remote import http_app from leap.soledad.common.l2db.remote import http_target +from leap.soledad.client._database import sqlite + if sys.version_info[0] < 3: from pysqlcipher import dbapi2 else: @@ -140,7 +141,7 @@ def copy_memory_database_for_test(test, db): def make_sqlite_partial_expanded_for_test(test, replica_uid): - db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + db = sqlite.SQLitePartialExpandDatabase(':memory:') db._set_replica_uid(replica_uid) return db @@ -151,7 +152,7 @@ def copy_sqlite_partial_expanded_for_test(test, db): # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR # HOUSE. - new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + new_db = sqlite.SQLitePartialExpandDatabase(':memory:') tmpfile = StringIO() for line in db._db_handle.iterdump(): if 'sqlite_sequence' not in line: # work around bug in iterdump diff --git a/testing/test_soledad/u1db_tests/test_open.py b/testing/test_soledad/u1db_tests/test_open.py index b572fba0..eaaee72f 100644 --- a/testing/test_soledad/u1db_tests/test_open.py +++ b/testing/test_soledad/u1db_tests/test_open.py @@ -27,7 +27,8 @@ from test_soledad.u1db_tests.test_backends import TestAlternativeDocument from leap.soledad.common.l2db import errors from leap.soledad.common.l2db import open as u1db_open -from leap.soledad.common.l2db.backends import sqlite_backend + +from leap.soledad.client._database import sqlite @skip("Skiping tests imported from U1DB.") @@ -47,7 +48,7 @@ class TestU1DBOpen(tests.TestCase): db = u1db_open(self.db_path, create=True) self.addCleanup(db.close) self.assertTrue(os.path.exists(self.db_path)) - self.assertIsInstance(db, sqlite_backend.SQLiteDatabase) + self.assertIsInstance(db, sqlite.SQLiteDatabase) def test_open_with_factory(self): db = u1db_open(self.db_path, create=True, @@ -56,7 +57,7 @@ class TestU1DBOpen(tests.TestCase): self.assertEqual(TestAlternativeDocument, db._factory) def test_open_existing(self): - db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) + db = sqlite.SQLitePartialExpandDatabase(self.db_path) self.addCleanup(db.close) doc = db.create_doc_from_json(tests.simple_doc) # Even though create=True, we shouldn't wipe the db @@ -66,8 +67,8 @@ class TestU1DBOpen(tests.TestCase): self.assertEqual(doc, doc2) def test_open_existing_no_create(self): - db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) + db = sqlite.SQLitePartialExpandDatabase(self.db_path) self.addCleanup(db.close) db2 = u1db_open(self.db_path, create=False) self.addCleanup(db2.close) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + self.assertIsInstance(db2, sqlite.SQLitePartialExpandDatabase) diff --git a/testing/test_soledad/util.py b/testing/test_soledad/util.py index 6ffb60b6..0335d544 100644 --- a/testing/test_soledad/util.py +++ b/testing/test_soledad/util.py @@ -46,9 +46,9 @@ from leap.soledad.common.couch.state import CouchServerState from leap.soledad.client import Soledad from leap.soledad.client import http_target from leap.soledad.client import auth -from leap.soledad.client.sqlcipher import SQLCipherDatabase -from leap.soledad.client.sqlcipher import SQLCipherOptions from leap.soledad.client._crypto import is_symmetrically_encrypted +from leap.soledad.client._database.sqlcipher import SQLCipherDatabase +from leap.soledad.client._database.sqlcipher import SQLCipherOptions from leap.soledad.server import SoledadApp diff --git a/testing/tests/blobs/test_blobs.py b/testing/tests/blobs/test_blobs.py index c99cc572..9a0feb61 100644 --- a/testing/tests/blobs/test_blobs.py +++ b/testing/tests/blobs/test_blobs.py @@ -17,12 +17,16 @@ """ Tests for blobs handling. """ +from io import BytesIO +from mock import Mock + from twisted.trial import unittest from twisted.internet import defer -from leap.soledad.client._blobs import DecrypterBuffer, BlobManager, FIXED_REV + +from leap.soledad.client._database.blobs import DecrypterBuffer +from leap.soledad.client._database.blobs import BlobManager +from leap.soledad.client._database.blobs import FIXED_REV from leap.soledad.client import _crypto -from io import BytesIO -from mock import Mock class BlobTestCase(unittest.TestCase): diff --git a/testing/tests/blobs/test_local_backend.py b/testing/tests/blobs/test_local_backend.py index 5caa7463..202d0611 100644 --- a/testing/tests/blobs/test_local_backend.py +++ b/testing/tests/blobs/test_local_backend.py @@ -19,7 +19,7 @@ Tests for sqlcipher backend on blobs client. """ from twisted.trial import unittest from twisted.internet import defer -from leap.soledad.client._blobs import BlobManager, BlobDoc, FIXED_REV +from leap.soledad.client._database.blobs import BlobManager, BlobDoc, FIXED_REV from io import BytesIO from mock import Mock import pytest diff --git a/testing/tests/client/test_aux_methods.py b/testing/tests/client/test_aux_methods.py index 729aa28a..da33255c 100644 --- a/testing/tests/client/test_aux_methods.py +++ b/testing/tests/client/test_aux_methods.py @@ -22,7 +22,7 @@ import os from pytest import inlineCallbacks from leap.soledad.client import Soledad -from leap.soledad.client.adbapi import U1DBConnectionPool +from leap.soledad.client._database.adbapi import U1DBConnectionPool from leap.soledad.client._secrets.util import SecretsError from test_soledad.util import BaseSoledadTest diff --git a/testing/tests/server/test_blobs_server.py b/testing/tests/server/test_blobs_server.py index 2e5af01f..cd39833f 100644 --- a/testing/tests/server/test_blobs_server.py +++ b/testing/tests/server/test_blobs_server.py @@ -24,8 +24,10 @@ from twisted.web.server import Site from twisted.internet import reactor from twisted.internet import defer from treq._utils import set_global_pool + from leap.soledad.server import _blobs as server_blobs -from leap.soledad.client._blobs import BlobManager, BlobAlreadyExistsError +from leap.soledad.client._database.blobs import BlobManager +from leap.soledad.client._database.blobs import BlobAlreadyExistsError class BlobServerTestCase(unittest.TestCase): diff --git a/testing/tests/sqlcipher/test_async.py b/testing/tests/sqlcipher/test_async.py index 42c315fe..dac6c6b9 100644 --- a/testing/tests/sqlcipher/test_async.py +++ b/testing/tests/sqlcipher/test_async.py @@ -20,8 +20,8 @@ import hashlib from twisted.internet import defer from test_soledad.util import BaseSoledadTest -from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database import adbapi +from leap.soledad.client._database import sqlcipher class ASyncSQLCipherRetryTestCase(BaseSoledadTest): @@ -42,7 +42,7 @@ class ASyncSQLCipherRetryTestCase(BaseSoledadTest): def _get_dbpool(self): tmpdb = os.path.join(self.tempdir, "test.soledad") - opts = SQLCipherOptions(tmpdb, "secret", create=True) + opts = sqlcipher.SQLCipherOptions(tmpdb, "secret", create=True) return adbapi.getConnectionPool(opts) def _get_sample(self): diff --git a/testing/tests/sqlcipher/test_backend.py b/testing/tests/sqlcipher/test_backend.py index 6e9595db..4f614fb3 100644 --- a/testing/tests/sqlcipher/test_backend.py +++ b/testing/tests/sqlcipher/test_backend.py @@ -26,14 +26,13 @@ import sys # l2db stuff. from leap.soledad.common.l2db import errors from leap.soledad.common.l2db import query_parser -from leap.soledad.common.l2db.backends.sqlite_backend \ - import SQLitePartialExpandDatabase # soledad stuff. from leap.soledad.common.document import SoledadDocument -from leap.soledad.client.sqlcipher import SQLCipherDatabase -from leap.soledad.client.sqlcipher import SQLCipherOptions -from leap.soledad.client.sqlcipher import DatabaseIsNotEncrypted +from leap.soledad.client._database.sqlite import SQLitePartialExpandDatabase +from leap.soledad.client._database.sqlcipher import SQLCipherDatabase +from leap.soledad.client._database.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import DatabaseIsNotEncrypted # u1db tests stuff. from test_soledad import u1db_tests as tests diff --git a/testing/tests/sync/test_sync_target.py b/testing/tests/sync/test_sync_target.py index a54a02ee..e1bd7de1 100644 --- a/testing/tests/sync/test_sync_target.py +++ b/testing/tests/sync/test_sync_target.py @@ -32,9 +32,9 @@ from twisted.internet import defer from leap.soledad.client import http_target as target from leap.soledad.client.http_target.fetch_protocol import DocStreamReceiver -from leap.soledad.client.sqlcipher import SQLCipherU1DBSync -from leap.soledad.client.sqlcipher import SQLCipherOptions -from leap.soledad.client.sqlcipher import SQLCipherDatabase +from leap.soledad.client._database.sqlcipher import SQLCipherU1DBSync +from leap.soledad.client._database.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherDatabase from leap.soledad.client import _crypto from leap.soledad.common import l2db -- cgit v1.2.3