diff options
Diffstat (limited to 'client/src/leap/soledad')
18 files changed, 1330 insertions, 81 deletions
diff --git a/client/src/leap/soledad/client/__init__.py b/client/src/leap/soledad/client/__init__.py index 3a114021..bcad78db 100644 --- a/client/src/leap/soledad/client/__init__.py +++ b/client/src/leap/soledad/client/__init__.py @@ -17,12 +17,14 @@ """ Soledad - Synchronization Of Locally Encrypted Data Among Devices. """ -from leap.soledad.client.api import Soledad from leap.soledad.common import soledad_assert +from .api import Soledad +from ._document import Document, AttachmentStates from ._version import get_versions __version__ = get_versions()['version'] del get_versions -__all__ = ['soledad_assert', 'Soledad', '__version__'] +__all__ = ['soledad_assert', 'Soledad', 'Document', 'AttachmentStates', + '__version__'] diff --git a/client/src/leap/soledad/client/_crypto.py b/client/src/leap/soledad/client/_crypto.py index 7a1d508e..e66cc600 100644 --- a/client/src/leap/soledad/client/_crypto.py +++ b/client/src/leap/soledad/client/_crypto.py @@ -135,7 +135,7 @@ class SoledadCrypto(object): and wrapping the result as a simple JSON string with a "raw" key. :param doc: the document to be encrypted. - :type doc: SoledadDocument + :type doc: Document :return: A deferred whose callback will be invoked with a JSON string containing the ciphertext as the value of "raw" key. :rtype: twisted.internet.defer.Deferred @@ -159,7 +159,7 @@ class SoledadCrypto(object): the decrypted cleartext content from the encrypted document. :param doc: the document to be decrypted. - :type doc: SoledadDocument + :type doc: Document :return: The decrypted cleartext content of the document. :rtype: str """ @@ -225,7 +225,7 @@ def decrypt_sym(data, key, iv, method=ENC_METHOD.aes_256_gcm): class BlobEncryptor(object): """ Produces encrypted data from the cleartext data associated with a given - SoledadDocument using AES-256 cipher in GCM mode. + Document using AES-256 cipher in GCM mode. The production happens using a Twisted's FileBodyProducer, which uses a Cooperator to schedule calls and can be paused/resumed. Each call takes at diff --git a/client/src/leap/soledad/client/_database/__init__.py b/client/src/leap/soledad/client/_database/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/client/src/leap/soledad/client/_database/__init__.py diff --git a/client/src/leap/soledad/client/adbapi.py b/client/src/leap/soledad/client/_database/adbapi.py index b002055e..5c28d108 100644 --- a/client/src/leap/soledad/client/adbapi.py +++ b/client/src/leap/soledad/client/_database/adbapi.py @@ -30,8 +30,9 @@ from zope.proxy import ProxyBase, setProxiedObject from leap.soledad.common.log import getLogger from leap.soledad.common.errors import DatabaseAccessError -from leap.soledad.client import sqlcipher as soledad_sqlcipher -from leap.soledad.client.pragmas import set_init_pragmas + +from . import sqlcipher +from . import pragmas if sys.version_info[0] < 3: from pysqlcipher import dbapi2 @@ -73,7 +74,7 @@ def getConnectionPool(opts, openfun=None, driver="pysqlcipher"): :rtype: U1DBConnectionPool """ if openfun is None and driver == "pysqlcipher": - openfun = partial(set_init_pragmas, opts=opts) + openfun = partial(pragmas.set_init_pragmas, opts=opts) return U1DBConnectionPool( opts, # the following params are relayed "as is" to twisted's @@ -87,7 +88,7 @@ class U1DBConnection(adbapi.Connection): A wrapper for a U1DB connection instance. """ - u1db_wrapper = soledad_sqlcipher.SoledadSQLCipherWrapper + u1db_wrapper = sqlcipher.SoledadSQLCipherWrapper """ The U1DB wrapper to use. """ diff --git a/client/src/leap/soledad/client/_blobs.py b/client/src/leap/soledad/client/_database/blobs.py index 1475b302..79404bf3 100644 --- a/client/src/leap/soledad/client/_blobs.py +++ b/client/src/leap/soledad/client/_database/blobs.py @@ -20,8 +20,9 @@ Clientside BlobBackend Storage. from urlparse import urljoin +import binascii +import errno import os -import uuid import base64 from io import BytesIO @@ -34,13 +35,18 @@ from twisted.web.client import FileBodyProducer import treq -from leap.soledad.client.sqlcipher import SQLCipherOptions -from leap.soledad.client import pragmas -from leap.soledad.client._pipes import TruncatedTailPipe, PreamblePipe from leap.soledad.common.errors import SoledadError -from _crypto import DocInfo, BlobEncryptor, BlobDecryptor -from _http import HTTPClient +from .._document import BlobDoc +from .._crypto import DocInfo +from .._crypto import BlobEncryptor +from .._crypto import BlobDecryptor +from .._http import HTTPClient +from .._pipes import TruncatedTailPipe +from .._pipes import PreamblePipe + +from . import pragmas +from . import sqlcipher logger = Logger() @@ -129,6 +135,16 @@ class DecrypterBuffer(object): return self.decrypter._end_stream(), real_size +def mkdir_p(path): + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise + + class BlobManager(object): """ Ideally, the decrypting flow goes like this: @@ -149,6 +165,7 @@ class BlobManager(object): self, local_path, remote, key, secret, user, token=None, cert_file=None): if local_path: + mkdir_p(os.path.dirname(local_path)) self.local = SQLiteBlobBackend(local_path, key) self.remote = remote self.secret = secret @@ -280,10 +297,13 @@ class SQLiteBlobBackend(object): def __init__(self, path, key=None): self.path = os.path.abspath( os.path.join(path, 'soledad_blob.db')) + mkdir_p(os.path.dirname(self.path)) if not key: raise ValueError('key cannot be None') backend = 'pysqlcipher.dbapi2' - opts = SQLCipherOptions('/tmp/ignored', key) + opts = sqlcipher.SQLCipherOptions( + '/tmp/ignored', binascii.b2a_hex(key), + is_raw_key=True, create=True) pragmafun = partial(pragmas.set_init_pragmas, opts=opts) openfun = _sqlcipherInitFactory(pragmafun) @@ -292,7 +312,11 @@ class SQLiteBlobBackend(object): cp_openfun=openfun, cp_min=1, cp_max=2, cp_name='blob_pool') def close(self): - return self.dbpool.close() + from twisted._threads import AlreadyQuit + try: + self.dbpool.close() + except AlreadyQuit: + pass @defer.inlineCallbacks def put(self, blob_id, blob_fd, size=None): @@ -352,20 +376,6 @@ def _sqlcipherInitFactory(fun): return _initialize -class BlobDoc(object): - - # TODO probably not needed, but convenient for testing for now. - - def __init__(self, content, blob_id): - - self.blob_id = blob_id - self.is_blob = True - self.blob_fd = content - if blob_id is None: - blob_id = uuid.uuid4().get_hex() - self.blob_id = blob_id - - # # testing facilities # @@ -431,8 +441,7 @@ def testit(reactor): # TODO convert these into proper unittests def _manager(): - if not os.path.isdir(args.path): - os.makedirs(args.path) + mkdir_p(os.path.dirname(args.path)) manager = BlobManager( args.path, args.url, 'A' * 32, args.secret, diff --git a/client/src/leap/soledad/client/_database/dbschema.sql b/client/src/leap/soledad/client/_database/dbschema.sql new file mode 100644 index 00000000..ae027fc5 --- /dev/null +++ b/client/src/leap/soledad/client/_database/dbschema.sql @@ -0,0 +1,42 @@ +-- Database schema +CREATE TABLE transaction_log ( + generation INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT NOT NULL, + transaction_id TEXT NOT NULL +); +CREATE TABLE document ( + doc_id TEXT PRIMARY KEY, + doc_rev TEXT NOT NULL, + content TEXT +); +CREATE TABLE document_fields ( + doc_id TEXT NOT NULL, + field_name TEXT NOT NULL, + value TEXT +); +CREATE INDEX document_fields_field_value_doc_idx + ON document_fields(field_name, value, doc_id); + +CREATE TABLE sync_log ( + replica_uid TEXT PRIMARY KEY, + known_generation INTEGER, + known_transaction_id TEXT +); +CREATE TABLE conflicts ( + doc_id TEXT, + doc_rev TEXT, + content TEXT, + CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) +); +CREATE TABLE index_definitions ( + name TEXT, + offset INT, + field TEXT, + CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) +); +create index index_definitions_field on index_definitions(field); +CREATE TABLE u1db_config ( + name TEXT PRIMARY KEY, + value TEXT +); +INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/client/src/leap/soledad/client/pragmas.py b/client/src/leap/soledad/client/_database/pragmas.py index 870ed63e..870ed63e 100644 --- a/client/src/leap/soledad/client/pragmas.py +++ b/client/src/leap/soledad/client/_database/pragmas.py diff --git a/client/src/leap/soledad/client/sqlcipher.py b/client/src/leap/soledad/client/_database/sqlcipher.py index 2c995d5a..d22017bd 100644 --- a/client/src/leap/soledad/client/sqlcipher.py +++ b/client/src/leap/soledad/client/_database/sqlcipher.py @@ -50,16 +50,16 @@ from twisted.internet import reactor from twisted.internet import defer from twisted.enterprise import adbapi -from leap.soledad.common.document import SoledadDocument from leap.soledad.common.log import getLogger from leap.soledad.common.l2db import errors as u1db_errors -from leap.soledad.common.l2db import Document -from leap.soledad.common.l2db.backends import sqlite_backend from leap.soledad.common.errors import DatabaseAccessError from leap.soledad.client.http_target import SoledadHTTPSyncTarget from leap.soledad.client.sync import SoledadSynchronizer -from leap.soledad.client import pragmas + +from .._document import Document +from . import sqlite +from . import pragmas if sys.version_info[0] < 3: from pysqlcipher import dbapi2 as sqlcipher_dbapi2 @@ -69,8 +69,8 @@ else: logger = getLogger(__name__) -# Monkey-patch u1db.backends.sqlite_backend with pysqlcipher.dbapi2 -sqlite_backend.dbapi2 = sqlcipher_dbapi2 +# Monkey-patch u1db.backends.sqlite with pysqlcipher.dbapi2 +sqlite.dbapi2 = sqlcipher_dbapi2 # we may want to collect statistics from the sync process @@ -193,7 +193,7 @@ class SQLCipherOptions(object): # The SQLCipher database # -class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase): +class SQLCipherDatabase(sqlite.SQLitePartialExpandDatabase): """ A U1DB implementation that uses SQLCipher as its persistence layer. """ @@ -232,7 +232,7 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase): # --------------------------------------------------------- self._ensure_schema() - self.set_document_factory(soledad_doc_factory) + self.set_document_factory(doc_factory) self._prime_replica_uid() def _prime_replica_uid(self): @@ -341,7 +341,7 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase): :param doc: The new version of the document. :type doc: u1db.Document """ - sqlite_backend.SQLitePartialExpandDatabase._put_and_update_indexes( + sqlite.SQLitePartialExpandDatabase._put_and_update_indexes( self, old_doc, doc) c = self._db_handle.cursor() c.execute('UPDATE document SET syncable=? WHERE doc_id=?', @@ -361,7 +361,7 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase): :return: a Document object. :type: u1db.Document """ - doc = sqlite_backend.SQLitePartialExpandDatabase._get_doc( + doc = sqlite.SQLitePartialExpandDatabase._get_doc( self, doc_id, check_for_conflicts) if doc: c = self._db_handle.cursor() @@ -437,7 +437,7 @@ class SQLCipherU1DBSync(SQLCipherDatabase): self._opts, check_same_thread=False) self._real_replica_uid = None self._ensure_schema() - self.set_document_factory(soledad_doc_factory) + self.set_document_factory(doc_factory) except sqlcipher_dbapi2.DatabaseError as e: raise DatabaseAccessError(str(e)) @@ -505,7 +505,7 @@ class SQLCipherU1DBSync(SQLCipherDatabase): return self._get_generation() -class U1DBSQLiteBackend(sqlite_backend.SQLitePartialExpandDatabase): +class U1DBSQLiteBackend(sqlite.SQLitePartialExpandDatabase): """ A very simple wrapper for u1db around sqlcipher backend. @@ -537,7 +537,7 @@ class SoledadSQLCipherWrapper(SQLCipherDatabase): self._db_handle = conn self._real_replica_uid = None self._ensure_schema() - self.set_document_factory(soledad_doc_factory) + self.set_document_factory(doc_factory) self._prime_replica_uid() @@ -562,7 +562,7 @@ def _assert_db_is_encrypted(opts): # If the regular backend succeeds, then we need to stop because # the database was not properly initialized. try: - sqlite_backend.SQLitePartialExpandDatabase(opts.path) + sqlite.SQLitePartialExpandDatabase(opts.path) except sqlcipher_dbapi2.DatabaseError: # assert that we can access it using SQLCipher with the given # key @@ -583,17 +583,17 @@ class DatabaseIsNotEncrypted(Exception): pass -def soledad_doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False, - syncable=True): +def doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False, + syncable=True): """ Return a default Soledad Document. Used in the initialization for SQLCipherDatabase """ - return SoledadDocument(doc_id=doc_id, rev=rev, json=json, - has_conflicts=has_conflicts, syncable=syncable) + return Document(doc_id=doc_id, rev=rev, json=json, + has_conflicts=has_conflicts, syncable=syncable) -sqlite_backend.SQLiteDatabase.register_implementation(SQLCipherDatabase) +sqlite.SQLiteDatabase.register_implementation(SQLCipherDatabase) # diff --git a/client/src/leap/soledad/client/_database/sqlite.py b/client/src/leap/soledad/client/_database/sqlite.py new file mode 100644 index 00000000..4f7b1259 --- /dev/null +++ b/client/src/leap/soledad/client/_database/sqlite.py @@ -0,0 +1,930 @@ +# Copyright 2011 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +""" +A L2DB implementation that uses SQLite as its persistence layer. +""" + +import errno +import os +import json +import sys +import time +import uuid +import pkg_resources + +from sqlite3 import dbapi2 + +from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget +from leap.soledad.common.l2db import ( + Document, errors, + query_parser, vectorclock) + + +class SQLiteDatabase(CommonBackend): + """A U1DB implementation that uses SQLite as its persistence layer.""" + + _sqlite_registry = {} + + def __init__(self, sqlite_file, document_factory=None): + """Create a new sqlite file.""" + self._db_handle = dbapi2.connect(sqlite_file) + self._real_replica_uid = None + self._ensure_schema() + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return SQLiteSyncTarget(self) + + @classmethod + def _which_index_storage(cls, c): + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'index_storage'") + except dbapi2.OperationalError as e: + # The table does not exist yet + return None, e + else: + return c.fetchone()[0], None + + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + + @classmethod + def _open_database(cls, sqlite_file, document_factory=None): + if not os.path.isfile(sqlite_file): + raise errors.DatabaseDoesNotExist() + tries = 2 + while True: + # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + db_handle = dbapi2.connect(sqlite_file) + c = db_handle.cursor() + v, err = cls._which_index_storage(c) + db_handle.close() + if v is not None: + break + # possibly another process is initializing it, wait for it to be + # done + if tries == 0: + raise err # go for the richest error? + tries -= 1 + time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) + return SQLiteDatabase._sqlite_registry[v]( + sqlite_file, document_factory=document_factory) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None): + try: + return cls._open_database( + sqlite_file, document_factory=document_factory) + except errors.DatabaseDoesNotExist: + if not create: + raise + if backend_cls is None: + # default is SQLitePartialExpandDatabase + backend_cls = SQLitePartialExpandDatabase + return backend_cls(sqlite_file, document_factory=document_factory) + + @staticmethod + def delete_database(sqlite_file): + try: + os.unlink(sqlite_file) + except OSError as ex: + if ex.errno == errno.ENOENT: + raise errors.DatabaseDoesNotExist() + raise + + @staticmethod + def register_implementation(klass): + """Register that we implement an SQLiteDatabase. + + The attribute _index_storage_value will be used as the lookup key. + """ + SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass + + def _get_sqlite_handle(self): + """Get access to the underlying sqlite database. + + This should only be used by the test suite, etc, for examining the + state of the underlying database. + """ + return self._db_handle + + def _close_sqlite_handle(self): + """Release access to the underlying sqlite database.""" + self._db_handle.close() + + def close(self): + self._close_sqlite_handle() + + def _is_initialized(self, c): + """Check if this database has been initialized.""" + c.execute("PRAGMA case_sensitive_like=ON") + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'sql_schema'") + except dbapi2.OperationalError: + # The table does not exist yet + val = None + else: + val = c.fetchone() + if val is not None: + return True + return False + + def _initialize(self, c): + """Create the schema in the database.""" + # read the script with sql commands + # TODO: Change how we set up the dependency. Most likely use something + # like lp:dirspec to grab the file from a common resource + # directory. Doesn't specifically need to be handled until we get + # to the point of packaging this. + schema_content = pkg_resources.resource_string( + __name__, 'dbschema.sql') + # Note: We'd like to use c.executescript() here, but it seems that + # executescript always commits, even if you set + # isolation_level = None, so if we want to properly handle + # exclusive locking and rollbacks between processes, we need + # to execute it line-by-line + for line in schema_content.split(';'): + if not line: + continue + c.execute(line) + # add extra fields + self._extra_schema_init(c) + # A unique identifier should be set for this replica. Implementations + # don't have to strictly use uuid here, but we do want the uid to be + # unique amongst all databases that will sync with each other. + # We might extend this to using something with hostname for easier + # debugging. + self._set_replica_uid_in_transaction(uuid.uuid4().hex) + c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", + (self._index_storage_value,)) + + def _ensure_schema(self): + """Ensure that the database schema has been created.""" + old_isolation_level = self._db_handle.isolation_level + c = self._db_handle.cursor() + if self._is_initialized(c): + return + try: + # autocommit/own mgmt of transactions + self._db_handle.isolation_level = None + with self._db_handle: + # only one execution path should initialize the db + c.execute("begin exclusive") + if self._is_initialized(c): + return + self._initialize(c) + finally: + self._db_handle.isolation_level = old_isolation_level + + def _extra_schema_init(self, c): + """Add any extra fields, etc to the basic table definitions.""" + + def _parse_index_definition(self, index_field): + """Parse a field definition for an index, returning a Getter.""" + # Note: We may want to keep a Parser object around, and cache the + # Getter objects for a greater length of time. Specifically, if + # you create a bunch of indexes, and then insert 50k docs, you'll + # re-parse the indexes between puts. The time to insert the docs + # is still likely to dominate put_doc time, though. + parser = query_parser.Parser() + getter = parser.parse(index_field) + return getter + + def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): + """Update document_fields for a single document. + + :param doc_id: Identifier for this document + :param raw_doc: The python dict representation of the document. + :param getters: A list of [(field_name, Getter)]. Getter.get will be + called to evaluate the index definition for this document, and the + results will be inserted into the db. + :param db_cursor: An sqlite Cursor. + :return: None + """ + values = [] + for field_name, getter in getters: + for idx_value in getter.get(raw_doc): + values.append((doc_id, field_name, idx_value)) + if values: + db_cursor.executemany( + "INSERT INTO document_fields VALUES (?, ?, ?)", values) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + with self._db_handle: + self._set_replica_uid_in_transaction(replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO u1db_config" + " VALUES ('replica_uid', ?)", + (replica_uid,)) + self._real_replica_uid = replica_uid + + def _get_replica_uid(self): + if self._real_replica_uid is not None: + return self._real_replica_uid + c = self._db_handle.cursor() + c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") + val = c.fetchone() + if val is None: + return None + self._real_replica_uid = val[0] + return self._real_replica_uid + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + c = self._db_handle.cursor() + c.execute('SELECT max(generation) FROM transaction_log') + val = c.fetchone()[0] + if val is None: + return 0 + return val + + def _get_generation_info(self): + c = self._db_handle.cursor() + c.execute( + 'SELECT max(generation), transaction_id FROM transaction_log ') + val = c.fetchone() + if val[0] is None: + return(0, '') + return val + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + c = self._db_handle.cursor() + c.execute( + 'SELECT transaction_id FROM transaction_log WHERE generation = ?', + (generation,)) + val = c.fetchone() + if val is None: + raise errors.InvalidGeneration + return val[0] + + def _get_transaction_log(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, transaction_id FROM transaction_log" + " ORDER BY generation") + return c.fetchall() + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + c = self._db_handle.cursor() + if check_for_conflicts: + c.execute( + "SELECT document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " + "conflicts ON conflicts.doc_id = document.doc_id WHERE " + "document.doc_id = ? GROUP BY document.doc_id, " + "document.doc_rev, document.content;", (doc_id,)) + else: + c.execute( + "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", + (doc_id,)) + val = c.fetchone() + if val is None: + return None + doc_rev, content, conflicts = val + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + return doc + + def _has_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", + (doc_id,)) + val = c.fetchone() + if val is None: + return False + else: + return True + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + c = self._db_handle.cursor() + c.execute( + "SELECT document.doc_id, document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " + "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " + "document.doc_rev, document.content;") + rows = c.fetchall() + for doc_id, doc_rev, content, conflicts in rows: + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): + """Convert a dict representation into named fields. + + So something like: {'key1': 'val1', 'key2': 'val2'} + gets converted into: [(doc_id, 'key1', 'val1', 0) + (doc_id, 'key2', 'val2', 0)] + :param doc_id: Just added to every record. + :param base_field: if set, these are nested keys, so each field should + be appropriately prefixed. + :param raw_doc: The python dictionary. + """ + # TODO: Handle lists + values = [] + for field_name, value in raw_doc.iteritems(): + if value is None and not save_none: + continue + if base_field: + full_name = base_field + '.' + field_name + else: + full_name = field_name + if value is None or isinstance(value, (int, float, basestring)): + values.append((doc_id, full_name, value, len(values))) + else: + subvalues = self._expand_to_fields(doc_id, full_name, value, + save_none) + for _, subfield_name, val, _ in subvalues: + values.append((doc_id, subfield_name, val, len(values))) + return values + + def _put_and_update_indexes(self, old_doc, doc): + """Actually insert a document into the database. + + This both updates the existing documents content, and any indexes that + refer to this document. + """ + raise NotImplementedError(self._put_and_update_indexes) + + def whats_changed(self, old_generation=0): + c = self._db_handle.cursor() + c.execute("SELECT generation, doc_id, transaction_id" + " FROM transaction_log" + " WHERE generation > ? ORDER BY generation DESC", + (old_generation,)) + results = c.fetchall() + cur_gen = old_generation + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + c.execute("SELECT generation, transaction_id" + " FROM transaction_log ORDER BY generation DESC LIMIT 1") + results = c.fetchone() + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, newest_trans_id = results + + return cur_gen, newest_trans_id, changes + + def delete_doc(self, doc): + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _get_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", + (doc_id,)) + return [self._factory(doc_id, doc_rev, content) + for doc_rev, content in c.fetchall()] + + def get_doc_conflicts(self, doc_id): + with self._db_handle: + conflict_docs = self._get_conflicts(doc_id) + if not conflict_docs: + return [] + this_doc = self._get_doc(doc_id) + this_doc.has_conflicts = True + return [this_doc] + conflict_docs + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + c = self._db_handle.cursor() + c.execute("SELECT known_generation, known_transaction_id FROM sync_log" + " WHERE replica_uid = ?", + (other_replica_uid,)) + val = c.fetchone() + if val is None: + other_gen = 0 + trans_id = '' + else: + other_gen = val[0] + trans_id = val[1] + return other_gen, trans_id + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + with self._db_handle: + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", + (other_replica_uid, other_generation, + other_transaction_id)) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + return super(SQLiteDatabase, self)._put_doc_if_newer( + doc, + save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + + def _add_conflict(self, c, doc_id, my_doc_rev, my_content): + c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", + (doc_id, my_doc_rev, my_content)) + + def _delete_conflicts(self, c, doc, conflict_revs): + deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] + c.executemany("DELETE FROM conflicts" + " WHERE doc_id=? AND doc_rev=?", deleting) + doc.has_conflicts = self._has_conflicts(doc.doc_id) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + c_revs_to_prune = [] + for c_doc in self._get_conflicts(doc.doc_id): + c_vcr = vectorclock.VectorClockRev(c_doc.rev) + if doc_vcr.is_newer(c_vcr): + c_revs_to_prune.append(c_doc.rev) + elif doc.same_content_as(c_doc): + c_revs_to_prune.append(c_doc.rev) + doc_vcr.maximize(c_vcr) + autoresolved = True + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + c = self._db_handle.cursor() + self._delete_conflicts(c, doc, c_revs_to_prune) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + c = self._db_handle.cursor() + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + def resolve_doc(self, doc, conflicted_doc_revs): + with self._db_handle: + cur_doc = self._get_doc(doc.doc_id) + # TODO: https://bugs.launchpad.net/u1db/+bug/928274 + # I think we have a logic bug in resolve_doc + # Specifically, cur_doc.rev is always in the final vector + # clock of revisions that we supersede, even if it wasn't in + # conflicted_doc_revs. We still add it as a conflict, but the + # fact that _put_doc_if_newer propagates resolutions means I + # think that conflict could accidentally be resolved. We need + # to add a test for this case first. (create a rev, create a + # conflict, create another conflict, resolve the first rev + # and first conflict, then make sure that the resolved + # rev doesn't supersede the second conflict rev.) It *might* + # not matter, because the superseding rev is in as a + # conflict, but it does seem incorrect + new_rev = self._ensure_maximal_rev(cur_doc.rev, + conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + c = self._db_handle.cursor() + doc.rev = new_rev + if cur_doc.rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) + # TODO: Is there some way that we could construct a rev that would + # end up in superseded_revs, such that we add a conflict, and + # then immediately delete it? + self._delete_conflicts(c, doc, superseded_revs) + + def list_indexes(self): + """Return the list of indexes and their definitions.""" + c = self._db_handle.cursor() + # TODO: How do we test the ordering? + c.execute("SELECT name, field FROM index_definitions" + " ORDER BY name, offset") + definitions = [] + cur_name = None + for name, field in c.fetchall(): + if cur_name != name: + definitions.append((name, [])) + cur_name = name + definitions[-1][-1].append(field) + return definitions + + def _get_index_definition(self, index_name): + """Return the stored definition for a given index_name.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions" + " WHERE name = ? ORDER BY offset", (index_name,)) + fields = [x[0] for x in c.fetchall()] + if not fields: + raise errors.IndexDoesNotExist + return fields + + @staticmethod + def _strip_glob(value): + """Remove the trailing * from a value.""" + assert value[-1] == '*' + return value[:-1] + + def _format_query(self, definition, key_values): + # First, build the definition. We join the document_fields table + # against itself, as many times as the 'width' of our definition. + # We then do a query for each key_value, one-at-a-time. + # Note: All of these strings are static, we could cache them, etc. + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + wildcard_where = [novalue_where[i] + + (" AND d%d.value NOT NULL" % (i,)) + for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + like_where = [novalue_where[i] + + (" AND d%d.value GLOB ?" % (i,)) + for i in range(len(definition))] + is_wildcard = False + # Merge the lists together, so that: + # [field1, field2, field3], [val1, val2, val3] + # Becomes: + # (field1, val1, field2, val2, field3, val3) + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(exact_where[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_from_index(self, index_name, *key_values): + definition = self._get_index_definition(index_name) + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + statement, args = self._format_query(definition, key_values) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError as e: + raise dbapi2.OperationalError( + str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def _format_range_query(self, definition, start_value, end_value): + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + wildcard_where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + like_where = [ + novalue_where[i] + ( + " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in + range(len(definition))] + range_where_lower = [ + novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in + range(len(definition))] + range_where_upper = [ + novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in + range(len(definition))] + args = [] + where = [] + if start_value: + if isinstance(start_value, basestring): + start_value = (start_value,) + if len(start_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, start_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(self._strip_glob(value)) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(value) + if end_value: + if isinstance(end_value, basestring): + end_value = (end_value,) + if len(end_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, end_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(self._strip_glob(value)) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_upper[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + definition = self._get_index_definition(index_name) + statement, args = self._format_range_query( + definition, start_value, end_value) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError as e: + raise dbapi2.OperationalError( + str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def get_index_keys(self, index_name): + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + value_fields = ', '.join([ + 'd%d.value' % i for i in range(len(definition))]) + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + statement = ( + "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( + value_fields, ', '.join(tables), ' AND '.join(where), + value_fields)) + try: + c.execute(statement, tuple(definition)) + except dbapi2.OperationalError as e: + raise dbapi2.OperationalError( + str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) + return c.fetchall() + + def delete_index(self, index_name): + with self._db_handle: + c = self._db_handle.cursor() + c.execute("DELETE FROM index_definitions WHERE name = ?", + (index_name,)) + c.execute( + "DELETE FROM document_fields WHERE document_fields.field_name " + " NOT IN (SELECT field from index_definitions)") + + +class SQLiteSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SQLitePartialExpandDatabase(SQLiteDatabase): + """An SQLite Backend that expands documents into a document_field table. + + It stores the original document text in document.doc. For fields that are + indexed, the data goes into document_fields. + """ + + _index_storage_value = 'expand referenced' + + def _get_indexed_fields(self): + """Determine what fields are indexed.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions") + return set([x[0] for x in c.fetchall()]) + + def _evaluate_index(self, raw_doc, field): + parser = query_parser.Parser() + getter = parser.parse(field) + return getter.get(raw_doc) + + def _put_and_update_indexes(self, old_doc, doc): + c = self._db_handle.cursor() + if doc and not doc.is_tombstone(): + raw_doc = json.loads(doc.get_json()) + else: + raw_doc = {} + if old_doc is not None: + c.execute("UPDATE document SET doc_rev=?, content=?" + " WHERE doc_id = ?", + (doc.rev, doc.get_json(), doc.doc_id)) + c.execute("DELETE FROM document_fields WHERE doc_id = ?", + (doc.doc_id,)) + else: + c.execute("INSERT INTO document (doc_id, doc_rev, content)" + " VALUES (?, ?, ?)", + (doc.doc_id, doc.rev, doc.get_json())) + indexed_fields = self._get_indexed_fields() + if indexed_fields: + # It is expected that len(indexed_fields) is shorter than + # len(raw_doc) + getters = [(field, self._parse_index_definition(field)) + for field in indexed_fields] + self._update_indexes(doc.doc_id, raw_doc, getters, c) + trans_id = self._allocate_transaction_id() + c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" + " VALUES (?, ?)", (doc.doc_id, trans_id)) + + def create_index(self, index_name, *index_expressions): + with self._db_handle: + c = self._db_handle.cursor() + cur_fields = self._get_indexed_fields() + definition = [(index_name, idx, field) + for idx, field in enumerate(index_expressions)] + try: + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + definition) + except dbapi2.IntegrityError as e: + stored_def = self._get_index_definition(index_name) + if stored_def == [x[-1] for x in definition]: + return + raise errors.IndexNameTakenError( + str(e) + + str(sys.exc_info()[2]) + ) + new_fields = set( + [f for f in index_expressions if f not in cur_fields]) + if new_fields: + self._update_all_indexes(new_fields) + + def _iter_all_docs(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, content FROM document") + while True: + next_rows = c.fetchmany() + if not next_rows: + break + for row in next_rows: + yield row + + def _update_all_indexes(self, new_fields): + """Iterate all the documents, and add content to document_fields. + + :param new_fields: The index definitions that need to be added. + """ + getters = [(field, self._parse_index_definition(field)) + for field in new_fields] + c = self._db_handle.cursor() + for doc_id, doc in self._iter_all_docs(): + if doc is None: + continue + raw_doc = json.loads(doc) + self._update_indexes(doc_id, raw_doc, getters, c) + + +SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/client/src/leap/soledad/client/_document.py b/client/src/leap/soledad/client/_document.py new file mode 100644 index 00000000..9ba08e93 --- /dev/null +++ b/client/src/leap/soledad/client/_document.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# _document.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Everything related to documents. +""" + +import enum +import weakref +import uuid + +from twisted.internet import defer + +from zope.interface import Interface +from zope.interface import implementer + +from leap.soledad.common.document import SoledadDocument + + +class IDocumentWithAttachment(Interface): + """ + A document that can have an attachment. + """ + + def set_store(self, store): + """ + Set the store used by this file to manage attachments. + + :param store: The store used to manage attachments. + :type store: Soledad + """ + + def put_attachment(self, fd): + """ + Attach data to this document. + + Add the attachment to local storage, enqueue for upload. + + The document content will be updated with a pointer to the attachment, + but the document has to be manually put in the database to reflect + modifications. + + :param fd: A file-like object whose content will be attached to this + document. + :type fd: file-like + + :return: A deferred which fires when the attachment has been added to + local storage. + :rtype: Deferred + """ + + def get_attachment(self): + """ + Return the data attached to this document. + + If document content contains a pointer to the attachment, try to get + the attachment from local storage and, if not found, from remote + storage. + + :return: A deferred which fires with a file like-object whose content + is the attachment of this document, or None if nothing is + attached. + :rtype: Deferred + """ + + def delete_attachment(self): + """ + Delete the attachment of this document. + + The pointer to the attachment will be removed from the document + content, but the document has to be manually put in the database to + reflect modifications. + + :return: A deferred which fires when the attachment has been deleted + from local storage. + :rtype: Deferred + """ + + def attachment_state(self): + """ + Return the state of the attachment of this document. + + The state is a member of AttachmentStates and is of one of NONE, + LOCAL, REMOTE or SYNCED. + + :return: A deferred which fires with The state of the attachment of + this document. + :rtype: Deferred + """ + + def is_dirty(self): + """ + Return wether this document's content differs from the contents stored + in local database. + + :return: Whether this document is dirty or not. + :rtype: bool + """ + + def upload_attachment(self): + """ + Upload this document's attachment. + + :return: A deferred which fires with the state of the attachment after + it's been uploaded, or NONE if there's no attachment for this + document. + :rtype: Deferred + """ + + def download_attachment(self): + """ + Download this document's attachment. + + :return: A deferred which fires with the state of the attachment after + it's been downloaded, or NONE if there's no attachment for + this document. + :rtype: Deferred + """ + + +class BlobDoc(object): + + # TODO probably not needed, but convenient for testing for now. + + def __init__(self, content, blob_id): + + self.blob_id = blob_id + self.is_blob = True + self.blob_fd = content + if blob_id is None: + blob_id = uuid.uuid4().get_hex() + self.blob_id = blob_id + + +class AttachmentStates(enum.IntEnum): + NONE = 0 + LOCAL = 1 + REMOTE = 2 + SYNCED = 4 + + +@implementer(IDocumentWithAttachment) +class Document(SoledadDocument): + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, + syncable=True, store=None): + SoledadDocument.__init__(self, doc_id=doc_id, rev=rev, json=json, + has_conflicts=has_conflicts, + syncable=syncable) + self.set_store(store) + + # + # properties + # + + @property + def _manager(self): + if not self.store or not hasattr(self.store, 'blobmanager'): + raise Exception('No blob manager found to manage attachments.') + return self.store.blobmanager + + @property + def _blob_id(self): + if self.content and 'blob_id' in self.content: + return self.content['blob_id'] + return None + + def get_store(self): + return self._store() if self._store else None + + def set_store(self, store): + self._store = weakref.ref(store) if store else None + + store = property(get_store, set_store) + + # + # attachment api + # + + def put_attachment(self, fd): + # add pointer to content + blob_id = self._blob_id or str(uuid.uuid4()) + if not self.content: + self.content = {} + self.content['blob_id'] = blob_id + # put using manager + blob = BlobDoc(fd, blob_id) + fd.seek(0, 2) + size = fd.tell() + fd.seek(0) + return self._manager.put(blob, size) + + def get_attachment(self): + if not self._blob_id: + return defer.succeed(None) + return self._manager.get(self._blob_id) + + def delete_attachment(self): + raise NotImplementedError + + @defer.inlineCallbacks + def attachment_state(self): + state = AttachmentStates.NONE + + if not self._blob_id: + defer.returnValue(state) + + local_list = yield self._manager.local_list() + if self._blob_id in local_list: + state |= AttachmentStates.LOCAL + + remote_list = yield self._manager.remote_list() + if self._blob_id in remote_list: + state |= AttachmentStates.REMOTE + + defer.returnValue(state) + + @defer.inlineCallbacks + def is_dirty(self): + stored = yield self.store.get_doc(self.doc_id) + if stored.content != self.content: + defer.returnValue(True) + defer.returnValue(False) + + @defer.inlineCallbacks + def upload_attachment(self): + if not self._blob_id: + defer.returnValue(AttachmentStates.NONE) + + fd = yield self._manager.get_blob(self._blob_id) + # TODO: turn following method into a public one + yield self._manager._encrypt_and_upload(self._blob_id, fd) + defer.returnValue(self.attachment_state()) + + @defer.inlineCallbacks + def download_attachment(self): + if not self._blob_id: + defer.returnValue(None) + yield self.get_attachment() + defer.returnValue(self.attachment_state()) diff --git a/client/src/leap/soledad/client/_secrets/storage.py b/client/src/leap/soledad/client/_secrets/storage.py index 6ea89900..85713a48 100644 --- a/client/src/leap/soledad/client/_secrets/storage.py +++ b/client/src/leap/soledad/client/_secrets/storage.py @@ -23,8 +23,8 @@ from hashlib import sha256 from leap.soledad.common import SHARED_DB_NAME from leap.soledad.common.log import getLogger -from leap.soledad.common.document import SoledadDocument from leap.soledad.client.shared_db import SoledadSharedDatabase +from leap.soledad.client._document import Document from leap.soledad.client._secrets.util import emit, UserDataMixin @@ -111,7 +111,7 @@ class SecretsStorage(UserDataMixin): def save_remote(self, encrypted): doc = self._remote_doc if not doc: - doc = SoledadDocument(doc_id=self._remote_doc_id()) + doc = Document(doc_id=self._remote_doc_id()) doc.content = encrypted db = self._shared_db if not db: diff --git a/client/src/leap/soledad/client/api.py b/client/src/leap/soledad/client/api.py index 18fb35d7..3dd99227 100644 --- a/client/src/leap/soledad/client/api.py +++ b/client/src/leap/soledad/client/api.py @@ -51,13 +51,14 @@ from leap.soledad.common.l2db.remote import http_client from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname from leap.soledad.common.errors import DatabaseAccessError -from leap.soledad.client import adbapi -from leap.soledad.client import events as soledad_events -from leap.soledad.client import interfaces as soledad_interfaces -from leap.soledad.client import sqlcipher -from leap.soledad.client._recovery_code import RecoveryCode -from leap.soledad.client._secrets import Secrets -from leap.soledad.client._crypto import SoledadCrypto +from . import events as soledad_events +from . import interfaces as soledad_interfaces +from ._crypto import SoledadCrypto +from ._database import adbapi +from ._database import blobs +from ._database import sqlcipher +from ._recovery_code import RecoveryCode +from ._secrets import Secrets logger = getLogger(__name__) @@ -171,6 +172,7 @@ class Soledad(object): self._recovery_code = RecoveryCode() self._secrets = Secrets(self) self._crypto = SoledadCrypto(self._secrets.remote_secret) + self._init_blobmanager() try: # initialize database access, trap any problems so we can shutdown @@ -262,6 +264,13 @@ class Soledad(object): sync_exchange_phase = _p return sync_phase, sync_exchange_phase + def _init_blobmanager(self): + path = os.path.join(os.path.dirname(self._local_db_path), 'blobs') + url = urlparse.urljoin(self.server_url, 'blobs/%s' % uuid) + key = self._secrets.local_key + self.blobmanager = blobs.BlobManager(path, url, key, self.uuid, + self.token, SOLEDAD_CERT) + # # Closing methods # @@ -272,6 +281,7 @@ class Soledad(object): """ logger.debug("closing soledad") self._dbpool.close() + self.blobmanager.close() if getattr(self, '_dbsyncer', None): self._dbsyncer.close() @@ -306,7 +316,7 @@ class Soledad(object): ============================== WARNING ============================== :param doc: A document with new content. - :type doc: leap.soledad.common.document.SoledadDocument + :type doc: leap.soledad.common.document.Document :return: A deferred whose callback will be invoked with the new revision identifier for the document. The document object will also be updated. @@ -323,7 +333,7 @@ class Soledad(object): This will also set doc.content to None. :param doc: A document to be deleted. - :type doc: leap.soledad.common.document.SoledadDocument + :type doc: leap.soledad.common.document.Document :return: A deferred. :rtype: twisted.internet.defer.Deferred """ @@ -387,6 +397,7 @@ class Soledad(object): """ return self._defer("get_all_docs", include_deleted) + @defer.inlineCallbacks def create_doc(self, content, doc_id=None): """ Create a new document. @@ -408,8 +419,9 @@ class Soledad(object): # create_doc (and probably to put_doc too). There are cases (mail # payloads for example) in which we already have the encoding in the # headers, so we don't need to guess it. - d = self._defer("create_doc", content, doc_id=doc_id) - return d + doc = yield self._defer("create_doc", content, doc_id=doc_id) + doc.set_store(self) + defer.returnValue(doc) def create_doc_from_json(self, json, doc_id=None): """ @@ -590,7 +602,7 @@ class Soledad(object): the time you GET_DOC_CONFLICTS until the point where you RESOLVE) :param doc: A Document with the new content to be inserted. - :type doc: SoledadDocument + :type doc: Document :param conflicted_doc_revs: A list of revisions that the new content supersedes. :type conflicted_doc_revs: list(str) diff --git a/client/src/leap/soledad/client/crypto.py b/client/src/leap/soledad/client/crypto.py index 09e90171..4795846c 100644 --- a/client/src/leap/soledad/client/crypto.py +++ b/client/src/leap/soledad/client/crypto.py @@ -167,7 +167,7 @@ class SoledadCrypto(object): Wrapper around encrypt_docstr that accepts the document as argument. :param doc: the document. - :type doc: SoledadDocument + :type doc: Document """ key = self.doc_passphrase(doc.doc_id) @@ -179,7 +179,7 @@ class SoledadCrypto(object): Wrapper around decrypt_doc_dict that accepts the document as argument. :param doc: the document. - :type doc: SoledadDocument + :type doc: Document :return: json string with the decrypted document :rtype: str @@ -194,7 +194,7 @@ class SoledadCrypto(object): # -# Crypto utilities for a SoledadDocument. +# Crypto utilities for a Document. # def mac_doc(doc_id, doc_rev, ciphertext, enc_scheme, enc_method, enc_iv, @@ -439,7 +439,7 @@ def is_symmetrically_encrypted(doc): Return True if the document was symmetrically encrypted. :param doc: The document to check. - :type doc: SoledadDocument + :type doc: Document :rtype: bool """ diff --git a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py index 92bc85d6..276b1200 100644 --- a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py +++ b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times.py @@ -28,7 +28,7 @@ from twisted.internet import defer, reactor from leap.soledad.common import l2db from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherOptions folder = os.environ.get("TMPDIR", "tmp") diff --git a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py index 429566c7..d288582b 100644 --- a/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py +++ b/client/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py @@ -27,7 +27,7 @@ import sys from twisted.internet import defer, reactor from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherOptions from leap.soledad.common import l2db diff --git a/client/src/leap/soledad/client/examples/use_adbapi.py b/client/src/leap/soledad/client/examples/use_adbapi.py index 39301b41..25b32307 100644 --- a/client/src/leap/soledad/client/examples/use_adbapi.py +++ b/client/src/leap/soledad/client/examples/use_adbapi.py @@ -24,7 +24,7 @@ import os from twisted.internet import defer, reactor from leap.soledad.client import adbapi -from leap.soledad.client.sqlcipher import SQLCipherOptions +from leap.soledad.client._database.sqlcipher import SQLCipherOptions from leap.soledad.common import l2db diff --git a/client/src/leap/soledad/client/http_target/fetch.py b/client/src/leap/soledad/client/http_target/fetch.py index cf4984d1..9d456830 100644 --- a/client/src/leap/soledad/client/http_target/fetch.py +++ b/client/src/leap/soledad/client/http_target/fetch.py @@ -23,10 +23,10 @@ from leap.soledad.client.events import emit_async from leap.soledad.client.http_target.support import RequestBody from leap.soledad.common.log import getLogger from leap.soledad.client._crypto import is_symmetrically_encrypted -from leap.soledad.common.document import SoledadDocument from leap.soledad.common.l2db import errors from leap.soledad.client import crypto as old_crypto +from .._document import Document from . import fetch_protocol logger = getLogger(__name__) @@ -113,7 +113,7 @@ class HTTPDocFetcher(object): @defer.inlineCallbacks def __atomic_doc_parse(self, doc_info, content, total): - doc = SoledadDocument(doc_info['id'], doc_info['rev'], content) + doc = Document(doc_info['id'], doc_info['rev'], content) if is_symmetrically_encrypted(content): content = (yield self._crypto.decrypt_doc(doc)).getvalue() elif old_crypto.is_symmetrically_encrypted(doc): diff --git a/client/src/leap/soledad/client/interfaces.py b/client/src/leap/soledad/client/interfaces.py index 1be47df7..0600449f 100644 --- a/client/src/leap/soledad/client/interfaces.py +++ b/client/src/leap/soledad/client/interfaces.py @@ -69,7 +69,7 @@ class ILocalStorage(Interface): Update a document in the local encrypted database. :param doc: the document to update - :type doc: SoledadDocument + :type doc: Document :return: a deferred that will fire with the new revision identifier for @@ -82,7 +82,7 @@ class ILocalStorage(Interface): Delete a document from the local encrypted database. :param doc: the document to delete - :type doc: SoledadDocument + :type doc: Document :return: a deferred that will fire with ... @@ -102,7 +102,7 @@ class ILocalStorage(Interface): :return: A deferred that will fire with the document object, containing a - SoledadDocument, or None if it could not be found + Document, or None if it could not be found :rtype: Deferred """ @@ -147,7 +147,7 @@ class ILocalStorage(Interface): :type doc_id: str :return: - A deferred tht will fire with the new document (SoledadDocument + A deferred tht will fire with the new document (Document instance). :rtype: Deferred """ @@ -167,7 +167,7 @@ class ILocalStorage(Interface): :param doc_id: An optional identifier specifying the document id. :type doc_id: :return: - A deferred that will fire with the new document (A SoledadDocument + A deferred that will fire with the new document (A Document instance) :rtype: Deferred """ @@ -304,7 +304,7 @@ class ILocalStorage(Interface): Mark a document as no longer conflicted. :param doc: a document with the new content to be inserted. - :type doc: SoledadDocument + :type doc: Document :param conflicted_doc_revs: A deferred that will fire with a list of revisions that the new content supersedes. |