diff options
| author | drebs <drebs@leap.se> | 2012-12-11 12:07:28 -0200 | 
|---|---|---|
| committer | drebs <drebs@leap.se> | 2012-12-11 12:07:28 -0200 | 
| commit | 4417d89bb9bdd59d717501c6db3f2215cdeb87fb (patch) | |
| tree | 8118bdbac78c1a723f18a5bc7f7292cd0d9b72db | |
| parent | 45908d847d09336d685dd38b698441a92570861e (diff) | |
SQLCipherDatabase now extends SQLitePartialExpandDatabase.
| -rw-r--r-- | src/leap/soledad/backends/sqlcipher.py | 831 | 
1 files changed, 3 insertions, 828 deletions
diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index 24f47eed..fcdab251 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -30,6 +30,7 @@ import uuid  import pkg_resources  from u1db.backends import CommonBackend, CommonSyncTarget +from u1db.backends.sqlite_backend import SQLitePartialExpandDatabase  from u1db import (      Document,      errors, @@ -56,7 +57,7 @@ def open(path, create, document_factory=None, password=None):          path, create=create, document_factory=document_factory, password=password) -class SQLCipherDatabase(CommonBackend): +class SQLCipherDatabase(SQLitePartialExpandDatabase):      """A U1DB implementation that uses SQLCipher as its persistence layer."""      _sqlite_registry = {} @@ -74,25 +75,6 @@ class SQLCipherDatabase(CommonBackend):          self._ensure_schema()          self._factory = document_factory or Document -    def set_document_factory(self, factory): -        self._factory = factory - -    def get_sync_target(self): -        return SQLCipherSyncTarget(self) - -    @classmethod -    def _which_index_storage(cls, c): -        try: -            c.execute("SELECT value FROM u1db_config" -                      " WHERE name = 'index_storage'") -        except dbapi2.OperationalError, e: -            # The table does not exist yet -            return None, e -        else: -            return c.fetchone()[0], None - -    WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 -      @classmethod      def _open_database(cls, sqlite_file, document_factory=None, password=None):          if not os.path.isfile(sqlite_file): @@ -136,15 +118,6 @@ class SQLCipherDatabase(CommonBackend):                                 password=password)      @staticmethod -    def delete_database(sqlite_file): -        try: -            os.unlink(sqlite_file) -        except OSError as ex: -            if ex.errno == errno.ENOENT: -                raise errors.DatabaseDoesNotExist() -            raise - -    @staticmethod      def register_implementation(klass):          """Register that we implement an SQLCipherDatabase. @@ -152,803 +125,5 @@ class SQLCipherDatabase(CommonBackend):          """          SQLCipherDatabase._sqlite_registry[klass._index_storage_value] = klass -    def _get_sqlite_handle(self): -        """Get access to the underlying sqlite database. - -        This should only be used by the test suite, etc, for examining the -        state of the underlying database. -        """ -        return self._db_handle - -    def _close_sqlite_handle(self): -        """Release access to the underlying sqlite database.""" -        self._db_handle.close() - -    def close(self): -        self._close_sqlite_handle() - -    def _is_initialized(self, c): -        """Check if this database has been initialized.""" -        c.execute("PRAGMA case_sensitive_like=ON") -        try: -            c.execute("SELECT value FROM u1db_config" -                      " WHERE name = 'sql_schema'") -        except dbapi2.OperationalError: -            # The table does not exist yet -            val = None -        else: -            val = c.fetchone() -        if val is not None: -            return True -        return False - -    def _initialize(self, c): -        """Create the schema in the database.""" -        #read the script with sql commands -        # TODO: Change how we set up the dependency. Most likely use something -        #   like lp:dirspec to grab the file from a common resource -        #   directory. Doesn't specifically need to be handled until we get -        #   to the point of packaging this. -        schema_content = pkg_resources.resource_string( -            __name__, 'dbschema.sql') -        # Note: We'd like to use c.executescript() here, but it seems that -        #       executescript always commits, even if you set -        #       isolation_level = None, so if we want to properly handle -        #       exclusive locking and rollbacks between processes, we need -        #       to execute it line-by-line -        for line in schema_content.split(';'): -            if not line: -                continue -            c.execute(line) -        #add extra fields -        self._extra_schema_init(c) -        # A unique identifier should be set for this replica. Implementations -        # don't have to strictly use uuid here, but we do want the uid to be -        # unique amongst all databases that will sync with each other. -        # We might extend this to using something with hostname for easier -        # debugging. -        self._set_replica_uid_in_transaction(uuid.uuid4().hex) -        c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", -                  (self._index_storage_value,)) - -    def _ensure_schema(self): -        """Ensure that the database schema has been created.""" -        old_isolation_level = self._db_handle.isolation_level -        c = self._db_handle.cursor() -        if self._is_initialized(c): -            return -        try: -            # autocommit/own mgmt of transactions -            self._db_handle.isolation_level = None -            with self._db_handle: -                # only one execution path should initialize the db -                c.execute("begin exclusive") -                if self._is_initialized(c): -                    return -                self._initialize(c) -        finally: -            self._db_handle.isolation_level = old_isolation_level - -    def _extra_schema_init(self, c): -        """Add any extra fields, etc to the basic table definitions.""" - -    def _parse_index_definition(self, index_field): -        """Parse a field definition for an index, returning a Getter.""" -        # Note: We may want to keep a Parser object around, and cache the -        #       Getter objects for a greater length of time. Specifically, if -        #       you create a bunch of indexes, and then insert 50k docs, you'll -        #       re-parse the indexes between puts. The time to insert the docs -        #       is still likely to dominate put_doc time, though. -        parser = query_parser.Parser() -        getter = parser.parse(index_field) -        return getter - -    def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): -        """Update document_fields for a single document. - -        :param doc_id: Identifier for this document -        :param raw_doc: The python dict representation of the document. -        :param getters: A list of [(field_name, Getter)]. Getter.get will be -            called to evaluate the index definition for this document, and the -            results will be inserted into the db. -        :param db_cursor: An sqlite Cursor. -        :return: None -        """ -        values = [] -        for field_name, getter in getters: -            for idx_value in getter.get(raw_doc): -                values.append((doc_id, field_name, idx_value)) -        if values: -            db_cursor.executemany( -                "INSERT INTO document_fields VALUES (?, ?, ?)", values) - -    def _set_replica_uid(self, replica_uid): -        """Force the replica_uid to be set.""" -        with self._db_handle: -            self._set_replica_uid_in_transaction(replica_uid) - -    def _set_replica_uid_in_transaction(self, replica_uid): -        """Set the replica_uid. A transaction should already be held.""" -        c = self._db_handle.cursor() -        c.execute("INSERT OR REPLACE INTO u1db_config" -                  " VALUES ('replica_uid', ?)", -                  (replica_uid,)) -        self._real_replica_uid = replica_uid - -    def _get_replica_uid(self): -        if self._real_replica_uid is not None: -            return self._real_replica_uid -        c = self._db_handle.cursor() -        c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") -        val = c.fetchone() -        if val is None: -            return None -        self._real_replica_uid = val[0] -        return self._real_replica_uid - -    _replica_uid = property(_get_replica_uid) - -    def _get_generation(self): -        c = self._db_handle.cursor() -        c.execute('SELECT max(generation) FROM transaction_log') -        val = c.fetchone()[0] -        if val is None: -            return 0 -        return val - -    def _get_generation_info(self): -        c = self._db_handle.cursor() -        c.execute( -            'SELECT max(generation), transaction_id FROM transaction_log ') -        val = c.fetchone() -        if val[0] is None: -            return(0, '') -        return val - -    def _get_trans_id_for_gen(self, generation): -        if generation == 0: -            return '' -        c = self._db_handle.cursor() -        c.execute( -            'SELECT transaction_id FROM transaction_log WHERE generation = ?', -            (generation,)) -        val = c.fetchone() -        if val is None: -            raise errors.InvalidGeneration -        return val[0] - -    def _get_transaction_log(self): -        c = self._db_handle.cursor() -        c.execute("SELECT doc_id, transaction_id FROM transaction_log" -                  " ORDER BY generation") -        return c.fetchall() - -    def _get_doc(self, doc_id, check_for_conflicts=False): -        """Get just the document content, without fancy handling.""" -        c = self._db_handle.cursor() -        if check_for_conflicts: -            c.execute( -                "SELECT document.doc_rev, document.content, " -                "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " -                "conflicts ON conflicts.doc_id = document.doc_id WHERE " -                "document.doc_id = ? GROUP BY document.doc_id, " -                "document.doc_rev, document.content;", (doc_id,)) -        else: -            c.execute( -                "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", -                (doc_id,)) -        val = c.fetchone() -        if val is None: -            return None -        doc_rev, content, conflicts = val -        doc = self._factory(doc_id, doc_rev, content) -        doc.has_conflicts = conflicts > 0 -        return doc - -    def _has_conflicts(self, doc_id): -        c = self._db_handle.cursor() -        c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", -                  (doc_id,)) -        val = c.fetchone() -        if val is None: -            return False -        else: -            return True - -    def get_doc(self, doc_id, include_deleted=False): -        doc = self._get_doc(doc_id, check_for_conflicts=True) -        if doc is None: -            return None -        if doc.is_tombstone() and not include_deleted: -            return None -        return doc - -    def get_all_docs(self, include_deleted=False): -        """Get all documents from the database.""" -        generation = self._get_generation() -        results = [] -        c = self._db_handle.cursor() -        c.execute( -            "SELECT document.doc_id, document.doc_rev, document.content, " -            "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " -            "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " -            "document.doc_rev, document.content;") -        rows = c.fetchall() -        for doc_id, doc_rev, content, conflicts in rows: -            if content is None and not include_deleted: -                continue -            doc = self._factory(doc_id, doc_rev, content) -            doc.has_conflicts = conflicts > 0 -            results.append(doc) -        return (generation, results) - -    def put_doc(self, doc): -        if doc.doc_id is None: -            raise errors.InvalidDocId() -        self._check_doc_id(doc.doc_id) -        self._check_doc_size(doc) -        with self._db_handle: -            old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) -            if old_doc and old_doc.has_conflicts: -                raise errors.ConflictedDoc() -            if old_doc and doc.rev is None and old_doc.is_tombstone(): -                new_rev = self._allocate_doc_rev(old_doc.rev) -            else: -                if old_doc is not None: -                        if old_doc.rev != doc.rev: -                            raise errors.RevisionConflict() -                else: -                    if doc.rev is not None: -                        raise errors.RevisionConflict() -                new_rev = self._allocate_doc_rev(doc.rev) -            doc.rev = new_rev -            self._put_and_update_indexes(old_doc, doc) -        return new_rev - -    def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): -        """Convert a dict representation into named fields. - -        So something like: {'key1': 'val1', 'key2': 'val2'} -        gets converted into: [(doc_id, 'key1', 'val1', 0) -                              (doc_id, 'key2', 'val2', 0)] -        :param doc_id: Just added to every record. -        :param base_field: if set, these are nested keys, so each field should -            be appropriately prefixed. -        :param raw_doc: The python dictionary. -        """ -        # TODO: Handle lists -        values = [] -        for field_name, value in raw_doc.iteritems(): -            if value is None and not save_none: -                continue -            if base_field: -                full_name = base_field + '.' + field_name -            else: -                full_name = field_name -            if value is None or isinstance(value, (int, float, basestring)): -                values.append((doc_id, full_name, value, len(values))) -            else: -                subvalues = self._expand_to_fields(doc_id, full_name, value, -                                                   save_none) -                for _, subfield_name, val, _ in subvalues: -                    values.append((doc_id, subfield_name, val, len(values))) -        return values - -    def _put_and_update_indexes(self, old_doc, doc): -        """Actually insert a document into the database. - -        This both updates the existing documents content, and any indexes that -        refer to this document. -        """ -        raise NotImplementedError(self._put_and_update_indexes) - -    def whats_changed(self, old_generation=0): -        c = self._db_handle.cursor() -        c.execute("SELECT generation, doc_id, transaction_id" -                  " FROM transaction_log" -                  " WHERE generation > ? ORDER BY generation DESC", -                  (old_generation,)) -        results = c.fetchall() -        cur_gen = old_generation -        seen = set() -        changes = [] -        newest_trans_id = '' -        for generation, doc_id, trans_id in results: -            if doc_id not in seen: -                changes.append((doc_id, generation, trans_id)) -                seen.add(doc_id) -        if changes: -            cur_gen = changes[0][1]  # max generation -            newest_trans_id = changes[0][2] -            changes.reverse() -        else: -            c.execute("SELECT generation, transaction_id" -                      " FROM transaction_log ORDER BY generation DESC LIMIT 1") -            results = c.fetchone() -            if not results: -                cur_gen = 0 -                newest_trans_id = '' -            else: -                cur_gen, newest_trans_id = results - -        return cur_gen, newest_trans_id, changes - -    def delete_doc(self, doc): -        with self._db_handle: -            old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) -            if old_doc is None: -                raise errors.DocumentDoesNotExist -            if old_doc.rev != doc.rev: -                raise errors.RevisionConflict() -            if old_doc.is_tombstone(): -                raise errors.DocumentAlreadyDeleted -            if old_doc.has_conflicts: -                raise errors.ConflictedDoc() -            new_rev = self._allocate_doc_rev(doc.rev) -            doc.rev = new_rev -            doc.make_tombstone() -            self._put_and_update_indexes(old_doc, doc) -        return new_rev - -    def _get_conflicts(self, doc_id): -        c = self._db_handle.cursor() -        c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", -                  (doc_id,)) -        return [self._factory(doc_id, doc_rev, content) -                for doc_rev, content in c.fetchall()] - -    def get_doc_conflicts(self, doc_id): -        with self._db_handle: -            conflict_docs = self._get_conflicts(doc_id) -            if not conflict_docs: -                return [] -            this_doc = self._get_doc(doc_id) -            this_doc.has_conflicts = True -            return [this_doc] + conflict_docs - -    def _get_replica_gen_and_trans_id(self, other_replica_uid): -        c = self._db_handle.cursor() -        c.execute("SELECT known_generation, known_transaction_id FROM sync_log" -                  " WHERE replica_uid = ?", -                  (other_replica_uid,)) -        val = c.fetchone() -        if val is None: -            other_gen = 0 -            trans_id = '' -        else: -            other_gen = val[0] -            trans_id = val[1] -        return other_gen, trans_id - -    def _set_replica_gen_and_trans_id(self, other_replica_uid, -                                      other_generation, other_transaction_id): -        with self._db_handle: -            self._do_set_replica_gen_and_trans_id( -                other_replica_uid, other_generation, other_transaction_id) - -    def _do_set_replica_gen_and_trans_id(self, other_replica_uid, -                                         other_generation, -                                         other_transaction_id): -            c = self._db_handle.cursor() -            c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", -                      (other_replica_uid, other_generation, -                       other_transaction_id)) - -    def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, -                          replica_gen=None, replica_trans_id=None): -        with self._db_handle: -            return super(SQLCipherDatabase, self)._put_doc_if_newer(doc, -                save_conflict=save_conflict, -                replica_uid=replica_uid, replica_gen=replica_gen, -                replica_trans_id=replica_trans_id) - -    def _add_conflict(self, c, doc_id, my_doc_rev, my_content): -        c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", -                  (doc_id, my_doc_rev, my_content)) - -    def _delete_conflicts(self, c, doc, conflict_revs): -        deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] -        c.executemany("DELETE FROM conflicts" -                      " WHERE doc_id=? AND doc_rev=?", deleting) -        doc.has_conflicts = self._has_conflicts(doc.doc_id) - -    def _prune_conflicts(self, doc, doc_vcr): -        if self._has_conflicts(doc.doc_id): -            autoresolved = False -            c_revs_to_prune = [] -            for c_doc in self._get_conflicts(doc.doc_id): -                c_vcr = vectorclock.VectorClockRev(c_doc.rev) -                if doc_vcr.is_newer(c_vcr): -                    c_revs_to_prune.append(c_doc.rev) -                elif doc.same_content_as(c_doc): -                    c_revs_to_prune.append(c_doc.rev) -                    doc_vcr.maximize(c_vcr) -                    autoresolved = True -            if autoresolved: -                doc_vcr.increment(self._replica_uid) -                doc.rev = doc_vcr.as_str() -            c = self._db_handle.cursor() -            self._delete_conflicts(c, doc, c_revs_to_prune) - -    def _force_doc_sync_conflict(self, doc): -        my_doc = self._get_doc(doc.doc_id) -        c = self._db_handle.cursor() -        self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) -        self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) -        doc.has_conflicts = True -        self._put_and_update_indexes(my_doc, doc) - -    def resolve_doc(self, doc, conflicted_doc_revs): -        with self._db_handle: -            cur_doc = self._get_doc(doc.doc_id) -            # TODO: https://bugs.launchpad.net/u1db/+bug/928274 -            #       I think we have a logic bug in resolve_doc -            #       Specifically, cur_doc.rev is always in the final vector -            #       clock of revisions that we supersede, even if it wasn't in -            #       conflicted_doc_revs. We still add it as a conflict, but the -            #       fact that _put_doc_if_newer propagates resolutions means I -            #       think that conflict could accidentally be resolved. We need -            #       to add a test for this case first. (create a rev, create a -            #       conflict, create another conflict, resolve the first rev -            #       and first conflict, then make sure that the resolved -            #       rev doesn't supersede the second conflict rev.) It *might* -            #       not matter, because the superseding rev is in as a -            #       conflict, but it does seem incorrect -            new_rev = self._ensure_maximal_rev(cur_doc.rev, -                                               conflicted_doc_revs) -            superseded_revs = set(conflicted_doc_revs) -            c = self._db_handle.cursor() -            doc.rev = new_rev -            if cur_doc.rev in superseded_revs: -                self._put_and_update_indexes(cur_doc, doc) -            else: -                self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) -            # TODO: Is there some way that we could construct a rev that would -            #       end up in superseded_revs, such that we add a conflict, and -            #       then immediately delete it? -            self._delete_conflicts(c, doc, superseded_revs) - -    def list_indexes(self): -        """Return the list of indexes and their definitions.""" -        c = self._db_handle.cursor() -        # TODO: How do we test the ordering? -        c.execute("SELECT name, field FROM index_definitions" -                  " ORDER BY name, offset") -        definitions = [] -        cur_name = None -        for name, field in c.fetchall(): -            if cur_name != name: -                definitions.append((name, [])) -                cur_name = name -            definitions[-1][-1].append(field) -        return definitions - -    def _get_index_definition(self, index_name): -        """Return the stored definition for a given index_name.""" -        c = self._db_handle.cursor() -        c.execute("SELECT field FROM index_definitions" -                  " WHERE name = ? ORDER BY offset", (index_name,)) -        fields = [x[0] for x in c.fetchall()] -        if not fields: -            raise errors.IndexDoesNotExist -        return fields - -    @staticmethod -    def _strip_glob(value): -        """Remove the trailing * from a value.""" -        assert value[-1] == '*' -        return value[:-1] - -    def _format_query(self, definition, key_values): -        # First, build the definition. We join the document_fields table -        # against itself, as many times as the 'width' of our definition. -        # We then do a query for each key_value, one-at-a-time. -        # Note: All of these strings are static, we could cache them, etc. -        tables = ["document_fields d%d" % i for i in range(len(definition))] -        novalue_where = ["d.doc_id = d%d.doc_id" -                         " AND d%d.field_name = ?" -                         % (i, i) for i in range(len(definition))] -        wildcard_where = [novalue_where[i] -                          + (" AND d%d.value NOT NULL" % (i,)) -                          for i in range(len(definition))] -        exact_where = [novalue_where[i] -                       + (" AND d%d.value = ?" % (i,)) -                       for i in range(len(definition))] -        like_where = [novalue_where[i] -                      + (" AND d%d.value GLOB ?" % (i,)) -                      for i in range(len(definition))] -        is_wildcard = False -        # Merge the lists together, so that: -        # [field1, field2, field3], [val1, val2, val3] -        # Becomes: -        # (field1, val1, field2, val2, field3, val3) -        args = [] -        where = [] -        for idx, (field, value) in enumerate(zip(definition, key_values)): -            args.append(field) -            if value.endswith('*'): -                if value == '*': -                    where.append(wildcard_where[idx]) -                else: -                    # This is a glob match -                    if is_wildcard: -                        # We can't have a partial wildcard following -                        # another wildcard -                        raise errors.InvalidGlobbing -                    where.append(like_where[idx]) -                    args.append(value) -                is_wildcard = True -            else: -                if is_wildcard: -                    raise errors.InvalidGlobbing -                where.append(exact_where[idx]) -                args.append(value) -        statement = ( -            "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " -            "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " -            "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " -            "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( -                ['d%d.value' % i for i in range(len(definition))]))) -        return statement, args - -    def get_from_index(self, index_name, *key_values): -        definition = self._get_index_definition(index_name) -        if len(key_values) != len(definition): -            raise errors.InvalidValueForIndex() -        statement, args = self._format_query(definition, key_values) -        c = self._db_handle.cursor() -        try: -            c.execute(statement, tuple(args)) -        except dbapi2.OperationalError, e: -            raise dbapi2.OperationalError(str(e) + -                '\nstatement: %s\nargs: %s\n' % (statement, args)) -        res = c.fetchall() -        results = [] -        for row in res: -            doc = self._factory(row[0], row[1], row[2]) -            doc.has_conflicts = row[3] > 0 -            results.append(doc) -        return results - -    def _format_range_query(self, definition, start_value, end_value): -        tables = ["document_fields d%d" % i for i in range(len(definition))] -        novalue_where = [ -            "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in -            range(len(definition))] -        wildcard_where = [ -            novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in -            range(len(definition))] -        like_where = [ -            novalue_where[i] + ( -                " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in -            range(len(definition))] -        range_where_lower = [ -            novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in -            range(len(definition))] -        range_where_upper = [ -            novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in -            range(len(definition))] -        args = [] -        where = [] -        if start_value: -            if isinstance(start_value, basestring): -                start_value = (start_value,) -            if len(start_value) != len(definition): -                raise errors.InvalidValueForIndex() -            is_wildcard = False -            for idx, (field, value) in enumerate(zip(definition, start_value)): -                args.append(field) -                if value.endswith('*'): -                    if value == '*': -                        where.append(wildcard_where[idx]) -                    else: -                        # This is a glob match -                        if is_wildcard: -                            # We can't have a partial wildcard following -                            # another wildcard -                            raise errors.InvalidGlobbing -                        where.append(range_where_lower[idx]) -                        args.append(self._strip_glob(value)) -                    is_wildcard = True -                else: -                    if is_wildcard: -                        raise errors.InvalidGlobbing -                    where.append(range_where_lower[idx]) -                    args.append(value) -        if end_value: -            if isinstance(end_value, basestring): -                end_value = (end_value,) -            if len(end_value) != len(definition): -                raise errors.InvalidValueForIndex() -            is_wildcard = False -            for idx, (field, value) in enumerate(zip(definition, end_value)): -                args.append(field) -                if value.endswith('*'): -                    if value == '*': -                        where.append(wildcard_where[idx]) -                    else: -                        # This is a glob match -                        if is_wildcard: -                            # We can't have a partial wildcard following -                            # another wildcard -                            raise errors.InvalidGlobbing -                        where.append(like_where[idx]) -                        args.append(self._strip_glob(value)) -                        args.append(value) -                    is_wildcard = True -                else: -                    if is_wildcard: -                        raise errors.InvalidGlobbing -                    where.append(range_where_upper[idx]) -                    args.append(value) -        statement = ( -            "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " -            "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " -            "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " -            "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( -                ['d%d.value' % i for i in range(len(definition))]))) -        return statement, args - -    def get_range_from_index(self, index_name, start_value=None, -                             end_value=None): -        """Return all documents with key values in the specified range.""" -        definition = self._get_index_definition(index_name) -        statement, args = self._format_range_query( -            definition, start_value, end_value) -        c = self._db_handle.cursor() -        try: -            c.execute(statement, tuple(args)) -        except dbapi2.OperationalError, e: -            raise dbapi2.OperationalError(str(e) + -                '\nstatement: %s\nargs: %s\n' % (statement, args)) -        res = c.fetchall() -        results = [] -        for row in res: -            doc = self._factory(row[0], row[1], row[2]) -            doc.has_conflicts = row[3] > 0 -            results.append(doc) -        return results - -    def get_index_keys(self, index_name): -        c = self._db_handle.cursor() -        definition = self._get_index_definition(index_name) -        value_fields = ', '.join([ -            'd%d.value' % i for i in range(len(definition))]) -        tables = ["document_fields d%d" % i for i in range(len(definition))] -        novalue_where = [ -            "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in -            range(len(definition))] -        where = [ -            novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in -            range(len(definition))] -        statement = ( -            "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( -                value_fields, ', '.join(tables), ' AND '.join(where), -                value_fields)) -        try: -            c.execute(statement, tuple(definition)) -        except dbapi2.OperationalError, e: -            raise dbapi2.OperationalError(str(e) + -                '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) -        return c.fetchall() - -    def delete_index(self, index_name): -        with self._db_handle: -            c = self._db_handle.cursor() -            c.execute("DELETE FROM index_definitions WHERE name = ?", -                      (index_name,)) -            c.execute( -                "DELETE FROM document_fields WHERE document_fields.field_name " -                " NOT IN (SELECT field from index_definitions)") - - -class SQLCipherSyncTarget(CommonSyncTarget): - -    def get_sync_info(self, source_replica_uid): -        source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( -            source_replica_uid) -        my_gen, my_trans_id = self._db._get_generation_info() -        return ( -            self._db._replica_uid, my_gen, my_trans_id, source_gen, -            source_trans_id) - -    def record_sync_info(self, source_replica_uid, source_replica_generation, -                         source_replica_transaction_id): -        if self._trace_hook: -            self._trace_hook('record_sync_info') -        self._db._set_replica_gen_and_trans_id( -            source_replica_uid, source_replica_generation, -            source_replica_transaction_id) - - -class SQLCipherPartialExpandDatabase(SQLCipherDatabase): -    """An SQLCipher Backend that expands documents into a document_field table. - -    It stores the original document text in document.doc. For fields that are -    indexed, the data goes into document_fields. -    """ - -    _index_storage_value = 'expand referenced' - -    def _get_indexed_fields(self): -        """Determine what fields are indexed.""" -        c = self._db_handle.cursor() -        c.execute("SELECT field FROM index_definitions") -        return set([x[0] for x in c.fetchall()]) - -    def _evaluate_index(self, raw_doc, field): -        parser = query_parser.Parser() -        getter = parser.parse(field) -        return getter.get(raw_doc) - -    def _put_and_update_indexes(self, old_doc, doc): -        c = self._db_handle.cursor() -        if doc and not doc.is_tombstone(): -            raw_doc = json.loads(doc.get_json()) -        else: -            raw_doc = {} -        if old_doc is not None: -            c.execute("UPDATE document SET doc_rev=?, content=?" -                      " WHERE doc_id = ?", -                      (doc.rev, doc.get_json(), doc.doc_id)) -            c.execute("DELETE FROM document_fields WHERE doc_id = ?", -                      (doc.doc_id,)) -        else: -            c.execute("INSERT INTO document (doc_id, doc_rev, content)" -                      " VALUES (?, ?, ?)", -                      (doc.doc_id, doc.rev, doc.get_json())) -        indexed_fields = self._get_indexed_fields() -        if indexed_fields: -            # It is expected that len(indexed_fields) is shorter than -            # len(raw_doc) -            getters = [(field, self._parse_index_definition(field)) -                       for field in indexed_fields] -            self._update_indexes(doc.doc_id, raw_doc, getters, c) -        trans_id = self._allocate_transaction_id() -        c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" -                  " VALUES (?, ?)", (doc.doc_id, trans_id)) - -    def create_index(self, index_name, *index_expressions): -        with self._db_handle: -            c = self._db_handle.cursor() -            cur_fields = self._get_indexed_fields() -            definition = [(index_name, idx, field) -                          for idx, field in enumerate(index_expressions)] -            try: -                c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", -                              definition) -            except dbapi2.IntegrityError as e: -                stored_def = self._get_index_definition(index_name) -                if stored_def == [x[-1] for x in definition]: -                    return -                raise errors.IndexNameTakenError, e, sys.exc_info()[2] -            new_fields = set( -                [f for f in index_expressions if f not in cur_fields]) -            if new_fields: -                self._update_all_indexes(new_fields) - -    def _iter_all_docs(self): -        c = self._db_handle.cursor() -        c.execute("SELECT doc_id, content FROM document") -        while True: -            next_rows = c.fetchmany() -            if not next_rows: -                break -            for row in next_rows: -                yield row - -    def _update_all_indexes(self, new_fields): -        """Iterate all the documents, and add content to document_fields. - -        :param new_fields: The index definitions that need to be added. -        """ -        getters = [(field, self._parse_index_definition(field)) -                   for field in new_fields] -        c = self._db_handle.cursor() -        for doc_id, doc in self._iter_all_docs(): -            if doc is None: -                continue -            raw_doc = json.loads(doc) -            self._update_indexes(doc_id, raw_doc, getters, c) -SQLCipherDatabase.register_implementation(SQLCipherPartialExpandDatabase) +SQLCipherDatabase.register_implementation(SQLCipherDatabase)  | 
