diff options
| -rw-r--r-- | backends/sqlcipher.py | 109 | 
1 files changed, 108 insertions, 1 deletions
| diff --git a/backends/sqlcipher.py b/backends/sqlcipher.py index 9a508dc2..a2ec9840 100644 --- a/backends/sqlcipher.py +++ b/backends/sqlcipher.py @@ -18,6 +18,7 @@  import os  from pysqlcipher import dbapi2 +from sqlite3 import dbapi2 as sqlite3_dbapi2  import time  from u1db.backends.sqlite_backend import ( @@ -92,7 +93,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase):                  # backend should raise a DatabaseError exception.                  SQLitePartialExpandDatabase(sqlite_file)                  raise DatabaseIsNotEncrypted() -            except dbapi2.DatabaseError: +            except sqlite3_dbapi2.DatabaseError:                  pass      @classmethod @@ -173,5 +174,111 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase):              doc.syncable = bool(c.fetchone()[0])          return doc +    # TODO: remove methods below after solving Exception handling problem. +    def _is_initialized(self, c): +        """Check if this database has been initialized.""" +        c.execute("PRAGMA case_sensitive_like=ON") +        try: +            c.execute("SELECT value FROM u1db_config" +                      " WHERE name = 'sql_schema'") +        except dbapi2.OperationalError: +            # The table does not exist yet +            val = None +        else: +            val = c.fetchone() +        if val is not None: +            return True +        return False + +    def get_from_index(self, index_name, *key_values): +        definition = self._get_index_definition(index_name) +        if len(key_values) != len(definition): +            raise errors.InvalidValueForIndex() +        statement, args = self._format_query(definition, key_values) +        c = self._db_handle.cursor() +        try: +            c.execute(statement, tuple(args)) +        except dbapi2.OperationalError, e: +            raise dbapi2.OperationalError( +                str(e) + +                '\nstatement: %s\nargs: %s\n' % (statement, args)) +        res = c.fetchall() +        results = [] +        for row in res: +            doc = self._factory(row[0], row[1], row[2]) +            doc.has_conflicts = row[3] > 0 +            results.append(doc) +        return results + +    def get_range_from_index(self, index_name, start_value=None, +                             end_value=None): +        """Return all documents with key values in the specified range.""" +        definition = self._get_index_definition(index_name) +        statement, args = self._format_range_query( +            definition, start_value, end_value) +        c = self._db_handle.cursor() +        try: +            c.execute(statement, tuple(args)) +        except dbapi2.OperationalError, e: +            raise dbapi2.OperationalError(str(e) + +                '\nstatement: %s\nargs: %s\n' % (statement, args)) +        res = c.fetchall() +        results = [] +        for row in res: +            doc = self._factory(row[0], row[1], row[2]) +            doc.has_conflicts = row[3] > 0 +            results.append(doc) +        return results + +    def get_index_keys(self, index_name): +        c = self._db_handle.cursor() +        definition = self._get_index_definition(index_name) +        value_fields = ', '.join([ +            'd%d.value' % i for i in range(len(definition))]) +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        novalue_where = [ +            "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in +            range(len(definition))] +        where = [ +            novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in +            range(len(definition))] +        statement = ( +            "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( +                value_fields, ', '.join(tables), ' AND '.join(where), +                value_fields)) +        try: +            c.execute(statement, tuple(definition)) +        except dbapi2.OperationalError, e: +            raise dbapi2.OperationalError(str(e) + +                '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) +        return c.fetchall() + +    def delete_index(self, index_name): +        with self._db_handle: +            c = self._db_handle.cursor() +            c.execute("DELETE FROM index_definitions WHERE name = ?", +                      (index_name,)) +            c.execute( +                "DELETE FROM document_fields WHERE document_fields.field_name " +                " NOT IN (SELECT field from index_definitions)") + +    def create_index(self, index_name, *index_expressions): +        with self._db_handle: +            c = self._db_handle.cursor() +            cur_fields = self._get_indexed_fields() +            definition = [(index_name, idx, field) +                          for idx, field in enumerate(index_expressions)] +            try: +                c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", +                              definition) +            except dbapi2.IntegrityError as e: +                stored_def = self._get_index_definition(index_name) +                if stored_def == [x[-1] for x in definition]: +                    return +                raise errors.IndexNameTakenError, e, sys.exc_info()[2] +            new_fields = set( +                [f for f in index_expressions if f not in cur_fields]) +            if new_fields: +                self._update_all_indexes(new_fields)  SQLiteDatabase.register_implementation(SQLCipherDatabase) | 
