From 584696e4dbfc13b793208dc4c5c6cdc224db5a12 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:07:53 -0200 Subject: Remove u1db and swiftclient dirs and refactor. --- src/leap/soledad/backends/__init__.py | 0 src/leap/soledad/backends/leap.py | 157 ++++++++++++++ src/leap/soledad/backends/openstack.py | 369 +++++++++++++++++++++++++++++++++ 3 files changed, 526 insertions(+) create mode 100644 src/leap/soledad/backends/__init__.py create mode 100644 src/leap/soledad/backends/leap.py create mode 100644 src/leap/soledad/backends/openstack.py (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/__init__.py b/src/leap/soledad/backends/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py new file mode 100644 index 00000000..2c815632 --- /dev/null +++ b/src/leap/soledad/backends/leap.py @@ -0,0 +1,157 @@ +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import Document +from u1db.remote.http_target import HTTPSyncTarget +from u1db.remote.http_database import HTTPDatabase +import base64 + + +class NoDefaultKey(Exception): + pass + + +class LeapDocument(Document): + """ + LEAP Documents are standard u1db documents with cabability of returning an + encrypted version of the document json string as well as setting document + content based on an encrypted version of json string. + """ + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, + encrypted_json=None, default_key=None, gpg_wrapper=None): + super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) + # we might want to get already initialized wrappers for testing. + if gpg_wrapper is None: + self._gpg = GPGWrapper() + else: + self._gpg = gpg_wrapper + if encrypted_json: + self.set_encrypted_json(encrypted_json) + self._default_key = default_key + + def get_encrypted_json(self): + """ + Returns document's json serialization encrypted with user's public key. + """ + if self._default_key is None: + raise NoDefaultKey() + cyphertext = self._gpg.encrypt(self.get_json(), + self._default_key, + always_trust = True) + # TODO: always trust? + return json.dumps({'cyphertext' : str(cyphertext)}) + + def set_encrypted_json(self, encrypted_json): + """ + Set document's content based on encrypted version of json string. + """ + cyphertext = json.loads(encrypted_json)['cyphertext'] + plaintext = str(self._gpg.decrypt(cyphertext)) + return self.set_json(plaintext) + + +class LeapDatabase(HTTPDatabase): + """Implement the HTTP remote database API to a Leap server.""" + + @staticmethod + def open_database(url, create): + db = LeapDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = LeapDatabase(url) + db._delete() + db.close() + + def get_sync_target(self): + st = LeapSyncTarget(self._url.geturl()) + st._creds = self._creds + return st + + +class LeapSyncTarget(HTTPSyncTarget): + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + """ + Does the same as parent's method but ensures incoming content will be + decrypted. + """ + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = LeapDocument(entry['id'], entry['rev'], + encrypted_json=entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + """ + Does the same as parent's method but encrypts content before syncing. + """ + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, + content=doc.get_encrypted_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py new file mode 100644 index 00000000..ec4609b4 --- /dev/null +++ b/src/leap/soledad/backends/openstack.py @@ -0,0 +1,369 @@ +from leap import * +from u1db import errors +from u1db.backends import CommonBackend +from u1db.remote.http_target import HTTPSyncTarget +from swiftclient import client + + +class OpenStackDatabase(CommonBackend): + """A U1DB implementation that uses OpenStack as its persistence layer.""" + + def __init__(self, auth_url, user, auth_key, container): + """Create a new OpenStack data container.""" + self._auth_url = auth_url + self._user = user + self._auth_key = auth_key + self._container = container + self.set_document_factory(LeapDocument) + self._connection = swiftclient.Connection(self._auth_url, self._user, + self._auth_key) + self._get_auth() + self._ensure_u1db_data() + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def set_document_factory(self, factory): + self._factory = factory + + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + self._get_u1db_data() + # This method is implemented in TransactionLog because testing is + # easier like this for now, but it can be moved to here afterwards. + return self._transaction_log.whats_changed(old_generation) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ + try: + response, contents = self._connection.get_object(self._container, doc_id) + rev = response['x-object-meta-rev'] + return self._factory(doc_id, rev, contents) + except swiftclient.ClientException: + return None + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + _, doc_ids = self._connection.get_container(self._container, + full_listing=True) + for doc_id in doc_ids: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + # TODO: check for conflicts? + new_rev = self._allocate_doc_rev(doc.rev) + headers = { 'X-Object-Meta-Rev' : new_rev } + self._connection.put_object(self._container, doc_id, doc.get_json(), + headers=headers) + new_gen = self._get_generation() + 1 + trans_id = self._allocate_transaction_id() + self._transaction_log.append((new_gen, doc.doc_id, trans_id)) + self._set_u1db_data() + return new_rev + + def delete_doc(self, doc): + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_doc(olddoc) + return new_rev + + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + + def get_doc_conflicts(self, doc_id): + return [] + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + return OpenStackSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + from u1db.sync import Synchronizer + from u1db.remote.http_target import OpenStackSyncTarget + return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + self._get_u1db_data() + return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._get_u1db_data() + self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, + other_generation, + other_transaction_id) + self._set_u1db_data() + + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- + + def _get_generation(self): + self._get_u1db_data() + return self._transaction_log.get_generation() + + def _get_generation_info(self): + self._get_u1db_data() + return self._transaction_log.get_generation_info() + + def _has_conflicts(self, doc_id): + # Documents never have conflicts on server. + return False + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + self._get_u1db_data() + trans_id = self._transaction_log.get_trans_id_for_gen(generation) + if trans_id is None: + raise errors.InvalidGeneration + return trans_id + + #------------------------------------------------------------------------- + # OpenStack specific methods + #------------------------------------------------------------------------- + + def _ensure_u1db_data(self): + """ + Guarantee that u1db data exists in store. + """ + if self._is_initialized(): + return + self._initialize() + + def _is_initialized(self): + """ + Verify if u1db data exists in store. + """ + if not self._get_doc('u1db_data'): + return False + return True + + def _initialize(self): + """ + Create u1db data object in store. + """ + content = { 'transaction_log' : [], + 'sync_log' : [] } + doc = self.create_doc('u1db_data', content) + + def _get_auth(self): + self._url, self._auth_token = self._connection.get_auth() + return self._url, self.auth_token + + def _get_u1db_data(self): + data = self.get_doc('u1db_data').content + self._transaction_log = data['transaction_log'] + self._sync_log = data['sync_log'] + + def _set_u1db_data(self): + doc = self._factory('u1db_data') + doc.content = { 'transaction_log' : self._transaction_log, + 'sync_log' : self._sync_log } + self.put_doc(doc) + + +class OpenStackSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SimpleLog(object): + def __init__(self): + self._log = [] + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return self._log + + log = property( + _get_log, _set_log, doc="Log contents.") + + def append(self, msg): + self._log.append(msg) + + def reduce(self, func, initializer=None): + return reduce(func, self.log, initializer) + + def map(self, func): + return map(func, self.log) + + def filter(self, func): + return filter(func, self.log) + + +class TransactionLog(SimpleLog): + """ + A list of (generation, doc_id, transaction_id) tuples. + """ + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return sorted(self._log, reverse=True) + + log = property( + _get_log, _set_log, doc="Log contents.") + + def get_generation(self): + """ + Return the current generation. + """ + gens = self.map(lambda x: x[0]) + if not gens: + return 0 + return max(gens) + + def get_generation_info(self): + """ + Return the current generation and transaction id. + """ + if not self._log: + return(0, '') + info = self.map(lambda x: (x[0], x[2])) + return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) + + def get_trans_id_for_gen(self, gen): + """ + Get the transaction id corresponding to a particular generation. + """ + log = self.reduce(lambda x, y: y if y[0] == gen else x) + if log is None: + return None + return log[2] + + def whats_changed(self, old_generation): + results = self.filter(lambda x: x[0] > old_generation) + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + results = self.log + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, _, newest_trans_id = results[0] + + return cur_gen, newest_trans_id, changes + + + +class SyncLog(SimpleLog): + """ + A list of (replica_id, generation, transaction_id) tuples. + """ + + def find_by_replica_uid(self, replica_uid): + if not self.log: + return () + return self.reduce(lambda x, y: y if y[0] == replica_uid else x) + + def get_replica_gen_and_trans_id(self, other_replica_uid): + """ + Return the last known generation and transaction id for the other db + replica. + """ + info = self.find_by_replica_uid(other_replica_uid) + if not info: + return (0, '') + return (info[1], info[2]) + + def set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """ + Set the last-known generation and transaction id for the other + database replica. + """ + self.log = self.filter(lambda x: x[0] != other_replica_uid) + self.append((other_replica_uid, other_generation, + other_transaction_id)) + -- cgit v1.2.3 From b925c880a7d604e6f3ce437d17fdd8b1bb6cbae7 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:08:11 -0200 Subject: Add sqlcipher backend. --- src/leap/soledad/backends/sqlcipher.py | 954 +++++++++++++++++++++++++++++++++ 1 file changed, 954 insertions(+) create mode 100644 src/leap/soledad/backends/sqlcipher.py (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py new file mode 100644 index 00000000..24f47eed --- /dev/null +++ b/src/leap/soledad/backends/sqlcipher.py @@ -0,0 +1,954 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""A U1DB implementation that uses SQLCipher as its persistence layer.""" + +import errno +import os +try: + import simplejson as json +except ImportError: + import json # noqa +from sqlite3 import dbapi2 +import sys +import time +import uuid + +import pkg_resources + +from u1db.backends import CommonBackend, CommonSyncTarget +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) + + +def open(path, create, document_factory=None, password=None): + """Open a database at the given location. + + Will raise u1db.errors.DatabaseDoesNotExist if create=False and the + database does not already exist. + + :param path: The filesystem path for the database to open. + :param create: True/False, should the database be created if it doesn't + already exist? + :param document_factory: A function that will be called with the same + parameters as Document.__init__. + :return: An instance of Database. + """ + from u1db.backends import sqlite_backend + return sqlite_backend.SQLCipherDatabase.open_database( + path, create=create, document_factory=document_factory, password=password) + + +class SQLCipherDatabase(CommonBackend): + """A U1DB implementation that uses SQLCipher as its persistence layer.""" + + _sqlite_registry = {} + + @classmethod + def set_pragma_key(cls, db_handle, key): + db_handle.cursor().execute("PRAGMA key = '%s'" % key) + + def __init__(self, sqlite_file, document_factory=None, password=None): + """Create a new sqlite file.""" + self._db_handle = dbapi2.connect(sqlite_file) + if password: + SQLiteDatabase.set_pragma_key(self._db_handle, password) + self._real_replica_uid = None + self._ensure_schema() + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return SQLCipherSyncTarget(self) + + @classmethod + def _which_index_storage(cls, c): + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'index_storage'") + except dbapi2.OperationalError, e: + # The table does not exist yet + return None, e + else: + return c.fetchone()[0], None + + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + + @classmethod + def _open_database(cls, sqlite_file, document_factory=None, password=None): + if not os.path.isfile(sqlite_file): + raise errors.DatabaseDoesNotExist() + tries = 2 + while True: + # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + db_handle = dbapi2.connect(sqlite_file) + if password: + SQLiteDatabase.set_pragma_key(db_handle, password) + c = db_handle.cursor() + v, err = cls._which_index_storage(c) + db_handle.close() + if v is not None: + break + # possibly another process is initializing it, wait for it to be + # done + if tries == 0: + raise err # go for the richest error? + tries -= 1 + time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) + return SQLCipherDatabase._sqlite_registry[v]( + sqlite_file, document_factory=document_factory) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None, password=None): + try: + return cls._open_database(sqlite_file, + document_factory=document_factory, + password=password) + except errors.DatabaseDoesNotExist: + if not create: + raise + if backend_cls is None: + # default is SQLCipherPartialExpandDatabase + backend_cls = SQLCipherPartialExpandDatabase + return backend_cls(sqlite_file, document_factory=document_factory, + password=password) + + @staticmethod + def delete_database(sqlite_file): + try: + os.unlink(sqlite_file) + except OSError as ex: + if ex.errno == errno.ENOENT: + raise errors.DatabaseDoesNotExist() + raise + + @staticmethod + def register_implementation(klass): + """Register that we implement an SQLCipherDatabase. + + The attribute _index_storage_value will be used as the lookup key. + """ + SQLCipherDatabase._sqlite_registry[klass._index_storage_value] = klass + + def _get_sqlite_handle(self): + """Get access to the underlying sqlite database. + + This should only be used by the test suite, etc, for examining the + state of the underlying database. + """ + return self._db_handle + + def _close_sqlite_handle(self): + """Release access to the underlying sqlite database.""" + self._db_handle.close() + + def close(self): + self._close_sqlite_handle() + + def _is_initialized(self, c): + """Check if this database has been initialized.""" + c.execute("PRAGMA case_sensitive_like=ON") + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'sql_schema'") + except dbapi2.OperationalError: + # The table does not exist yet + val = None + else: + val = c.fetchone() + if val is not None: + return True + return False + + def _initialize(self, c): + """Create the schema in the database.""" + #read the script with sql commands + # TODO: Change how we set up the dependency. Most likely use something + # like lp:dirspec to grab the file from a common resource + # directory. Doesn't specifically need to be handled until we get + # to the point of packaging this. + schema_content = pkg_resources.resource_string( + __name__, 'dbschema.sql') + # Note: We'd like to use c.executescript() here, but it seems that + # executescript always commits, even if you set + # isolation_level = None, so if we want to properly handle + # exclusive locking and rollbacks between processes, we need + # to execute it line-by-line + for line in schema_content.split(';'): + if not line: + continue + c.execute(line) + #add extra fields + self._extra_schema_init(c) + # A unique identifier should be set for this replica. Implementations + # don't have to strictly use uuid here, but we do want the uid to be + # unique amongst all databases that will sync with each other. + # We might extend this to using something with hostname for easier + # debugging. + self._set_replica_uid_in_transaction(uuid.uuid4().hex) + c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", + (self._index_storage_value,)) + + def _ensure_schema(self): + """Ensure that the database schema has been created.""" + old_isolation_level = self._db_handle.isolation_level + c = self._db_handle.cursor() + if self._is_initialized(c): + return + try: + # autocommit/own mgmt of transactions + self._db_handle.isolation_level = None + with self._db_handle: + # only one execution path should initialize the db + c.execute("begin exclusive") + if self._is_initialized(c): + return + self._initialize(c) + finally: + self._db_handle.isolation_level = old_isolation_level + + def _extra_schema_init(self, c): + """Add any extra fields, etc to the basic table definitions.""" + + def _parse_index_definition(self, index_field): + """Parse a field definition for an index, returning a Getter.""" + # Note: We may want to keep a Parser object around, and cache the + # Getter objects for a greater length of time. Specifically, if + # you create a bunch of indexes, and then insert 50k docs, you'll + # re-parse the indexes between puts. The time to insert the docs + # is still likely to dominate put_doc time, though. + parser = query_parser.Parser() + getter = parser.parse(index_field) + return getter + + def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): + """Update document_fields for a single document. + + :param doc_id: Identifier for this document + :param raw_doc: The python dict representation of the document. + :param getters: A list of [(field_name, Getter)]. Getter.get will be + called to evaluate the index definition for this document, and the + results will be inserted into the db. + :param db_cursor: An sqlite Cursor. + :return: None + """ + values = [] + for field_name, getter in getters: + for idx_value in getter.get(raw_doc): + values.append((doc_id, field_name, idx_value)) + if values: + db_cursor.executemany( + "INSERT INTO document_fields VALUES (?, ?, ?)", values) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + with self._db_handle: + self._set_replica_uid_in_transaction(replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO u1db_config" + " VALUES ('replica_uid', ?)", + (replica_uid,)) + self._real_replica_uid = replica_uid + + def _get_replica_uid(self): + if self._real_replica_uid is not None: + return self._real_replica_uid + c = self._db_handle.cursor() + c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") + val = c.fetchone() + if val is None: + return None + self._real_replica_uid = val[0] + return self._real_replica_uid + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + c = self._db_handle.cursor() + c.execute('SELECT max(generation) FROM transaction_log') + val = c.fetchone()[0] + if val is None: + return 0 + return val + + def _get_generation_info(self): + c = self._db_handle.cursor() + c.execute( + 'SELECT max(generation), transaction_id FROM transaction_log ') + val = c.fetchone() + if val[0] is None: + return(0, '') + return val + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + c = self._db_handle.cursor() + c.execute( + 'SELECT transaction_id FROM transaction_log WHERE generation = ?', + (generation,)) + val = c.fetchone() + if val is None: + raise errors.InvalidGeneration + return val[0] + + def _get_transaction_log(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, transaction_id FROM transaction_log" + " ORDER BY generation") + return c.fetchall() + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + c = self._db_handle.cursor() + if check_for_conflicts: + c.execute( + "SELECT document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " + "conflicts ON conflicts.doc_id = document.doc_id WHERE " + "document.doc_id = ? GROUP BY document.doc_id, " + "document.doc_rev, document.content;", (doc_id,)) + else: + c.execute( + "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", + (doc_id,)) + val = c.fetchone() + if val is None: + return None + doc_rev, content, conflicts = val + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + return doc + + def _has_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", + (doc_id,)) + val = c.fetchone() + if val is None: + return False + else: + return True + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + c = self._db_handle.cursor() + c.execute( + "SELECT document.doc_id, document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " + "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " + "document.doc_rev, document.content;") + rows = c.fetchall() + for doc_id, doc_rev, content, conflicts in rows: + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): + """Convert a dict representation into named fields. + + So something like: {'key1': 'val1', 'key2': 'val2'} + gets converted into: [(doc_id, 'key1', 'val1', 0) + (doc_id, 'key2', 'val2', 0)] + :param doc_id: Just added to every record. + :param base_field: if set, these are nested keys, so each field should + be appropriately prefixed. + :param raw_doc: The python dictionary. + """ + # TODO: Handle lists + values = [] + for field_name, value in raw_doc.iteritems(): + if value is None and not save_none: + continue + if base_field: + full_name = base_field + '.' + field_name + else: + full_name = field_name + if value is None or isinstance(value, (int, float, basestring)): + values.append((doc_id, full_name, value, len(values))) + else: + subvalues = self._expand_to_fields(doc_id, full_name, value, + save_none) + for _, subfield_name, val, _ in subvalues: + values.append((doc_id, subfield_name, val, len(values))) + return values + + def _put_and_update_indexes(self, old_doc, doc): + """Actually insert a document into the database. + + This both updates the existing documents content, and any indexes that + refer to this document. + """ + raise NotImplementedError(self._put_and_update_indexes) + + def whats_changed(self, old_generation=0): + c = self._db_handle.cursor() + c.execute("SELECT generation, doc_id, transaction_id" + " FROM transaction_log" + " WHERE generation > ? ORDER BY generation DESC", + (old_generation,)) + results = c.fetchall() + cur_gen = old_generation + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + c.execute("SELECT generation, transaction_id" + " FROM transaction_log ORDER BY generation DESC LIMIT 1") + results = c.fetchone() + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, newest_trans_id = results + + return cur_gen, newest_trans_id, changes + + def delete_doc(self, doc): + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _get_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", + (doc_id,)) + return [self._factory(doc_id, doc_rev, content) + for doc_rev, content in c.fetchall()] + + def get_doc_conflicts(self, doc_id): + with self._db_handle: + conflict_docs = self._get_conflicts(doc_id) + if not conflict_docs: + return [] + this_doc = self._get_doc(doc_id) + this_doc.has_conflicts = True + return [this_doc] + conflict_docs + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + c = self._db_handle.cursor() + c.execute("SELECT known_generation, known_transaction_id FROM sync_log" + " WHERE replica_uid = ?", + (other_replica_uid,)) + val = c.fetchone() + if val is None: + other_gen = 0 + trans_id = '' + else: + other_gen = val[0] + trans_id = val[1] + return other_gen, trans_id + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + with self._db_handle: + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", + (other_replica_uid, other_generation, + other_transaction_id)) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + with self._db_handle: + return super(SQLCipherDatabase, self)._put_doc_if_newer(doc, + save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + + def _add_conflict(self, c, doc_id, my_doc_rev, my_content): + c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", + (doc_id, my_doc_rev, my_content)) + + def _delete_conflicts(self, c, doc, conflict_revs): + deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] + c.executemany("DELETE FROM conflicts" + " WHERE doc_id=? AND doc_rev=?", deleting) + doc.has_conflicts = self._has_conflicts(doc.doc_id) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + c_revs_to_prune = [] + for c_doc in self._get_conflicts(doc.doc_id): + c_vcr = vectorclock.VectorClockRev(c_doc.rev) + if doc_vcr.is_newer(c_vcr): + c_revs_to_prune.append(c_doc.rev) + elif doc.same_content_as(c_doc): + c_revs_to_prune.append(c_doc.rev) + doc_vcr.maximize(c_vcr) + autoresolved = True + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + c = self._db_handle.cursor() + self._delete_conflicts(c, doc, c_revs_to_prune) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + c = self._db_handle.cursor() + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + def resolve_doc(self, doc, conflicted_doc_revs): + with self._db_handle: + cur_doc = self._get_doc(doc.doc_id) + # TODO: https://bugs.launchpad.net/u1db/+bug/928274 + # I think we have a logic bug in resolve_doc + # Specifically, cur_doc.rev is always in the final vector + # clock of revisions that we supersede, even if it wasn't in + # conflicted_doc_revs. We still add it as a conflict, but the + # fact that _put_doc_if_newer propagates resolutions means I + # think that conflict could accidentally be resolved. We need + # to add a test for this case first. (create a rev, create a + # conflict, create another conflict, resolve the first rev + # and first conflict, then make sure that the resolved + # rev doesn't supersede the second conflict rev.) It *might* + # not matter, because the superseding rev is in as a + # conflict, but it does seem incorrect + new_rev = self._ensure_maximal_rev(cur_doc.rev, + conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + c = self._db_handle.cursor() + doc.rev = new_rev + if cur_doc.rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) + # TODO: Is there some way that we could construct a rev that would + # end up in superseded_revs, such that we add a conflict, and + # then immediately delete it? + self._delete_conflicts(c, doc, superseded_revs) + + def list_indexes(self): + """Return the list of indexes and their definitions.""" + c = self._db_handle.cursor() + # TODO: How do we test the ordering? + c.execute("SELECT name, field FROM index_definitions" + " ORDER BY name, offset") + definitions = [] + cur_name = None + for name, field in c.fetchall(): + if cur_name != name: + definitions.append((name, [])) + cur_name = name + definitions[-1][-1].append(field) + return definitions + + def _get_index_definition(self, index_name): + """Return the stored definition for a given index_name.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions" + " WHERE name = ? ORDER BY offset", (index_name,)) + fields = [x[0] for x in c.fetchall()] + if not fields: + raise errors.IndexDoesNotExist + return fields + + @staticmethod + def _strip_glob(value): + """Remove the trailing * from a value.""" + assert value[-1] == '*' + return value[:-1] + + def _format_query(self, definition, key_values): + # First, build the definition. We join the document_fields table + # against itself, as many times as the 'width' of our definition. + # We then do a query for each key_value, one-at-a-time. + # Note: All of these strings are static, we could cache them, etc. + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + wildcard_where = [novalue_where[i] + + (" AND d%d.value NOT NULL" % (i,)) + for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + like_where = [novalue_where[i] + + (" AND d%d.value GLOB ?" % (i,)) + for i in range(len(definition))] + is_wildcard = False + # Merge the lists together, so that: + # [field1, field2, field3], [val1, val2, val3] + # Becomes: + # (field1, val1, field2, val2, field3, val3) + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(exact_where[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_from_index(self, index_name, *key_values): + definition = self._get_index_definition(index_name) + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + statement, args = self._format_query(definition, key_values) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def _format_range_query(self, definition, start_value, end_value): + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + wildcard_where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + like_where = [ + novalue_where[i] + ( + " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in + range(len(definition))] + range_where_lower = [ + novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in + range(len(definition))] + range_where_upper = [ + novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in + range(len(definition))] + args = [] + where = [] + if start_value: + if isinstance(start_value, basestring): + start_value = (start_value,) + if len(start_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, start_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(self._strip_glob(value)) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(value) + if end_value: + if isinstance(end_value, basestring): + end_value = (end_value,) + if len(end_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, end_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(self._strip_glob(value)) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_upper[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + definition = self._get_index_definition(index_name) + statement, args = self._format_range_query( + definition, start_value, end_value) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def get_index_keys(self, index_name): + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + value_fields = ', '.join([ + 'd%d.value' % i for i in range(len(definition))]) + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + statement = ( + "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( + value_fields, ', '.join(tables), ' AND '.join(where), + value_fields)) + try: + c.execute(statement, tuple(definition)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) + return c.fetchall() + + def delete_index(self, index_name): + with self._db_handle: + c = self._db_handle.cursor() + c.execute("DELETE FROM index_definitions WHERE name = ?", + (index_name,)) + c.execute( + "DELETE FROM document_fields WHERE document_fields.field_name " + " NOT IN (SELECT field from index_definitions)") + + +class SQLCipherSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SQLCipherPartialExpandDatabase(SQLCipherDatabase): + """An SQLCipher Backend that expands documents into a document_field table. + + It stores the original document text in document.doc. For fields that are + indexed, the data goes into document_fields. + """ + + _index_storage_value = 'expand referenced' + + def _get_indexed_fields(self): + """Determine what fields are indexed.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions") + return set([x[0] for x in c.fetchall()]) + + def _evaluate_index(self, raw_doc, field): + parser = query_parser.Parser() + getter = parser.parse(field) + return getter.get(raw_doc) + + def _put_and_update_indexes(self, old_doc, doc): + c = self._db_handle.cursor() + if doc and not doc.is_tombstone(): + raw_doc = json.loads(doc.get_json()) + else: + raw_doc = {} + if old_doc is not None: + c.execute("UPDATE document SET doc_rev=?, content=?" + " WHERE doc_id = ?", + (doc.rev, doc.get_json(), doc.doc_id)) + c.execute("DELETE FROM document_fields WHERE doc_id = ?", + (doc.doc_id,)) + else: + c.execute("INSERT INTO document (doc_id, doc_rev, content)" + " VALUES (?, ?, ?)", + (doc.doc_id, doc.rev, doc.get_json())) + indexed_fields = self._get_indexed_fields() + if indexed_fields: + # It is expected that len(indexed_fields) is shorter than + # len(raw_doc) + getters = [(field, self._parse_index_definition(field)) + for field in indexed_fields] + self._update_indexes(doc.doc_id, raw_doc, getters, c) + trans_id = self._allocate_transaction_id() + c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" + " VALUES (?, ?)", (doc.doc_id, trans_id)) + + def create_index(self, index_name, *index_expressions): + with self._db_handle: + c = self._db_handle.cursor() + cur_fields = self._get_indexed_fields() + definition = [(index_name, idx, field) + for idx, field in enumerate(index_expressions)] + try: + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + definition) + except dbapi2.IntegrityError as e: + stored_def = self._get_index_definition(index_name) + if stored_def == [x[-1] for x in definition]: + return + raise errors.IndexNameTakenError, e, sys.exc_info()[2] + new_fields = set( + [f for f in index_expressions if f not in cur_fields]) + if new_fields: + self._update_all_indexes(new_fields) + + def _iter_all_docs(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, content FROM document") + while True: + next_rows = c.fetchmany() + if not next_rows: + break + for row in next_rows: + yield row + + def _update_all_indexes(self, new_fields): + """Iterate all the documents, and add content to document_fields. + + :param new_fields: The index definitions that need to be added. + """ + getters = [(field, self._parse_index_definition(field)) + for field in new_fields] + c = self._db_handle.cursor() + for doc_id, doc in self._iter_all_docs(): + if doc is None: + continue + raw_doc = json.loads(doc) + self._update_indexes(doc_id, raw_doc, getters, c) + +SQLCipherDatabase.register_implementation(SQLCipherPartialExpandDatabase) -- cgit v1.2.3 From 7cc7aee73fbf82b604988585e051da32b99dc70e Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:15:42 -0200 Subject: Move log classes so all backends can use them. --- src/leap/soledad/backends/openstack.py | 124 --------------------------------- 1 file changed, 124 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index ec4609b4..6c971485 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -32,8 +32,6 @@ class OpenStackDatabase(CommonBackend): def whats_changed(self, old_generation=0): self._get_u1db_data() - # This method is implemented in TransactionLog because testing is - # easier like this for now, but it can be moved to here afterwards. return self._transaction_log.whats_changed(old_generation) def _get_doc(self, doc_id, check_for_conflicts=False): @@ -245,125 +243,3 @@ class OpenStackSyncTarget(HTTPSyncTarget): source_replica_transaction_id) -class SimpleLog(object): - def __init__(self): - self._log = [] - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return self._log - - log = property( - _get_log, _set_log, doc="Log contents.") - - def append(self, msg): - self._log.append(msg) - - def reduce(self, func, initializer=None): - return reduce(func, self.log, initializer) - - def map(self, func): - return map(func, self.log) - - def filter(self, func): - return filter(func, self.log) - - -class TransactionLog(SimpleLog): - """ - A list of (generation, doc_id, transaction_id) tuples. - """ - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return sorted(self._log, reverse=True) - - log = property( - _get_log, _set_log, doc="Log contents.") - - def get_generation(self): - """ - Return the current generation. - """ - gens = self.map(lambda x: x[0]) - if not gens: - return 0 - return max(gens) - - def get_generation_info(self): - """ - Return the current generation and transaction id. - """ - if not self._log: - return(0, '') - info = self.map(lambda x: (x[0], x[2])) - return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) - - def get_trans_id_for_gen(self, gen): - """ - Get the transaction id corresponding to a particular generation. - """ - log = self.reduce(lambda x, y: y if y[0] == gen else x) - if log is None: - return None - return log[2] - - def whats_changed(self, old_generation): - results = self.filter(lambda x: x[0] > old_generation) - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - results = self.log - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, _, newest_trans_id = results[0] - - return cur_gen, newest_trans_id, changes - - - -class SyncLog(SimpleLog): - """ - A list of (replica_id, generation, transaction_id) tuples. - """ - - def find_by_replica_uid(self, replica_uid): - if not self.log: - return () - return self.reduce(lambda x, y: y if y[0] == replica_uid else x) - - def get_replica_gen_and_trans_id(self, other_replica_uid): - """ - Return the last known generation and transaction id for the other db - replica. - """ - info = self.find_by_replica_uid(other_replica_uid) - if not info: - return (0, '') - return (info[1], info[2]) - - def set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - """ - Set the last-known generation and transaction id for the other - database replica. - """ - self.log = self.filter(lambda x: x[0] != other_replica_uid) - self.append((other_replica_uid, other_generation, - other_transaction_id)) - -- cgit v1.2.3 From f89f2e0fe490899ecc4baf3395f3441111da328f Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 11:00:10 -0200 Subject: Refactor to add ObjectStore class. --- src/leap/soledad/backends/objectstore.py | 153 +++++++++++++++++++++++++++++++ src/leap/soledad/backends/openstack.py | 143 +---------------------------- 2 files changed, 156 insertions(+), 140 deletions(-) create mode 100644 src/leap/soledad/backends/objectstore.py (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py new file mode 100644 index 00000000..e36df72d --- /dev/null +++ b/src/leap/soledad/backends/objectstore.py @@ -0,0 +1,153 @@ +from u1db.backends import CommonBackend + + +class ObjectStore(CommonBackend): + + def __init__(self): + self._sync_log = SyncLog() + self._transaction_log = TransactionLog() + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def set_document_factory(self, factory): + self._factory = factory + + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + self._get_u1db_data() + return self._transaction_log.whats_changed(old_generation) + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def delete_doc(self, doc): + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_doc(olddoc) + return new_rev + + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + + def get_doc_conflicts(self, doc_id): + return [] + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + self._get_u1db_data() + return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._get_u1db_data() + self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, + other_generation, + other_transaction_id) + self._set_u1db_data() + + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- + + def _get_generation(self): + self._get_u1db_data() + return self._transaction_log.get_generation() + + def _get_generation_info(self): + self._get_u1db_data() + return self._transaction_log.get_generation_info() + + def _has_conflicts(self, doc_id): + # Documents never have conflicts on server. + return False + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + self._get_u1db_data() + trans_id = self._transaction_log.get_trans_id_for_gen(generation) + if trans_id is None: + raise errors.InvalidGeneration + return trans_id + + def _ensure_u1db_data(self): + """ + Guarantee that u1db data exists in store. + """ + if not self._is_initialized(): + self._initialize() + u1db_data = self._get_doc('u1db_data') + self._sync_log.log = u1db_data.content['sync_log'] + self._transaction_log.log = u1db_data.content['transaction_log'] + + def _is_initialized(self): + """ + Verify if u1db data exists in store. + """ + if not self._get_doc('u1db_data'): + return False + return True + + def _initialize(self): + """ + Create u1db data object in store. + """ + content = { 'transaction_log' : [], + 'sync_log' : [] } + doc = self.create_doc('u1db_data', content) + + def _get_u1db_data(self): + data = self.get_doc('u1db_data').content + self._transaction_log = data['transaction_log'] + self._sync_log = data['sync_log'] + + def _set_u1db_data(self): + doc = self._factory('u1db_data') + doc.content = { 'transaction_log' : self._transaction_log, + 'sync_log' : self._sync_log } + self.put_doc(doc) + + diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index 6c971485..f8563d81 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -1,15 +1,16 @@ -from leap import * from u1db import errors from u1db.backends import CommonBackend from u1db.remote.http_target import HTTPSyncTarget from swiftclient import client +from soledad.backends.objectstore import ObjectStore -class OpenStackDatabase(CommonBackend): +class OpenStackDatabase(ObjectStore): """A U1DB implementation that uses OpenStack as its persistence layer.""" def __init__(self, auth_url, user, auth_key, container): """Create a new OpenStack data container.""" + super(OpenStackDatabase, self) self._auth_url = auth_url self._user = user self._auth_key = auth_key @@ -24,16 +25,6 @@ class OpenStackDatabase(CommonBackend): # implemented methods from Database #------------------------------------------------------------------------- - def set_document_factory(self, factory): - self._factory = factory - - def set_document_size_limit(self, limit): - raise NotImplementedError(self.set_document_size_limit) - - def whats_changed(self, old_generation=0): - self._get_u1db_data() - return self._transaction_log.whats_changed(old_generation) - def _get_doc(self, doc_id, check_for_conflicts=False): """Get just the document content, without fancy handling. @@ -47,14 +38,6 @@ class OpenStackDatabase(CommonBackend): except swiftclient.ClientException: return None - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - def get_all_docs(self, include_deleted=False): """Get all documents from the database.""" generation = self._get_generation() @@ -84,51 +67,6 @@ class OpenStackDatabase(CommonBackend): self._set_u1db_data() return new_rev - def delete_doc(self, doc): - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_doc(olddoc) - return new_rev - - # start of index-related methods: these are not supported by this backend. - - def create_index(self, index_name, *index_expressions): - return False - - def delete_index(self, index_name): - return False - - def list_indexes(self): - return [] - - def get_from_index(self, index_name, *key_values): - return [] - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - return [] - - def get_index_keys(self, index_name): - return [] - - # end of index-related methods: these are not supported by this backend. - - def get_doc_conflicts(self, doc_id): - return [] - - def resolve_doc(self, doc, conflicted_doc_revs): - raise NotImplementedError(self.resolve_doc) - def get_sync_target(self): return OpenStackSyncTarget(self) @@ -141,89 +79,14 @@ class OpenStackDatabase(CommonBackend): return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync( autocreate=autocreate) - def _get_replica_gen_and_trans_id(self, other_replica_uid): - self._get_u1db_data() - return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - self._get_u1db_data() - self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, - other_generation, - other_transaction_id) - self._set_u1db_data() - - #------------------------------------------------------------------------- - # implemented methods from CommonBackend - #------------------------------------------------------------------------- - - def _get_generation(self): - self._get_u1db_data() - return self._transaction_log.get_generation() - - def _get_generation_info(self): - self._get_u1db_data() - return self._transaction_log.get_generation_info() - - def _has_conflicts(self, doc_id): - # Documents never have conflicts on server. - return False - - def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): - raise NotImplementedError(self._put_and_update_indexes) - - - def _get_trans_id_for_gen(self, generation): - self._get_u1db_data() - trans_id = self._transaction_log.get_trans_id_for_gen(generation) - if trans_id is None: - raise errors.InvalidGeneration - return trans_id - #------------------------------------------------------------------------- # OpenStack specific methods #------------------------------------------------------------------------- - def _ensure_u1db_data(self): - """ - Guarantee that u1db data exists in store. - """ - if self._is_initialized(): - return - self._initialize() - - def _is_initialized(self): - """ - Verify if u1db data exists in store. - """ - if not self._get_doc('u1db_data'): - return False - return True - - def _initialize(self): - """ - Create u1db data object in store. - """ - content = { 'transaction_log' : [], - 'sync_log' : [] } - doc = self.create_doc('u1db_data', content) - def _get_auth(self): self._url, self._auth_token = self._connection.get_auth() return self._url, self.auth_token - def _get_u1db_data(self): - data = self.get_doc('u1db_data').content - self._transaction_log = data['transaction_log'] - self._sync_log = data['sync_log'] - - def _set_u1db_data(self): - doc = self._factory('u1db_data') - doc.content = { 'transaction_log' : self._transaction_log, - 'sync_log' : self._sync_log } - self.put_doc(doc) - - class OpenStackSyncTarget(HTTPSyncTarget): def get_sync_info(self, source_replica_uid): -- cgit v1.2.3 From b3090f710e3777bad2a9f996444e5099883c9f03 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 12:05:31 -0200 Subject: Add CouchDB u1db backend. --- src/leap/soledad/backends/couchdb.py | 97 ++++++++++++++++++++++++++++++++ src/leap/soledad/backends/objectstore.py | 26 +++++++++ src/leap/soledad/backends/openstack.py | 20 ++----- 3 files changed, 128 insertions(+), 15 deletions(-) create mode 100644 src/leap/soledad/backends/couchdb.py (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/couchdb.py b/src/leap/soledad/backends/couchdb.py new file mode 100644 index 00000000..89b713f9 --- /dev/null +++ b/src/leap/soledad/backends/couchdb.py @@ -0,0 +1,97 @@ +from u1db import errors +from u1db.remote.http_target import HTTPSyncTarget +from couchdb import * +from soledad.backends.objectstore import ObjectStore + + +class CouchDatabase(ObjectStore): + """A U1DB implementation that uses Couch as its persistence layer.""" + + def __init__(self, url, database, full_commit=True, session=None): + """Create a new Couch data container.""" + self._url = url + self._full_commit = full_commit + self._session = session + self._server = couchdb.Server(url=self._url, + full_commit=self._full_commit, + session=self._session) + # this will ensure that transaction and sync logs exist and are + # up-to-date. + super(CouchDatabase, self) + self._database = self._server[database] + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ + cdoc = self._database.get(doc_id) + if cdoc is not None: + content = {} + for key, value in content: + if not key in ['_id', '_rev', '_u1db_rev']: + content[key] = value + doc = self._factory(doc_id=doc_id, rev=cdoc['_u1db_rev']) + doc.content = content + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + for doc_id in self._database: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) + + def _put_doc(self, doc, new_rev): + # map u1db metadata to couch + content = doc.content + content['_id'] = doc.doc_id + content['_u1db_rev'] = new_rev + self._database.save(doc.content) + + def get_sync_target(self): + return CouchSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + from u1db.sync import Synchronizer + from u1db.remote.http_target import CouchSyncTarget + return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + #------------------------------------------------------------------------- + # Couch specific methods + #------------------------------------------------------------------------- + + # no specific methods so far. + +class CouchSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index e36df72d..456892b3 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,11 +1,17 @@ from u1db.backends import CommonBackend +from soledad import SyncLog, TransactionLog class ObjectStore(CommonBackend): def __init__(self): + # This initialization method should be called after the connection + # with the database is established, so it can ensure that u1db data is + # configured and up-to-date. + self.set_document_factory(LeapDocument) self._sync_log = SyncLog() self._transaction_log = TransactionLog() + self._ensure_u1db_data() #------------------------------------------------------------------------- # implemented methods from Database @@ -29,6 +35,26 @@ class ObjectStore(CommonBackend): return None return doc + def _put_doc(self, doc) + raise NotImplementedError(self._put_doc) + + def put_doc(self, doc) + # consistency check + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + # put the document + new_rev = self._allocate_doc_rev(doc.rev) + self._put_doc(doc, new_rev) + doc.rev = new_rev + # update u1db generation and logs + new_gen = self._get_generation() + 1 + trans_id = self._allocate_transaction_id() + self._transaction_log.append((new_gen, doc.doc_id, trans_id)) + self._set_u1db_data() + return new_rev + def delete_doc(self, doc): old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) if old_doc is None: diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index f8563d81..5f2a2771 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -1,5 +1,4 @@ from u1db import errors -from u1db.backends import CommonBackend from u1db.remote.http_target import HTTPSyncTarget from swiftclient import client from soledad.backends.objectstore import ObjectStore @@ -10,16 +9,15 @@ class OpenStackDatabase(ObjectStore): def __init__(self, auth_url, user, auth_key, container): """Create a new OpenStack data container.""" - super(OpenStackDatabase, self) self._auth_url = auth_url self._user = user self._auth_key = auth_key self._container = container - self.set_document_factory(LeapDocument) self._connection = swiftclient.Connection(self._auth_url, self._user, self._auth_key) self._get_auth() - self._ensure_u1db_data() + # this will ensure transaction and sync logs exist and are up-to-date. + super(OpenStackDatabase, self) #------------------------------------------------------------------------- # implemented methods from Database @@ -33,6 +31,7 @@ class OpenStackDatabase(ObjectStore): """ try: response, contents = self._connection.get_object(self._container, doc_id) + # TODO: change revision to be a dictionary element? rev = response['x-object-meta-rev'] return self._factory(doc_id, rev, contents) except swiftclient.ClientException: @@ -51,21 +50,12 @@ class OpenStackDatabase(ObjectStore): results.append(doc) return (generation, results) - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - # TODO: check for conflicts? + def _put_doc(self, doc, new_rev): new_rev = self._allocate_doc_rev(doc.rev) + # TODO: change revision to be a dictionary element? headers = { 'X-Object-Meta-Rev' : new_rev } self._connection.put_object(self._container, doc_id, doc.get_json(), headers=headers) - new_gen = self._get_generation() + 1 - trans_id = self._allocate_transaction_id() - self._transaction_log.append((new_gen, doc.doc_id, trans_id)) - self._set_u1db_data() - return new_rev def get_sync_target(self): return OpenStackSyncTarget(self) -- cgit v1.2.3 From 817d4a1dab5cfce6228593ad61951e1593777eeb Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 14:43:08 -0200 Subject: Fix lack of collons on some methods. --- src/leap/soledad/backends/objectstore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index 456892b3..d9ab7cbd 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -35,10 +35,10 @@ class ObjectStore(CommonBackend): return None return doc - def _put_doc(self, doc) + def _put_doc(self, doc): raise NotImplementedError(self._put_doc) - def put_doc(self, doc) + def put_doc(self, doc): # consistency check if doc.doc_id is None: raise errors.InvalidDocId() -- cgit v1.2.3 From 002d2bfdbc4ca62733478524ec588cf0aa9f9383 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 18:39:56 -0200 Subject: CouchDB backend can put and get objects. --- src/leap/soledad/backends/couch.py | 115 +++++++++++++++++++++++++++++++ src/leap/soledad/backends/couchdb.py | 97 -------------------------- src/leap/soledad/backends/leap.py | 1 + src/leap/soledad/backends/objectstore.py | 43 +++++++----- src/leap/soledad/backends/openstack.py | 2 +- 5 files changed, 141 insertions(+), 117 deletions(-) create mode 100644 src/leap/soledad/backends/couch.py delete mode 100644 src/leap/soledad/backends/couchdb.py (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/couch.py b/src/leap/soledad/backends/couch.py new file mode 100644 index 00000000..5586ea9c --- /dev/null +++ b/src/leap/soledad/backends/couch.py @@ -0,0 +1,115 @@ +from u1db import errors +from u1db.remote.http_target import HTTPSyncTarget +from couchdb.client import Server, Document +from couchdb.http import ResourceNotFound +from soledad.backends.objectstore import ObjectStore +from soledad.backends.leap import LeapDocument + + +class CouchDatabase(ObjectStore): + """A U1DB implementation that uses Couch as its persistence layer.""" + + def __init__(self, url, database, full_commit=True, session=None): + """Create a new Couch data container.""" + self._url = url + self._full_commit = full_commit + self._session = session + self._server = Server(url=self._url, + full_commit=self._full_commit, + session=self._session) + # this will ensure that transaction and sync logs exist and are + # up-to-date. + self.set_document_factory(LeapDocument) + try: + self._database = self._server[database] + except ResourceNotFound: + self._server.create(database) + self._database = self._server[database] + super(CouchDatabase, self).__init__() + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ + cdoc = self._database.get(doc_id) + if cdoc is None: + return None + content = {} + for (key, value) in cdoc.items(): + if key not in ['_id', '_rev', 'u1db_rev']: + content[key] = value + doc = self._factory(doc_id=doc_id, rev=cdoc['u1db_rev']) + doc.content = content + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + for doc_id in self._database: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) + + def _put_doc(self, doc): + # map u1db metadata to couch + content = doc.content + cdoc = Document() + cdoc['_id'] = doc.doc_id + cdoc['u1db_rev'] = doc.rev + for (key, value) in content.items(): + cdoc[key] = value + self._database.save(cdoc) + + def get_sync_target(self): + return CouchSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + from u1db.sync import Synchronizer + from u1db.remote.http_target import CouchSyncTarget + return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_u1db_data(self): + cdoc = self._database.get(self.U1DB_DATA_DOC_ID) + self._sync_log.log = cdoc['sync_log'] + self._transaction_log.log = cdoc['transaction_log'] + self._replica_uid = cdoc['replica_uid'] + self._couch_rev = cdoc['_rev'] + + #------------------------------------------------------------------------- + # Couch specific methods + #------------------------------------------------------------------------- + + # no specific methods so far. + +class CouchSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + diff --git a/src/leap/soledad/backends/couchdb.py b/src/leap/soledad/backends/couchdb.py deleted file mode 100644 index 89b713f9..00000000 --- a/src/leap/soledad/backends/couchdb.py +++ /dev/null @@ -1,97 +0,0 @@ -from u1db import errors -from u1db.remote.http_target import HTTPSyncTarget -from couchdb import * -from soledad.backends.objectstore import ObjectStore - - -class CouchDatabase(ObjectStore): - """A U1DB implementation that uses Couch as its persistence layer.""" - - def __init__(self, url, database, full_commit=True, session=None): - """Create a new Couch data container.""" - self._url = url - self._full_commit = full_commit - self._session = session - self._server = couchdb.Server(url=self._url, - full_commit=self._full_commit, - session=self._session) - # this will ensure that transaction and sync logs exist and are - # up-to-date. - super(CouchDatabase, self) - self._database = self._server[database] - - #------------------------------------------------------------------------- - # implemented methods from Database - #------------------------------------------------------------------------- - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling. - - Conflicts do not happen on server side, so there's no need to check - for them. - """ - cdoc = self._database.get(doc_id) - if cdoc is not None: - content = {} - for key, value in content: - if not key in ['_id', '_rev', '_u1db_rev']: - content[key] = value - doc = self._factory(doc_id=doc_id, rev=cdoc['_u1db_rev']) - doc.content = content - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - for doc_id in self._database: - doc = self._get_doc(doc_id) - if doc.content is None and not include_deleted: - continue - results.append(doc) - return (generation, results) - - def _put_doc(self, doc, new_rev): - # map u1db metadata to couch - content = doc.content - content['_id'] = doc.doc_id - content['_u1db_rev'] = new_rev - self._database.save(doc.content) - - def get_sync_target(self): - return CouchSyncTarget(self) - - def close(self): - raise NotImplementedError(self.close) - - def sync(self, url, creds=None, autocreate=True): - from u1db.sync import Synchronizer - from u1db.remote.http_target import CouchSyncTarget - return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync( - autocreate=autocreate) - - #------------------------------------------------------------------------- - # Couch specific methods - #------------------------------------------------------------------------- - - # no specific methods so far. - -class CouchSyncTarget(HTTPSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index 2c815632..ce00c8f3 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -7,6 +7,7 @@ from u1db import Document from u1db.remote.http_target import HTTPSyncTarget from u1db.remote.http_database import HTTPDatabase import base64 +from soledad import GPGWrapper class NoDefaultKey(Exception): diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index d9ab7cbd..5bd864c8 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,5 +1,7 @@ +import uuid from u1db.backends import CommonBackend from soledad import SyncLog, TransactionLog +from soledad.backends.leap import LeapDocument class ObjectStore(CommonBackend): @@ -45,15 +47,14 @@ class ObjectStore(CommonBackend): self._check_doc_id(doc.doc_id) self._check_doc_size(doc) # put the document - new_rev = self._allocate_doc_rev(doc.rev) - self._put_doc(doc, new_rev) - doc.rev = new_rev + doc.rev = self._allocate_doc_rev(doc.rev) + self._put_doc(doc) # update u1db generation and logs new_gen = self._get_generation() + 1 trans_id = self._allocate_transaction_id() self._transaction_log.append((new_gen, doc.doc_id, trans_id)) self._set_u1db_data() - return new_rev + return doc.rev def delete_doc(self, doc): old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) @@ -145,15 +146,16 @@ class ObjectStore(CommonBackend): """ if not self._is_initialized(): self._initialize() - u1db_data = self._get_doc('u1db_data') - self._sync_log.log = u1db_data.content['sync_log'] - self._transaction_log.log = u1db_data.content['transaction_log'] + self._get_u1db_data() + + U1DB_DATA_DOC_ID = 'u1db_data' def _is_initialized(self): """ Verify if u1db data exists in store. """ - if not self._get_doc('u1db_data'): + doc = self._get_doc(self.U1DB_DATA_DOC_ID) + if not self._get_doc(self.U1DB_DATA_DOC_ID): return False return True @@ -161,19 +163,22 @@ class ObjectStore(CommonBackend): """ Create u1db data object in store. """ - content = { 'transaction_log' : [], - 'sync_log' : [] } - doc = self.create_doc('u1db_data', content) + self._replica_uid = uuid.uuid4().hex + doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID) + doc.content = { 'transaction_log' : [], + 'sync_log' : [], + 'replica_uid' : self._replica_uid } + self._put_doc(doc) - def _get_u1db_data(self): - data = self.get_doc('u1db_data').content - self._transaction_log = data['transaction_log'] - self._sync_log = data['sync_log'] + def _get_u1db_data(self, u1db_data_doc_id): + NotImplementedError(self._get_u1db_data) def _set_u1db_data(self): - doc = self._factory('u1db_data') - doc.content = { 'transaction_log' : self._transaction_log, - 'sync_log' : self._sync_log } - self.put_doc(doc) + doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID) + doc.content = { 'transaction_log' : self._transaction_log.log, + 'sync_log' : self._sync_log.log, + 'replica_uid' : self._replica_uid, + '_rev' : self._couch_rev} + self._put_doc(doc) diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index 5f2a2771..c027231c 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -17,7 +17,7 @@ class OpenStackDatabase(ObjectStore): self._auth_key) self._get_auth() # this will ensure transaction and sync logs exist and are up-to-date. - super(OpenStackDatabase, self) + super(OpenStackDatabase, self).__init__() #------------------------------------------------------------------------- # implemented methods from Database -- cgit v1.2.3 From d5816c05136c9c018b8984b5f8a104c164676e9f Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 11:47:16 -0200 Subject: Fix ObjectStore's put_doc. --- src/leap/soledad/backends/objectstore.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index 5bd864c8..298bdda3 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,5 +1,6 @@ import uuid from u1db.backends import CommonBackend +from u1db import errors from soledad import SyncLog, TransactionLog from soledad.backends.leap import LeapDocument @@ -46,8 +47,21 @@ class ObjectStore(CommonBackend): raise errors.InvalidDocId() self._check_doc_id(doc.doc_id) self._check_doc_size(doc) - # put the document - doc.rev = self._allocate_doc_rev(doc.rev) + # check if document exists + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev self._put_doc(doc) # update u1db generation and logs new_gen = self._get_generation() + 1 @@ -69,7 +83,7 @@ class ObjectStore(CommonBackend): new_rev = self._allocate_doc_rev(doc.rev) doc.rev = new_rev doc.make_tombstone() - self._put_doc(olddoc) + self._put_doc(doc) return new_rev # start of index-related methods: these are not supported by this backend. @@ -171,9 +185,15 @@ class ObjectStore(CommonBackend): self._put_doc(doc) def _get_u1db_data(self, u1db_data_doc_id): + """ + Fetch u1db configuration data from backend storage. + """ NotImplementedError(self._get_u1db_data) def _set_u1db_data(self): + """ + Save u1db configuration data on backend storage. + """ doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID) doc.content = { 'transaction_log' : self._transaction_log.log, 'sync_log' : self._sync_log.log, -- cgit v1.2.3 From 703224c26e868546d37e9850db75747df1f92348 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 11:47:38 -0200 Subject: Store u1db contents in couch as json string. --- src/leap/soledad/backends/couch.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/couch.py b/src/leap/soledad/backends/couch.py index 5586ea9c..ed356fdd 100644 --- a/src/leap/soledad/backends/couch.py +++ b/src/leap/soledad/backends/couch.py @@ -5,6 +5,11 @@ from couchdb.http import ResourceNotFound from soledad.backends.objectstore import ObjectStore from soledad.backends.leap import LeapDocument +try: + import simplejson as json +except ImportError: + import json # noqa + class CouchDatabase(ObjectStore): """A U1DB implementation that uses Couch as its persistence layer.""" @@ -40,12 +45,11 @@ class CouchDatabase(ObjectStore): cdoc = self._database.get(doc_id) if cdoc is None: return None - content = {} - for (key, value) in cdoc.items(): - if key not in ['_id', '_rev', 'u1db_rev']: - content[key] = value doc = self._factory(doc_id=doc_id, rev=cdoc['u1db_rev']) - doc.content = content + if cdoc['u1db_json'] is not None: + doc.content = json.loads(cdoc['u1db_json']) + else: + doc.make_tombstone() return doc def get_all_docs(self, include_deleted=False): @@ -60,13 +64,20 @@ class CouchDatabase(ObjectStore): return (generation, results) def _put_doc(self, doc): - # map u1db metadata to couch - content = doc.content + # prepare couch's Document cdoc = Document() cdoc['_id'] = doc.doc_id + # we have to guarantee that couch's _rev is cosistent + old_cdoc = self._database.get(doc.doc_id) + if old_cdoc is not None: + cdoc['_rev'] = old_cdoc['_rev'] + # store u1db's rev cdoc['u1db_rev'] = doc.rev - for (key, value) in content.items(): - cdoc[key] = value + # store u1db's content as json string + if not doc.is_tombstone(): + cdoc['u1db_json'] = doc.get_json() + else: + cdoc['u1db_json'] = None self._database.save(cdoc) def get_sync_target(self): @@ -83,9 +94,10 @@ class CouchDatabase(ObjectStore): def _get_u1db_data(self): cdoc = self._database.get(self.U1DB_DATA_DOC_ID) - self._sync_log.log = cdoc['sync_log'] - self._transaction_log.log = cdoc['transaction_log'] - self._replica_uid = cdoc['replica_uid'] + content = json.loads(cdoc['u1db_json']) + self._sync_log.log = content['sync_log'] + self._transaction_log.log = content['transaction_log'] + self._replica_uid = content['replica_uid'] self._couch_rev = cdoc['_rev'] #------------------------------------------------------------------------- -- cgit v1.2.3 From 4417d89bb9bdd59d717501c6db3f2215cdeb87fb Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 12:07:28 -0200 Subject: SQLCipherDatabase now extends SQLitePartialExpandDatabase. --- src/leap/soledad/backends/sqlcipher.py | 831 +-------------------------------- 1 file changed, 3 insertions(+), 828 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index 24f47eed..fcdab251 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -30,6 +30,7 @@ import uuid import pkg_resources from u1db.backends import CommonBackend, CommonSyncTarget +from u1db.backends.sqlite_backend import SQLitePartialExpandDatabase from u1db import ( Document, errors, @@ -56,7 +57,7 @@ def open(path, create, document_factory=None, password=None): path, create=create, document_factory=document_factory, password=password) -class SQLCipherDatabase(CommonBackend): +class SQLCipherDatabase(SQLitePartialExpandDatabase): """A U1DB implementation that uses SQLCipher as its persistence layer.""" _sqlite_registry = {} @@ -74,25 +75,6 @@ class SQLCipherDatabase(CommonBackend): self._ensure_schema() self._factory = document_factory or Document - def set_document_factory(self, factory): - self._factory = factory - - def get_sync_target(self): - return SQLCipherSyncTarget(self) - - @classmethod - def _which_index_storage(cls, c): - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'index_storage'") - except dbapi2.OperationalError, e: - # The table does not exist yet - return None, e - else: - return c.fetchone()[0], None - - WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 - @classmethod def _open_database(cls, sqlite_file, document_factory=None, password=None): if not os.path.isfile(sqlite_file): @@ -135,15 +117,6 @@ class SQLCipherDatabase(CommonBackend): return backend_cls(sqlite_file, document_factory=document_factory, password=password) - @staticmethod - def delete_database(sqlite_file): - try: - os.unlink(sqlite_file) - except OSError as ex: - if ex.errno == errno.ENOENT: - raise errors.DatabaseDoesNotExist() - raise - @staticmethod def register_implementation(klass): """Register that we implement an SQLCipherDatabase. @@ -152,803 +125,5 @@ class SQLCipherDatabase(CommonBackend): """ SQLCipherDatabase._sqlite_registry[klass._index_storage_value] = klass - def _get_sqlite_handle(self): - """Get access to the underlying sqlite database. - - This should only be used by the test suite, etc, for examining the - state of the underlying database. - """ - return self._db_handle - - def _close_sqlite_handle(self): - """Release access to the underlying sqlite database.""" - self._db_handle.close() - - def close(self): - self._close_sqlite_handle() - - def _is_initialized(self, c): - """Check if this database has been initialized.""" - c.execute("PRAGMA case_sensitive_like=ON") - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'sql_schema'") - except dbapi2.OperationalError: - # The table does not exist yet - val = None - else: - val = c.fetchone() - if val is not None: - return True - return False - - def _initialize(self, c): - """Create the schema in the database.""" - #read the script with sql commands - # TODO: Change how we set up the dependency. Most likely use something - # like lp:dirspec to grab the file from a common resource - # directory. Doesn't specifically need to be handled until we get - # to the point of packaging this. - schema_content = pkg_resources.resource_string( - __name__, 'dbschema.sql') - # Note: We'd like to use c.executescript() here, but it seems that - # executescript always commits, even if you set - # isolation_level = None, so if we want to properly handle - # exclusive locking and rollbacks between processes, we need - # to execute it line-by-line - for line in schema_content.split(';'): - if not line: - continue - c.execute(line) - #add extra fields - self._extra_schema_init(c) - # A unique identifier should be set for this replica. Implementations - # don't have to strictly use uuid here, but we do want the uid to be - # unique amongst all databases that will sync with each other. - # We might extend this to using something with hostname for easier - # debugging. - self._set_replica_uid_in_transaction(uuid.uuid4().hex) - c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", - (self._index_storage_value,)) - - def _ensure_schema(self): - """Ensure that the database schema has been created.""" - old_isolation_level = self._db_handle.isolation_level - c = self._db_handle.cursor() - if self._is_initialized(c): - return - try: - # autocommit/own mgmt of transactions - self._db_handle.isolation_level = None - with self._db_handle: - # only one execution path should initialize the db - c.execute("begin exclusive") - if self._is_initialized(c): - return - self._initialize(c) - finally: - self._db_handle.isolation_level = old_isolation_level - - def _extra_schema_init(self, c): - """Add any extra fields, etc to the basic table definitions.""" - - def _parse_index_definition(self, index_field): - """Parse a field definition for an index, returning a Getter.""" - # Note: We may want to keep a Parser object around, and cache the - # Getter objects for a greater length of time. Specifically, if - # you create a bunch of indexes, and then insert 50k docs, you'll - # re-parse the indexes between puts. The time to insert the docs - # is still likely to dominate put_doc time, though. - parser = query_parser.Parser() - getter = parser.parse(index_field) - return getter - - def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): - """Update document_fields for a single document. - - :param doc_id: Identifier for this document - :param raw_doc: The python dict representation of the document. - :param getters: A list of [(field_name, Getter)]. Getter.get will be - called to evaluate the index definition for this document, and the - results will be inserted into the db. - :param db_cursor: An sqlite Cursor. - :return: None - """ - values = [] - for field_name, getter in getters: - for idx_value in getter.get(raw_doc): - values.append((doc_id, field_name, idx_value)) - if values: - db_cursor.executemany( - "INSERT INTO document_fields VALUES (?, ?, ?)", values) - - def _set_replica_uid(self, replica_uid): - """Force the replica_uid to be set.""" - with self._db_handle: - self._set_replica_uid_in_transaction(replica_uid) - - def _set_replica_uid_in_transaction(self, replica_uid): - """Set the replica_uid. A transaction should already be held.""" - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO u1db_config" - " VALUES ('replica_uid', ?)", - (replica_uid,)) - self._real_replica_uid = replica_uid - - def _get_replica_uid(self): - if self._real_replica_uid is not None: - return self._real_replica_uid - c = self._db_handle.cursor() - c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") - val = c.fetchone() - if val is None: - return None - self._real_replica_uid = val[0] - return self._real_replica_uid - - _replica_uid = property(_get_replica_uid) - - def _get_generation(self): - c = self._db_handle.cursor() - c.execute('SELECT max(generation) FROM transaction_log') - val = c.fetchone()[0] - if val is None: - return 0 - return val - - def _get_generation_info(self): - c = self._db_handle.cursor() - c.execute( - 'SELECT max(generation), transaction_id FROM transaction_log ') - val = c.fetchone() - if val[0] is None: - return(0, '') - return val - - def _get_trans_id_for_gen(self, generation): - if generation == 0: - return '' - c = self._db_handle.cursor() - c.execute( - 'SELECT transaction_id FROM transaction_log WHERE generation = ?', - (generation,)) - val = c.fetchone() - if val is None: - raise errors.InvalidGeneration - return val[0] - - def _get_transaction_log(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, transaction_id FROM transaction_log" - " ORDER BY generation") - return c.fetchall() - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - c = self._db_handle.cursor() - if check_for_conflicts: - c.execute( - "SELECT document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " - "conflicts ON conflicts.doc_id = document.doc_id WHERE " - "document.doc_id = ? GROUP BY document.doc_id, " - "document.doc_rev, document.content;", (doc_id,)) - else: - c.execute( - "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", - (doc_id,)) - val = c.fetchone() - if val is None: - return None - doc_rev, content, conflicts = val - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - return doc - - def _has_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", - (doc_id,)) - val = c.fetchone() - if val is None: - return False - else: - return True - - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - c = self._db_handle.cursor() - c.execute( - "SELECT document.doc_id, document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " - "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " - "document.doc_rev, document.content;") - rows = c.fetchall() - for doc_id, doc_rev, content, conflicts in rows: - if content is None and not include_deleted: - continue - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - results.append(doc) - return (generation, results) - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc and old_doc.has_conflicts: - raise errors.ConflictedDoc() - if old_doc and doc.rev is None and old_doc.is_tombstone(): - new_rev = self._allocate_doc_rev(old_doc.rev) - else: - if old_doc is not None: - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - else: - if doc.rev is not None: - raise errors.RevisionConflict() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): - """Convert a dict representation into named fields. - - So something like: {'key1': 'val1', 'key2': 'val2'} - gets converted into: [(doc_id, 'key1', 'val1', 0) - (doc_id, 'key2', 'val2', 0)] - :param doc_id: Just added to every record. - :param base_field: if set, these are nested keys, so each field should - be appropriately prefixed. - :param raw_doc: The python dictionary. - """ - # TODO: Handle lists - values = [] - for field_name, value in raw_doc.iteritems(): - if value is None and not save_none: - continue - if base_field: - full_name = base_field + '.' + field_name - else: - full_name = field_name - if value is None or isinstance(value, (int, float, basestring)): - values.append((doc_id, full_name, value, len(values))) - else: - subvalues = self._expand_to_fields(doc_id, full_name, value, - save_none) - for _, subfield_name, val, _ in subvalues: - values.append((doc_id, subfield_name, val, len(values))) - return values - - def _put_and_update_indexes(self, old_doc, doc): - """Actually insert a document into the database. - - This both updates the existing documents content, and any indexes that - refer to this document. - """ - raise NotImplementedError(self._put_and_update_indexes) - - def whats_changed(self, old_generation=0): - c = self._db_handle.cursor() - c.execute("SELECT generation, doc_id, transaction_id" - " FROM transaction_log" - " WHERE generation > ? ORDER BY generation DESC", - (old_generation,)) - results = c.fetchall() - cur_gen = old_generation - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - c.execute("SELECT generation, transaction_id" - " FROM transaction_log ORDER BY generation DESC LIMIT 1") - results = c.fetchone() - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, newest_trans_id = results - - return cur_gen, newest_trans_id, changes - - def delete_doc(self, doc): - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _get_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", - (doc_id,)) - return [self._factory(doc_id, doc_rev, content) - for doc_rev, content in c.fetchall()] - - def get_doc_conflicts(self, doc_id): - with self._db_handle: - conflict_docs = self._get_conflicts(doc_id) - if not conflict_docs: - return [] - this_doc = self._get_doc(doc_id) - this_doc.has_conflicts = True - return [this_doc] + conflict_docs - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - c = self._db_handle.cursor() - c.execute("SELECT known_generation, known_transaction_id FROM sync_log" - " WHERE replica_uid = ?", - (other_replica_uid,)) - val = c.fetchone() - if val is None: - other_gen = 0 - trans_id = '' - else: - other_gen = val[0] - trans_id = val[1] - return other_gen, trans_id - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - with self._db_handle: - self._do_set_replica_gen_and_trans_id( - other_replica_uid, other_generation, other_transaction_id) - - def _do_set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, - other_transaction_id): - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", - (other_replica_uid, other_generation, - other_transaction_id)) - - def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, - replica_gen=None, replica_trans_id=None): - with self._db_handle: - return super(SQLCipherDatabase, self)._put_doc_if_newer(doc, - save_conflict=save_conflict, - replica_uid=replica_uid, replica_gen=replica_gen, - replica_trans_id=replica_trans_id) - - def _add_conflict(self, c, doc_id, my_doc_rev, my_content): - c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", - (doc_id, my_doc_rev, my_content)) - - def _delete_conflicts(self, c, doc, conflict_revs): - deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] - c.executemany("DELETE FROM conflicts" - " WHERE doc_id=? AND doc_rev=?", deleting) - doc.has_conflicts = self._has_conflicts(doc.doc_id) - - def _prune_conflicts(self, doc, doc_vcr): - if self._has_conflicts(doc.doc_id): - autoresolved = False - c_revs_to_prune = [] - for c_doc in self._get_conflicts(doc.doc_id): - c_vcr = vectorclock.VectorClockRev(c_doc.rev) - if doc_vcr.is_newer(c_vcr): - c_revs_to_prune.append(c_doc.rev) - elif doc.same_content_as(c_doc): - c_revs_to_prune.append(c_doc.rev) - doc_vcr.maximize(c_vcr) - autoresolved = True - if autoresolved: - doc_vcr.increment(self._replica_uid) - doc.rev = doc_vcr.as_str() - c = self._db_handle.cursor() - self._delete_conflicts(c, doc, c_revs_to_prune) - - def _force_doc_sync_conflict(self, doc): - my_doc = self._get_doc(doc.doc_id) - c = self._db_handle.cursor() - self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) - self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) - doc.has_conflicts = True - self._put_and_update_indexes(my_doc, doc) - - def resolve_doc(self, doc, conflicted_doc_revs): - with self._db_handle: - cur_doc = self._get_doc(doc.doc_id) - # TODO: https://bugs.launchpad.net/u1db/+bug/928274 - # I think we have a logic bug in resolve_doc - # Specifically, cur_doc.rev is always in the final vector - # clock of revisions that we supersede, even if it wasn't in - # conflicted_doc_revs. We still add it as a conflict, but the - # fact that _put_doc_if_newer propagates resolutions means I - # think that conflict could accidentally be resolved. We need - # to add a test for this case first. (create a rev, create a - # conflict, create another conflict, resolve the first rev - # and first conflict, then make sure that the resolved - # rev doesn't supersede the second conflict rev.) It *might* - # not matter, because the superseding rev is in as a - # conflict, but it does seem incorrect - new_rev = self._ensure_maximal_rev(cur_doc.rev, - conflicted_doc_revs) - superseded_revs = set(conflicted_doc_revs) - c = self._db_handle.cursor() - doc.rev = new_rev - if cur_doc.rev in superseded_revs: - self._put_and_update_indexes(cur_doc, doc) - else: - self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) - # TODO: Is there some way that we could construct a rev that would - # end up in superseded_revs, such that we add a conflict, and - # then immediately delete it? - self._delete_conflicts(c, doc, superseded_revs) - - def list_indexes(self): - """Return the list of indexes and their definitions.""" - c = self._db_handle.cursor() - # TODO: How do we test the ordering? - c.execute("SELECT name, field FROM index_definitions" - " ORDER BY name, offset") - definitions = [] - cur_name = None - for name, field in c.fetchall(): - if cur_name != name: - definitions.append((name, [])) - cur_name = name - definitions[-1][-1].append(field) - return definitions - - def _get_index_definition(self, index_name): - """Return the stored definition for a given index_name.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions" - " WHERE name = ? ORDER BY offset", (index_name,)) - fields = [x[0] for x in c.fetchall()] - if not fields: - raise errors.IndexDoesNotExist - return fields - - @staticmethod - def _strip_glob(value): - """Remove the trailing * from a value.""" - assert value[-1] == '*' - return value[:-1] - - def _format_query(self, definition, key_values): - # First, build the definition. We join the document_fields table - # against itself, as many times as the 'width' of our definition. - # We then do a query for each key_value, one-at-a-time. - # Note: All of these strings are static, we could cache them, etc. - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = ["d.doc_id = d%d.doc_id" - " AND d%d.field_name = ?" - % (i, i) for i in range(len(definition))] - wildcard_where = [novalue_where[i] - + (" AND d%d.value NOT NULL" % (i,)) - for i in range(len(definition))] - exact_where = [novalue_where[i] - + (" AND d%d.value = ?" % (i,)) - for i in range(len(definition))] - like_where = [novalue_where[i] - + (" AND d%d.value GLOB ?" % (i,)) - for i in range(len(definition))] - is_wildcard = False - # Merge the lists together, so that: - # [field1, field2, field3], [val1, val2, val3] - # Becomes: - # (field1, val1, field2, val2, field3, val3) - args = [] - where = [] - for idx, (field, value) in enumerate(zip(definition, key_values)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(exact_where[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_from_index(self, index_name, *key_values): - definition = self._get_index_definition(index_name) - if len(key_values) != len(definition): - raise errors.InvalidValueForIndex() - statement, args = self._format_query(definition, key_values) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def _format_range_query(self, definition, start_value, end_value): - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - wildcard_where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - like_where = [ - novalue_where[i] + ( - " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in - range(len(definition))] - range_where_lower = [ - novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in - range(len(definition))] - range_where_upper = [ - novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in - range(len(definition))] - args = [] - where = [] - if start_value: - if isinstance(start_value, basestring): - start_value = (start_value,) - if len(start_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, start_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(self._strip_glob(value)) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(value) - if end_value: - if isinstance(end_value, basestring): - end_value = (end_value,) - if len(end_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, end_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(self._strip_glob(value)) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_upper[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - """Return all documents with key values in the specified range.""" - definition = self._get_index_definition(index_name) - statement, args = self._format_range_query( - definition, start_value, end_value) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def get_index_keys(self, index_name): - c = self._db_handle.cursor() - definition = self._get_index_definition(index_name) - value_fields = ', '.join([ - 'd%d.value' % i for i in range(len(definition))]) - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - statement = ( - "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( - value_fields, ', '.join(tables), ' AND '.join(where), - value_fields)) - try: - c.execute(statement, tuple(definition)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) - return c.fetchall() - - def delete_index(self, index_name): - with self._db_handle: - c = self._db_handle.cursor() - c.execute("DELETE FROM index_definitions WHERE name = ?", - (index_name,)) - c.execute( - "DELETE FROM document_fields WHERE document_fields.field_name " - " NOT IN (SELECT field from index_definitions)") - - -class SQLCipherSyncTarget(CommonSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - -class SQLCipherPartialExpandDatabase(SQLCipherDatabase): - """An SQLCipher Backend that expands documents into a document_field table. - - It stores the original document text in document.doc. For fields that are - indexed, the data goes into document_fields. - """ - - _index_storage_value = 'expand referenced' - - def _get_indexed_fields(self): - """Determine what fields are indexed.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions") - return set([x[0] for x in c.fetchall()]) - - def _evaluate_index(self, raw_doc, field): - parser = query_parser.Parser() - getter = parser.parse(field) - return getter.get(raw_doc) - - def _put_and_update_indexes(self, old_doc, doc): - c = self._db_handle.cursor() - if doc and not doc.is_tombstone(): - raw_doc = json.loads(doc.get_json()) - else: - raw_doc = {} - if old_doc is not None: - c.execute("UPDATE document SET doc_rev=?, content=?" - " WHERE doc_id = ?", - (doc.rev, doc.get_json(), doc.doc_id)) - c.execute("DELETE FROM document_fields WHERE doc_id = ?", - (doc.doc_id,)) - else: - c.execute("INSERT INTO document (doc_id, doc_rev, content)" - " VALUES (?, ?, ?)", - (doc.doc_id, doc.rev, doc.get_json())) - indexed_fields = self._get_indexed_fields() - if indexed_fields: - # It is expected that len(indexed_fields) is shorter than - # len(raw_doc) - getters = [(field, self._parse_index_definition(field)) - for field in indexed_fields] - self._update_indexes(doc.doc_id, raw_doc, getters, c) - trans_id = self._allocate_transaction_id() - c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" - " VALUES (?, ?)", (doc.doc_id, trans_id)) - - def create_index(self, index_name, *index_expressions): - with self._db_handle: - c = self._db_handle.cursor() - cur_fields = self._get_indexed_fields() - definition = [(index_name, idx, field) - for idx, field in enumerate(index_expressions)] - try: - c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", - definition) - except dbapi2.IntegrityError as e: - stored_def = self._get_index_definition(index_name) - if stored_def == [x[-1] for x in definition]: - return - raise errors.IndexNameTakenError, e, sys.exc_info()[2] - new_fields = set( - [f for f in index_expressions if f not in cur_fields]) - if new_fields: - self._update_all_indexes(new_fields) - - def _iter_all_docs(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, content FROM document") - while True: - next_rows = c.fetchmany() - if not next_rows: - break - for row in next_rows: - yield row - - def _update_all_indexes(self, new_fields): - """Iterate all the documents, and add content to document_fields. - - :param new_fields: The index definitions that need to be added. - """ - getters = [(field, self._parse_index_definition(field)) - for field in new_fields] - c = self._db_handle.cursor() - for doc_id, doc in self._iter_all_docs(): - if doc is None: - continue - raw_doc = json.loads(doc) - self._update_indexes(doc_id, raw_doc, getters, c) -SQLCipherDatabase.register_implementation(SQLCipherPartialExpandDatabase) +SQLCipherDatabase.register_implementation(SQLCipherDatabase) -- cgit v1.2.3 From a12b80b23695dd1db8ac5edeb4b79e6ff8e527c2 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 15:03:12 -0200 Subject: Fix SQLCipherDatabase and add tests. --- src/leap/soledad/backends/sqlcipher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index fcdab251..301d4a7f 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -60,7 +60,8 @@ def open(path, create, document_factory=None, password=None): class SQLCipherDatabase(SQLitePartialExpandDatabase): """A U1DB implementation that uses SQLCipher as its persistence layer.""" - _sqlite_registry = {} + _index_storage_value = 'expand referenced encrypted' + @classmethod def set_pragma_key(cls, db_handle, key): @@ -113,7 +114,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): raise if backend_cls is None: # default is SQLCipherPartialExpandDatabase - backend_cls = SQLCipherPartialExpandDatabase + backend_cls = SQLCipherDatabase return backend_cls(sqlite_file, document_factory=document_factory, password=password) -- cgit v1.2.3 From 19ee861b5c5dca236800ffcb944b4299561d841d Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 13 Dec 2012 13:29:17 -0200 Subject: Change name of cyphertext field to something more meaningful. --- src/leap/soledad/backends/leap.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index ce00c8f3..c113f5c2 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -43,13 +43,13 @@ class LeapDocument(Document): self._default_key, always_trust = True) # TODO: always trust? - return json.dumps({'cyphertext' : str(cyphertext)}) + return json.dumps({'_encrypted_json' : str(cyphertext)}) def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ - cyphertext = json.loads(encrypted_json)['cyphertext'] + cyphertext = json.loads(encrypted_json)['_encrypted_json'] plaintext = str(self._gpg.decrypt(cyphertext)) return self.set_json(plaintext) @@ -97,6 +97,7 @@ class LeapSyncTarget(HTTPSyncTarget): raise BrokenSyncStream line, comma = utils.check_and_strip_comma(entry) entry = json.loads(line) + # decrypt after receiving from server. doc = LeapDocument(entry['id'], entry['rev'], encrypted_json=entry['content']) return_doc_cb(doc, entry['gen'], entry['trans_id']) @@ -142,6 +143,7 @@ class LeapSyncTarget(HTTPSyncTarget): ensure=ensure_callback is not None) comma = ',' for doc, gen, trans_id in docs_by_generations: + # encrypt before sending to server. size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_encrypted_json(), gen=gen, trans_id=trans_id) -- cgit v1.2.3 From ece9f7c2116fa961cafabcc6a5790206412c95ae Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 13 Dec 2012 13:46:27 -0200 Subject: Enforce password on SQLCipher backend. --- src/leap/soledad/backends/sqlcipher.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index 301d4a7f..6fd6e619 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -54,7 +54,7 @@ def open(path, create, document_factory=None, password=None): """ from u1db.backends import sqlite_backend return sqlite_backend.SQLCipherDatabase.open_database( - path, create=create, document_factory=document_factory, password=password) + path, password, create=create, document_factory=document_factory) class SQLCipherDatabase(SQLitePartialExpandDatabase): @@ -67,17 +67,16 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): def set_pragma_key(cls, db_handle, key): db_handle.cursor().execute("PRAGMA key = '%s'" % key) - def __init__(self, sqlite_file, document_factory=None, password=None): + def __init__(self, sqlite_file, password, document_factory=None): """Create a new sqlite file.""" self._db_handle = dbapi2.connect(sqlite_file) - if password: - SQLiteDatabase.set_pragma_key(self._db_handle, password) + SQLCipherDatabase.set_pragma_key(self._db_handle, password) self._real_replica_uid = None self._ensure_schema() self._factory = document_factory or Document @classmethod - def _open_database(cls, sqlite_file, document_factory=None, password=None): + def _open_database(cls, sqlite_file, password, document_factory=None): if not os.path.isfile(sqlite_file): raise errors.DatabaseDoesNotExist() tries = 2 @@ -86,8 +85,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): # where without re-opening the database on Windows, it # doesn't see the transaction that was just committed db_handle = dbapi2.connect(sqlite_file) - if password: - SQLiteDatabase.set_pragma_key(db_handle, password) + SQLCipherDatabase.set_pragma_key(db_handle, password) c = db_handle.cursor() v, err = cls._which_index_storage(c) db_handle.close() @@ -100,23 +98,22 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): tries -= 1 time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) return SQLCipherDatabase._sqlite_registry[v]( - sqlite_file, document_factory=document_factory) + sqlite_file, password, document_factory=document_factory) @classmethod - def open_database(cls, sqlite_file, create, backend_cls=None, - document_factory=None, password=None): + def open_database(cls, sqlite_file, password, create, backend_cls=None, + document_factory=None): try: - return cls._open_database(sqlite_file, - document_factory=document_factory, - password=password) + return cls._open_database(sqlite_file, password, + document_factory=document_factory) except errors.DatabaseDoesNotExist: if not create: raise if backend_cls is None: # default is SQLCipherPartialExpandDatabase backend_cls = SQLCipherDatabase - return backend_cls(sqlite_file, document_factory=document_factory, - password=password) + return backend_cls(sqlite_file, password, + document_factory=document_factory) @staticmethod def register_implementation(klass): -- cgit v1.2.3 From 7a67c36efd95d86dea04ab0741c68f5307a95c09 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 18 Dec 2012 18:51:01 -0200 Subject: Refactor and symmetric encryption --- src/leap/soledad/backends/leap.py | 53 +++++++++++++++++++++----------- src/leap/soledad/backends/objectstore.py | 7 ++--- 2 files changed, 38 insertions(+), 22 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index ce00c8f3..4a496d3e 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -7,12 +7,15 @@ from u1db import Document from u1db.remote.http_target import HTTPSyncTarget from u1db.remote.http_database import HTTPDatabase import base64 -from soledad import GPGWrapper +from soledad.util import GPGWrapper class NoDefaultKey(Exception): pass +class NoSoledadInstance(Exception): + pass + class LeapDocument(Document): """ @@ -22,41 +25,40 @@ class LeapDocument(Document): """ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, - encrypted_json=None, default_key=None, gpg_wrapper=None): + encrypted_json=None, soledad=None): super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) - # we might want to get already initialized wrappers for testing. - if gpg_wrapper is None: - self._gpg = GPGWrapper() - else: - self._gpg = gpg_wrapper + self._soledad = soledad if encrypted_json: self.set_encrypted_json(encrypted_json) - self._default_key = default_key def get_encrypted_json(self): """ Returns document's json serialization encrypted with user's public key. """ - if self._default_key is None: - raise NoDefaultKey() - cyphertext = self._gpg.encrypt(self.get_json(), - self._default_key, - always_trust = True) - # TODO: always trust? - return json.dumps({'cyphertext' : str(cyphertext)}) + if not self._soledad: + raise NoSoledadInstance() + cyphertext = self._soledad.encrypt_symmetric(self.get_json()) + return json.dumps({'_encrypted_json' : cyphertext}) def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ - cyphertext = json.loads(encrypted_json)['cyphertext'] - plaintext = str(self._gpg.decrypt(cyphertext)) + if not self._soledad: + raise NoSoledadInstance() + cyphertext = json.loads(encrypted_json)['_encrypted_json'] + plaintext = self._soledad.decrypt_symmetric(cyphertext) return self.set_json(plaintext) class LeapDatabase(HTTPDatabase): """Implement the HTTP remote database API to a Leap server.""" + def __init__(self, url, document_factory=None, creds=None, soledad=None): + super(LeapDatabase, self).__init__(url, creds=creds) + self._soledad = soledad + self._factory = LeapDocument + @staticmethod def open_database(url, create): db = LeapDatabase(url) @@ -74,9 +76,21 @@ class LeapDatabase(HTTPDatabase): st._creds = self._creds return st + def create_doc_from_json(self, content, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + res, headers = self._request_json('PUT', ['doc', doc_id], {}, + content, 'application/json') + new_doc = self._factory(doc_id, res['rev'], content, soledad=self._soledad) + return new_doc + class LeapSyncTarget(HTTPSyncTarget): + def __init__(self, url, creds=None, soledad=None): + super(LeapSyncTarget, self).__init__(url, creds) + self._soledad = soledad + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): """ Does the same as parent's method but ensures incoming content will be @@ -97,8 +111,10 @@ class LeapSyncTarget(HTTPSyncTarget): raise BrokenSyncStream line, comma = utils.check_and_strip_comma(entry) entry = json.loads(line) + # decrypt after receiving from server. doc = LeapDocument(entry['id'], entry['rev'], - encrypted_json=entry['content']) + encrypted_json=entry['content'], + soledad=self._soledad) return_doc_cb(doc, entry['gen'], entry['trans_id']) if parts[-1] != ']': try: @@ -142,6 +158,7 @@ class LeapSyncTarget(HTTPSyncTarget): ensure=ensure_callback is not None) comma = ',' for doc, gen, trans_id in docs_by_generations: + # encrypt before sending to server. size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_encrypted_json(), gen=gen, trans_id=trans_id) diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index 298bdda3..a8e139f7 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,8 +1,7 @@ import uuid from u1db.backends import CommonBackend -from u1db import errors -from soledad import SyncLog, TransactionLog -from soledad.backends.leap import LeapDocument +from u1db import errors, Document +from soledad.util import SyncLog, TransactionLog class ObjectStore(CommonBackend): @@ -11,7 +10,7 @@ class ObjectStore(CommonBackend): # This initialization method should be called after the connection # with the database is established, so it can ensure that u1db data is # configured and up-to-date. - self.set_document_factory(LeapDocument) + self.set_document_factory(Document) self._sync_log = SyncLog() self._transaction_log = TransactionLog() self._ensure_u1db_data() -- cgit v1.2.3 From 4cd81148ec25cd6f1a9498345c7405a4d37a4012 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 18 Dec 2012 18:57:01 -0200 Subject: Correct typ0 --- src/leap/soledad/backends/leap.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index 4a496d3e..c019ed3f 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -37,8 +37,8 @@ class LeapDocument(Document): """ if not self._soledad: raise NoSoledadInstance() - cyphertext = self._soledad.encrypt_symmetric(self.get_json()) - return json.dumps({'_encrypted_json' : cyphertext}) + ciphertext = self._soledad.encrypt_symmetric(self.get_json()) + return json.dumps({'_encrypted_json' : ciphertext}) def set_encrypted_json(self, encrypted_json): """ @@ -46,8 +46,8 @@ class LeapDocument(Document): """ if not self._soledad: raise NoSoledadInstance() - cyphertext = json.loads(encrypted_json)['_encrypted_json'] - plaintext = self._soledad.decrypt_symmetric(cyphertext) + ciphertext = json.loads(encrypted_json)['_encrypted_json'] + plaintext = self._soledad.decrypt_symmetric(ciphertext) return self.set_json(plaintext) -- cgit v1.2.3 From 7161784fc65698e2603cf53e797dbd13711689e0 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 20 Dec 2012 11:35:19 -0200 Subject: Use doc_id with HMAC for symmetric encryption --- src/leap/soledad/backends/leap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index c019ed3f..9fbd49fe 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -37,7 +37,7 @@ class LeapDocument(Document): """ if not self._soledad: raise NoSoledadInstance() - ciphertext = self._soledad.encrypt_symmetric(self.get_json()) + ciphertext = self._soledad.encrypt_symmetric(self.doc_id, self.get_json()) return json.dumps({'_encrypted_json' : ciphertext}) def set_encrypted_json(self, encrypted_json): @@ -47,7 +47,7 @@ class LeapDocument(Document): if not self._soledad: raise NoSoledadInstance() ciphertext = json.loads(encrypted_json)['_encrypted_json'] - plaintext = self._soledad.decrypt_symmetric(ciphertext) + plaintext = self._soledad.decrypt_symmetric(self.doc_id, ciphertext) return self.set_json(plaintext) -- cgit v1.2.3 From 277f17aa7b7bbcc48583149a3d72d8621f83c0ff Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 24 Dec 2012 10:13:12 -0200 Subject: Document ObjectStore --- src/leap/soledad/backends/objectstore.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'src/leap/soledad/backends') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index a8e139f7..61445a1f 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -5,6 +5,9 @@ from soledad.util import SyncLog, TransactionLog class ObjectStore(CommonBackend): + """ + A backend for storing u1db data in an object store. + """ def __init__(self): # This initialization method should be called after the connection @@ -153,9 +156,13 @@ class ObjectStore(CommonBackend): raise errors.InvalidGeneration return trans_id + #------------------------------------------------------------------------- + # methods specific for object stores + #------------------------------------------------------------------------- + def _ensure_u1db_data(self): """ - Guarantee that u1db data exists in store. + Guarantee that u1db data (logs and replica info) exists in store. """ if not self._is_initialized(): self._initialize() -- cgit v1.2.3