summaryrefslogtreecommitdiff
path: root/src/leap/soledad/u1db/backends/sqlite_backend.py
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2012-11-29 10:56:49 -0200
committerdrebs <drebs@leap.se>2012-11-29 10:56:49 -0200
commit17ccbcb831044c29f521b529f5aa96dc2a3cd18f (patch)
treef8d1144f83f3b1d33fe246887b316029b75195ec /src/leap/soledad/u1db/backends/sqlite_backend.py
parent0f1f9474e7ea6b52dc3ae18444cfaaca56ff3070 (diff)
add u1db code (not as submodule)
Diffstat (limited to 'src/leap/soledad/u1db/backends/sqlite_backend.py')
-rw-r--r--src/leap/soledad/u1db/backends/sqlite_backend.py926
1 files changed, 926 insertions, 0 deletions
diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py
new file mode 100644
index 00000000..773213b5
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/sqlite_backend.py
@@ -0,0 +1,926 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A U1DB implementation that uses SQLite as its persistence layer."""
+
+import errno
+import os
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from sqlite3 import dbapi2
+import sys
+import time
+import uuid
+
+import pkg_resources
+
+from u1db.backends import CommonBackend, CommonSyncTarget
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+
+
+class SQLiteDatabase(CommonBackend):
+ """A U1DB implementation that uses SQLite as its persistence layer."""
+
+ _sqlite_registry = {}
+
+ def __init__(self, sqlite_file, document_factory=None):
+ """Create a new sqlite file."""
+ self._db_handle = dbapi2.connect(sqlite_file)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def get_sync_target(self):
+ return SQLiteSyncTarget(self)
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'index_storage'")
+ except dbapi2.OperationalError, e:
+ # The table does not exist yet
+ return None, e
+ else:
+ return c.fetchone()[0], None
+
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
+
+ @classmethod
+ def _open_database(cls, sqlite_file, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLiteDatabase._sqlite_registry[v](
+ sqlite_file, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(
+ sqlite_file, document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLitePartialExpandDatabase
+ backend_cls = SQLitePartialExpandDatabase
+ return backend_cls(sqlite_file, document_factory=document_factory)
+
+ @staticmethod
+ def delete_database(sqlite_file):
+ try:
+ os.unlink(sqlite_file)
+ except OSError as ex:
+ if ex.errno == errno.ENOENT:
+ raise errors.DatabaseDoesNotExist()
+ raise
+
+ @staticmethod
+ def register_implementation(klass):
+ """Register that we implement an SQLiteDatabase.
+
+ The attribute _index_storage_value will be used as the lookup key.
+ """
+ SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
+
+ def _get_sqlite_handle(self):
+ """Get access to the underlying sqlite database.
+
+ This should only be used by the test suite, etc, for examining the
+ state of the underlying database.
+ """
+ return self._db_handle
+
+ def _close_sqlite_handle(self):
+ """Release access to the underlying sqlite database."""
+ self._db_handle.close()
+
+ def close(self):
+ self._close_sqlite_handle()
+
+ def _is_initialized(self, c):
+ """Check if this database has been initialized."""
+ c.execute("PRAGMA case_sensitive_like=ON")
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'sql_schema'")
+ except dbapi2.OperationalError:
+ # The table does not exist yet
+ val = None
+ else:
+ val = c.fetchone()
+ if val is not None:
+ return True
+ return False
+
+ def _initialize(self, c):
+ """Create the schema in the database."""
+ #read the script with sql commands
+ # TODO: Change how we set up the dependency. Most likely use something
+ # like lp:dirspec to grab the file from a common resource
+ # directory. Doesn't specifically need to be handled until we get
+ # to the point of packaging this.
+ schema_content = pkg_resources.resource_string(
+ __name__, 'dbschema.sql')
+ # Note: We'd like to use c.executescript() here, but it seems that
+ # executescript always commits, even if you set
+ # isolation_level = None, so if we want to properly handle
+ # exclusive locking and rollbacks between processes, we need
+ # to execute it line-by-line
+ for line in schema_content.split(';'):
+ if not line:
+ continue
+ c.execute(line)
+ #add extra fields
+ self._extra_schema_init(c)
+ # A unique identifier should be set for this replica. Implementations
+ # don't have to strictly use uuid here, but we do want the uid to be
+ # unique amongst all databases that will sync with each other.
+ # We might extend this to using something with hostname for easier
+ # debugging.
+ self._set_replica_uid_in_transaction(uuid.uuid4().hex)
+ c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
+ (self._index_storage_value,))
+
+ def _ensure_schema(self):
+ """Ensure that the database schema has been created."""
+ old_isolation_level = self._db_handle.isolation_level
+ c = self._db_handle.cursor()
+ if self._is_initialized(c):
+ return
+ try:
+ # autocommit/own mgmt of transactions
+ self._db_handle.isolation_level = None
+ with self._db_handle:
+ # only one execution path should initialize the db
+ c.execute("begin exclusive")
+ if self._is_initialized(c):
+ return
+ self._initialize(c)
+ finally:
+ self._db_handle.isolation_level = old_isolation_level
+
+ def _extra_schema_init(self, c):
+ """Add any extra fields, etc to the basic table definitions."""
+
+ def _parse_index_definition(self, index_field):
+ """Parse a field definition for an index, returning a Getter."""
+ # Note: We may want to keep a Parser object around, and cache the
+ # Getter objects for a greater length of time. Specifically, if
+ # you create a bunch of indexes, and then insert 50k docs, you'll
+ # re-parse the indexes between puts. The time to insert the docs
+ # is still likely to dominate put_doc time, though.
+ parser = query_parser.Parser()
+ getter = parser.parse(index_field)
+ return getter
+
+ def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
+ """Update document_fields for a single document.
+
+ :param doc_id: Identifier for this document
+ :param raw_doc: The python dict representation of the document.
+ :param getters: A list of [(field_name, Getter)]. Getter.get will be
+ called to evaluate the index definition for this document, and the
+ results will be inserted into the db.
+ :param db_cursor: An sqlite Cursor.
+ :return: None
+ """
+ values = []
+ for field_name, getter in getters:
+ for idx_value in getter.get(raw_doc):
+ values.append((doc_id, field_name, idx_value))
+ if values:
+ db_cursor.executemany(
+ "INSERT INTO document_fields VALUES (?, ?, ?)", values)
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ with self._db_handle:
+ self._set_replica_uid_in_transaction(replica_uid)
+
+ def _set_replica_uid_in_transaction(self, replica_uid):
+ """Set the replica_uid. A transaction should already be held."""
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO u1db_config"
+ " VALUES ('replica_uid', ?)",
+ (replica_uid,))
+ self._real_replica_uid = replica_uid
+
+ def _get_replica_uid(self):
+ if self._real_replica_uid is not None:
+ return self._real_replica_uid
+ c = self._db_handle.cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
+ val = c.fetchone()
+ if val is None:
+ return None
+ self._real_replica_uid = val[0]
+ return self._real_replica_uid
+
+ _replica_uid = property(_get_replica_uid)
+
+ def _get_generation(self):
+ c = self._db_handle.cursor()
+ c.execute('SELECT max(generation) FROM transaction_log')
+ val = c.fetchone()[0]
+ if val is None:
+ return 0
+ return val
+
+ def _get_generation_info(self):
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT max(generation), transaction_id FROM transaction_log ')
+ val = c.fetchone()
+ if val[0] is None:
+ return(0, '')
+ return val
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
+ (generation,))
+ val = c.fetchone()
+ if val is None:
+ raise errors.InvalidGeneration
+ return val[0]
+
+ def _get_transaction_log(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, transaction_id FROM transaction_log"
+ " ORDER BY generation")
+ return c.fetchall()
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ c = self._db_handle.cursor()
+ if check_for_conflicts:
+ c.execute(
+ "SELECT document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
+ "conflicts ON conflicts.doc_id = document.doc_id WHERE "
+ "document.doc_id = ? GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;", (doc_id,))
+ else:
+ c.execute(
+ "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return None
+ doc_rev, content, conflicts = val
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return False
+ else:
+ return True
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ c = self._db_handle.cursor()
+ c.execute(
+ "SELECT document.doc_id, document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
+ "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;")
+ rows = c.fetchall()
+ for doc_id, doc_rev, content, conflicts in rows:
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ results.append(doc)
+ return (generation, results)
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
+ """Convert a dict representation into named fields.
+
+ So something like: {'key1': 'val1', 'key2': 'val2'}
+ gets converted into: [(doc_id, 'key1', 'val1', 0)
+ (doc_id, 'key2', 'val2', 0)]
+ :param doc_id: Just added to every record.
+ :param base_field: if set, these are nested keys, so each field should
+ be appropriately prefixed.
+ :param raw_doc: The python dictionary.
+ """
+ # TODO: Handle lists
+ values = []
+ for field_name, value in raw_doc.iteritems():
+ if value is None and not save_none:
+ continue
+ if base_field:
+ full_name = base_field + '.' + field_name
+ else:
+ full_name = field_name
+ if value is None or isinstance(value, (int, float, basestring)):
+ values.append((doc_id, full_name, value, len(values)))
+ else:
+ subvalues = self._expand_to_fields(doc_id, full_name, value,
+ save_none)
+ for _, subfield_name, val, _ in subvalues:
+ values.append((doc_id, subfield_name, val, len(values)))
+ return values
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """Actually insert a document into the database.
+
+ This both updates the existing documents content, and any indexes that
+ refer to this document.
+ """
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def whats_changed(self, old_generation=0):
+ c = self._db_handle.cursor()
+ c.execute("SELECT generation, doc_id, transaction_id"
+ " FROM transaction_log"
+ " WHERE generation > ? ORDER BY generation DESC",
+ (old_generation,))
+ results = c.fetchall()
+ cur_gen = old_generation
+ seen = set()
+ changes = []
+ newest_trans_id = ''
+ for generation, doc_id, trans_id in results:
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ if changes:
+ cur_gen = changes[0][1] # max generation
+ newest_trans_id = changes[0][2]
+ changes.reverse()
+ else:
+ c.execute("SELECT generation, transaction_id"
+ " FROM transaction_log ORDER BY generation DESC LIMIT 1")
+ results = c.fetchone()
+ if not results:
+ cur_gen = 0
+ newest_trans_id = ''
+ else:
+ cur_gen, newest_trans_id = results
+
+ return cur_gen, newest_trans_id, changes
+
+ def delete_doc(self, doc):
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _get_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
+ (doc_id,))
+ return [self._factory(doc_id, doc_rev, content)
+ for doc_rev, content in c.fetchall()]
+
+ def get_doc_conflicts(self, doc_id):
+ with self._db_handle:
+ conflict_docs = self._get_conflicts(doc_id)
+ if not conflict_docs:
+ return []
+ this_doc = self._get_doc(doc_id)
+ this_doc.has_conflicts = True
+ return [this_doc] + conflict_docs
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ c = self._db_handle.cursor()
+ c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
+ " WHERE replica_uid = ?",
+ (other_replica_uid,))
+ val = c.fetchone()
+ if val is None:
+ other_gen = 0
+ trans_id = ''
+ else:
+ other_gen = val[0]
+ trans_id = val[1]
+ return other_gen, trans_id
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ with self._db_handle:
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
+ (other_replica_uid, other_generation,
+ other_transaction_id))
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ with self._db_handle:
+ return super(SQLiteDatabase, self)._put_doc_if_newer(doc,
+ save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+
+ def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
+ c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
+ (doc_id, my_doc_rev, my_content))
+
+ def _delete_conflicts(self, c, doc, conflict_revs):
+ deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
+ c.executemany("DELETE FROM conflicts"
+ " WHERE doc_id=? AND doc_rev=?", deleting)
+ doc.has_conflicts = self._has_conflicts(doc.doc_id)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ c_revs_to_prune = []
+ for c_doc in self._get_conflicts(doc.doc_id):
+ c_vcr = vectorclock.VectorClockRev(c_doc.rev)
+ if doc_vcr.is_newer(c_vcr):
+ c_revs_to_prune.append(c_doc.rev)
+ elif doc.same_content_as(c_doc):
+ c_revs_to_prune.append(c_doc.rev)
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ c = self._db_handle.cursor()
+ self._delete_conflicts(c, doc, c_revs_to_prune)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ c = self._db_handle.cursor()
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ with self._db_handle:
+ cur_doc = self._get_doc(doc.doc_id)
+ # TODO: https://bugs.launchpad.net/u1db/+bug/928274
+ # I think we have a logic bug in resolve_doc
+ # Specifically, cur_doc.rev is always in the final vector
+ # clock of revisions that we supersede, even if it wasn't in
+ # conflicted_doc_revs. We still add it as a conflict, but the
+ # fact that _put_doc_if_newer propagates resolutions means I
+ # think that conflict could accidentally be resolved. We need
+ # to add a test for this case first. (create a rev, create a
+ # conflict, create another conflict, resolve the first rev
+ # and first conflict, then make sure that the resolved
+ # rev doesn't supersede the second conflict rev.) It *might*
+ # not matter, because the superseding rev is in as a
+ # conflict, but it does seem incorrect
+ new_rev = self._ensure_maximal_rev(cur_doc.rev,
+ conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ c = self._db_handle.cursor()
+ doc.rev = new_rev
+ if cur_doc.rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
+ # TODO: Is there some way that we could construct a rev that would
+ # end up in superseded_revs, such that we add a conflict, and
+ # then immediately delete it?
+ self._delete_conflicts(c, doc, superseded_revs)
+
+ def list_indexes(self):
+ """Return the list of indexes and their definitions."""
+ c = self._db_handle.cursor()
+ # TODO: How do we test the ordering?
+ c.execute("SELECT name, field FROM index_definitions"
+ " ORDER BY name, offset")
+ definitions = []
+ cur_name = None
+ for name, field in c.fetchall():
+ if cur_name != name:
+ definitions.append((name, []))
+ cur_name = name
+ definitions[-1][-1].append(field)
+ return definitions
+
+ def _get_index_definition(self, index_name):
+ """Return the stored definition for a given index_name."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions"
+ " WHERE name = ? ORDER BY offset", (index_name,))
+ fields = [x[0] for x in c.fetchall()]
+ if not fields:
+ raise errors.IndexDoesNotExist
+ return fields
+
+ @staticmethod
+ def _strip_glob(value):
+ """Remove the trailing * from a value."""
+ assert value[-1] == '*'
+ return value[:-1]
+
+ def _format_query(self, definition, key_values):
+ # First, build the definition. We join the document_fields table
+ # against itself, as many times as the 'width' of our definition.
+ # We then do a query for each key_value, one-at-a-time.
+ # Note: All of these strings are static, we could cache them, etc.
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ wildcard_where = [novalue_where[i]
+ + (" AND d%d.value NOT NULL" % (i,))
+ for i in range(len(definition))]
+ exact_where = [novalue_where[i]
+ + (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ like_where = [novalue_where[i]
+ + (" AND d%d.value GLOB ?" % (i,))
+ for i in range(len(definition))]
+ is_wildcard = False
+ # Merge the lists together, so that:
+ # [field1, field2, field3], [val1, val2, val3]
+ # Becomes:
+ # (field1, val1, field2, val2, field3, val3)
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(exact_where[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_from_index(self, index_name, *key_values):
+ definition = self._get_index_definition(index_name)
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ statement, args = self._format_query(definition, key_values)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def _format_range_query(self, definition, start_value, end_value):
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ wildcard_where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ like_where = [
+ novalue_where[i] + (
+ " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
+ range(len(definition))]
+ range_where_lower = [
+ novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
+ range(len(definition))]
+ range_where_upper = [
+ novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
+ range(len(definition))]
+ args = []
+ where = []
+ if start_value:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if len(start_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, start_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(self._strip_glob(value))
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(value)
+ if end_value:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ if len(end_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, end_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(self._strip_glob(value))
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_upper[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ definition = self._get_index_definition(index_name)
+ statement, args = self._format_range_query(
+ definition, start_value, end_value)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def get_index_keys(self, index_name):
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+ value_fields = ', '.join([
+ 'd%d.value' % i for i in range(len(definition))])
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ statement = (
+ "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
+ value_fields, ', '.join(tables), ' AND '.join(where),
+ value_fields))
+ try:
+ c.execute(statement, tuple(definition))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
+ return c.fetchall()
+
+ def delete_index(self, index_name):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ c.execute("DELETE FROM index_definitions WHERE name = ?",
+ (index_name,))
+ c.execute(
+ "DELETE FROM document_fields WHERE document_fields.field_name "
+ " NOT IN (SELECT field from index_definitions)")
+
+
+class SQLiteSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
+
+
+class SQLitePartialExpandDatabase(SQLiteDatabase):
+ """An SQLite Backend that expands documents into a document_field table.
+
+ It stores the original document text in document.doc. For fields that are
+ indexed, the data goes into document_fields.
+ """
+
+ _index_storage_value = 'expand referenced'
+
+ def _get_indexed_fields(self):
+ """Determine what fields are indexed."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions")
+ return set([x[0] for x in c.fetchall()])
+
+ def _evaluate_index(self, raw_doc, field):
+ parser = query_parser.Parser()
+ getter = parser.parse(field)
+ return getter.get(raw_doc)
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ c = self._db_handle.cursor()
+ if doc and not doc.is_tombstone():
+ raw_doc = json.loads(doc.get_json())
+ else:
+ raw_doc = {}
+ if old_doc is not None:
+ c.execute("UPDATE document SET doc_rev=?, content=?"
+ " WHERE doc_id = ?",
+ (doc.rev, doc.get_json(), doc.doc_id))
+ c.execute("DELETE FROM document_fields WHERE doc_id = ?",
+ (doc.doc_id,))
+ else:
+ c.execute("INSERT INTO document (doc_id, doc_rev, content)"
+ " VALUES (?, ?, ?)",
+ (doc.doc_id, doc.rev, doc.get_json()))
+ indexed_fields = self._get_indexed_fields()
+ if indexed_fields:
+ # It is expected that len(indexed_fields) is shorter than
+ # len(raw_doc)
+ getters = [(field, self._parse_index_definition(field))
+ for field in indexed_fields]
+ self._update_indexes(doc.doc_id, raw_doc, getters, c)
+ trans_id = self._allocate_transaction_id()
+ c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
+ " VALUES (?, ?)", (doc.doc_id, trans_id))
+
+ def create_index(self, index_name, *index_expressions):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ cur_fields = self._get_indexed_fields()
+ definition = [(index_name, idx, field)
+ for idx, field in enumerate(index_expressions)]
+ try:
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ definition)
+ except dbapi2.IntegrityError as e:
+ stored_def = self._get_index_definition(index_name)
+ if stored_def == [x[-1] for x in definition]:
+ return
+ raise errors.IndexNameTakenError, e, sys.exc_info()[2]
+ new_fields = set(
+ [f for f in index_expressions if f not in cur_fields])
+ if new_fields:
+ self._update_all_indexes(new_fields)
+
+ def _iter_all_docs(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, content FROM document")
+ while True:
+ next_rows = c.fetchmany()
+ if not next_rows:
+ break
+ for row in next_rows:
+ yield row
+
+ def _update_all_indexes(self, new_fields):
+ """Iterate all the documents, and add content to document_fields.
+
+ :param new_fields: The index definitions that need to be added.
+ """
+ getters = [(field, self._parse_index_definition(field))
+ for field in new_fields]
+ c = self._db_handle.cursor()
+ for doc_id, doc in self._iter_all_docs():
+ if doc is None:
+ continue
+ raw_doc = json.loads(doc)
+ self._update_indexes(doc_id, raw_doc, getters, c)
+
+SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)