summaryrefslogtreecommitdiff
path: root/src/leap/soledad/u1db/backends
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2012-11-29 10:56:49 -0200
committerdrebs <drebs@leap.se>2012-11-29 10:56:49 -0200
commit17ccbcb831044c29f521b529f5aa96dc2a3cd18f (patch)
treef8d1144f83f3b1d33fe246887b316029b75195ec /src/leap/soledad/u1db/backends
parent0f1f9474e7ea6b52dc3ae18444cfaaca56ff3070 (diff)
add u1db code (not as submodule)
Diffstat (limited to 'src/leap/soledad/u1db/backends')
-rw-r--r--src/leap/soledad/u1db/backends/__init__.py211
-rw-r--r--src/leap/soledad/u1db/backends/dbschema.sql42
-rw-r--r--src/leap/soledad/u1db/backends/inmemory.py469
-rw-r--r--src/leap/soledad/u1db/backends/sqlite_backend.py926
4 files changed, 1648 insertions, 0 deletions
diff --git a/src/leap/soledad/u1db/backends/__init__.py b/src/leap/soledad/u1db/backends/__init__.py
new file mode 100644
index 00000000..c8e5adc6
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/__init__.py
@@ -0,0 +1,211 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Abstract classes and common implementations for the backends."""
+
+import re
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import uuid
+
+import u1db
+from u1db import (
+ errors,
+)
+import u1db.sync
+from u1db.vectorclock import VectorClockRev
+
+
+check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE)
+
+
+class CommonSyncTarget(u1db.sync.LocalSyncTarget):
+ pass
+
+
+class CommonBackend(u1db.Database):
+
+ document_size_limit = 0
+
+ def _allocate_doc_id(self):
+ """Generate a unique identifier for this document."""
+ return 'D-' + uuid.uuid4().hex # 'D-' stands for document
+
+ def _allocate_transaction_id(self):
+ return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction
+
+ def _allocate_doc_rev(self, old_doc_rev):
+ vcr = VectorClockRev(old_doc_rev)
+ vcr.increment(self._replica_uid)
+ return vcr.as_str()
+
+ def _check_doc_id(self, doc_id):
+ if not check_doc_id_re.match(doc_id):
+ raise errors.InvalidDocId()
+
+ def _check_doc_size(self, doc):
+ if not self.document_size_limit:
+ return
+ if doc.get_size() > self.document_size_limit:
+ raise errors.DocumentTooBig
+
+ def _get_generation(self):
+ """Return the current generation.
+
+ """
+ raise NotImplementedError(self._get_generation)
+
+ def _get_generation_info(self):
+ """Return the current generation and transaction id.
+
+ """
+ raise NotImplementedError(self._get_generation_info)
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Extract the document from storage.
+
+ This can return None if the document doesn't exist.
+ """
+ raise NotImplementedError(self._get_doc)
+
+ def _has_conflicts(self, doc_id):
+ """Return True if the doc has conflicts, False otherwise."""
+ raise NotImplementedError(self._has_conflicts)
+
+ def create_doc(self, content, doc_id=None):
+ json_string = json.dumps(content)
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ doc = self._factory(doc_id, None, json_string)
+ self.put_doc(doc)
+ return doc
+
+ def create_doc_from_json(self, json, doc_id=None):
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ doc = self._factory(doc_id, None, json)
+ self.put_doc(doc)
+ return doc
+
+ def _get_transaction_log(self):
+ """This is only for the test suite, it is not part of the api."""
+ raise NotImplementedError(self._get_transaction_log)
+
+ def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content):
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ for doc_id in doc_ids:
+ doc = self._get_doc(
+ doc_id, check_for_conflicts=check_for_conflicts)
+ if doc.is_tombstone() and not include_deleted:
+ continue
+ yield doc
+
+ def _get_trans_id_for_gen(self, generation):
+ """Get the transaction id corresponding to a particular generation.
+
+ Raises an InvalidGeneration when the generation does not exist.
+
+ """
+ raise NotImplementedError(self._get_trans_id_for_gen)
+
+ def validate_gen_and_trans_id(self, generation, trans_id):
+ """Validate the generation and transaction id.
+
+ Raises an InvalidGeneration when the generation does not exist, and an
+ InvalidTransactionId when it does but with a different transaction id.
+
+ """
+ if generation == 0:
+ return
+ known_trans_id = self._get_trans_id_for_gen(generation)
+ if known_trans_id != trans_id:
+ raise errors.InvalidTransactionId
+
+ def _validate_source(self, other_replica_uid, other_generation,
+ other_transaction_id):
+ """Validate the new generation and transaction id.
+
+ other_generation must be greater than what we have stored for this
+ replica, *or* it must be the same and the transaction_id must be the
+ same as well.
+ """
+ (old_generation,
+ old_transaction_id) = self._get_replica_gen_and_trans_id(
+ other_replica_uid)
+ if other_generation < old_generation:
+ raise errors.InvalidGeneration
+ if other_generation > old_generation:
+ return
+ if other_transaction_id == old_transaction_id:
+ return
+ raise errors.InvalidTransactionId
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen,
+ replica_trans_id=''):
+ cur_doc = self._get_doc(doc.doc_id)
+ doc_vcr = VectorClockRev(doc.rev)
+ if cur_doc is None:
+ cur_vcr = VectorClockRev(None)
+ else:
+ cur_vcr = VectorClockRev(cur_doc.rev)
+ self._validate_source(replica_uid, replica_gen, replica_trans_id)
+ if doc_vcr.is_newer(cur_vcr):
+ rev = doc.rev
+ self._prune_conflicts(doc, doc_vcr)
+ if doc.rev != rev:
+ # conflicts have been autoresolved
+ state = 'superseded'
+ else:
+ state = 'inserted'
+ self._put_and_update_indexes(cur_doc, doc)
+ elif doc.rev == cur_doc.rev:
+ # magical convergence
+ state = 'converged'
+ elif cur_vcr.is_newer(doc_vcr):
+ # Don't add this to seen_ids, because we have something newer,
+ # so we should send it back, and we should not generate a
+ # conflict
+ state = 'superseded'
+ elif cur_doc.same_content_as(doc):
+ # the documents have been edited to the same thing at both ends
+ doc_vcr.maximize(cur_vcr)
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._put_and_update_indexes(cur_doc, doc)
+ state = 'superseded'
+ else:
+ state = 'conflicted'
+ if save_conflict:
+ self._force_doc_sync_conflict(doc)
+ if replica_uid is not None and replica_gen is not None:
+ self._do_set_replica_gen_and_trans_id(
+ replica_uid, replica_gen, replica_trans_id)
+ return state, self._get_generation()
+
+ def _ensure_maximal_rev(self, cur_rev, extra_revs):
+ vcr = VectorClockRev(cur_rev)
+ for rev in extra_revs:
+ vcr.maximize(VectorClockRev(rev))
+ vcr.increment(self._replica_uid)
+ return vcr.as_str()
+
+ def set_document_size_limit(self, limit):
+ self.document_size_limit = limit
diff --git a/src/leap/soledad/u1db/backends/dbschema.sql b/src/leap/soledad/u1db/backends/dbschema.sql
new file mode 100644
index 00000000..ae027fc5
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/dbschema.sql
@@ -0,0 +1,42 @@
+-- Database schema
+CREATE TABLE transaction_log (
+ generation INTEGER PRIMARY KEY AUTOINCREMENT,
+ doc_id TEXT NOT NULL,
+ transaction_id TEXT NOT NULL
+);
+CREATE TABLE document (
+ doc_id TEXT PRIMARY KEY,
+ doc_rev TEXT NOT NULL,
+ content TEXT
+);
+CREATE TABLE document_fields (
+ doc_id TEXT NOT NULL,
+ field_name TEXT NOT NULL,
+ value TEXT
+);
+CREATE INDEX document_fields_field_value_doc_idx
+ ON document_fields(field_name, value, doc_id);
+
+CREATE TABLE sync_log (
+ replica_uid TEXT PRIMARY KEY,
+ known_generation INTEGER,
+ known_transaction_id TEXT
+);
+CREATE TABLE conflicts (
+ doc_id TEXT,
+ doc_rev TEXT,
+ content TEXT,
+ CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
+);
+CREATE TABLE index_definitions (
+ name TEXT,
+ offset INT,
+ field TEXT,
+ CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
+);
+create index index_definitions_field on index_definitions(field);
+CREATE TABLE u1db_config (
+ name TEXT PRIMARY KEY,
+ value TEXT
+);
+INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py
new file mode 100644
index 00000000..a271bb37
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/inmemory.py
@@ -0,0 +1,469 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The in-memory Database class for U1DB."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+from u1db.backends import CommonBackend, CommonSyncTarget
+
+
+def get_prefix(value):
+ key_prefix = '\x01'.join(value)
+ return key_prefix.rstrip('*')
+
+
+class InMemoryDatabase(CommonBackend):
+ """A database that only stores the data internally."""
+
+ def __init__(self, replica_uid, document_factory=None):
+ self._transaction_log = []
+ self._docs = {}
+ # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
+ self._conflicts = {}
+ self._other_generations = {}
+ self._indexes = {}
+ self._replica_uid = replica_uid
+ self._factory = document_factory or Document
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ self._replica_uid = replica_uid
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def close(self):
+ # This is a no-op, We don't want to free the data because one client
+ # may be closing it, while another wants to inspect the results.
+ pass
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ return self._other_generations.get(other_replica_uid, (0, ''))
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ # TODO: to handle race conditions, we may want to check if the current
+ # value is greater than this new value.
+ self._other_generations[other_replica_uid] = (other_generation,
+ other_transaction_id)
+
+ def get_sync_target(self):
+ return InMemorySyncTarget(self)
+
+ def _get_transaction_log(self):
+ # snapshot!
+ return self._transaction_log[:]
+
+ def _get_generation(self):
+ return len(self._transaction_log)
+
+ def _get_generation_info(self):
+ if not self._transaction_log:
+ return 0, ''
+ return len(self._transaction_log), self._transaction_log[-1][1]
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ if generation > len(self._transaction_log):
+ raise errors.InvalidGeneration
+ return self._transaction_log[generation - 1][1]
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ for index in self._indexes.itervalues():
+ if old_doc is not None and not old_doc.is_tombstone():
+ index.remove_json(old_doc.doc_id, old_doc.get_json())
+ if not doc.is_tombstone():
+ index.add_json(doc.doc_id, doc.get_json())
+ trans_id = self._allocate_transaction_id()
+ self._docs[doc.doc_id] = (doc.rev, doc.get_json())
+ self._transaction_log.append((doc.doc_id, trans_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ try:
+ doc_rev, content = self._docs[doc_id]
+ except KeyError:
+ return None
+ doc = self._factory(doc_id, doc_rev, content)
+ if check_for_conflicts:
+ doc.has_conflicts = (doc.doc_id in self._conflicts)
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ return doc_id in self._conflicts
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Return all documents in the database."""
+ generation = self._get_generation()
+ results = []
+ for doc_id, (doc_rev, content) in self._docs.items():
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = self._has_conflicts(doc_id)
+ results.append(doc)
+ return (generation, results)
+
+ def get_doc_conflicts(self, doc_id):
+ if doc_id not in self._conflicts:
+ return []
+ result = [self._get_doc(doc_id)]
+ result[0].has_conflicts = True
+ result.extend([self._factory(doc_id, rev, content)
+ for rev, content in self._conflicts[doc_id]])
+ return result
+
+ def _replace_conflicts(self, doc, conflicts):
+ if not conflicts:
+ del self._conflicts[doc.doc_id]
+ else:
+ self._conflicts[doc.doc_id] = conflicts
+ doc.has_conflicts = bool(conflicts)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ c_vcr = vectorclock.VectorClockRev(c_rev)
+ if doc_vcr.is_newer(c_vcr):
+ continue
+ if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ cur_doc = self._get_doc(doc.doc_id)
+ if cur_doc is None:
+ cur_rev = None
+ else:
+ cur_rev = cur_doc.rev
+ new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ if c_rev in superseded_revs:
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ doc.rev = new_rev
+ if cur_rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ remaining_conflicts.append((new_rev, doc.get_json()))
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def delete_doc(self, doc):
+ if doc.doc_id not in self._docs:
+ raise errors.DocumentDoesNotExist
+ if self._docs[doc.doc_id][1] in ('null', None):
+ raise errors.DocumentAlreadyDeleted
+ doc.make_tombstone()
+ self.put_doc(doc)
+
+ def create_index(self, index_name, *index_expressions):
+ if index_name in self._indexes:
+ if self._indexes[index_name]._definition == list(
+ index_expressions):
+ return
+ raise errors.IndexNameTakenError
+ index = InMemoryIndex(index_name, list(index_expressions))
+ for doc_id, (doc_rev, doc) in self._docs.iteritems():
+ if doc is not None:
+ index.add_json(doc_id, doc)
+ self._indexes[index_name] = index
+
+ def delete_index(self, index_name):
+ del self._indexes[index_name]
+
+ def list_indexes(self):
+ definitions = []
+ for idx in self._indexes.itervalues():
+ definitions.append((idx._name, idx._definition))
+ return definitions
+
+ def get_from_index(self, index_name, *key_values):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ doc_ids = index.lookup(key_values)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ doc_ids = index.lookup_range(start_value, end_value)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_index_keys(self, index_name):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ keys = index.keys()
+ # XXX inefficiency warning
+ return list(set([tuple(key.split('\x01')) for key in keys]))
+
+ def whats_changed(self, old_generation=0):
+ changes = []
+ relevant_tail = self._transaction_log[old_generation:]
+ # We don't use len(self._transaction_log) because _transaction_log may
+ # get mutated by a concurrent operation.
+ cur_generation = old_generation + len(relevant_tail)
+ last_trans_id = ''
+ if relevant_tail:
+ last_trans_id = relevant_tail[-1][1]
+ elif self._transaction_log:
+ last_trans_id = self._transaction_log[-1][1]
+ seen = set()
+ generation = cur_generation
+ for doc_id, trans_id in reversed(relevant_tail):
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ generation -= 1
+ changes.reverse()
+ return (cur_generation, last_trans_id, changes)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._conflicts.setdefault(doc.doc_id, []).append(
+ (my_doc.rev, my_doc.get_json()))
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+
+class InMemoryIndex(object):
+ """Interface for managing an Index."""
+
+ def __init__(self, index_name, index_definition):
+ self._name = index_name
+ self._definition = index_definition
+ self._values = {}
+ parser = query_parser.Parser()
+ self._getters = parser.parse_all(self._definition)
+
+ def evaluate_json(self, doc):
+ """Determine the 'key' after applying this index to the doc."""
+ raw = json.loads(doc)
+ return self.evaluate(raw)
+
+ def evaluate(self, obj):
+ """Evaluate a dict object, applying this definition."""
+ all_rows = [[]]
+ for getter in self._getters:
+ new_rows = []
+ keys = getter.get(obj)
+ if not keys:
+ return []
+ for key in keys:
+ new_rows.extend([row + [key] for row in all_rows])
+ all_rows = new_rows
+ all_rows = ['\x01'.join(row) for row in all_rows]
+ return all_rows
+
+ def add_json(self, doc_id, doc):
+ """Add this json doc to the index."""
+ keys = self.evaluate_json(doc)
+ if not keys:
+ return
+ for key in keys:
+ self._values.setdefault(key, []).append(doc_id)
+
+ def remove_json(self, doc_id, doc):
+ """Remove this json doc from the index."""
+ keys = self.evaluate_json(doc)
+ if keys:
+ for key in keys:
+ doc_ids = self._values[key]
+ doc_ids.remove(doc_id)
+ if not doc_ids:
+ del self._values[key]
+
+ def _find_non_wildcards(self, values):
+ """Check if this should be a wildcard match.
+
+ Further, this will raise an exception if the syntax is improperly
+ defined.
+
+ :return: The offset of the last value we need to match against.
+ """
+ if len(values) != len(self._definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ last = 0
+ for idx, val in enumerate(values):
+ if val.endswith('*'):
+ if val != '*':
+ # We have an 'x*' style wildcard
+ if is_wildcard:
+ # We were already in wildcard mode, so this is invalid
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ # We were in wildcard mode, we can't follow that with
+ # non-wildcard
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ if not is_wildcard:
+ return -1
+ return last
+
+ def lookup(self, values):
+ """Find docs that match the values."""
+ last = self._find_non_wildcards(values)
+ if last == -1:
+ return self._lookup_exact(values)
+ else:
+ return self._lookup_prefix(values[:last])
+
+ def lookup_range(self, start_values, end_values):
+ """Find docs within the range."""
+ # TODO: Wildly inefficient, which is unlikely to be a problem for the
+ # inmemory implementation.
+ if start_values:
+ self._find_non_wildcards(start_values)
+ start_values = get_prefix(start_values)
+ if end_values:
+ if self._find_non_wildcards(end_values) == -1:
+ exact = True
+ else:
+ exact = False
+ end_values = get_prefix(end_values)
+ found = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if start_values and start_values > key:
+ continue
+ if end_values and end_values < key:
+ if exact:
+ break
+ else:
+ if not key.startswith(end_values):
+ break
+ found.extend(doc_ids)
+ return found
+
+ def keys(self):
+ """Find the indexed keys."""
+ return self._values.keys()
+
+ def _lookup_prefix(self, value):
+ """Find docs that match the prefix string in values."""
+ # TODO: We need a different data structure to make prefix style fast,
+ # some sort of sorted list would work, but a plain dict doesn't.
+ key_prefix = get_prefix(value)
+ all_doc_ids = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if key.startswith(key_prefix):
+ all_doc_ids.extend(doc_ids)
+ return all_doc_ids
+
+ def _lookup_exact(self, value):
+ """Find docs that match exactly."""
+ key = '\x01'.join(value)
+ if key in self._values:
+ return self._values[key]
+ return ()
+
+
+class InMemorySyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_transaction_id)
diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py
new file mode 100644
index 00000000..773213b5
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/sqlite_backend.py
@@ -0,0 +1,926 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A U1DB implementation that uses SQLite as its persistence layer."""
+
+import errno
+import os
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from sqlite3 import dbapi2
+import sys
+import time
+import uuid
+
+import pkg_resources
+
+from u1db.backends import CommonBackend, CommonSyncTarget
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+
+
+class SQLiteDatabase(CommonBackend):
+ """A U1DB implementation that uses SQLite as its persistence layer."""
+
+ _sqlite_registry = {}
+
+ def __init__(self, sqlite_file, document_factory=None):
+ """Create a new sqlite file."""
+ self._db_handle = dbapi2.connect(sqlite_file)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def get_sync_target(self):
+ return SQLiteSyncTarget(self)
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'index_storage'")
+ except dbapi2.OperationalError, e:
+ # The table does not exist yet
+ return None, e
+ else:
+ return c.fetchone()[0], None
+
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
+
+ @classmethod
+ def _open_database(cls, sqlite_file, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLiteDatabase._sqlite_registry[v](
+ sqlite_file, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(
+ sqlite_file, document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLitePartialExpandDatabase
+ backend_cls = SQLitePartialExpandDatabase
+ return backend_cls(sqlite_file, document_factory=document_factory)
+
+ @staticmethod
+ def delete_database(sqlite_file):
+ try:
+ os.unlink(sqlite_file)
+ except OSError as ex:
+ if ex.errno == errno.ENOENT:
+ raise errors.DatabaseDoesNotExist()
+ raise
+
+ @staticmethod
+ def register_implementation(klass):
+ """Register that we implement an SQLiteDatabase.
+
+ The attribute _index_storage_value will be used as the lookup key.
+ """
+ SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
+
+ def _get_sqlite_handle(self):
+ """Get access to the underlying sqlite database.
+
+ This should only be used by the test suite, etc, for examining the
+ state of the underlying database.
+ """
+ return self._db_handle
+
+ def _close_sqlite_handle(self):
+ """Release access to the underlying sqlite database."""
+ self._db_handle.close()
+
+ def close(self):
+ self._close_sqlite_handle()
+
+ def _is_initialized(self, c):
+ """Check if this database has been initialized."""
+ c.execute("PRAGMA case_sensitive_like=ON")
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'sql_schema'")
+ except dbapi2.OperationalError:
+ # The table does not exist yet
+ val = None
+ else:
+ val = c.fetchone()
+ if val is not None:
+ return True
+ return False
+
+ def _initialize(self, c):
+ """Create the schema in the database."""
+ #read the script with sql commands
+ # TODO: Change how we set up the dependency. Most likely use something
+ # like lp:dirspec to grab the file from a common resource
+ # directory. Doesn't specifically need to be handled until we get
+ # to the point of packaging this.
+ schema_content = pkg_resources.resource_string(
+ __name__, 'dbschema.sql')
+ # Note: We'd like to use c.executescript() here, but it seems that
+ # executescript always commits, even if you set
+ # isolation_level = None, so if we want to properly handle
+ # exclusive locking and rollbacks between processes, we need
+ # to execute it line-by-line
+ for line in schema_content.split(';'):
+ if not line:
+ continue
+ c.execute(line)
+ #add extra fields
+ self._extra_schema_init(c)
+ # A unique identifier should be set for this replica. Implementations
+ # don't have to strictly use uuid here, but we do want the uid to be
+ # unique amongst all databases that will sync with each other.
+ # We might extend this to using something with hostname for easier
+ # debugging.
+ self._set_replica_uid_in_transaction(uuid.uuid4().hex)
+ c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
+ (self._index_storage_value,))
+
+ def _ensure_schema(self):
+ """Ensure that the database schema has been created."""
+ old_isolation_level = self._db_handle.isolation_level
+ c = self._db_handle.cursor()
+ if self._is_initialized(c):
+ return
+ try:
+ # autocommit/own mgmt of transactions
+ self._db_handle.isolation_level = None
+ with self._db_handle:
+ # only one execution path should initialize the db
+ c.execute("begin exclusive")
+ if self._is_initialized(c):
+ return
+ self._initialize(c)
+ finally:
+ self._db_handle.isolation_level = old_isolation_level
+
+ def _extra_schema_init(self, c):
+ """Add any extra fields, etc to the basic table definitions."""
+
+ def _parse_index_definition(self, index_field):
+ """Parse a field definition for an index, returning a Getter."""
+ # Note: We may want to keep a Parser object around, and cache the
+ # Getter objects for a greater length of time. Specifically, if
+ # you create a bunch of indexes, and then insert 50k docs, you'll
+ # re-parse the indexes between puts. The time to insert the docs
+ # is still likely to dominate put_doc time, though.
+ parser = query_parser.Parser()
+ getter = parser.parse(index_field)
+ return getter
+
+ def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
+ """Update document_fields for a single document.
+
+ :param doc_id: Identifier for this document
+ :param raw_doc: The python dict representation of the document.
+ :param getters: A list of [(field_name, Getter)]. Getter.get will be
+ called to evaluate the index definition for this document, and the
+ results will be inserted into the db.
+ :param db_cursor: An sqlite Cursor.
+ :return: None
+ """
+ values = []
+ for field_name, getter in getters:
+ for idx_value in getter.get(raw_doc):
+ values.append((doc_id, field_name, idx_value))
+ if values:
+ db_cursor.executemany(
+ "INSERT INTO document_fields VALUES (?, ?, ?)", values)
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ with self._db_handle:
+ self._set_replica_uid_in_transaction(replica_uid)
+
+ def _set_replica_uid_in_transaction(self, replica_uid):
+ """Set the replica_uid. A transaction should already be held."""
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO u1db_config"
+ " VALUES ('replica_uid', ?)",
+ (replica_uid,))
+ self._real_replica_uid = replica_uid
+
+ def _get_replica_uid(self):
+ if self._real_replica_uid is not None:
+ return self._real_replica_uid
+ c = self._db_handle.cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
+ val = c.fetchone()
+ if val is None:
+ return None
+ self._real_replica_uid = val[0]
+ return self._real_replica_uid
+
+ _replica_uid = property(_get_replica_uid)
+
+ def _get_generation(self):
+ c = self._db_handle.cursor()
+ c.execute('SELECT max(generation) FROM transaction_log')
+ val = c.fetchone()[0]
+ if val is None:
+ return 0
+ return val
+
+ def _get_generation_info(self):
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT max(generation), transaction_id FROM transaction_log ')
+ val = c.fetchone()
+ if val[0] is None:
+ return(0, '')
+ return val
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
+ (generation,))
+ val = c.fetchone()
+ if val is None:
+ raise errors.InvalidGeneration
+ return val[0]
+
+ def _get_transaction_log(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, transaction_id FROM transaction_log"
+ " ORDER BY generation")
+ return c.fetchall()
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ c = self._db_handle.cursor()
+ if check_for_conflicts:
+ c.execute(
+ "SELECT document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
+ "conflicts ON conflicts.doc_id = document.doc_id WHERE "
+ "document.doc_id = ? GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;", (doc_id,))
+ else:
+ c.execute(
+ "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return None
+ doc_rev, content, conflicts = val
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return False
+ else:
+ return True
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ c = self._db_handle.cursor()
+ c.execute(
+ "SELECT document.doc_id, document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
+ "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;")
+ rows = c.fetchall()
+ for doc_id, doc_rev, content, conflicts in rows:
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ results.append(doc)
+ return (generation, results)
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
+ """Convert a dict representation into named fields.
+
+ So something like: {'key1': 'val1', 'key2': 'val2'}
+ gets converted into: [(doc_id, 'key1', 'val1', 0)
+ (doc_id, 'key2', 'val2', 0)]
+ :param doc_id: Just added to every record.
+ :param base_field: if set, these are nested keys, so each field should
+ be appropriately prefixed.
+ :param raw_doc: The python dictionary.
+ """
+ # TODO: Handle lists
+ values = []
+ for field_name, value in raw_doc.iteritems():
+ if value is None and not save_none:
+ continue
+ if base_field:
+ full_name = base_field + '.' + field_name
+ else:
+ full_name = field_name
+ if value is None or isinstance(value, (int, float, basestring)):
+ values.append((doc_id, full_name, value, len(values)))
+ else:
+ subvalues = self._expand_to_fields(doc_id, full_name, value,
+ save_none)
+ for _, subfield_name, val, _ in subvalues:
+ values.append((doc_id, subfield_name, val, len(values)))
+ return values
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """Actually insert a document into the database.
+
+ This both updates the existing documents content, and any indexes that
+ refer to this document.
+ """
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def whats_changed(self, old_generation=0):
+ c = self._db_handle.cursor()
+ c.execute("SELECT generation, doc_id, transaction_id"
+ " FROM transaction_log"
+ " WHERE generation > ? ORDER BY generation DESC",
+ (old_generation,))
+ results = c.fetchall()
+ cur_gen = old_generation
+ seen = set()
+ changes = []
+ newest_trans_id = ''
+ for generation, doc_id, trans_id in results:
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ if changes:
+ cur_gen = changes[0][1] # max generation
+ newest_trans_id = changes[0][2]
+ changes.reverse()
+ else:
+ c.execute("SELECT generation, transaction_id"
+ " FROM transaction_log ORDER BY generation DESC LIMIT 1")
+ results = c.fetchone()
+ if not results:
+ cur_gen = 0
+ newest_trans_id = ''
+ else:
+ cur_gen, newest_trans_id = results
+
+ return cur_gen, newest_trans_id, changes
+
+ def delete_doc(self, doc):
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _get_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
+ (doc_id,))
+ return [self._factory(doc_id, doc_rev, content)
+ for doc_rev, content in c.fetchall()]
+
+ def get_doc_conflicts(self, doc_id):
+ with self._db_handle:
+ conflict_docs = self._get_conflicts(doc_id)
+ if not conflict_docs:
+ return []
+ this_doc = self._get_doc(doc_id)
+ this_doc.has_conflicts = True
+ return [this_doc] + conflict_docs
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ c = self._db_handle.cursor()
+ c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
+ " WHERE replica_uid = ?",
+ (other_replica_uid,))
+ val = c.fetchone()
+ if val is None:
+ other_gen = 0
+ trans_id = ''
+ else:
+ other_gen = val[0]
+ trans_id = val[1]
+ return other_gen, trans_id
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ with self._db_handle:
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
+ (other_replica_uid, other_generation,
+ other_transaction_id))
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ with self._db_handle:
+ return super(SQLiteDatabase, self)._put_doc_if_newer(doc,
+ save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+
+ def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
+ c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
+ (doc_id, my_doc_rev, my_content))
+
+ def _delete_conflicts(self, c, doc, conflict_revs):
+ deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
+ c.executemany("DELETE FROM conflicts"
+ " WHERE doc_id=? AND doc_rev=?", deleting)
+ doc.has_conflicts = self._has_conflicts(doc.doc_id)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ c_revs_to_prune = []
+ for c_doc in self._get_conflicts(doc.doc_id):
+ c_vcr = vectorclock.VectorClockRev(c_doc.rev)
+ if doc_vcr.is_newer(c_vcr):
+ c_revs_to_prune.append(c_doc.rev)
+ elif doc.same_content_as(c_doc):
+ c_revs_to_prune.append(c_doc.rev)
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ c = self._db_handle.cursor()
+ self._delete_conflicts(c, doc, c_revs_to_prune)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ c = self._db_handle.cursor()
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ with self._db_handle:
+ cur_doc = self._get_doc(doc.doc_id)
+ # TODO: https://bugs.launchpad.net/u1db/+bug/928274
+ # I think we have a logic bug in resolve_doc
+ # Specifically, cur_doc.rev is always in the final vector
+ # clock of revisions that we supersede, even if it wasn't in
+ # conflicted_doc_revs. We still add it as a conflict, but the
+ # fact that _put_doc_if_newer propagates resolutions means I
+ # think that conflict could accidentally be resolved. We need
+ # to add a test for this case first. (create a rev, create a
+ # conflict, create another conflict, resolve the first rev
+ # and first conflict, then make sure that the resolved
+ # rev doesn't supersede the second conflict rev.) It *might*
+ # not matter, because the superseding rev is in as a
+ # conflict, but it does seem incorrect
+ new_rev = self._ensure_maximal_rev(cur_doc.rev,
+ conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ c = self._db_handle.cursor()
+ doc.rev = new_rev
+ if cur_doc.rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
+ # TODO: Is there some way that we could construct a rev that would
+ # end up in superseded_revs, such that we add a conflict, and
+ # then immediately delete it?
+ self._delete_conflicts(c, doc, superseded_revs)
+
+ def list_indexes(self):
+ """Return the list of indexes and their definitions."""
+ c = self._db_handle.cursor()
+ # TODO: How do we test the ordering?
+ c.execute("SELECT name, field FROM index_definitions"
+ " ORDER BY name, offset")
+ definitions = []
+ cur_name = None
+ for name, field in c.fetchall():
+ if cur_name != name:
+ definitions.append((name, []))
+ cur_name = name
+ definitions[-1][-1].append(field)
+ return definitions
+
+ def _get_index_definition(self, index_name):
+ """Return the stored definition for a given index_name."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions"
+ " WHERE name = ? ORDER BY offset", (index_name,))
+ fields = [x[0] for x in c.fetchall()]
+ if not fields:
+ raise errors.IndexDoesNotExist
+ return fields
+
+ @staticmethod
+ def _strip_glob(value):
+ """Remove the trailing * from a value."""
+ assert value[-1] == '*'
+ return value[:-1]
+
+ def _format_query(self, definition, key_values):
+ # First, build the definition. We join the document_fields table
+ # against itself, as many times as the 'width' of our definition.
+ # We then do a query for each key_value, one-at-a-time.
+ # Note: All of these strings are static, we could cache them, etc.
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ wildcard_where = [novalue_where[i]
+ + (" AND d%d.value NOT NULL" % (i,))
+ for i in range(len(definition))]
+ exact_where = [novalue_where[i]
+ + (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ like_where = [novalue_where[i]
+ + (" AND d%d.value GLOB ?" % (i,))
+ for i in range(len(definition))]
+ is_wildcard = False
+ # Merge the lists together, so that:
+ # [field1, field2, field3], [val1, val2, val3]
+ # Becomes:
+ # (field1, val1, field2, val2, field3, val3)
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(exact_where[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_from_index(self, index_name, *key_values):
+ definition = self._get_index_definition(index_name)
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ statement, args = self._format_query(definition, key_values)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def _format_range_query(self, definition, start_value, end_value):
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ wildcard_where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ like_where = [
+ novalue_where[i] + (
+ " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
+ range(len(definition))]
+ range_where_lower = [
+ novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
+ range(len(definition))]
+ range_where_upper = [
+ novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
+ range(len(definition))]
+ args = []
+ where = []
+ if start_value:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if len(start_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, start_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(self._strip_glob(value))
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(value)
+ if end_value:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ if len(end_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, end_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(self._strip_glob(value))
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_upper[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ definition = self._get_index_definition(index_name)
+ statement, args = self._format_range_query(
+ definition, start_value, end_value)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def get_index_keys(self, index_name):
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+ value_fields = ', '.join([
+ 'd%d.value' % i for i in range(len(definition))])
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ statement = (
+ "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
+ value_fields, ', '.join(tables), ' AND '.join(where),
+ value_fields))
+ try:
+ c.execute(statement, tuple(definition))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
+ return c.fetchall()
+
+ def delete_index(self, index_name):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ c.execute("DELETE FROM index_definitions WHERE name = ?",
+ (index_name,))
+ c.execute(
+ "DELETE FROM document_fields WHERE document_fields.field_name "
+ " NOT IN (SELECT field from index_definitions)")
+
+
+class SQLiteSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
+
+
+class SQLitePartialExpandDatabase(SQLiteDatabase):
+ """An SQLite Backend that expands documents into a document_field table.
+
+ It stores the original document text in document.doc. For fields that are
+ indexed, the data goes into document_fields.
+ """
+
+ _index_storage_value = 'expand referenced'
+
+ def _get_indexed_fields(self):
+ """Determine what fields are indexed."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions")
+ return set([x[0] for x in c.fetchall()])
+
+ def _evaluate_index(self, raw_doc, field):
+ parser = query_parser.Parser()
+ getter = parser.parse(field)
+ return getter.get(raw_doc)
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ c = self._db_handle.cursor()
+ if doc and not doc.is_tombstone():
+ raw_doc = json.loads(doc.get_json())
+ else:
+ raw_doc = {}
+ if old_doc is not None:
+ c.execute("UPDATE document SET doc_rev=?, content=?"
+ " WHERE doc_id = ?",
+ (doc.rev, doc.get_json(), doc.doc_id))
+ c.execute("DELETE FROM document_fields WHERE doc_id = ?",
+ (doc.doc_id,))
+ else:
+ c.execute("INSERT INTO document (doc_id, doc_rev, content)"
+ " VALUES (?, ?, ?)",
+ (doc.doc_id, doc.rev, doc.get_json()))
+ indexed_fields = self._get_indexed_fields()
+ if indexed_fields:
+ # It is expected that len(indexed_fields) is shorter than
+ # len(raw_doc)
+ getters = [(field, self._parse_index_definition(field))
+ for field in indexed_fields]
+ self._update_indexes(doc.doc_id, raw_doc, getters, c)
+ trans_id = self._allocate_transaction_id()
+ c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
+ " VALUES (?, ?)", (doc.doc_id, trans_id))
+
+ def create_index(self, index_name, *index_expressions):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ cur_fields = self._get_indexed_fields()
+ definition = [(index_name, idx, field)
+ for idx, field in enumerate(index_expressions)]
+ try:
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ definition)
+ except dbapi2.IntegrityError as e:
+ stored_def = self._get_index_definition(index_name)
+ if stored_def == [x[-1] for x in definition]:
+ return
+ raise errors.IndexNameTakenError, e, sys.exc_info()[2]
+ new_fields = set(
+ [f for f in index_expressions if f not in cur_fields])
+ if new_fields:
+ self._update_all_indexes(new_fields)
+
+ def _iter_all_docs(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, content FROM document")
+ while True:
+ next_rows = c.fetchmany()
+ if not next_rows:
+ break
+ for row in next_rows:
+ yield row
+
+ def _update_all_indexes(self, new_fields):
+ """Iterate all the documents, and add content to document_fields.
+
+ :param new_fields: The index definitions that need to be added.
+ """
+ getters = [(field, self._parse_index_definition(field))
+ for field in new_fields]
+ c = self._db_handle.cursor()
+ for doc_id, doc in self._iter_all_docs():
+ if doc is None:
+ continue
+ raw_doc = json.loads(doc)
+ self._update_indexes(doc_id, raw_doc, getters, c)
+
+SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)