summaryrefslogtreecommitdiff
path: root/src/leap/soledad/client/_db
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/soledad/client/_db')
-rw-r--r--src/leap/soledad/client/_db/__init__.py0
-rw-r--r--src/leap/soledad/client/_db/adbapi.py298
-rw-r--r--src/leap/soledad/client/_db/blobs.py554
-rw-r--r--src/leap/soledad/client/_db/dbschema.sql42
-rw-r--r--src/leap/soledad/client/_db/pragmas.py379
-rw-r--r--src/leap/soledad/client/_db/sqlcipher.py633
-rw-r--r--src/leap/soledad/client/_db/sqlite.py930
7 files changed, 2836 insertions, 0 deletions
diff --git a/src/leap/soledad/client/_db/__init__.py b/src/leap/soledad/client/_db/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/soledad/client/_db/__init__.py
diff --git a/src/leap/soledad/client/_db/adbapi.py b/src/leap/soledad/client/_db/adbapi.py
new file mode 100644
index 00000000..5c28d108
--- /dev/null
+++ b/src/leap/soledad/client/_db/adbapi.py
@@ -0,0 +1,298 @@
+# -*- coding: utf-8 -*-
+# adbapi.py
+# Copyright (C) 2013, 2014 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+An asyncrhonous interface to soledad using sqlcipher backend.
+It uses twisted.enterprise.adbapi.
+"""
+import re
+import sys
+
+from functools import partial
+
+from twisted.enterprise import adbapi
+from twisted.internet.defer import DeferredSemaphore
+from twisted.python import compat
+from zope.proxy import ProxyBase, setProxiedObject
+
+from leap.soledad.common.log import getLogger
+from leap.soledad.common.errors import DatabaseAccessError
+
+from . import sqlcipher
+from . import pragmas
+
+if sys.version_info[0] < 3:
+ from pysqlcipher import dbapi2
+else:
+ from pysqlcipher3 import dbapi2
+
+
+logger = getLogger(__name__)
+
+
+"""
+How long the SQLCipher connection should wait for the lock to go away until
+raising an exception.
+"""
+SQLCIPHER_CONNECTION_TIMEOUT = 10
+
+"""
+How many times a SQLCipher query should be retried in case of timeout.
+"""
+SQLCIPHER_MAX_RETRIES = 20
+
+
+def getConnectionPool(opts, openfun=None, driver="pysqlcipher"):
+ """
+ Return a connection pool.
+
+ :param opts:
+ Options for the SQLCipher connection.
+ :type opts: SQLCipherOptions
+ :param openfun:
+ Callback invoked after every connect() on the underlying DB-API
+ object.
+ :type openfun: callable
+ :param driver:
+ The connection driver.
+ :type driver: str
+
+ :return: A U1DB connection pool.
+ :rtype: U1DBConnectionPool
+ """
+ if openfun is None and driver == "pysqlcipher":
+ openfun = partial(pragmas.set_init_pragmas, opts=opts)
+ return U1DBConnectionPool(
+ opts,
+ # the following params are relayed "as is" to twisted's
+ # ConnectionPool.
+ "%s.dbapi2" % driver, opts.path, timeout=SQLCIPHER_CONNECTION_TIMEOUT,
+ check_same_thread=False, cp_openfun=openfun)
+
+
+class U1DBConnection(adbapi.Connection):
+ """
+ A wrapper for a U1DB connection instance.
+ """
+
+ u1db_wrapper = sqlcipher.SoledadSQLCipherWrapper
+ """
+ The U1DB wrapper to use.
+ """
+
+ def __init__(self, pool, init_u1db=False):
+ """
+ :param pool: The pool of connections to that owns this connection.
+ :type pool: adbapi.ConnectionPool
+ :param init_u1db: Wether the u1db database should be initialized.
+ :type init_u1db: bool
+ """
+ self.init_u1db = init_u1db
+ try:
+ adbapi.Connection.__init__(self, pool)
+ except dbapi2.DatabaseError as e:
+ raise DatabaseAccessError(
+ 'Error initializing connection to sqlcipher database: %s'
+ % str(e))
+
+ def reconnect(self):
+ """
+ Reconnect to the U1DB database.
+ """
+ if self._connection is not None:
+ self._pool.disconnect(self._connection)
+ self._connection = self._pool.connect()
+
+ if self.init_u1db:
+ self._u1db = self.u1db_wrapper(
+ self._connection,
+ self._pool.opts)
+
+ def __getattr__(self, name):
+ """
+ Route the requested attribute either to the U1DB wrapper or to the
+ connection.
+
+ :param name: The name of the attribute.
+ :type name: str
+ """
+ if name.startswith('u1db_'):
+ attr = re.sub('^u1db_', '', name)
+ return getattr(self._u1db, attr)
+ else:
+ return getattr(self._connection, name)
+
+
+class U1DBTransaction(adbapi.Transaction):
+ """
+ A wrapper for a U1DB 'cursor' object.
+ """
+
+ def __getattr__(self, name):
+ """
+ Route the requested attribute either to the U1DB wrapper of the
+ connection or to the actual connection cursor.
+
+ :param name: The name of the attribute.
+ :type name: str
+ """
+ if name.startswith('u1db_'):
+ attr = re.sub('^u1db_', '', name)
+ return getattr(self._connection._u1db, attr)
+ else:
+ return getattr(self._cursor, name)
+
+
+class U1DBConnectionPool(adbapi.ConnectionPool):
+ """
+ Represent a pool of connections to an U1DB database.
+ """
+
+ connectionFactory = U1DBConnection
+ transactionFactory = U1DBTransaction
+
+ def __init__(self, opts, *args, **kwargs):
+ """
+ Initialize the connection pool.
+ """
+ self.opts = opts
+ try:
+ adbapi.ConnectionPool.__init__(self, *args, **kwargs)
+ except dbapi2.DatabaseError as e:
+ raise DatabaseAccessError(
+ 'Error initializing u1db connection pool: %s' % str(e))
+
+ # all u1db connections, hashed by thread-id
+ self._u1dbconnections = {}
+
+ # The replica uid, primed by the connections on init.
+ self.replica_uid = ProxyBase(None)
+
+ try:
+ conn = self.connectionFactory(
+ self, init_u1db=True)
+ replica_uid = conn._u1db._real_replica_uid
+ setProxiedObject(self.replica_uid, replica_uid)
+ except DatabaseAccessError as e:
+ self.threadpool.stop()
+ raise DatabaseAccessError(
+ "Error initializing connection factory: %s" % str(e))
+
+ def runU1DBQuery(self, meth, *args, **kw):
+ """
+ Execute a U1DB query in a thread, using a pooled connection.
+
+ Concurrent threads trying to update the same database may timeout
+ because of other threads holding the database lock. Because of this,
+ we will retry SQLCIPHER_MAX_RETRIES times and fail after that.
+
+ :param meth: The U1DB wrapper method name.
+ :type meth: str
+
+ :return: a Deferred which will fire the return value of
+ 'self._runU1DBQuery(Transaction(...), *args, **kw)', or a Failure.
+ :rtype: twisted.internet.defer.Deferred
+ """
+ meth = "u1db_%s" % meth
+ semaphore = DeferredSemaphore(SQLCIPHER_MAX_RETRIES)
+
+ def _run_interaction():
+ return self.runInteraction(
+ self._runU1DBQuery, meth, *args, **kw)
+
+ def _errback(failure):
+ failure.trap(dbapi2.OperationalError)
+ if failure.getErrorMessage() == "database is locked":
+ logger.warn("database operation timed out")
+ should_retry = semaphore.acquire()
+ if should_retry:
+ logger.warn("trying again...")
+ return _run_interaction()
+ logger.warn("giving up!")
+ return failure
+
+ d = _run_interaction()
+ d.addErrback(_errback)
+ return d
+
+ def _runU1DBQuery(self, trans, meth, *args, **kw):
+ """
+ Execute a U1DB query.
+
+ :param trans: An U1DB transaction.
+ :type trans: adbapi.Transaction
+ :param meth: the U1DB wrapper method name.
+ :type meth: str
+ """
+ meth = getattr(trans, meth)
+ return meth(*args, **kw)
+ # XXX should return a fetchall?
+
+ # XXX add _runOperation too
+
+ def _runInteraction(self, interaction, *args, **kw):
+ """
+ Interact with the database and return the result.
+
+ :param interaction:
+ A callable object whose first argument is an
+ L{adbapi.Transaction}.
+ :type interaction: callable
+ :return: a Deferred which will fire the return value of
+ 'interaction(Transaction(...), *args, **kw)', or a Failure.
+ :rtype: twisted.internet.defer.Deferred
+ """
+ tid = self.threadID()
+ u1db = self._u1dbconnections.get(tid)
+ conn = self.connectionFactory(
+ self, init_u1db=not bool(u1db))
+
+ if self.replica_uid is None:
+ replica_uid = conn._u1db._real_replica_uid
+ setProxiedObject(self.replica_uid, replica_uid)
+
+ if u1db is None:
+ self._u1dbconnections[tid] = conn._u1db
+ else:
+ conn._u1db = u1db
+
+ trans = self.transactionFactory(self, conn)
+ try:
+ result = interaction(trans, *args, **kw)
+ trans.close()
+ conn.commit()
+ return result
+ except:
+ excType, excValue, excTraceback = sys.exc_info()
+ try:
+ conn.rollback()
+ except:
+ logger.error(None, "Rollback failed")
+ compat.reraise(excValue, excTraceback)
+
+ def finalClose(self):
+ """
+ A final close, only called by the shutdown trigger.
+ """
+ self.shutdownID = None
+ if self.threadpool.started:
+ self.threadpool.stop()
+ self.running = False
+ for conn in self.connections.values():
+ self._close(conn)
+ for u1db in self._u1dbconnections.values():
+ self._close(u1db)
+ self.connections.clear()
diff --git a/src/leap/soledad/client/_db/blobs.py b/src/leap/soledad/client/_db/blobs.py
new file mode 100644
index 00000000..10b90c71
--- /dev/null
+++ b/src/leap/soledad/client/_db/blobs.py
@@ -0,0 +1,554 @@
+# -*- coding: utf-8 -*-
+# _blobs.py
+# Copyright (C) 2017 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Clientside BlobBackend Storage.
+"""
+
+from urlparse import urljoin
+
+import binascii
+import os
+import base64
+
+from io import BytesIO
+from functools import partial
+
+from twisted.logger import Logger
+from twisted.enterprise import adbapi
+from twisted.internet import defer
+from twisted.web.client import FileBodyProducer
+
+import treq
+
+from leap.soledad.common.errors import SoledadError
+from leap.common.files import mkdir_p
+
+from .._document import BlobDoc
+from .._crypto import DocInfo
+from .._crypto import BlobEncryptor
+from .._crypto import BlobDecryptor
+from .._http import HTTPClient
+from .._pipes import TruncatedTailPipe
+from .._pipes import PreamblePipe
+
+from . import pragmas
+from . import sqlcipher
+
+
+logger = Logger()
+FIXED_REV = 'ImmutableRevision' # Blob content is immutable
+
+
+class BlobAlreadyExistsError(SoledadError):
+ pass
+
+
+class ConnectionPool(adbapi.ConnectionPool):
+
+ def insertAndGetLastRowid(self, *args, **kwargs):
+ """
+ Execute an SQL query and return the last rowid.
+
+ See: https://sqlite.org/c3ref/last_insert_rowid.html
+ """
+ return self.runInteraction(
+ self._insertAndGetLastRowid, *args, **kwargs)
+
+ def _insertAndGetLastRowid(self, trans, *args, **kw):
+ trans.execute(*args, **kw)
+ return trans.lastrowid
+
+ def blob(self, table, column, irow, flags):
+ """
+ Open a BLOB for incremental I/O.
+
+ Return a handle to the BLOB that would be selected by:
+
+ SELECT column FROM table WHERE rowid = irow;
+
+ See: https://sqlite.org/c3ref/blob_open.html
+
+ :param table: The table in which to lookup the blob.
+ :type table: str
+ :param column: The column where the BLOB is located.
+ :type column: str
+ :param rowid: The rowid of the BLOB.
+ :type rowid: int
+ :param flags: If zero, BLOB is opened for read-only. If non-zero,
+ BLOB is opened for RW.
+ :type flags: int
+
+ :return: A BLOB handle.
+ :rtype: pysqlcipher.dbapi.Blob
+ """
+ return self.runInteraction(self._blob, table, column, irow, flags)
+
+ def _blob(self, trans, table, column, irow, flags):
+ # TODO: should not use transaction private variable here
+ handle = trans._connection.blob(table, column, irow, flags)
+ return handle
+
+
+def check_http_status(code):
+ if code == 409:
+ raise BlobAlreadyExistsError()
+ elif code != 200:
+ raise SoledadError("Server Error")
+
+
+class DecrypterBuffer(object):
+
+ def __init__(self, blob_id, secret, tag):
+ self.doc_info = DocInfo(blob_id, FIXED_REV)
+ self.secret = secret
+ self.tag = tag
+ self.preamble_pipe = PreamblePipe(self._make_decryptor)
+
+ def _make_decryptor(self, preamble):
+ self.decrypter = BlobDecryptor(
+ self.doc_info, preamble,
+ secret=self.secret,
+ armor=False,
+ start_stream=False,
+ tag=self.tag)
+ return TruncatedTailPipe(self.decrypter, tail_size=len(self.tag))
+
+ def write(self, data):
+ self.preamble_pipe.write(data)
+
+ def close(self):
+ real_size = self.decrypter.decrypted_content_size
+ return self.decrypter._end_stream(), real_size
+
+
+class BlobManager(object):
+ """
+ Ideally, the decrypting flow goes like this:
+
+ - GET a blob from remote server.
+ - Decrypt the preamble
+ - Allocate a zeroblob in the sqlcipher sink
+ - Mark the blob as unusable (ie, not verified)
+ - Decrypt the payload incrementally, and write chunks to sqlcipher
+ ** Is it possible to use a small buffer for the aes writer w/o
+ ** allocating all the memory in openssl?
+ - Finalize the AES decryption
+ - If preamble + payload verifies correctly, mark the blob as usable
+
+ """
+
+ def __init__(
+ self, local_path, remote, key, secret, user, token=None,
+ cert_file=None):
+ if local_path:
+ mkdir_p(os.path.dirname(local_path))
+ self.local = SQLiteBlobBackend(local_path, key)
+ self.remote = remote
+ self.secret = secret
+ self.user = user
+ self._client = HTTPClient(user, token, cert_file)
+
+ def close(self):
+ if hasattr(self, 'local') and self.local:
+ return self.local.close()
+
+ @defer.inlineCallbacks
+ def remote_list(self):
+ uri = urljoin(self.remote, self.user + '/')
+ data = yield self._client.get(uri)
+ defer.returnValue((yield data.json()))
+
+ def local_list(self):
+ return self.local.list()
+
+ @defer.inlineCallbacks
+ def send_missing(self):
+ our_blobs = yield self.local_list()
+ server_blobs = yield self.remote_list()
+ missing = [b_id for b_id in our_blobs if b_id not in server_blobs]
+ logger.info("Amount of documents missing on server: %s" % len(missing))
+ # TODO: Send concurrently when we are able to stream directly from db
+ for blob_id in missing:
+ fd = yield self.local.get(blob_id)
+ logger.info("Upload local blob: %s" % blob_id)
+ yield self._encrypt_and_upload(blob_id, fd)
+
+ @defer.inlineCallbacks
+ def fetch_missing(self):
+ # TODO: Use something to prioritize user requests over general new docs
+ our_blobs = yield self.local_list()
+ server_blobs = yield self.remote_list()
+ docs_we_want = [b_id for b_id in server_blobs if b_id not in our_blobs]
+ logger.info("Fetching new docs from server: %s" % len(docs_we_want))
+ # TODO: Fetch concurrently when we are able to stream directly into db
+ for blob_id in docs_we_want:
+ logger.info("Fetching new doc: %s" % blob_id)
+ yield self.get(blob_id)
+
+ @defer.inlineCallbacks
+ def put(self, doc, size):
+ if (yield self.local.exists(doc.blob_id)):
+ error_message = "Blob already exists: %s" % doc.blob_id
+ raise BlobAlreadyExistsError(error_message)
+ fd = doc.blob_fd
+ # TODO this is a tee really, but ok... could do db and upload
+ # concurrently. not sure if we'd gain something.
+ yield self.local.put(doc.blob_id, fd, size=size)
+ # In fact, some kind of pipe is needed here, where each write on db
+ # handle gets forwarded into a write on the connection handle
+ fd = yield self.local.get(doc.blob_id)
+ yield self._encrypt_and_upload(doc.blob_id, fd)
+
+ @defer.inlineCallbacks
+ def get(self, blob_id):
+ local_blob = yield self.local.get(blob_id)
+ if local_blob:
+ logger.info("Found blob in local database: %s" % blob_id)
+ defer.returnValue(local_blob)
+
+ result = yield self._download_and_decrypt(blob_id)
+
+ if not result:
+ defer.returnValue(None)
+ blob, size = result
+
+ if blob:
+ logger.info("Got decrypted blob of type: %s" % type(blob))
+ blob.seek(0)
+ yield self.local.put(blob_id, blob, size=size)
+ defer.returnValue((yield self.local.get(blob_id)))
+ else:
+ # XXX we shouldn't get here, but we will...
+ # lots of ugly error handling possible:
+ # 1. retry, might be network error
+ # 2. try later, maybe didn't finished streaming
+ # 3.. resignation, might be error while verifying
+ logger.error('sorry, dunno what happened')
+
+ @defer.inlineCallbacks
+ def _encrypt_and_upload(self, blob_id, fd):
+ # TODO ------------------------------------------
+ # this is wrong, is doing 2 stages.
+ # the crypto producer can be passed to
+ # the uploader and react as data is written.
+ # try to rewrite as a tube: pass the fd to aes and let aes writer
+ # produce data to the treq request fd.
+ # ------------------------------------------------
+ logger.info("Staring upload of blob: %s" % blob_id)
+ doc_info = DocInfo(blob_id, FIXED_REV)
+ uri = urljoin(self.remote, self.user + "/" + blob_id)
+ crypter = BlobEncryptor(doc_info, fd, secret=self.secret,
+ armor=False)
+ fd = yield crypter.encrypt()
+ response = yield self._client.put(uri, data=fd)
+ check_http_status(response.code)
+ logger.info("Finished upload: %s" % (blob_id,))
+
+ @defer.inlineCallbacks
+ def _download_and_decrypt(self, blob_id):
+ logger.info("Staring download of blob: %s" % blob_id)
+ # TODO this needs to be connected in a tube
+ uri = urljoin(self.remote, self.user + '/' + blob_id)
+ data = yield self._client.get(uri)
+
+ if data.code == 404:
+ logger.warn("Blob not found in server: %s" % blob_id)
+ defer.returnValue(None)
+ elif not data.headers.hasHeader('Tag'):
+ logger.error("Server didn't send a tag header for: %s" % blob_id)
+ defer.returnValue(None)
+ tag = data.headers.getRawHeaders('Tag')[0]
+ tag = base64.urlsafe_b64decode(tag)
+ buf = DecrypterBuffer(blob_id, self.secret, tag)
+
+ # incrementally collect the body of the response
+ yield treq.collect(data, buf.write)
+ fd, size = buf.close()
+ logger.info("Finished download: (%s, %d)" % (blob_id, size))
+ defer.returnValue((fd, size))
+
+ @defer.inlineCallbacks
+ def delete(self, blob_id):
+ logger.info("Staring deletion of blob: %s" % blob_id)
+ yield self._delete_from_remote(blob_id)
+ if (yield self.local.exists(blob_id)):
+ yield self.local.delete(blob_id)
+
+ def _delete_from_remote(self, blob_id):
+ # TODO this needs to be connected in a tube
+ uri = urljoin(self.remote, self.user + '/' + blob_id)
+ return self._client.delete(uri)
+
+
+class SQLiteBlobBackend(object):
+
+ def __init__(self, path, key=None):
+ self.path = os.path.abspath(
+ os.path.join(path, 'soledad_blob.db'))
+ mkdir_p(os.path.dirname(self.path))
+ if not key:
+ raise ValueError('key cannot be None')
+ backend = 'pysqlcipher.dbapi2'
+ opts = sqlcipher.SQLCipherOptions(
+ '/tmp/ignored', binascii.b2a_hex(key),
+ is_raw_key=True, create=True)
+ pragmafun = partial(pragmas.set_init_pragmas, opts=opts)
+ openfun = _sqlcipherInitFactory(pragmafun)
+
+ self.dbpool = ConnectionPool(
+ backend, self.path, check_same_thread=False, timeout=5,
+ cp_openfun=openfun, cp_min=1, cp_max=2, cp_name='blob_pool')
+
+ def close(self):
+ from twisted._threads import AlreadyQuit
+ try:
+ self.dbpool.close()
+ except AlreadyQuit:
+ pass
+
+ @defer.inlineCallbacks
+ def put(self, blob_id, blob_fd, size=None):
+ logger.info("Saving blob in local database...")
+ insert = 'INSERT INTO blobs (blob_id, payload) VALUES (?, zeroblob(?))'
+ irow = yield self.dbpool.insertAndGetLastRowid(insert, (blob_id, size))
+ handle = yield self.dbpool.blob('blobs', 'payload', irow, 1)
+ blob_fd.seek(0)
+ # XXX I have to copy the buffer here so that I'm able to
+ # return a non-closed file to the caller (blobmanager.get)
+ # FIXME should remove this duplication!
+ # have a look at how treq does cope with closing the handle
+ # for uploading a file
+ producer = FileBodyProducer(blob_fd)
+ done = yield producer.startProducing(handle)
+ logger.info("Finished saving blob in local database.")
+ defer.returnValue(done)
+
+ @defer.inlineCallbacks
+ def get(self, blob_id):
+ # TODO we can also stream the blob value using sqlite
+ # incremental interface for blobs - and just return the raw fd instead
+ select = 'SELECT payload FROM blobs WHERE blob_id = ?'
+ result = yield self.dbpool.runQuery(select, (blob_id,))
+ if result:
+ defer.returnValue(BytesIO(str(result[0][0])))
+
+ @defer.inlineCallbacks
+ def list(self):
+ query = 'select blob_id from blobs'
+ result = yield self.dbpool.runQuery(query)
+ if result:
+ defer.returnValue([b_id[0] for b_id in result])
+ else:
+ defer.returnValue([])
+
+ @defer.inlineCallbacks
+ def exists(self, blob_id):
+ query = 'SELECT blob_id from blobs WHERE blob_id = ?'
+ result = yield self.dbpool.runQuery(query, (blob_id,))
+ defer.returnValue(bool(len(result)))
+
+ def delete(self, blob_id):
+ query = 'DELETE FROM blobs WHERE blob_id = ?'
+ return self.dbpool.runQuery(query, (blob_id,))
+
+
+def _init_blob_table(conn):
+ maybe_create = (
+ "CREATE TABLE IF NOT EXISTS "
+ "blobs ("
+ "blob_id PRIMARY KEY, "
+ "payload BLOB)")
+ conn.execute(maybe_create)
+
+
+def _sqlcipherInitFactory(fun):
+ def _initialize(conn):
+ fun(conn)
+ _init_blob_table(conn)
+ return _initialize
+
+
+#
+# testing facilities
+#
+
+@defer.inlineCallbacks
+def testit(reactor):
+ # configure logging to stdout
+ from twisted.python import log
+ import sys
+ log.startLogging(sys.stdout)
+
+ # parse command line arguments
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--url', default='http://localhost:9000/')
+ parser.add_argument('--path', default='/tmp/blobs')
+ parser.add_argument('--secret', default='secret')
+ parser.add_argument('--uuid', default='user')
+ parser.add_argument('--token', default=None)
+ parser.add_argument('--cert-file', default='')
+
+ subparsers = parser.add_subparsers(help='sub-command help', dest='action')
+
+ # parse upload command
+ parser_upload = subparsers.add_parser(
+ 'upload', help='upload blob and bypass local db')
+ parser_upload.add_argument('payload')
+ parser_upload.add_argument('blob_id')
+
+ # parse download command
+ parser_download = subparsers.add_parser(
+ 'download', help='download blob and bypass local db')
+ parser_download.add_argument('blob_id')
+ parser_download.add_argument('--output-file', default='/tmp/incoming-file')
+
+ # parse put command
+ parser_put = subparsers.add_parser(
+ 'put', help='put blob in local db and upload')
+ parser_put.add_argument('payload')
+ parser_put.add_argument('blob_id')
+
+ # parse get command
+ parser_get = subparsers.add_parser(
+ 'get', help='get blob from local db, get if needed')
+ parser_get.add_argument('blob_id')
+
+ # parse delete command
+ parser_get = subparsers.add_parser(
+ 'delete', help='delete blob from local and remote db')
+ parser_get.add_argument('blob_id')
+
+ # parse list command
+ parser_get = subparsers.add_parser(
+ 'list', help='list local and remote blob ids')
+
+ # parse send_missing command
+ parser_get = subparsers.add_parser(
+ 'send_missing', help='send all pending upload blobs')
+
+ # parse send_missing command
+ parser_get = subparsers.add_parser(
+ 'fetch_missing', help='fetch all new server blobs')
+
+ # parse arguments
+ args = parser.parse_args()
+
+ # TODO convert these into proper unittests
+
+ def _manager():
+ mkdir_p(os.path.dirname(args.path))
+ manager = BlobManager(
+ args.path, args.url,
+ 'A' * 32, args.secret,
+ args.uuid, args.token, args.cert_file)
+ return manager
+
+ @defer.inlineCallbacks
+ def _upload(blob_id, payload):
+ logger.info(":: Starting upload only: %s" % str((blob_id, payload)))
+ manager = _manager()
+ with open(payload, 'r') as fd:
+ yield manager._encrypt_and_upload(blob_id, fd)
+ logger.info(":: Finished upload only: %s" % str((blob_id, payload)))
+
+ @defer.inlineCallbacks
+ def _download(blob_id):
+ logger.info(":: Starting download only: %s" % blob_id)
+ manager = _manager()
+ result = yield manager._download_and_decrypt(blob_id)
+ logger.info(":: Result of download: %s" % str(result))
+ if result:
+ fd, _ = result
+ with open(args.output_file, 'w') as f:
+ logger.info(":: Writing data to %s" % args.output_file)
+ f.write(fd.read())
+ logger.info(":: Finished download only: %s" % blob_id)
+
+ @defer.inlineCallbacks
+ def _put(blob_id, payload):
+ logger.info(":: Starting full put: %s" % blob_id)
+ manager = _manager()
+ size = os.path.getsize(payload)
+ with open(payload) as fd:
+ doc = BlobDoc(fd, blob_id)
+ result = yield manager.put(doc, size=size)
+ logger.info(":: Result of put: %s" % str(result))
+ logger.info(":: Finished full put: %s" % blob_id)
+
+ @defer.inlineCallbacks
+ def _get(blob_id):
+ logger.info(":: Starting full get: %s" % blob_id)
+ manager = _manager()
+ fd = yield manager.get(blob_id)
+ if fd:
+ logger.info(":: Result of get: " + fd.getvalue())
+ logger.info(":: Finished full get: %s" % blob_id)
+
+ @defer.inlineCallbacks
+ def _delete(blob_id):
+ logger.info(":: Starting deletion of: %s" % blob_id)
+ manager = _manager()
+ yield manager.delete(blob_id)
+ logger.info(":: Finished deletion of: %s" % blob_id)
+
+ @defer.inlineCallbacks
+ def _list():
+ logger.info(":: Listing local blobs")
+ manager = _manager()
+ local_list = yield manager.local_list()
+ logger.info(":: Local list: %s" % local_list)
+ logger.info(":: Listing remote blobs")
+ remote_list = yield manager.remote_list()
+ logger.info(":: Remote list: %s" % remote_list)
+
+ @defer.inlineCallbacks
+ def _send_missing():
+ logger.info(":: Sending local pending upload docs")
+ manager = _manager()
+ yield manager.send_missing()
+ logger.info(":: Finished sending missing docs")
+
+ @defer.inlineCallbacks
+ def _fetch_missing():
+ logger.info(":: Fetching remote new docs")
+ manager = _manager()
+ yield manager.fetch_missing()
+ logger.info(":: Finished fetching new docs")
+
+ if args.action == 'upload':
+ yield _upload(args.blob_id, args.payload)
+ elif args.action == 'download':
+ yield _download(args.blob_id)
+ elif args.action == 'put':
+ yield _put(args.blob_id, args.payload)
+ elif args.action == 'get':
+ yield _get(args.blob_id)
+ elif args.action == 'delete':
+ yield _delete(args.blob_id)
+ elif args.action == 'list':
+ yield _list()
+ elif args.action == 'send_missing':
+ yield _send_missing()
+ elif args.action == 'fetch_missing':
+ yield _fetch_missing()
+
+
+if __name__ == '__main__':
+ from twisted.internet.task import react
+ react(testit)
diff --git a/src/leap/soledad/client/_db/dbschema.sql b/src/leap/soledad/client/_db/dbschema.sql
new file mode 100644
index 00000000..ae027fc5
--- /dev/null
+++ b/src/leap/soledad/client/_db/dbschema.sql
@@ -0,0 +1,42 @@
+-- Database schema
+CREATE TABLE transaction_log (
+ generation INTEGER PRIMARY KEY AUTOINCREMENT,
+ doc_id TEXT NOT NULL,
+ transaction_id TEXT NOT NULL
+);
+CREATE TABLE document (
+ doc_id TEXT PRIMARY KEY,
+ doc_rev TEXT NOT NULL,
+ content TEXT
+);
+CREATE TABLE document_fields (
+ doc_id TEXT NOT NULL,
+ field_name TEXT NOT NULL,
+ value TEXT
+);
+CREATE INDEX document_fields_field_value_doc_idx
+ ON document_fields(field_name, value, doc_id);
+
+CREATE TABLE sync_log (
+ replica_uid TEXT PRIMARY KEY,
+ known_generation INTEGER,
+ known_transaction_id TEXT
+);
+CREATE TABLE conflicts (
+ doc_id TEXT,
+ doc_rev TEXT,
+ content TEXT,
+ CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
+);
+CREATE TABLE index_definitions (
+ name TEXT,
+ offset INT,
+ field TEXT,
+ CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
+);
+create index index_definitions_field on index_definitions(field);
+CREATE TABLE u1db_config (
+ name TEXT PRIMARY KEY,
+ value TEXT
+);
+INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/src/leap/soledad/client/_db/pragmas.py b/src/leap/soledad/client/_db/pragmas.py
new file mode 100644
index 00000000..870ed63e
--- /dev/null
+++ b/src/leap/soledad/client/_db/pragmas.py
@@ -0,0 +1,379 @@
+# -*- coding: utf-8 -*-
+# pragmas.py
+# Copyright (C) 2013, 2014 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Different pragmas used in the initialization of the SQLCipher database.
+"""
+import string
+import threading
+import os
+
+from leap.soledad.common import soledad_assert
+from leap.soledad.common.log import getLogger
+
+
+logger = getLogger(__name__)
+
+
+_db_init_lock = threading.Lock()
+
+
+def set_init_pragmas(conn, opts=None, extra_queries=None):
+ """
+ Set the initialization pragmas.
+
+ This includes the crypto pragmas, and any other options that must
+ be passed early to sqlcipher db.
+ """
+ soledad_assert(opts is not None)
+ extra_queries = [] if extra_queries is None else extra_queries
+ with _db_init_lock:
+ # only one execution path should initialize the db
+ _set_init_pragmas(conn, opts, extra_queries)
+
+
+def _set_init_pragmas(conn, opts, extra_queries):
+
+ sync_off = os.environ.get('LEAP_SQLITE_NOSYNC')
+ memstore = os.environ.get('LEAP_SQLITE_MEMSTORE')
+ nowal = os.environ.get('LEAP_SQLITE_NOWAL')
+
+ set_crypto_pragmas(conn, opts)
+
+ if not nowal:
+ set_write_ahead_logging(conn)
+ if sync_off:
+ set_synchronous_off(conn)
+ else:
+ set_synchronous_normal(conn)
+ if memstore:
+ set_mem_temp_store(conn)
+
+ for query in extra_queries:
+ conn.cursor().execute(query)
+
+
+def set_crypto_pragmas(db_handle, sqlcipher_opts):
+ """
+ Set cryptographic params (key, cipher, KDF number of iterations and
+ cipher page size).
+
+ :param db_handle:
+ :type db_handle:
+ :param sqlcipher_opts: options for the SQLCipherDatabase
+ :type sqlcipher_opts: SQLCipherOpts instance
+ """
+ # XXX assert CryptoOptions
+ opts = sqlcipher_opts
+ _set_key(db_handle, opts.key, opts.is_raw_key)
+ _set_cipher(db_handle, opts.cipher)
+ _set_kdf_iter(db_handle, opts.kdf_iter)
+ _set_cipher_page_size(db_handle, opts.cipher_page_size)
+
+
+def _set_key(db_handle, key, is_raw_key):
+ """
+ Set the ``key`` for use with the database.
+
+ The process of creating a new, encrypted database is called 'keying'
+ the database. SQLCipher uses just-in-time key derivation at the point
+ it is first needed for an operation. This means that the key (and any
+ options) must be set before the first operation on the database. As
+ soon as the database is touched (e.g. SELECT, CREATE TABLE, UPDATE,
+ etc.) and pages need to be read or written, the key is prepared for
+ use.
+
+ Implementation Notes:
+
+ * PRAGMA key should generally be called as the first operation on a
+ database.
+
+ :param key: The key for use with the database.
+ :type key: str
+ :param is_raw_key:
+ Whether C{key} is a raw 64-char hex string or a passphrase that should
+ be hashed to obtain the encyrption key.
+ :type is_raw_key: bool
+ """
+ if is_raw_key:
+ _set_key_raw(db_handle, key)
+ else:
+ _set_key_passphrase(db_handle, key)
+
+
+def _set_key_passphrase(db_handle, passphrase):
+ """
+ Set a passphrase for encryption key derivation.
+
+ The key itself can be a passphrase, which is converted to a key using
+ PBKDF2 key derivation. The result is used as the encryption key for
+ the database. By using this method, there is no way to alter the KDF;
+ if you want to do so you should use a raw key instead and derive the
+ key using your own KDF.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param passphrase: The passphrase used to derive the encryption key.
+ :type passphrase: str
+ """
+ db_handle.cursor().execute("PRAGMA key = '%s'" % passphrase)
+
+
+def _set_key_raw(db_handle, key):
+ """
+ Set a raw hexadecimal encryption key.
+
+ It is possible to specify an exact byte sequence using a blob literal.
+ With this method, it is the calling application's responsibility to
+ ensure that the data provided is a 64 character hex string, which will
+ be converted directly to 32 bytes (256 bits) of key data.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param key: A 64 character hex string.
+ :type key: str
+ """
+ if not all(c in string.hexdigits for c in key):
+ raise NotAnHexString(key)
+ db_handle.cursor().execute('PRAGMA key = "x\'%s"' % key)
+
+
+def _set_cipher(db_handle, cipher='aes-256-cbc'):
+ """
+ Set the cipher and mode to use for symmetric encryption.
+
+ SQLCipher uses aes-256-cbc as the default cipher and mode of
+ operation. It is possible to change this, though not generally
+ recommended, using PRAGMA cipher.
+
+ SQLCipher makes direct use of libssl, so all cipher options available
+ to libssl are also available for use with SQLCipher. See `man enc` for
+ OpenSSL's supported ciphers.
+
+ Implementation Notes:
+
+ * PRAGMA cipher must be called after PRAGMA key and before the first
+ actual database operation or it will have no effect.
+
+ * If a non-default value is used PRAGMA cipher to create a database,
+ it must also be called every time that database is opened.
+
+ * SQLCipher does not implement its own encryption. Instead it uses the
+ widely available and peer-reviewed OpenSSL libcrypto for all
+ cryptographic functions.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param cipher: The cipher and mode to use.
+ :type cipher: str
+ """
+ db_handle.cursor().execute("PRAGMA cipher = '%s'" % cipher)
+
+
+def _set_kdf_iter(db_handle, kdf_iter=4000):
+ """
+ Set the number of iterations for the key derivation function.
+
+ SQLCipher uses PBKDF2 key derivation to strengthen the key and make it
+ resistent to brute force and dictionary attacks. The default
+ configuration uses 4000 PBKDF2 iterations (effectively 16,000 SHA1
+ operations). PRAGMA kdf_iter can be used to increase or decrease the
+ number of iterations used.
+
+ Implementation Notes:
+
+ * PRAGMA kdf_iter must be called after PRAGMA key and before the first
+ actual database operation or it will have no effect.
+
+ * If a non-default value is used PRAGMA kdf_iter to create a database,
+ it must also be called every time that database is opened.
+
+ * It is not recommended to reduce the number of iterations if a
+ passphrase is in use.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param kdf_iter: The number of iterations to use.
+ :type kdf_iter: int
+ """
+ db_handle.cursor().execute("PRAGMA kdf_iter = '%d'" % kdf_iter)
+
+
+def _set_cipher_page_size(db_handle, cipher_page_size=1024):
+ """
+ Set the page size of the encrypted database.
+
+ SQLCipher 2 introduced the new PRAGMA cipher_page_size that can be
+ used to adjust the page size for the encrypted database. The default
+ page size is 1024 bytes, but it can be desirable for some applications
+ to use a larger page size for increased performance. For instance,
+ some recent testing shows that increasing the page size can noticeably
+ improve performance (5-30%) for certain queries that manipulate a
+ large number of pages (e.g. selects without an index, large inserts in
+ a transaction, big deletes).
+
+ To adjust the page size, call the pragma immediately after setting the
+ key for the first time and each subsequent time that you open the
+ database.
+
+ Implementation Notes:
+
+ * PRAGMA cipher_page_size must be called after PRAGMA key and before
+ the first actual database operation or it will have no effect.
+
+ * If a non-default value is used PRAGMA cipher_page_size to create a
+ database, it must also be called every time that database is opened.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param cipher_page_size: The page size.
+ :type cipher_page_size: int
+ """
+ db_handle.cursor().execute(
+ "PRAGMA cipher_page_size = '%d'" % cipher_page_size)
+
+
+# XXX UNUSED ?
+def set_rekey(db_handle, new_key, is_raw_key):
+ """
+ Change the key of an existing encrypted database.
+
+ To change the key on an existing encrypted database, it must first be
+ unlocked with the current encryption key. Once the database is
+ readable and writeable, PRAGMA rekey can be used to re-encrypt every
+ page in the database with a new key.
+
+ * PRAGMA rekey must be called after PRAGMA key. It can be called at any
+ time once the database is readable.
+
+ * PRAGMA rekey can not be used to encrypted a standard SQLite
+ database! It is only useful for changing the key on an existing
+ database.
+
+ * Previous versions of SQLCipher provided a PRAGMA rekey_cipher and
+ code>PRAGMA rekey_kdf_iter. These are deprecated and should not be
+ used. Instead, use sqlcipher_export().
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param new_key: The new key.
+ :type new_key: str
+ :param is_raw_key: Whether C{password} is a raw 64-char hex string or a
+ passphrase that should be hashed to obtain the encyrption
+ key.
+ :type is_raw_key: bool
+ """
+ if is_raw_key:
+ _set_rekey_raw(db_handle, new_key)
+ else:
+ _set_rekey_passphrase(db_handle, new_key)
+
+
+def _set_rekey_passphrase(db_handle, passphrase):
+ """
+ Change the passphrase for encryption key derivation.
+
+ The key itself can be a passphrase, which is converted to a key using
+ PBKDF2 key derivation. The result is used as the encryption key for
+ the database.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param passphrase: The passphrase used to derive the encryption key.
+ :type passphrase: str
+ """
+ db_handle.cursor().execute("PRAGMA rekey = '%s'" % passphrase)
+
+
+def _set_rekey_raw(db_handle, key):
+ """
+ Change the raw hexadecimal encryption key.
+
+ It is possible to specify an exact byte sequence using a blob literal.
+ With this method, it is the calling application's responsibility to
+ ensure that the data provided is a 64 character hex string, which will
+ be converted directly to 32 bytes (256 bits) of key data.
+
+ :param db_handle: A handle to the SQLCipher database.
+ :type db_handle: pysqlcipher.Connection
+ :param key: A 64 character hex string.
+ :type key: str
+ """
+ if not all(c in string.hexdigits for c in key):
+ raise NotAnHexString(key)
+ db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % key)
+
+
+def set_synchronous_off(db_handle):
+ """
+ Change the setting of the "synchronous" flag to OFF.
+ """
+ logger.debug("sqlcipher: setting synchronous off")
+ db_handle.cursor().execute('PRAGMA synchronous=OFF')
+
+
+def set_synchronous_normal(db_handle):
+ """
+ Change the setting of the "synchronous" flag to NORMAL.
+ """
+ logger.debug("sqlcipher: setting synchronous normal")
+ db_handle.cursor().execute('PRAGMA synchronous=NORMAL')
+
+
+def set_mem_temp_store(db_handle):
+ """
+ Use a in-memory store for temporary tables.
+ """
+ logger.debug("sqlcipher: setting temp_store memory")
+ db_handle.cursor().execute('PRAGMA temp_store=MEMORY')
+
+
+def set_write_ahead_logging(db_handle):
+ """
+ Enable write-ahead logging, and set the autocheckpoint to 50 pages.
+
+ Setting the autocheckpoint to a small value, we make the reads not
+ suffer too much performance degradation.
+
+ From the sqlite docs:
+
+ "There is a tradeoff between average read performance and average write
+ performance. To maximize the read performance, one wants to keep the
+ WAL as small as possible and hence run checkpoints frequently, perhaps
+ as often as every COMMIT. To maximize write performance, one wants to
+ amortize the cost of each checkpoint over as many writes as possible,
+ meaning that one wants to run checkpoints infrequently and let the WAL
+ grow as large as possible before each checkpoint. The decision of how
+ often to run checkpoints may therefore vary from one application to
+ another depending on the relative read and write performance
+ requirements of the application. The default strategy is to run a
+ checkpoint once the WAL reaches 1000 pages"
+ """
+ logger.debug("sqlcipher: setting write-ahead logging")
+ db_handle.cursor().execute('PRAGMA journal_mode=WAL')
+
+ # The optimum value can still use a little bit of tuning, but we favor
+ # small sizes of the WAL file to get fast reads, since we assume that
+ # the writes will be quick enough to not block too much.
+
+ db_handle.cursor().execute('PRAGMA wal_autocheckpoint=50')
+
+
+class NotAnHexString(Exception):
+ """
+ Raised when trying to (raw) key the database with a non-hex string.
+ """
+ pass
diff --git a/src/leap/soledad/client/_db/sqlcipher.py b/src/leap/soledad/client/_db/sqlcipher.py
new file mode 100644
index 00000000..d22017bd
--- /dev/null
+++ b/src/leap/soledad/client/_db/sqlcipher.py
@@ -0,0 +1,633 @@
+# -*- coding: utf-8 -*-
+# sqlcipher.py
+# Copyright (C) 2013, 2014 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+A U1DB backend that uses SQLCipher as its persistence layer.
+
+The SQLCipher API (http://sqlcipher.net/sqlcipher-api/) is fully implemented,
+with the exception of the following statements:
+
+ * PRAGMA cipher_use_hmac
+ * PRAGMA cipher_default_use_mac
+
+SQLCipher 2.0 introduced a per-page HMAC to validate that the page data has
+not be tampered with. By default, when creating or opening a database using
+SQLCipher 2, SQLCipher will attempt to use an HMAC check. This change in
+database format means that SQLCipher 2 can't operate on version 1.1.x
+databases by default. Thus, in order to provide backward compatibility with
+SQLCipher 1.1.x, PRAGMA cipher_use_hmac can be used to disable the HMAC
+functionality on specific databases.
+
+In some very specific cases, it is not possible to call PRAGMA cipher_use_hmac
+as one of the first operations on a database. An example of this is when
+trying to ATTACH a 1.1.x database to the main database. In these cases PRAGMA
+cipher_default_use_hmac can be used to globally alter the default use of HMAC
+when opening a database.
+
+So, as the statements above were introduced for backwards compatibility with
+SQLCipher 1.1 databases, we do not implement them as all SQLCipher databases
+handled by Soledad should be created by SQLCipher >= 2.0.
+"""
+import os
+import sys
+
+from functools import partial
+
+from twisted.internet import reactor
+from twisted.internet import defer
+from twisted.enterprise import adbapi
+
+from leap.soledad.common.log import getLogger
+from leap.soledad.common.l2db import errors as u1db_errors
+from leap.soledad.common.errors import DatabaseAccessError
+
+from leap.soledad.client.http_target import SoledadHTTPSyncTarget
+from leap.soledad.client.sync import SoledadSynchronizer
+
+from .._document import Document
+from . import sqlite
+from . import pragmas
+
+if sys.version_info[0] < 3:
+ from pysqlcipher import dbapi2 as sqlcipher_dbapi2
+else:
+ from pysqlcipher3 import dbapi2 as sqlcipher_dbapi2
+
+logger = getLogger(__name__)
+
+
+# Monkey-patch u1db.backends.sqlite with pysqlcipher.dbapi2
+sqlite.dbapi2 = sqlcipher_dbapi2
+
+
+# we may want to collect statistics from the sync process
+DO_STATS = False
+if os.environ.get('SOLEDAD_STATS'):
+ DO_STATS = True
+
+
+def initialize_sqlcipher_db(opts, on_init=None, check_same_thread=True):
+ """
+ Initialize a SQLCipher database.
+
+ :param opts:
+ :type opts: SQLCipherOptions
+ :param on_init: a tuple of queries to be executed on initialization
+ :type on_init: tuple
+ :return: pysqlcipher.dbapi2.Connection
+ """
+ # Note: There seemed to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ # Removing from here now, look at the pysqlite implementation if the
+ # bug shows up in windows.
+
+ if not os.path.isfile(opts.path) and not opts.create:
+ raise u1db_errors.DatabaseDoesNotExist()
+
+ conn = sqlcipher_dbapi2.connect(
+ opts.path, check_same_thread=check_same_thread)
+ pragmas.set_init_pragmas(conn, opts, extra_queries=on_init)
+ return conn
+
+
+def initialize_sqlcipher_adbapi_db(opts, extra_queries=None):
+ from leap.soledad.client import sqlcipher_adbapi
+ return sqlcipher_adbapi.getConnectionPool(
+ opts, extra_queries=extra_queries)
+
+
+class SQLCipherOptions(object):
+ """
+ A container with options for the initialization of an SQLCipher database.
+ """
+
+ @classmethod
+ def copy(cls, source, path=None, key=None, create=None,
+ is_raw_key=None, cipher=None, kdf_iter=None,
+ cipher_page_size=None, sync_db_key=None):
+ """
+ Return a copy of C{source} with parameters different than None
+ replaced by new values.
+ """
+ local_vars = locals()
+ args = []
+ kwargs = {}
+
+ for name in ["path", "key"]:
+ val = local_vars[name]
+ if val is not None:
+ args.append(val)
+ else:
+ args.append(getattr(source, name))
+
+ for name in ["create", "is_raw_key", "cipher", "kdf_iter",
+ "cipher_page_size", "sync_db_key"]:
+ val = local_vars[name]
+ if val is not None:
+ kwargs[name] = val
+ else:
+ kwargs[name] = getattr(source, name)
+
+ return SQLCipherOptions(*args, **kwargs)
+
+ def __init__(self, path, key, create=True, is_raw_key=False,
+ cipher='aes-256-cbc', kdf_iter=4000, cipher_page_size=1024,
+ sync_db_key=None):
+ """
+ :param path: The filesystem path for the database to open.
+ :type path: str
+ :param create:
+ True/False, should the database be created if it doesn't
+ already exist?
+ :param create: bool
+ :param is_raw_key:
+ Whether ``password`` is a raw 64-char hex string or a passphrase
+ that should be hashed to obtain the encyrption key.
+ :type raw_key: bool
+ :param cipher: The cipher and mode to use.
+ :type cipher: str
+ :param kdf_iter: The number of iterations to use.
+ :type kdf_iter: int
+ :param cipher_page_size: The page size.
+ :type cipher_page_size: int
+ """
+ self.path = path
+ self.key = key
+ self.is_raw_key = is_raw_key
+ self.create = create
+ self.cipher = cipher
+ self.kdf_iter = kdf_iter
+ self.cipher_page_size = cipher_page_size
+ self.sync_db_key = sync_db_key
+
+ def __str__(self):
+ """
+ Return string representation of options, for easy debugging.
+
+ :return: String representation of options.
+ :rtype: str
+ """
+ attr_names = filter(lambda a: not a.startswith('_'), dir(self))
+ attr_str = []
+ for a in attr_names:
+ attr_str.append(a + "=" + str(getattr(self, a)))
+ name = self.__class__.__name__
+ return "%s(%s)" % (name, ', '.join(attr_str))
+
+
+#
+# The SQLCipher database
+#
+
+class SQLCipherDatabase(sqlite.SQLitePartialExpandDatabase):
+ """
+ A U1DB implementation that uses SQLCipher as its persistence layer.
+ """
+
+ # The attribute _index_storage_value will be used as the lookup key for the
+ # implementation of the SQLCipher storage backend.
+ _index_storage_value = 'expand referenced encrypted'
+
+ def __init__(self, opts):
+ """
+ Connect to an existing SQLCipher database, creating a new sqlcipher
+ database file if needed.
+
+ *** IMPORTANT ***
+
+ Don't forget to close the database after use by calling the close()
+ method otherwise some resources might not be freed and you may
+ experience several kinds of leakages.
+
+ *** IMPORTANT ***
+
+ :param opts: options for initialization of the SQLCipher database.
+ :type opts: SQLCipherOptions
+ """
+ # ensure the db is encrypted if the file already exists
+ if os.path.isfile(opts.path):
+ self._db_handle = _assert_db_is_encrypted(opts)
+ else:
+ # connect to the sqlcipher database
+ self._db_handle = initialize_sqlcipher_db(opts)
+
+ # TODO ---------------------------------------------------
+ # Everything else in this initialization has to be factored
+ # out, so it can be used from SoledadSQLCipherWrapper.__init__
+ # too.
+ # ---------------------------------------------------------
+
+ self._ensure_schema()
+ self.set_document_factory(doc_factory)
+ self._prime_replica_uid()
+
+ def _prime_replica_uid(self):
+ """
+ In the u1db implementation, _replica_uid is a property
+ that returns the value in _real_replica_uid, and does
+ a db query if no value found.
+ Here we prime the replica uid during initialization so
+ that we don't have to wait for the query afterwards.
+ """
+ self._real_replica_uid = None
+ self._get_replica_uid()
+
+ def _extra_schema_init(self, c):
+ """
+ Add any extra fields, etc to the basic table definitions.
+
+ This method is called by u1db.backends.sqlite_backend._initialize()
+ method, which is executed when the database schema is created. Here,
+ we use it to include the "syncable" property for LeapDocuments.
+
+ :param c: The cursor for querying the database.
+ :type c: dbapi2.cursor
+ """
+ c.execute(
+ 'ALTER TABLE document '
+ 'ADD COLUMN syncable BOOL NOT NULL DEFAULT TRUE')
+
+ #
+ # SQLCipher API methods
+ #
+
+ # Extra query methods: extensions to the base u1db sqlite implmentation.
+
+ def get_count_from_index(self, index_name, *key_values):
+ """
+ Return the count for a given combination of index_name
+ and key values.
+
+ Extension method made from similar methods in u1db version 13.09
+
+ :param index_name: The index to query
+ :type index_name: str
+ :param key_values: values to match. eg, if you have
+ an index with 3 fields then you would have:
+ get_from_index(index_name, val1, val2, val3)
+ :type key_values: tuple
+ :return: count.
+ :rtype: int
+ """
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+
+ if len(key_values) != len(definition):
+ raise u1db_errors.InvalidValueForIndex()
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ exact_where = [novalue_where[i] + (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ where.append(exact_where[idx])
+ args.append(value)
+
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ statement = (
+ "SELECT COUNT(*) FROM document d, %s WHERE %s " % (
+ ', '.join(tables),
+ ' AND '.join(where),
+ ))
+ try:
+ c.execute(statement, tuple(args))
+ except sqlcipher_dbapi2.OperationalError as e:
+ raise sqlcipher_dbapi2.OperationalError(
+ str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ return res[0][0]
+
+ def close(self):
+ """
+ Close db connections.
+ """
+ # TODO should be handled by adbapi instead
+ # TODO syncdb should be stopped first
+
+ if logger is not None: # logger might be none if called from __del__
+ logger.debug("SQLCipher backend: closing")
+
+ # close the actual database
+ if getattr(self, '_db_handle', False):
+ self._db_handle.close()
+ self._db_handle = None
+
+ # indexes
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """
+ Update a document and all indexes related to it.
+
+ :param old_doc: The old version of the document.
+ :type old_doc: u1db.Document
+ :param doc: The new version of the document.
+ :type doc: u1db.Document
+ """
+ sqlite.SQLitePartialExpandDatabase._put_and_update_indexes(
+ self, old_doc, doc)
+ c = self._db_handle.cursor()
+ c.execute('UPDATE document SET syncable=? WHERE doc_id=?',
+ (doc.syncable, doc.doc_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """
+ Get just the document content, without fancy handling.
+
+ :param doc_id: The unique document identifier
+ :type doc_id: str
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise asking for a deleted
+ document will return None.
+ :type include_deleted: bool
+
+ :return: a Document object.
+ :type: u1db.Document
+ """
+ doc = sqlite.SQLitePartialExpandDatabase._get_doc(
+ self, doc_id, check_for_conflicts)
+ if doc:
+ c = self._db_handle.cursor()
+ c.execute('SELECT syncable FROM document WHERE doc_id=?',
+ (doc.doc_id,))
+ result = c.fetchone()
+ doc.syncable = bool(result[0])
+ return doc
+
+ def __del__(self):
+ """
+ Free resources when deleting or garbage collecting the database.
+
+ This is only here to minimze problems if someone ever forgets to call
+ the close() method after using the database; you should not rely on
+ garbage collecting to free up the database resources.
+ """
+ self.close()
+
+
+class SQLCipherU1DBSync(SQLCipherDatabase):
+ """
+ Soledad syncer implementation.
+ """
+
+ """
+ The name of the local symmetrically encrypted documents to
+ sync database file.
+ """
+ LOCAL_SYMMETRIC_SYNC_FILE_NAME = 'sync.u1db'
+
+ """
+ Period or recurrence of the Looping Call that will do the encryption to the
+ syncdb (in seconds).
+ """
+ ENCRYPT_LOOP_PERIOD = 1
+
+ def __init__(self, opts, soledad_crypto, replica_uid, cert_file):
+ self._opts = opts
+ self._path = opts.path
+ self._crypto = soledad_crypto
+ self.__replica_uid = replica_uid
+ self._cert_file = cert_file
+
+ # storage for the documents received during a sync
+ self.received_docs = []
+
+ self.running = False
+ self._db_handle = None
+
+ # initialize the main db before scheduling a start
+ self._initialize_main_db()
+ self._reactor = reactor
+ self._reactor.callWhenRunning(self._start)
+
+ if DO_STATS:
+ self.sync_phase = None
+
+ def commit(self):
+ self._db_handle.commit()
+
+ @property
+ def _replica_uid(self):
+ return str(self.__replica_uid)
+
+ def _start(self):
+ if not self.running:
+ self.running = True
+
+ def _initialize_main_db(self):
+ try:
+ self._db_handle = initialize_sqlcipher_db(
+ self._opts, check_same_thread=False)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self.set_document_factory(doc_factory)
+ except sqlcipher_dbapi2.DatabaseError as e:
+ raise DatabaseAccessError(str(e))
+
+ @defer.inlineCallbacks
+ def sync(self, url, creds=None):
+ """
+ Synchronize documents with remote replica exposed at url.
+
+ It is not safe to initiate more than one sync process and let them run
+ concurrently. It is responsibility of the caller to ensure that there
+ are no concurrent sync processes running. This is currently controlled
+ by the main Soledad object because it may also run post-sync hooks,
+ which should be run while the lock is locked.
+
+ :param url: The url of the target replica to sync with.
+ :type url: str
+ :param creds: optional dictionary giving credentials to authorize the
+ operation with the server.
+ :type creds: dict
+
+ :return:
+ A Deferred, that will fire with the local generation (type `int`)
+ before the synchronisation was performed.
+ :rtype: Deferred
+ """
+ syncer = self._get_syncer(url, creds=creds)
+ if DO_STATS:
+ self.sync_phase = syncer.sync_phase
+ self.syncer = syncer
+ self.sync_exchange_phase = syncer.sync_exchange_phase
+ local_gen_before_sync = yield syncer.sync()
+ self.received_docs = syncer.received_docs
+ defer.returnValue(local_gen_before_sync)
+
+ def _get_syncer(self, url, creds=None):
+ """
+ Get a synchronizer for ``url`` using ``creds``.
+
+ :param url: The url of the target replica to sync with.
+ :type url: str
+ :param creds: optional dictionary giving credentials.
+ to authorize the operation with the server.
+ :type creds: dict
+
+ :return: A synchronizer.
+ :rtype: Synchronizer
+ """
+ return SoledadSynchronizer(
+ self,
+ SoledadHTTPSyncTarget(
+ url,
+ # XXX is the replica_uid ready?
+ self._replica_uid,
+ creds=creds,
+ crypto=self._crypto,
+ cert_file=self._cert_file))
+
+ #
+ # Symmetric encryption of syncing docs
+ #
+
+ def get_generation(self):
+ # FIXME
+ # XXX this SHOULD BE a callback
+ return self._get_generation()
+
+
+class U1DBSQLiteBackend(sqlite.SQLitePartialExpandDatabase):
+ """
+ A very simple wrapper for u1db around sqlcipher backend.
+
+ Instead of initializing the database on the fly, it just uses an existing
+ connection that is passed to it in the initializer.
+
+ It can be used in tests and debug runs to initialize the adbapi with plain
+ sqlite connections, decoupled from the sqlcipher layer.
+ """
+
+ def __init__(self, conn):
+ self._db_handle = conn
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = Document
+
+
+class SoledadSQLCipherWrapper(SQLCipherDatabase):
+ """
+ A wrapper for u1db that uses the Soledad-extended sqlcipher backend.
+
+ Instead of initializing the database on the fly, it just uses an existing
+ connection that is passed to it in the initializer.
+
+ It can be used from adbapi to initialize a soledad database after
+ getting a regular connection to a sqlcipher database.
+ """
+ def __init__(self, conn, opts):
+ self._db_handle = conn
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self.set_document_factory(doc_factory)
+ self._prime_replica_uid()
+
+
+def _assert_db_is_encrypted(opts):
+ """
+ Assert that the sqlcipher file contains an encrypted database.
+
+ When opening an existing database, PRAGMA key will not immediately
+ throw an error if the key provided is incorrect. To test that the
+ database can be successfully opened with the provided key, it is
+ necessary to perform some operation on the database (i.e. read from
+ it) and confirm it is success.
+
+ The easiest way to do this is select off the sqlite_master table,
+ which will attempt to read the first page of the database and will
+ parse the schema.
+
+ :param opts:
+ """
+ # We try to open an encrypted database with the regular u1db
+ # backend should raise a DatabaseError exception.
+ # If the regular backend succeeds, then we need to stop because
+ # the database was not properly initialized.
+ try:
+ sqlite.SQLitePartialExpandDatabase(opts.path)
+ except sqlcipher_dbapi2.DatabaseError:
+ # assert that we can access it using SQLCipher with the given
+ # key
+ dummy_query = ('SELECT count(*) FROM sqlite_master',)
+ return initialize_sqlcipher_db(opts, on_init=dummy_query)
+ else:
+ raise DatabaseIsNotEncrypted()
+
+#
+# Exceptions
+#
+
+
+class DatabaseIsNotEncrypted(Exception):
+ """
+ Exception raised when trying to open non-encrypted databases.
+ """
+ pass
+
+
+def doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False,
+ syncable=True):
+ """
+ Return a default Soledad Document.
+ Used in the initialization for SQLCipherDatabase
+ """
+ return Document(doc_id=doc_id, rev=rev, json=json,
+ has_conflicts=has_conflicts, syncable=syncable)
+
+
+sqlite.SQLiteDatabase.register_implementation(SQLCipherDatabase)
+
+
+#
+# twisted.enterprise.adbapi SQLCipher implementation
+#
+
+SQLCIPHER_CONNECTION_TIMEOUT = 10
+
+
+def getConnectionPool(opts, extra_queries=None):
+ openfun = partial(
+ pragmas.set_init_pragmas,
+ opts=opts,
+ extra_queries=extra_queries)
+ return SQLCipherConnectionPool(
+ database=opts.path,
+ check_same_thread=False,
+ cp_openfun=openfun,
+ timeout=SQLCIPHER_CONNECTION_TIMEOUT)
+
+
+class SQLCipherConnection(adbapi.Connection):
+ pass
+
+
+class SQLCipherTransaction(adbapi.Transaction):
+ pass
+
+
+class SQLCipherConnectionPool(adbapi.ConnectionPool):
+
+ connectionFactory = SQLCipherConnection
+ transactionFactory = SQLCipherTransaction
+
+ def __init__(self, *args, **kwargs):
+ adbapi.ConnectionPool.__init__(
+ self, "pysqlcipher.dbapi2", *args, **kwargs)
diff --git a/src/leap/soledad/client/_db/sqlite.py b/src/leap/soledad/client/_db/sqlite.py
new file mode 100644
index 00000000..4f7b1259
--- /dev/null
+++ b/src/leap/soledad/client/_db/sqlite.py
@@ -0,0 +1,930 @@
+# Copyright 2011 Canonical Ltd.
+# Copyright 2016 LEAP Encryption Access Project
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+A L2DB implementation that uses SQLite as its persistence layer.
+"""
+
+import errno
+import os
+import json
+import sys
+import time
+import uuid
+import pkg_resources
+
+from sqlite3 import dbapi2
+
+from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget
+from leap.soledad.common.l2db import (
+ Document, errors,
+ query_parser, vectorclock)
+
+
+class SQLiteDatabase(CommonBackend):
+ """A U1DB implementation that uses SQLite as its persistence layer."""
+
+ _sqlite_registry = {}
+
+ def __init__(self, sqlite_file, document_factory=None):
+ """Create a new sqlite file."""
+ self._db_handle = dbapi2.connect(sqlite_file)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def get_sync_target(self):
+ return SQLiteSyncTarget(self)
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'index_storage'")
+ except dbapi2.OperationalError as e:
+ # The table does not exist yet
+ return None, e
+ else:
+ return c.fetchone()[0], None
+
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
+
+ @classmethod
+ def _open_database(cls, sqlite_file, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLiteDatabase._sqlite_registry[v](
+ sqlite_file, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(
+ sqlite_file, document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLitePartialExpandDatabase
+ backend_cls = SQLitePartialExpandDatabase
+ return backend_cls(sqlite_file, document_factory=document_factory)
+
+ @staticmethod
+ def delete_database(sqlite_file):
+ try:
+ os.unlink(sqlite_file)
+ except OSError as ex:
+ if ex.errno == errno.ENOENT:
+ raise errors.DatabaseDoesNotExist()
+ raise
+
+ @staticmethod
+ def register_implementation(klass):
+ """Register that we implement an SQLiteDatabase.
+
+ The attribute _index_storage_value will be used as the lookup key.
+ """
+ SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
+
+ def _get_sqlite_handle(self):
+ """Get access to the underlying sqlite database.
+
+ This should only be used by the test suite, etc, for examining the
+ state of the underlying database.
+ """
+ return self._db_handle
+
+ def _close_sqlite_handle(self):
+ """Release access to the underlying sqlite database."""
+ self._db_handle.close()
+
+ def close(self):
+ self._close_sqlite_handle()
+
+ def _is_initialized(self, c):
+ """Check if this database has been initialized."""
+ c.execute("PRAGMA case_sensitive_like=ON")
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'sql_schema'")
+ except dbapi2.OperationalError:
+ # The table does not exist yet
+ val = None
+ else:
+ val = c.fetchone()
+ if val is not None:
+ return True
+ return False
+
+ def _initialize(self, c):
+ """Create the schema in the database."""
+ # read the script with sql commands
+ # TODO: Change how we set up the dependency. Most likely use something
+ # like lp:dirspec to grab the file from a common resource
+ # directory. Doesn't specifically need to be handled until we get
+ # to the point of packaging this.
+ schema_content = pkg_resources.resource_string(
+ __name__, 'dbschema.sql')
+ # Note: We'd like to use c.executescript() here, but it seems that
+ # executescript always commits, even if you set
+ # isolation_level = None, so if we want to properly handle
+ # exclusive locking and rollbacks between processes, we need
+ # to execute it line-by-line
+ for line in schema_content.split(';'):
+ if not line:
+ continue
+ c.execute(line)
+ # add extra fields
+ self._extra_schema_init(c)
+ # A unique identifier should be set for this replica. Implementations
+ # don't have to strictly use uuid here, but we do want the uid to be
+ # unique amongst all databases that will sync with each other.
+ # We might extend this to using something with hostname for easier
+ # debugging.
+ self._set_replica_uid_in_transaction(uuid.uuid4().hex)
+ c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
+ (self._index_storage_value,))
+
+ def _ensure_schema(self):
+ """Ensure that the database schema has been created."""
+ old_isolation_level = self._db_handle.isolation_level
+ c = self._db_handle.cursor()
+ if self._is_initialized(c):
+ return
+ try:
+ # autocommit/own mgmt of transactions
+ self._db_handle.isolation_level = None
+ with self._db_handle:
+ # only one execution path should initialize the db
+ c.execute("begin exclusive")
+ if self._is_initialized(c):
+ return
+ self._initialize(c)
+ finally:
+ self._db_handle.isolation_level = old_isolation_level
+
+ def _extra_schema_init(self, c):
+ """Add any extra fields, etc to the basic table definitions."""
+
+ def _parse_index_definition(self, index_field):
+ """Parse a field definition for an index, returning a Getter."""
+ # Note: We may want to keep a Parser object around, and cache the
+ # Getter objects for a greater length of time. Specifically, if
+ # you create a bunch of indexes, and then insert 50k docs, you'll
+ # re-parse the indexes between puts. The time to insert the docs
+ # is still likely to dominate put_doc time, though.
+ parser = query_parser.Parser()
+ getter = parser.parse(index_field)
+ return getter
+
+ def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
+ """Update document_fields for a single document.
+
+ :param doc_id: Identifier for this document
+ :param raw_doc: The python dict representation of the document.
+ :param getters: A list of [(field_name, Getter)]. Getter.get will be
+ called to evaluate the index definition for this document, and the
+ results will be inserted into the db.
+ :param db_cursor: An sqlite Cursor.
+ :return: None
+ """
+ values = []
+ for field_name, getter in getters:
+ for idx_value in getter.get(raw_doc):
+ values.append((doc_id, field_name, idx_value))
+ if values:
+ db_cursor.executemany(
+ "INSERT INTO document_fields VALUES (?, ?, ?)", values)
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ with self._db_handle:
+ self._set_replica_uid_in_transaction(replica_uid)
+
+ def _set_replica_uid_in_transaction(self, replica_uid):
+ """Set the replica_uid. A transaction should already be held."""
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO u1db_config"
+ " VALUES ('replica_uid', ?)",
+ (replica_uid,))
+ self._real_replica_uid = replica_uid
+
+ def _get_replica_uid(self):
+ if self._real_replica_uid is not None:
+ return self._real_replica_uid
+ c = self._db_handle.cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
+ val = c.fetchone()
+ if val is None:
+ return None
+ self._real_replica_uid = val[0]
+ return self._real_replica_uid
+
+ _replica_uid = property(_get_replica_uid)
+
+ def _get_generation(self):
+ c = self._db_handle.cursor()
+ c.execute('SELECT max(generation) FROM transaction_log')
+ val = c.fetchone()[0]
+ if val is None:
+ return 0
+ return val
+
+ def _get_generation_info(self):
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT max(generation), transaction_id FROM transaction_log ')
+ val = c.fetchone()
+ if val[0] is None:
+ return(0, '')
+ return val
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
+ (generation,))
+ val = c.fetchone()
+ if val is None:
+ raise errors.InvalidGeneration
+ return val[0]
+
+ def _get_transaction_log(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, transaction_id FROM transaction_log"
+ " ORDER BY generation")
+ return c.fetchall()
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ c = self._db_handle.cursor()
+ if check_for_conflicts:
+ c.execute(
+ "SELECT document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
+ "conflicts ON conflicts.doc_id = document.doc_id WHERE "
+ "document.doc_id = ? GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;", (doc_id,))
+ else:
+ c.execute(
+ "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return None
+ doc_rev, content, conflicts = val
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return False
+ else:
+ return True
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ c = self._db_handle.cursor()
+ c.execute(
+ "SELECT document.doc_id, document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
+ "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;")
+ rows = c.fetchall()
+ for doc_id, doc_rev, content, conflicts in rows:
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ results.append(doc)
+ return (generation, results)
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
+ """Convert a dict representation into named fields.
+
+ So something like: {'key1': 'val1', 'key2': 'val2'}
+ gets converted into: [(doc_id, 'key1', 'val1', 0)
+ (doc_id, 'key2', 'val2', 0)]
+ :param doc_id: Just added to every record.
+ :param base_field: if set, these are nested keys, so each field should
+ be appropriately prefixed.
+ :param raw_doc: The python dictionary.
+ """
+ # TODO: Handle lists
+ values = []
+ for field_name, value in raw_doc.iteritems():
+ if value is None and not save_none:
+ continue
+ if base_field:
+ full_name = base_field + '.' + field_name
+ else:
+ full_name = field_name
+ if value is None or isinstance(value, (int, float, basestring)):
+ values.append((doc_id, full_name, value, len(values)))
+ else:
+ subvalues = self._expand_to_fields(doc_id, full_name, value,
+ save_none)
+ for _, subfield_name, val, _ in subvalues:
+ values.append((doc_id, subfield_name, val, len(values)))
+ return values
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """Actually insert a document into the database.
+
+ This both updates the existing documents content, and any indexes that
+ refer to this document.
+ """
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def whats_changed(self, old_generation=0):
+ c = self._db_handle.cursor()
+ c.execute("SELECT generation, doc_id, transaction_id"
+ " FROM transaction_log"
+ " WHERE generation > ? ORDER BY generation DESC",
+ (old_generation,))
+ results = c.fetchall()
+ cur_gen = old_generation
+ seen = set()
+ changes = []
+ newest_trans_id = ''
+ for generation, doc_id, trans_id in results:
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ if changes:
+ cur_gen = changes[0][1] # max generation
+ newest_trans_id = changes[0][2]
+ changes.reverse()
+ else:
+ c.execute("SELECT generation, transaction_id"
+ " FROM transaction_log ORDER BY generation DESC LIMIT 1")
+ results = c.fetchone()
+ if not results:
+ cur_gen = 0
+ newest_trans_id = ''
+ else:
+ cur_gen, newest_trans_id = results
+
+ return cur_gen, newest_trans_id, changes
+
+ def delete_doc(self, doc):
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _get_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
+ (doc_id,))
+ return [self._factory(doc_id, doc_rev, content)
+ for doc_rev, content in c.fetchall()]
+
+ def get_doc_conflicts(self, doc_id):
+ with self._db_handle:
+ conflict_docs = self._get_conflicts(doc_id)
+ if not conflict_docs:
+ return []
+ this_doc = self._get_doc(doc_id)
+ this_doc.has_conflicts = True
+ return [this_doc] + conflict_docs
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ c = self._db_handle.cursor()
+ c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
+ " WHERE replica_uid = ?",
+ (other_replica_uid,))
+ val = c.fetchone()
+ if val is None:
+ other_gen = 0
+ trans_id = ''
+ else:
+ other_gen = val[0]
+ trans_id = val[1]
+ return other_gen, trans_id
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ with self._db_handle:
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
+ (other_replica_uid, other_generation,
+ other_transaction_id))
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ return super(SQLiteDatabase, self)._put_doc_if_newer(
+ doc,
+ save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+
+ def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
+ c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
+ (doc_id, my_doc_rev, my_content))
+
+ def _delete_conflicts(self, c, doc, conflict_revs):
+ deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
+ c.executemany("DELETE FROM conflicts"
+ " WHERE doc_id=? AND doc_rev=?", deleting)
+ doc.has_conflicts = self._has_conflicts(doc.doc_id)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ c_revs_to_prune = []
+ for c_doc in self._get_conflicts(doc.doc_id):
+ c_vcr = vectorclock.VectorClockRev(c_doc.rev)
+ if doc_vcr.is_newer(c_vcr):
+ c_revs_to_prune.append(c_doc.rev)
+ elif doc.same_content_as(c_doc):
+ c_revs_to_prune.append(c_doc.rev)
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ c = self._db_handle.cursor()
+ self._delete_conflicts(c, doc, c_revs_to_prune)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ c = self._db_handle.cursor()
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ with self._db_handle:
+ cur_doc = self._get_doc(doc.doc_id)
+ # TODO: https://bugs.launchpad.net/u1db/+bug/928274
+ # I think we have a logic bug in resolve_doc
+ # Specifically, cur_doc.rev is always in the final vector
+ # clock of revisions that we supersede, even if it wasn't in
+ # conflicted_doc_revs. We still add it as a conflict, but the
+ # fact that _put_doc_if_newer propagates resolutions means I
+ # think that conflict could accidentally be resolved. We need
+ # to add a test for this case first. (create a rev, create a
+ # conflict, create another conflict, resolve the first rev
+ # and first conflict, then make sure that the resolved
+ # rev doesn't supersede the second conflict rev.) It *might*
+ # not matter, because the superseding rev is in as a
+ # conflict, but it does seem incorrect
+ new_rev = self._ensure_maximal_rev(cur_doc.rev,
+ conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ c = self._db_handle.cursor()
+ doc.rev = new_rev
+ if cur_doc.rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
+ # TODO: Is there some way that we could construct a rev that would
+ # end up in superseded_revs, such that we add a conflict, and
+ # then immediately delete it?
+ self._delete_conflicts(c, doc, superseded_revs)
+
+ def list_indexes(self):
+ """Return the list of indexes and their definitions."""
+ c = self._db_handle.cursor()
+ # TODO: How do we test the ordering?
+ c.execute("SELECT name, field FROM index_definitions"
+ " ORDER BY name, offset")
+ definitions = []
+ cur_name = None
+ for name, field in c.fetchall():
+ if cur_name != name:
+ definitions.append((name, []))
+ cur_name = name
+ definitions[-1][-1].append(field)
+ return definitions
+
+ def _get_index_definition(self, index_name):
+ """Return the stored definition for a given index_name."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions"
+ " WHERE name = ? ORDER BY offset", (index_name,))
+ fields = [x[0] for x in c.fetchall()]
+ if not fields:
+ raise errors.IndexDoesNotExist
+ return fields
+
+ @staticmethod
+ def _strip_glob(value):
+ """Remove the trailing * from a value."""
+ assert value[-1] == '*'
+ return value[:-1]
+
+ def _format_query(self, definition, key_values):
+ # First, build the definition. We join the document_fields table
+ # against itself, as many times as the 'width' of our definition.
+ # We then do a query for each key_value, one-at-a-time.
+ # Note: All of these strings are static, we could cache them, etc.
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ wildcard_where = [novalue_where[i] +
+ (" AND d%d.value NOT NULL" % (i,))
+ for i in range(len(definition))]
+ exact_where = [novalue_where[i] +
+ (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ like_where = [novalue_where[i] +
+ (" AND d%d.value GLOB ?" % (i,))
+ for i in range(len(definition))]
+ is_wildcard = False
+ # Merge the lists together, so that:
+ # [field1, field2, field3], [val1, val2, val3]
+ # Becomes:
+ # (field1, val1, field2, val2, field3, val3)
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(exact_where[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_from_index(self, index_name, *key_values):
+ definition = self._get_index_definition(index_name)
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ statement, args = self._format_query(definition, key_values)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError as e:
+ raise dbapi2.OperationalError(
+ str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def _format_range_query(self, definition, start_value, end_value):
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ wildcard_where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ like_where = [
+ novalue_where[i] + (
+ " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
+ range(len(definition))]
+ range_where_lower = [
+ novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
+ range(len(definition))]
+ range_where_upper = [
+ novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
+ range(len(definition))]
+ args = []
+ where = []
+ if start_value:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if len(start_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, start_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(self._strip_glob(value))
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(value)
+ if end_value:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ if len(end_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, end_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(self._strip_glob(value))
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_upper[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ definition = self._get_index_definition(index_name)
+ statement, args = self._format_range_query(
+ definition, start_value, end_value)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError as e:
+ raise dbapi2.OperationalError(
+ str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def get_index_keys(self, index_name):
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+ value_fields = ', '.join([
+ 'd%d.value' % i for i in range(len(definition))])
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ statement = (
+ "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
+ value_fields, ', '.join(tables), ' AND '.join(where),
+ value_fields))
+ try:
+ c.execute(statement, tuple(definition))
+ except dbapi2.OperationalError as e:
+ raise dbapi2.OperationalError(
+ str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
+ return c.fetchall()
+
+ def delete_index(self, index_name):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ c.execute("DELETE FROM index_definitions WHERE name = ?",
+ (index_name,))
+ c.execute(
+ "DELETE FROM document_fields WHERE document_fields.field_name "
+ " NOT IN (SELECT field from index_definitions)")
+
+
+class SQLiteSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
+
+
+class SQLitePartialExpandDatabase(SQLiteDatabase):
+ """An SQLite Backend that expands documents into a document_field table.
+
+ It stores the original document text in document.doc. For fields that are
+ indexed, the data goes into document_fields.
+ """
+
+ _index_storage_value = 'expand referenced'
+
+ def _get_indexed_fields(self):
+ """Determine what fields are indexed."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions")
+ return set([x[0] for x in c.fetchall()])
+
+ def _evaluate_index(self, raw_doc, field):
+ parser = query_parser.Parser()
+ getter = parser.parse(field)
+ return getter.get(raw_doc)
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ c = self._db_handle.cursor()
+ if doc and not doc.is_tombstone():
+ raw_doc = json.loads(doc.get_json())
+ else:
+ raw_doc = {}
+ if old_doc is not None:
+ c.execute("UPDATE document SET doc_rev=?, content=?"
+ " WHERE doc_id = ?",
+ (doc.rev, doc.get_json(), doc.doc_id))
+ c.execute("DELETE FROM document_fields WHERE doc_id = ?",
+ (doc.doc_id,))
+ else:
+ c.execute("INSERT INTO document (doc_id, doc_rev, content)"
+ " VALUES (?, ?, ?)",
+ (doc.doc_id, doc.rev, doc.get_json()))
+ indexed_fields = self._get_indexed_fields()
+ if indexed_fields:
+ # It is expected that len(indexed_fields) is shorter than
+ # len(raw_doc)
+ getters = [(field, self._parse_index_definition(field))
+ for field in indexed_fields]
+ self._update_indexes(doc.doc_id, raw_doc, getters, c)
+ trans_id = self._allocate_transaction_id()
+ c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
+ " VALUES (?, ?)", (doc.doc_id, trans_id))
+
+ def create_index(self, index_name, *index_expressions):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ cur_fields = self._get_indexed_fields()
+ definition = [(index_name, idx, field)
+ for idx, field in enumerate(index_expressions)]
+ try:
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ definition)
+ except dbapi2.IntegrityError as e:
+ stored_def = self._get_index_definition(index_name)
+ if stored_def == [x[-1] for x in definition]:
+ return
+ raise errors.IndexNameTakenError(
+ str(e) +
+ str(sys.exc_info()[2])
+ )
+ new_fields = set(
+ [f for f in index_expressions if f not in cur_fields])
+ if new_fields:
+ self._update_all_indexes(new_fields)
+
+ def _iter_all_docs(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, content FROM document")
+ while True:
+ next_rows = c.fetchmany()
+ if not next_rows:
+ break
+ for row in next_rows:
+ yield row
+
+ def _update_all_indexes(self, new_fields):
+ """Iterate all the documents, and add content to document_fields.
+
+ :param new_fields: The index definitions that need to be added.
+ """
+ getters = [(field, self._parse_index_definition(field))
+ for field in new_fields]
+ c = self._db_handle.cursor()
+ for doc_id, doc in self._iter_all_docs():
+ if doc is None:
+ continue
+ raw_doc = json.loads(doc)
+ self._update_indexes(doc_id, raw_doc, getters, c)
+
+
+SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)