summaryrefslogtreecommitdiff
path: root/client/src/leap
diff options
context:
space:
mode:
Diffstat (limited to 'client/src/leap')
-rw-r--r--client/src/leap/soledad/client/__init__.py6
-rw-r--r--client/src/leap/soledad/client/_crypto.py6
-rw-r--r--client/src/leap/soledad/client/_database/__init__.py0
-rw-r--r--client/src/leap/soledad/client/_database/adbapi.py (renamed from client/src/leap/soledad/client/adbapi.py)9
-rw-r--r--client/src/leap/soledad/client/_database/blobs.py (renamed from client/src/leap/soledad/client/_blobs.py)57
-rw-r--r--client/src/leap/soledad/client/_database/dbschema.sql42
-rw-r--r--client/src/leap/soledad/client/_database/pragmas.py (renamed from client/src/leap/soledad/client/pragmas.py)0
-rw-r--r--client/src/leap/soledad/client/_database/sqlcipher.py (renamed from client/src/leap/soledad/client/sqlcipher.py)38
-rw-r--r--client/src/leap/soledad/client/_database/sqlite.py930
-rw-r--r--client/src/leap/soledad/client/_document.py253
-rw-r--r--client/src/leap/soledad/client/_secrets/storage.py4
-rw-r--r--client/src/leap/soledad/client/api.py36
-rw-r--r--client/src/leap/soledad/client/crypto.py8
-rw-r--r--client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py2
-rw-r--r--client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py2
-rw-r--r--client/src/leap/soledad/client/examples/use_adbapi.py2
-rw-r--r--client/src/leap/soledad/client/http_target/fetch.py4
-rw-r--r--client/src/leap/soledad/client/interfaces.py12
18 files changed, 1330 insertions, 81 deletions
diff --git a/client/src/leap/soledad/client/__init__.py b/client/src/leap/soledad/client/__init__.py
index 3a114021..bcad78db 100644
--- a/client/src/leap/soledad/client/__init__.py
+++ b/client/src/leap/soledad/client/__init__.py
@@ -17,12 +17,14 @@
"""
Soledad - Synchronization Of Locally Encrypted Data Among Devices.
"""
-from leap.soledad.client.api import Soledad
from leap.soledad.common import soledad_assert
+from .api import Soledad
+from ._document import Document, AttachmentStates
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
-__all__ = ['soledad_assert', 'Soledad', '__version__']
+__all__ = ['soledad_assert', 'Soledad', 'Document', 'AttachmentStates',
+ '__version__']
diff --git a/client/src/leap/soledad/client/_crypto.py b/client/src/leap/soledad/client/_crypto.py
index 7a1d508e..e66cc600 100644
--- a/client/src/leap/soledad/client/_crypto.py
+++ b/client/src/leap/soledad/client/_crypto.py
@@ -135,7 +135,7 @@ class SoledadCrypto(object):
and wrapping the result as a simple JSON string with a "raw" key.
:param doc: the document to be encrypted.
- :type doc: SoledadDocument
+ :type doc: Document
:return: A deferred whose callback will be invoked with a JSON string
containing the ciphertext as the value of "raw" key.
:rtype: twisted.internet.defer.Deferred
@@ -159,7 +159,7 @@ class SoledadCrypto(object):
the decrypted cleartext content from the encrypted document.
:param doc: the document to be decrypted.
- :type doc: SoledadDocument
+ :type doc: Document
:return: The decrypted cleartext content of the document.
:rtype: str
"""
@@ -225,7 +225,7 @@ def decrypt_sym(data, key, iv, method=ENC_METHOD.aes_256_gcm):
class BlobEncryptor(object):
"""
Produces encrypted data from the cleartext data associated with a given
- SoledadDocument using AES-256 cipher in GCM mode.
+ Document using AES-256 cipher in GCM mode.
The production happens using a Twisted's FileBodyProducer, which uses a
Cooperator to schedule calls and can be paused/resumed. Each call takes at
diff --git a/client/src/leap/soledad/client/_database/__init__.py b/client/src/leap/soledad/client/_database/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/client/src/leap/soledad/client/_database/__init__.py
diff --git a/client/src/leap/soledad/client/adbapi.py b/client/src/leap/soledad/client/_database/adbapi.py
index b002055e..5c28d108 100644
--- a/client/src/leap/soledad/client/adbapi.py
+++ b/client/src/leap/soledad/client/_database/adbapi.py
@@ -30,8 +30,9 @@ from zope.proxy import ProxyBase, setProxiedObject
from leap.soledad.common.log import getLogger
from leap.soledad.common.errors import DatabaseAccessError
-from leap.soledad.client import sqlcipher as soledad_sqlcipher
-from leap.soledad.client.pragmas import set_init_pragmas
+
+from . import sqlcipher
+from . import pragmas
if sys.version_info[0] < 3:
from pysqlcipher import dbapi2
@@ -73,7 +74,7 @@ def getConnectionPool(opts, openfun=None, driver="pysqlcipher"):
:rtype: U1DBConnectionPool
"""
if openfun is None and driver == "pysqlcipher":
- openfun = partial(set_init_pragmas, opts=opts)
+ openfun = partial(pragmas.set_init_pragmas, opts=opts)
return U1DBConnectionPool(
opts,
# the following params are relayed "as is" to twisted's
@@ -87,7 +88,7 @@ class U1DBConnection(adbapi.Connection):
A wrapper for a U1DB connection instance.
"""
- u1db_wrapper = soledad_sqlcipher.SoledadSQLCipherWrapper
+ u1db_wrapper = sqlcipher.SoledadSQLCipherWrapper
"""
The U1DB wrapper to use.
"""
diff --git a/client/src/leap/soledad/client/_blobs.py b/client/src/leap/soledad/client/_database/blobs.py
index 1475b302..79404bf3 100644
--- a/client/src/leap/soledad/client/_blobs.py
+++ b/client/src/leap/soledad/client/_database/blobs.py
@@ -20,8 +20,9 @@ Clientside BlobBackend Storage.
from urlparse import urljoin
+import binascii
+import errno
import os
-import uuid
import base64
from io import BytesIO
@@ -34,13 +35,18 @@ from twisted.web.client import FileBodyProducer
import treq
-from leap.soledad.client.sqlcipher import SQLCipherOptions
-from leap.soledad.client import pragmas
-from leap.soledad.client._pipes import TruncatedTailPipe, PreamblePipe
from leap.soledad.common.errors import SoledadError
-from _crypto import DocInfo, BlobEncryptor, BlobDecryptor
-from _http import HTTPClient
+from .._document import BlobDoc
+from .._crypto import DocInfo
+from .._crypto import BlobEncryptor
+from .._crypto import BlobDecryptor
+from .._http import HTTPClient
+from .._pipes import TruncatedTailPipe
+from .._pipes import PreamblePipe
+
+from . import pragmas
+from . import sqlcipher
logger = Logger()
@@ -129,6 +135,16 @@ class DecrypterBuffer(object):
return self.decrypter._end_stream(), real_size
+def mkdir_p(path):
+ try:
+ os.makedirs(path)
+ except OSError as exc:
+ if exc.errno == errno.EEXIST and os.path.isdir(path):
+ pass
+ else:
+ raise
+
+
class BlobManager(object):
"""
Ideally, the decrypting flow goes like this:
@@ -149,6 +165,7 @@ class BlobManager(object):
self, local_path, remote, key, secret, user, token=None,
cert_file=None):
if local_path:
+ mkdir_p(os.path.dirname(local_path))
self.local = SQLiteBlobBackend(local_path, key)
self.remote = remote
self.secret = secret
@@ -280,10 +297,13 @@ class SQLiteBlobBackend(object):
def __init__(self, path, key=None):
self.path = os.path.abspath(
os.path.join(path, 'soledad_blob.db'))
+ mkdir_p(os.path.dirname(self.path))
if not key:
raise ValueError('key cannot be None')
backend = 'pysqlcipher.dbapi2'
- opts = SQLCipherOptions('/tmp/ignored', key)
+ opts = sqlcipher.SQLCipherOptions(
+ '/tmp/ignored', binascii.b2a_hex(key),
+ is_raw_key=True, create=True)
pragmafun = partial(pragmas.set_init_pragmas, opts=opts)
openfun = _sqlcipherInitFactory(pragmafun)
@@ -292,7 +312,11 @@ class SQLiteBlobBackend(object):
cp_openfun=openfun, cp_min=1, cp_max=2, cp_name='blob_pool')
def close(self):
- return self.dbpool.close()
+ from twisted._threads import AlreadyQuit
+ try:
+ self.dbpool.close()
+ except AlreadyQuit:
+ pass
@defer.inlineCallbacks
def put(self, blob_id, blob_fd, size=None):
@@ -352,20 +376,6 @@ def _sqlcipherInitFactory(fun):
return _initialize
-class BlobDoc(object):
-
- # TODO probably not needed, but convenient for testing for now.
-
- def __init__(self, content, blob_id):
-
- self.blob_id = blob_id
- self.is_blob = True
- self.blob_fd = content
- if blob_id is None:
- blob_id = uuid.uuid4().get_hex()
- self.blob_id = blob_id
-
-
#
# testing facilities
#
@@ -431,8 +441,7 @@ def testit(reactor):
# TODO convert these into proper unittests
def _manager():
- if not os.path.isdir(args.path):
- os.makedirs(args.path)
+ mkdir_p(os.path.dirname(args.path))
manager = BlobManager(
args.path, args.url,
'A' * 32, args.secret,
diff --git a/client/src/leap/soledad/client/_database/dbschema.sql b/client/src/leap/soledad/client/_database/dbschema.sql
new file mode 100644
index 00000000..ae027fc5
--- /dev/null
+++ b/client/src/leap/soledad/client/_database/dbschema.sql
@@ -0,0 +1,42 @@
+-- Database schema
+CREATE TABLE transaction_log (
+ generation INTEGER PRIMARY KEY AUTOINCREMENT,
+ doc_id TEXT NOT NULL,
+ transaction_id TEXT NOT NULL
+);
+CREATE TABLE document (
+ doc_id TEXT PRIMARY KEY,
+ doc_rev TEXT NOT NULL,
+ content TEXT
+);
+CREATE TABLE document_fields (
+ doc_id TEXT NOT NULL,
+ field_name TEXT NOT NULL,
+ value TEXT
+);
+CREATE INDEX document_fields_field_value_doc_idx
+ ON document_fields(field_name, value, doc_id);
+
+CREATE TABLE sync_log (
+ replica_uid TEXT PRIMARY KEY,
+ known_generation INTEGER,
+ known_transaction_id TEXT
+);
+CREATE TABLE conflicts (
+ doc_id TEXT,
+ doc_rev TEXT,
+ content TEXT,
+ CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
+);
+CREATE TABLE index_definitions (
+ name TEXT,
+ offset INT,
+ field TEXT,
+ CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
+);
+create index index_definitions_field on index_definitions(field);
+CREATE TABLE u1db_config (
+ name TEXT PRIMARY KEY,
+ value TEXT
+);
+INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/client/src/leap/soledad/client/pragmas.py b/client/src/leap/soledad/client/_database/pragmas.py
index 870ed63e..870ed63e 100644
--- a/client/src/leap/soledad/client/pragmas.py
+++ b/client/src/leap/soledad/client/_database/pragmas.py
diff --git a/client/src/leap/soledad/client/sqlcipher.py b/client/src/leap/soledad/client/_database/sqlcipher.py
index 2c995d5a..d22017bd 100644
--- a/client/src/leap/soledad/client/sqlcipher.py
+++ b/client/src/leap/soledad/client/_database/sqlcipher.py
@@ -50,16 +50,16 @@ from twisted.internet import reactor
from twisted.internet import defer
from twisted.enterprise import adbapi
-from leap.soledad.common.document import SoledadDocument
from leap.soledad.common.log import getLogger
from leap.soledad.common.l2db import errors as u1db_errors
-from leap.soledad.common.l2db import Document
-from leap.soledad.common.l2db.backends import sqlite_backend
from leap.soledad.common.errors import DatabaseAccessError
from leap.soledad.client.http_target import SoledadHTTPSyncTarget
from leap.soledad.client.sync import SoledadSynchronizer
-from leap.soledad.client import pragmas
+
+from .._document import Document
+from . import sqlite
+from . import pragmas
if sys.version_info[0] < 3:
from pysqlcipher import dbapi2 as sqlcipher_dbapi2
@@ -69,8 +69,8 @@ else:
logger = getLogger(__name__)
-# Monkey-patch u1db.backends.sqlite_backend with pysqlcipher.dbapi2
-sqlite_backend.dbapi2 = sqlcipher_dbapi2
+# Monkey-patch u1db.backends.sqlite with pysqlcipher.dbapi2
+sqlite.dbapi2 = sqlcipher_dbapi2
# we may want to collect statistics from the sync process
@@ -193,7 +193,7 @@ class SQLCipherOptions(object):
# The SQLCipher database
#
-class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase):
+class SQLCipherDatabase(sqlite.SQLitePartialExpandDatabase):
"""
A U1DB implementation that uses SQLCipher as its persistence layer.
"""
@@ -232,7 +232,7 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase):
# ---------------------------------------------------------
self._ensure_schema()
- self.set_document_factory(soledad_doc_factory)
+ self.set_document_factory(doc_factory)
self._prime_replica_uid()
def _prime_replica_uid(self):
@@ -341,7 +341,7 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase):
:param doc: The new version of the document.
:type doc: u1db.Document
"""
- sqlite_backend.SQLitePartialExpandDatabase._put_and_update_indexes(
+ sqlite.SQLitePartialExpandDatabase._put_and_update_indexes(
self, old_doc, doc)
c = self._db_handle.cursor()
c.execute('UPDATE document SET syncable=? WHERE doc_id=?',
@@ -361,7 +361,7 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase):
:return: a Document object.
:type: u1db.Document
"""
- doc = sqlite_backend.SQLitePartialExpandDatabase._get_doc(
+ doc = sqlite.SQLitePartialExpandDatabase._get_doc(
self, doc_id, check_for_conflicts)
if doc:
c = self._db_handle.cursor()
@@ -437,7 +437,7 @@ class SQLCipherU1DBSync(SQLCipherDatabase):
self._opts, check_same_thread=False)
self._real_replica_uid = None
self._ensure_schema()
- self.set_document_factory(soledad_doc_factory)
+ self.set_document_factory(doc_factory)
except sqlcipher_dbapi2.DatabaseError as e:
raise DatabaseAccessError(str(e))
@@ -505,7 +505,7 @@ class SQLCipherU1DBSync(SQLCipherDatabase):
return self._get_generation()
-class U1DBSQLiteBackend(sqlite_backend.SQLitePartialExpandDatabase):
+class U1DBSQLiteBackend(sqlite.SQLitePartialExpandDatabase):
"""
A very simple wrapper for u1db around sqlcipher backend.
@@ -537,7 +537,7 @@ class SoledadSQLCipherWrapper(SQLCipherDatabase):
self._db_handle = conn
self._real_replica_uid = None
self._ensure_schema()
- self.set_document_factory(soledad_doc_factory)
+ self.set_document_factory(doc_factory)
self._prime_replica_uid()
@@ -562,7 +562,7 @@ def _assert_db_is_encrypted(opts):
# If the regular backend succeeds, then we need to stop because
# the database was not properly initialized.
try:
- sqlite_backend.SQLitePartialExpandDatabase(opts.path)
+ sqlite.SQLitePartialExpandDatabase(opts.path)
except sqlcipher_dbapi2.DatabaseError:
# assert that we can access it using SQLCipher with the given
# key
@@ -583,17 +583,17 @@ class DatabaseIsNotEncrypted(Exception):
pass
-def soledad_doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False,
- syncable=True):
+def doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False,
+ syncable=True):
"""
Return a default Soledad Document.
Used in the initialization for SQLCipherDatabase
"""
- return SoledadDocument(doc_id=doc_id, rev=rev, json=json,
- has_conflicts=has_conflicts, syncable=syncable)
+ return Document(doc_id=doc_id, rev=rev, json=json,
+ has_conflicts=has_conflicts, syncable=syncable)
-sqlite_backend.SQLiteDatabase.register_implementation(SQLCipherDatabase)
+sqlite.SQLiteDatabase.register_implementation(SQLCipherDatabase)
#
diff --git a/client/src/leap/soledad/client/_database/sqlite.py b/client/src/leap/soledad/client/_database/sqlite.py
new file mode 100644
index 00000000..4f7b1259
--- /dev/null
+++ b/client/src/leap/soledad/client/_database/sqlite.py
@@ -0,0 +1,930 @@
+# Copyright 2011 Canonical Ltd.
+# Copyright 2016 LEAP Encryption Access Project
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+A L2DB implementation that uses SQLite as its persistence layer.
+"""
+
+import errno
+import os
+import json
+import sys
+import time
+import uuid
+import pkg_resources
+
+from sqlite3 import dbapi2
+
+from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget
+from leap.soledad.common.l2db import (
+ Document, errors,
+ query_parser, vectorclock)
+
+
+class SQLiteDatabase(CommonBackend):
+ """A U1DB implementation that uses SQLite as its persistence layer."""
+
+ _sqlite_registry = {}
+
+ def __init__(self, sqlite_file, document_factory=None):
+ """Create a new sqlite file."""
+ self._db_handle = dbapi2.connect(sqlite_file)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def get_sync_target(self):
+ return SQLiteSyncTarget(self)
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'index_storage'")
+ except dbapi2.OperationalError as e:
+ # The table does not exist yet
+ return None, e
+ else:
+ return c.fetchone()[0], None
+
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
+
+ @classmethod
+ def _open_database(cls, sqlite_file, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLiteDatabase._sqlite_registry[v](
+ sqlite_file, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(
+ sqlite_file, document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLitePartialExpandDatabase
+ backend_cls = SQLitePartialExpandDatabase
+ return backend_cls(sqlite_file, document_factory=document_factory)
+
+ @staticmethod
+ def delete_database(sqlite_file):
+ try:
+ os.unlink(sqlite_file)
+ except OSError as ex:
+ if ex.errno == errno.ENOENT:
+ raise errors.DatabaseDoesNotExist()
+ raise
+
+ @staticmethod
+ def register_implementation(klass):
+ """Register that we implement an SQLiteDatabase.
+
+ The attribute _index_storage_value will be used as the lookup key.
+ """
+ SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
+
+ def _get_sqlite_handle(self):
+ """Get access to the underlying sqlite database.
+
+ This should only be used by the test suite, etc, for examining the
+ state of the underlying database.
+ """
+ return self._db_handle
+
+ def _close_sqlite_handle(self):
+ """Release access to the underlying sqlite database."""
+ self._db_handle.close()
+
+ def close(self):
+ self._close_sqlite_handle()
+
+ def _is_initialized(self, c):
+ """Check if this database has been initialized."""
+ c.execute("PRAGMA case_sensitive_like=ON")
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'sql_schema'")
+ except dbapi2.OperationalError:
+ # The table does not exist yet
+ val = None
+ else:
+ val = c.fetchone()
+ if val is not None:
+ return True
+ return False
+
+ def _initialize(self, c):
+ """Create the schema in the database."""
+ # read the script with sql commands
+ # TODO: Change how we set up the dependency. Most likely use something
+ # like lp:dirspec to grab the file from a common resource
+ # directory. Doesn't specifically need to be handled until we get
+ # to the point of packaging this.
+ schema_content = pkg_resources.resource_string(
+ __name__, 'dbschema.sql')
+ # Note: We'd like to use c.executescript() here, but it seems that
+ # executescript always commits, even if you set
+ # isolation_level = None, so if we want to properly handle
+ # exclusive locking and rollbacks between processes, we need
+ # to execute it line-by-line
+ for line in schema_content.split(';'):
+ if not line:
+ continue
+ c.execute(line)
+ # add extra fields
+ self._extra_schema_init(c)
+ # A unique identifier should be set for this replica. Implementations
+ # don't have to strictly use uuid here, but we do want the uid to be
+ # unique amongst all databases that will sync with each other.
+ # We might extend this to using something with hostname for easier
+ # debugging.
+ self._set_replica_uid_in_transaction(uuid.uuid4().hex)
+ c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
+ (self._index_storage_value,))
+
+ def _ensure_schema(self):
+ """Ensure that the database schema has been created."""
+ old_isolation_level = self._db_handle.isolation_level
+ c = self._db_handle.cursor()
+ if self._is_initialized(c):
+ return
+ try:
+ # autocommit/own mgmt of transactions
+ self._db_handle.isolation_level = None
+ with self._db_handle:
+ # only one execution path should initialize the db
+ c.execute("begin exclusive")
+ if self._is_initialized(c):
+ return
+ self._initialize(c)
+ finally:
+ self._db_handle.isolation_level = old_isolation_level
+
+ def _extra_schema_init(self, c):
+ """Add any extra fields, etc to the basic table definitions."""
+
+ def _parse_index_definition(self, index_field):
+ """Parse a field definition for an index, returning a Getter."""
+ # Note: We may want to keep a Parser object around, and cache the
+ # Getter objects for a greater length of time. Specifically, if
+ # you create a bunch of indexes, and then insert 50k docs, you'll
+ # re-parse the indexes between puts. The time to insert the docs
+ # is still likely to dominate put_doc time, though.
+ parser = query_parser.Parser()
+ getter = parser.parse(index_field)
+ return getter
+
+ def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
+ """Update document_fields for a single document.
+
+ :param doc_id: Identifier for this document
+ :param raw_doc: The python dict representation of the document.
+ :param getters: A list of [(field_name, Getter)]. Getter.get will be
+ called to evaluate the index definition for this document, and the
+ results will be inserted into the db.
+ :param db_cursor: An sqlite Cursor.
+ :return: None
+ """
+ values = []
+ for field_name, getter in getters:
+ for idx_value in getter.get(raw_doc):
+ values.append((doc_id, field_name, idx_value))
+ if values:
+ db_cursor.executemany(
+ "INSERT INTO document_fields VALUES (?, ?, ?)", values)
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ with self._db_handle:
+ self._set_replica_uid_in_transaction(replica_uid)
+
+ def _set_replica_uid_in_transaction(self, replica_uid):
+ """Set the replica_uid. A transaction should already be held."""
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO u1db_config"
+ " VALUES ('replica_uid', ?)",
+ (replica_uid,))
+ self._real_replica_uid = replica_uid
+
+ def _get_replica_uid(self):
+ if self._real_replica_uid is not None:
+ return self._real_replica_uid
+ c = self._db_handle.cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
+ val = c.fetchone()
+ if val is None:
+ return None
+ self._real_replica_uid = val[0]
+ return self._real_replica_uid
+
+ _replica_uid = property(_get_replica_uid)
+
+ def _get_generation(self):
+ c = self._db_handle.cursor()
+ c.execute('SELECT max(generation) FROM transaction_log')
+ val = c.fetchone()[0]
+ if val is None:
+ return 0
+ return val
+
+ def _get_generation_info(self):
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT max(generation), transaction_id FROM transaction_log ')
+ val = c.fetchone()
+ if val[0] is None:
+ return(0, '')
+ return val
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
+ (generation,))
+ val = c.fetchone()
+ if val is None:
+ raise errors.InvalidGeneration
+ return val[0]
+
+ def _get_transaction_log(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, transaction_id FROM transaction_log"
+ " ORDER BY generation")
+ return c.fetchall()
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ c = self._db_handle.cursor()
+ if check_for_conflicts:
+ c.execute(
+ "SELECT document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
+ "conflicts ON conflicts.doc_id = document.doc_id WHERE "
+ "document.doc_id = ? GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;", (doc_id,))
+ else:
+ c.execute(
+ "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return None
+ doc_rev, content, conflicts = val
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return False
+ else:
+ return True
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ c = self._db_handle.cursor()
+ c.execute(
+ "SELECT document.doc_id, document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
+ "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;")
+ rows = c.fetchall()
+ for doc_id, doc_rev, content, conflicts in rows:
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ results.append(doc)
+ return (generation, results)
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
+ """Convert a dict representation into named fields.
+
+ So something like: {'key1': 'val1', 'key2': 'val2'}
+ gets converted into: [(doc_id, 'key1', 'val1', 0)
+ (doc_id, 'key2', 'val2', 0)]
+ :param doc_id: Just added to every record.
+ :param base_field: if set, these are nested keys, so each field should
+ be appropriately prefixed.
+ :param raw_doc: The python dictionary.
+ """
+ # TODO: Handle lists
+ values = []
+ for field_name, value in raw_doc.iteritems():
+ if value is None and not save_none:
+ continue
+ if base_field:
+ full_name = base_field + '.' + field_name
+ else:
+ full_name = field_name
+ if value is None or isinstance(value, (int, float, basestring)):
+ values.append((doc_id, full_name, value, len(values)))
+ else:
+ subvalues = self._expand_to_fields(doc_id, full_name, value,
+ save_none)
+ for _, subfield_name, val, _ in subvalues:
+ values.append((doc_id, subfield_name, val, len(values)))
+ return values
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """Actually insert a document into the database.
+
+ This both updates the existing documents content, and any indexes that
+ refer to this document.
+ """
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def whats_changed(self, old_generation=0):
+ c = self._db_handle.cursor()
+ c.execute("SELECT generation, doc_id, transaction_id"
+ " FROM transaction_log"
+ " WHERE generation > ? ORDER BY generation DESC",
+ (old_generation,))
+ results = c.fetchall()
+ cur_gen = old_generation
+ seen = set()
+ changes = []
+ newest_trans_id = ''
+ for generation, doc_id, trans_id in results:
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ if changes:
+ cur_gen = changes[0][1] # max generation
+ newest_trans_id = changes[0][2]
+ changes.reverse()
+ else:
+ c.execute("SELECT generation, transaction_id"
+ " FROM transaction_log ORDER BY generation DESC LIMIT 1")
+ results = c.fetchone()
+ if not results:
+ cur_gen = 0
+ newest_trans_id = ''
+ else:
+ cur_gen, newest_trans_id = results
+
+ return cur_gen, newest_trans_id, changes
+
+ def delete_doc(self, doc):
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _get_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
+ (doc_id,))
+ return [self._factory(doc_id, doc_rev, content)
+ for doc_rev, content in c.fetchall()]
+
+ def get_doc_conflicts(self, doc_id):
+ with self._db_handle:
+ conflict_docs = self._get_conflicts(doc_id)
+ if not conflict_docs:
+ return []
+ this_doc = self._get_doc(doc_id)
+ this_doc.has_conflicts = True
+ return [this_doc] + conflict_docs
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ c = self._db_handle.cursor()
+ c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
+ " WHERE replica_uid = ?",
+ (other_replica_uid,))
+ val = c.fetchone()
+ if val is None:
+ other_gen = 0
+ trans_id = ''
+ else:
+ other_gen = val[0]
+ trans_id = val[1]
+ return other_gen, trans_id
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ with self._db_handle:
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
+ (other_replica_uid, other_generation,
+ other_transaction_id))
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ return super(SQLiteDatabase, self)._put_doc_if_newer(
+ doc,
+ save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+
+ def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
+ c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
+ (doc_id, my_doc_rev, my_content))
+
+ def _delete_conflicts(self, c, doc, conflict_revs):
+ deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
+ c.executemany("DELETE FROM conflicts"
+ " WHERE doc_id=? AND doc_rev=?", deleting)
+ doc.has_conflicts = self._has_conflicts(doc.doc_id)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ c_revs_to_prune = []
+ for c_doc in self._get_conflicts(doc.doc_id):
+ c_vcr = vectorclock.VectorClockRev(c_doc.rev)
+ if doc_vcr.is_newer(c_vcr):
+ c_revs_to_prune.append(c_doc.rev)
+ elif doc.same_content_as(c_doc):
+ c_revs_to_prune.append(c_doc.rev)
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ c = self._db_handle.cursor()
+ self._delete_conflicts(c, doc, c_revs_to_prune)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ c = self._db_handle.cursor()
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ with self._db_handle:
+ cur_doc = self._get_doc(doc.doc_id)
+ # TODO: https://bugs.launchpad.net/u1db/+bug/928274
+ # I think we have a logic bug in resolve_doc
+ # Specifically, cur_doc.rev is always in the final vector
+ # clock of revisions that we supersede, even if it wasn't in
+ # conflicted_doc_revs. We still add it as a conflict, but the
+ # fact that _put_doc_if_newer propagates resolutions means I
+ # think that conflict could accidentally be resolved. We need
+ # to add a test for this case first. (create a rev, create a
+ # conflict, create another conflict, resolve the first rev
+ # and first conflict, then make sure that the resolved
+ # rev doesn't supersede the second conflict rev.) It *might*
+ # not matter, because the superseding rev is in as a
+ # conflict, but it does seem incorrect
+ new_rev = self._ensure_maximal_rev(cur_doc.rev,
+ conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ c = self._db_handle.cursor()
+ doc.rev = new_rev
+ if cur_doc.rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
+ # TODO: Is there some way that we could construct a rev that would
+ # end up in superseded_revs, such that we add a conflict, and
+ # then immediately delete it?
+ self._delete_conflicts(c, doc, superseded_revs)
+
+ def list_indexes(self):
+ """Return the list of indexes and their definitions."""
+ c = self._db_handle.cursor()
+ # TODO: How do we test the ordering?
+ c.execute("SELECT name, field FROM index_definitions"
+ " ORDER BY name, offset")
+ definitions = []
+ cur_name = None
+ for name, field in c.fetchall():
+ if cur_name != name:
+ definitions.append((name, []))
+ cur_name = name
+ definitions[-1][-1].append(field)
+ return definitions
+
+ def _get_index_definition(self, index_name):
+ """Return the stored definition for a given index_name."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions"
+ " WHERE name = ? ORDER BY offset", (index_name,))
+ fields = [x[0] for x in c.fetchall()]
+ if not fields:
+ raise errors.IndexDoesNotExist
+ return fields
+
+ @staticmethod
+ def _strip_glob(value):
+ """Remove the trailing * from a value."""
+ assert value[-1] == '*'
+ return value[:-1]
+
+ def _format_query(self, definition, key_values):
+ # First, build the definition. We join the document_fields table
+ # against itself, as many times as the 'width' of our definition.
+ # We then do a query for each key_value, one-at-a-time.
+ # Note: All of these strings are static, we could cache them, etc.
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ wildcard_where = [novalue_where[i] +
+ (" AND d%d.value NOT NULL" % (i,))
+ for i in range(len(definition))]
+ exact_where = [novalue_where[i] +
+ (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ like_where = [novalue_where[i] +
+ (" AND d%d.value GLOB ?" % (i,))
+ for i in range(len(definition))]
+ is_wildcard = False
+ # Merge the lists together, so that:
+ # [field1, field2, field3], [val1, val2, val3]
+ # Becomes:
+ # (field1, val1, field2, val2, field3, val3)
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(exact_where[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_from_index(self, index_name, *key_values):
+ definition = self._get_index_definition(index_name)
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ statement, args = self._format_query(definition, key_values)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError as e:
+ raise dbapi2.OperationalError(
+ str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def _format_range_query(self, definition, start_value, end_value):
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ wildcard_where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ like_where = [
+ novalue_where[i] + (
+ " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
+ range(len(definition))]
+ range_where_lower = [
+ novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
+ range(len(definition))]
+ range_where_upper = [
+ novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
+ range(len(definition))]
+ args = []
+ where = []
+ if start_value:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if len(start_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, start_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(self._strip_glob(value))
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(value)
+ if end_value:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ if len(end_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, end_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(self._strip_glob(value))
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_upper[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ definition = self._get_index_definition(index_name)
+ statement, args = self._format_range_query(
+ definition, start_value, end_value)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError as e:
+ raise dbapi2.OperationalError(
+ str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def get_index_keys(self, index_name):
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+ value_fields = ', '.join([
+ 'd%d.value' % i for i in range(len(definition))])
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ statement = (
+ "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
+ value_fields, ', '.join(tables), ' AND '.join(where),
+ value_fields))
+ try:
+ c.execute(statement, tuple(definition))
+ except dbapi2.OperationalError as e:
+ raise dbapi2.OperationalError(
+ str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
+ return c.fetchall()
+
+ def delete_index(self, index_name):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ c.execute("DELETE FROM index_definitions WHERE name = ?",
+ (index_name,))
+ c.execute(
+ "DELETE FROM document_fields WHERE document_fields.field_name "
+ " NOT IN (SELECT field from index_definitions)")
+
+
+class SQLiteSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
+
+
+class SQLitePartialExpandDatabase(SQLiteDatabase):
+ """An SQLite Backend that expands documents into a document_field table.
+
+ It stores the original document text in document.doc. For fields that are
+ indexed, the data goes into document_fields.
+ """
+
+ _index_storage_value = 'expand referenced'
+
+ def _get_indexed_fields(self):
+ """Determine what fields are indexed."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions")
+ return set([x[0] for x in c.fetchall()])
+
+ def _evaluate_index(self, raw_doc, field):
+ parser = query_parser.Parser()
+ getter = parser.parse(field)
+ return getter.get(raw_doc)
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ c = self._db_handle.cursor()
+ if doc and not doc.is_tombstone():
+ raw_doc = json.loads(doc.get_json())
+ else:
+ raw_doc = {}
+ if old_doc is not None:
+ c.execute("UPDATE document SET doc_rev=?, content=?"
+ " WHERE doc_id = ?",
+ (doc.rev, doc.get_json(), doc.doc_id))
+ c.execute("DELETE FROM document_fields WHERE doc_id = ?",
+ (doc.doc_id,))
+ else:
+ c.execute("INSERT INTO document (doc_id, doc_rev, content)"
+ " VALUES (?, ?, ?)",
+ (doc.doc_id, doc.rev, doc.get_json()))
+ indexed_fields = self._get_indexed_fields()
+ if indexed_fields:
+ # It is expected that len(indexed_fields) is shorter than
+ # len(raw_doc)
+ getters = [(field, self._parse_index_definition(field))
+ for field in indexed_fields]
+ self._update_indexes(doc.doc_id, raw_doc, getters, c)
+ trans_id = self._allocate_transaction_id()
+ c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
+ " VALUES (?, ?)", (doc.doc_id, trans_id))
+
+ def create_index(self, index_name, *index_expressions):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ cur_fields = self._get_indexed_fields()
+ definition = [(index_name, idx, field)
+ for idx, field in enumerate(index_expressions)]
+ try:
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ definition)
+ except dbapi2.IntegrityError as e:
+ stored_def = self._get_index_definition(index_name)
+ if stored_def == [x[-1] for x in definition]:
+ return
+ raise errors.IndexNameTakenError(
+ str(e) +
+ str(sys.exc_info()[2])
+ )
+ new_fields = set(
+ [f for f in index_expressions if f not in cur_fields])
+ if new_fields:
+ self._update_all_indexes(new_fields)
+
+ def _iter_all_docs(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, content FROM document")
+ while True:
+ next_rows = c.fetchmany()
+ if not next_rows:
+ break
+ for row in next_rows:
+ yield row
+
+ def _update_all_indexes(self, new_fields):
+ """Iterate all the documents, and add content to document_fields.
+
+ :param new_fields: The index definitions that need to be added.
+ """
+ getters = [(field, self._parse_index_definition(field))
+ for field in new_fields]
+ c = self._db_handle.cursor()
+ for doc_id, doc in self._iter_all_docs():
+ if doc is None:
+ continue
+ raw_doc = json.loads(doc)
+ self._update_indexes(doc_id, raw_doc, getters, c)
+
+
+SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)
diff --git a/client/src/leap/soledad/client/_document.py b/client/src/leap/soledad/client/_document.py
new file mode 100644
index 00000000..9ba08e93
--- /dev/null
+++ b/client/src/leap/soledad/client/_document.py
@@ -0,0 +1,253 @@
+# -*- coding: utf-8 -*-
+# _document.py
+# Copyright (C) 2017 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Everything related to documents.
+"""
+
+import enum
+import weakref
+import uuid
+
+from twisted.internet import defer
+
+from zope.interface import Interface
+from zope.interface import implementer
+
+from leap.soledad.common.document import SoledadDocument
+
+
+class IDocumentWithAttachment(Interface):
+ """
+ A document that can have an attachment.
+ """
+
+ def set_store(self, store):
+ """
+ Set the store used by this file to manage attachments.
+
+ :param store: The store used to manage attachments.
+ :type store: Soledad
+ """
+
+ def put_attachment(self, fd):
+ """
+ Attach data to this document.
+
+ Add the attachment to local storage, enqueue for upload.
+
+ The document content will be updated with a pointer to the attachment,
+ but the document has to be manually put in the database to reflect
+ modifications.
+
+ :param fd: A file-like object whose content will be attached to this
+ document.
+ :type fd: file-like
+
+ :return: A deferred which fires when the attachment has been added to
+ local storage.
+ :rtype: Deferred
+ """
+
+ def get_attachment(self):
+ """
+ Return the data attached to this document.
+
+ If document content contains a pointer to the attachment, try to get
+ the attachment from local storage and, if not found, from remote
+ storage.
+
+ :return: A deferred which fires with a file like-object whose content
+ is the attachment of this document, or None if nothing is
+ attached.
+ :rtype: Deferred
+ """
+
+ def delete_attachment(self):
+ """
+ Delete the attachment of this document.
+
+ The pointer to the attachment will be removed from the document
+ content, but the document has to be manually put in the database to
+ reflect modifications.
+
+ :return: A deferred which fires when the attachment has been deleted
+ from local storage.
+ :rtype: Deferred
+ """
+
+ def attachment_state(self):
+ """
+ Return the state of the attachment of this document.
+
+ The state is a member of AttachmentStates and is of one of NONE,
+ LOCAL, REMOTE or SYNCED.
+
+ :return: A deferred which fires with The state of the attachment of
+ this document.
+ :rtype: Deferred
+ """
+
+ def is_dirty(self):
+ """
+ Return wether this document's content differs from the contents stored
+ in local database.
+
+ :return: Whether this document is dirty or not.
+ :rtype: bool
+ """
+
+ def upload_attachment(self):
+ """
+ Upload this document's attachment.
+
+ :return: A deferred which fires with the state of the attachment after
+ it's been uploaded, or NONE if there's no attachment for this
+ document.
+ :rtype: Deferred
+ """
+
+ def download_attachment(self):
+ """
+ Download this document's attachment.
+
+ :return: A deferred which fires with the state of the attachment after
+ it's been downloaded, or NONE if there's no attachment for
+ this document.
+ :rtype: Deferred
+ """
+
+
+class BlobDoc(object):
+
+ # TODO probably not needed, but convenient for testing for now.
+
+ def __init__(self, content, blob_id):
+
+ self.blob_id = blob_id
+ self.is_blob = True
+ self.blob_fd = content
+ if blob_id is None:
+ blob_id = uuid.uuid4().get_hex()
+ self.blob_id = blob_id
+
+
+class AttachmentStates(enum.IntEnum):
+ NONE = 0
+ LOCAL = 1
+ REMOTE = 2
+ SYNCED = 4
+
+
+@implementer(IDocumentWithAttachment)
+class Document(SoledadDocument):
+
+ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False,
+ syncable=True, store=None):
+ SoledadDocument.__init__(self, doc_id=doc_id, rev=rev, json=json,
+ has_conflicts=has_conflicts,
+ syncable=syncable)
+ self.set_store(store)
+
+ #
+ # properties
+ #
+
+ @property
+ def _manager(self):
+ if not self.store or not hasattr(self.store, 'blobmanager'):
+ raise Exception('No blob manager found to manage attachments.')
+ return self.store.blobmanager
+
+ @property
+ def _blob_id(self):
+ if self.content and 'blob_id' in self.content:
+ return self.content['blob_id']
+ return None
+
+ def get_store(self):
+ return self._store() if self._store else None
+
+ def set_store(self, store):
+ self._store = weakref.ref(store) if store else None
+
+ store = property(get_store, set_store)
+
+ #
+ # attachment api
+ #
+
+ def put_attachment(self, fd):
+ # add pointer to content
+ blob_id = self._blob_id or str(uuid.uuid4())
+ if not self.content:
+ self.content = {}
+ self.content['blob_id'] = blob_id
+ # put using manager
+ blob = BlobDoc(fd, blob_id)
+ fd.seek(0, 2)
+ size = fd.tell()
+ fd.seek(0)
+ return self._manager.put(blob, size)
+
+ def get_attachment(self):
+ if not self._blob_id:
+ return defer.succeed(None)
+ return self._manager.get(self._blob_id)
+
+ def delete_attachment(self):
+ raise NotImplementedError
+
+ @defer.inlineCallbacks
+ def attachment_state(self):
+ state = AttachmentStates.NONE
+
+ if not self._blob_id:
+ defer.returnValue(state)
+
+ local_list = yield self._manager.local_list()
+ if self._blob_id in local_list:
+ state |= AttachmentStates.LOCAL
+
+ remote_list = yield self._manager.remote_list()
+ if self._blob_id in remote_list:
+ state |= AttachmentStates.REMOTE
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def is_dirty(self):
+ stored = yield self.store.get_doc(self.doc_id)
+ if stored.content != self.content:
+ defer.returnValue(True)
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def upload_attachment(self):
+ if not self._blob_id:
+ defer.returnValue(AttachmentStates.NONE)
+
+ fd = yield self._manager.get_blob(self._blob_id)
+ # TODO: turn following method into a public one
+ yield self._manager._encrypt_and_upload(self._blob_id, fd)
+ defer.returnValue(self.attachment_state())
+
+ @defer.inlineCallbacks
+ def download_attachment(self):
+ if not self._blob_id:
+ defer.returnValue(None)
+ yield self.get_attachment()
+ defer.returnValue(self.attachment_state())
diff --git a/client/src/leap/soledad/client/_secrets/storage.py b/client/src/leap/soledad/client/_secrets/storage.py
index 6ea89900..85713a48 100644
--- a/client/src/leap/soledad/client/_secrets/storage.py
+++ b/client/src/leap/soledad/client/_secrets/storage.py
@@ -23,8 +23,8 @@ from hashlib import sha256
from leap.soledad.common import SHARED_DB_NAME
from leap.soledad.common.log import getLogger
-from leap.soledad.common.document import SoledadDocument
from leap.soledad.client.shared_db import SoledadSharedDatabase
+from leap.soledad.client._document import Document
from leap.soledad.client._secrets.util import emit, UserDataMixin
@@ -111,7 +111,7 @@ class SecretsStorage(UserDataMixin):
def save_remote(self, encrypted):
doc = self._remote_doc
if not doc:
- doc = SoledadDocument(doc_id=self._remote_doc_id())
+ doc = Document(doc_id=self._remote_doc_id())
doc.content = encrypted
db = self._shared_db
if not db:
diff --git a/client/src/leap/soledad/client/api.py b/client/src/leap/soledad/client/api.py
index 18fb35d7..3dd99227 100644
--- a/client/src/leap/soledad/client/api.py
+++ b/client/src/leap/soledad/client/api.py
@@ -51,13 +51,14 @@ from leap.soledad.common.l2db.remote import http_client
from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname
from leap.soledad.common.errors import DatabaseAccessError
-from leap.soledad.client import adbapi
-from leap.soledad.client import events as soledad_events
-from leap.soledad.client import interfaces as soledad_interfaces
-from leap.soledad.client import sqlcipher
-from leap.soledad.client._recovery_code import RecoveryCode
-from leap.soledad.client._secrets import Secrets
-from leap.soledad.client._crypto import SoledadCrypto
+from . import events as soledad_events
+from . import interfaces as soledad_interfaces
+from ._crypto import SoledadCrypto
+from ._database import adbapi
+from ._database import blobs
+from ._database import sqlcipher
+from ._recovery_code import RecoveryCode
+from ._secrets import Secrets
logger = getLogger(__name__)
@@ -171,6 +172,7 @@ class Soledad(object):
self._recovery_code = RecoveryCode()
self._secrets = Secrets(self)
self._crypto = SoledadCrypto(self._secrets.remote_secret)
+ self._init_blobmanager()
try:
# initialize database access, trap any problems so we can shutdown
@@ -262,6 +264,13 @@ class Soledad(object):
sync_exchange_phase = _p
return sync_phase, sync_exchange_phase
+ def _init_blobmanager(self):
+ path = os.path.join(os.path.dirname(self._local_db_path), 'blobs')
+ url = urlparse.urljoin(self.server_url, 'blobs/%s' % uuid)
+ key = self._secrets.local_key
+ self.blobmanager = blobs.BlobManager(path, url, key, self.uuid,
+ self.token, SOLEDAD_CERT)
+
#
# Closing methods
#
@@ -272,6 +281,7 @@ class Soledad(object):
"""
logger.debug("closing soledad")
self._dbpool.close()
+ self.blobmanager.close()
if getattr(self, '_dbsyncer', None):
self._dbsyncer.close()
@@ -306,7 +316,7 @@ class Soledad(object):
============================== WARNING ==============================
:param doc: A document with new content.
- :type doc: leap.soledad.common.document.SoledadDocument
+ :type doc: leap.soledad.common.document.Document
:return: A deferred whose callback will be invoked with the new
revision identifier for the document. The document object will
also be updated.
@@ -323,7 +333,7 @@ class Soledad(object):
This will also set doc.content to None.
:param doc: A document to be deleted.
- :type doc: leap.soledad.common.document.SoledadDocument
+ :type doc: leap.soledad.common.document.Document
:return: A deferred.
:rtype: twisted.internet.defer.Deferred
"""
@@ -387,6 +397,7 @@ class Soledad(object):
"""
return self._defer("get_all_docs", include_deleted)
+ @defer.inlineCallbacks
def create_doc(self, content, doc_id=None):
"""
Create a new document.
@@ -408,8 +419,9 @@ class Soledad(object):
# create_doc (and probably to put_doc too). There are cases (mail
# payloads for example) in which we already have the encoding in the
# headers, so we don't need to guess it.
- d = self._defer("create_doc", content, doc_id=doc_id)
- return d
+ doc = yield self._defer("create_doc", content, doc_id=doc_id)
+ doc.set_store(self)
+ defer.returnValue(doc)
def create_doc_from_json(self, json, doc_id=None):
"""
@@ -590,7 +602,7 @@ class Soledad(object):
the time you GET_DOC_CONFLICTS until the point where you RESOLVE)
:param doc: A Document with the new content to be inserted.
- :type doc: SoledadDocument
+ :type doc: Document
:param conflicted_doc_revs: A list of revisions that the new content
supersedes.
:type conflicted_doc_revs: list(str)
diff --git a/client/src/leap/soledad/client/crypto.py b/client/src/leap/soledad/client/crypto.py
index 09e90171..4795846c 100644
--- a/client/src/leap/soledad/client/crypto.py
+++ b/client/src/leap/soledad/client/crypto.py
@@ -167,7 +167,7 @@ class SoledadCrypto(object):
Wrapper around encrypt_docstr that accepts the document as argument.
:param doc: the document.
- :type doc: SoledadDocument
+ :type doc: Document
"""
key = self.doc_passphrase(doc.doc_id)
@@ -179,7 +179,7 @@ class SoledadCrypto(object):
Wrapper around decrypt_doc_dict that accepts the document as argument.
:param doc: the document.
- :type doc: SoledadDocument
+ :type doc: Document
:return: json string with the decrypted document
:rtype: str
@@ -194,7 +194,7 @@ class SoledadCrypto(object):
#
-# Crypto utilities for a SoledadDocument.
+# Crypto utilities for a Document.
#
def mac_doc(doc_id, doc_rev, ciphertext, enc_scheme, enc_method, enc_iv,
@@ -439,7 +439,7 @@ def is_symmetrically_encrypted(doc):
Return True if the document was symmetrically encrypted.
:param doc: The document to check.
- :type doc: SoledadDocument
+ :type doc: Document
:rtype: bool
"""
diff --git a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py
index 92bc85d6..276b1200 100644
--- a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py
+++ b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py
@@ -28,7 +28,7 @@ from twisted.internet import defer, reactor
from leap.soledad.common import l2db
from leap.soledad.client import adbapi
-from leap.soledad.client.sqlcipher import SQLCipherOptions
+from leap.soledad.client._database.sqlcipher import SQLCipherOptions
folder = os.environ.get("TMPDIR", "tmp")
diff --git a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py
index 429566c7..d288582b 100644
--- a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py
+++ b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py
@@ -27,7 +27,7 @@ import sys
from twisted.internet import defer, reactor
from leap.soledad.client import adbapi
-from leap.soledad.client.sqlcipher import SQLCipherOptions
+from leap.soledad.client._database.sqlcipher import SQLCipherOptions
from leap.soledad.common import l2db
diff --git a/client/src/leap/soledad/client/examples/use_adbapi.py b/client/src/leap/soledad/client/examples/use_adbapi.py
index 39301b41..25b32307 100644
--- a/client/src/leap/soledad/client/examples/use_adbapi.py
+++ b/client/src/leap/soledad/client/examples/use_adbapi.py
@@ -24,7 +24,7 @@ import os
from twisted.internet import defer, reactor
from leap.soledad.client import adbapi
-from leap.soledad.client.sqlcipher import SQLCipherOptions
+from leap.soledad.client._database.sqlcipher import SQLCipherOptions
from leap.soledad.common import l2db
diff --git a/client/src/leap/soledad/client/http_target/fetch.py b/client/src/leap/soledad/client/http_target/fetch.py
index cf4984d1..9d456830 100644
--- a/client/src/leap/soledad/client/http_target/fetch.py
+++ b/client/src/leap/soledad/client/http_target/fetch.py
@@ -23,10 +23,10 @@ from leap.soledad.client.events import emit_async
from leap.soledad.client.http_target.support import RequestBody
from leap.soledad.common.log import getLogger
from leap.soledad.client._crypto import is_symmetrically_encrypted
-from leap.soledad.common.document import SoledadDocument
from leap.soledad.common.l2db import errors
from leap.soledad.client import crypto as old_crypto
+from .._document import Document
from . import fetch_protocol
logger = getLogger(__name__)
@@ -113,7 +113,7 @@ class HTTPDocFetcher(object):
@defer.inlineCallbacks
def __atomic_doc_parse(self, doc_info, content, total):
- doc = SoledadDocument(doc_info['id'], doc_info['rev'], content)
+ doc = Document(doc_info['id'], doc_info['rev'], content)
if is_symmetrically_encrypted(content):
content = (yield self._crypto.decrypt_doc(doc)).getvalue()
elif old_crypto.is_symmetrically_encrypted(doc):
diff --git a/client/src/leap/soledad/client/interfaces.py b/client/src/leap/soledad/client/interfaces.py
index 1be47df7..0600449f 100644
--- a/client/src/leap/soledad/client/interfaces.py
+++ b/client/src/leap/soledad/client/interfaces.py
@@ -69,7 +69,7 @@ class ILocalStorage(Interface):
Update a document in the local encrypted database.
:param doc: the document to update
- :type doc: SoledadDocument
+ :type doc: Document
:return:
a deferred that will fire with the new revision identifier for
@@ -82,7 +82,7 @@ class ILocalStorage(Interface):
Delete a document from the local encrypted database.
:param doc: the document to delete
- :type doc: SoledadDocument
+ :type doc: Document
:return:
a deferred that will fire with ...
@@ -102,7 +102,7 @@ class ILocalStorage(Interface):
:return:
A deferred that will fire with the document object, containing a
- SoledadDocument, or None if it could not be found
+ Document, or None if it could not be found
:rtype: Deferred
"""
@@ -147,7 +147,7 @@ class ILocalStorage(Interface):
:type doc_id: str
:return:
- A deferred tht will fire with the new document (SoledadDocument
+ A deferred tht will fire with the new document (Document
instance).
:rtype: Deferred
"""
@@ -167,7 +167,7 @@ class ILocalStorage(Interface):
:param doc_id: An optional identifier specifying the document id.
:type doc_id:
:return:
- A deferred that will fire with the new document (A SoledadDocument
+ A deferred that will fire with the new document (A Document
instance)
:rtype: Deferred
"""
@@ -304,7 +304,7 @@ class ILocalStorage(Interface):
Mark a document as no longer conflicted.
:param doc: a document with the new content to be inserted.
- :type doc: SoledadDocument
+ :type doc: Document
:param conflicted_doc_revs:
A deferred that will fire with a list of revisions that the new
content supersedes.