summaryrefslogtreecommitdiff
path: root/src/leap/soledad/u1db/backends/inmemory.py
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2012-11-29 10:56:49 -0200
committerdrebs <drebs@leap.se>2012-11-29 10:56:49 -0200
commit17ccbcb831044c29f521b529f5aa96dc2a3cd18f (patch)
treef8d1144f83f3b1d33fe246887b316029b75195ec /src/leap/soledad/u1db/backends/inmemory.py
parent0f1f9474e7ea6b52dc3ae18444cfaaca56ff3070 (diff)
add u1db code (not as submodule)
Diffstat (limited to 'src/leap/soledad/u1db/backends/inmemory.py')
-rw-r--r--src/leap/soledad/u1db/backends/inmemory.py469
1 files changed, 469 insertions, 0 deletions
diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py
new file mode 100644
index 00000000..a271bb37
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/inmemory.py
@@ -0,0 +1,469 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The in-memory Database class for U1DB."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+from u1db.backends import CommonBackend, CommonSyncTarget
+
+
+def get_prefix(value):
+ key_prefix = '\x01'.join(value)
+ return key_prefix.rstrip('*')
+
+
+class InMemoryDatabase(CommonBackend):
+ """A database that only stores the data internally."""
+
+ def __init__(self, replica_uid, document_factory=None):
+ self._transaction_log = []
+ self._docs = {}
+ # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
+ self._conflicts = {}
+ self._other_generations = {}
+ self._indexes = {}
+ self._replica_uid = replica_uid
+ self._factory = document_factory or Document
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ self._replica_uid = replica_uid
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def close(self):
+ # This is a no-op, We don't want to free the data because one client
+ # may be closing it, while another wants to inspect the results.
+ pass
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ return self._other_generations.get(other_replica_uid, (0, ''))
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ # TODO: to handle race conditions, we may want to check if the current
+ # value is greater than this new value.
+ self._other_generations[other_replica_uid] = (other_generation,
+ other_transaction_id)
+
+ def get_sync_target(self):
+ return InMemorySyncTarget(self)
+
+ def _get_transaction_log(self):
+ # snapshot!
+ return self._transaction_log[:]
+
+ def _get_generation(self):
+ return len(self._transaction_log)
+
+ def _get_generation_info(self):
+ if not self._transaction_log:
+ return 0, ''
+ return len(self._transaction_log), self._transaction_log[-1][1]
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ if generation > len(self._transaction_log):
+ raise errors.InvalidGeneration
+ return self._transaction_log[generation - 1][1]
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ for index in self._indexes.itervalues():
+ if old_doc is not None and not old_doc.is_tombstone():
+ index.remove_json(old_doc.doc_id, old_doc.get_json())
+ if not doc.is_tombstone():
+ index.add_json(doc.doc_id, doc.get_json())
+ trans_id = self._allocate_transaction_id()
+ self._docs[doc.doc_id] = (doc.rev, doc.get_json())
+ self._transaction_log.append((doc.doc_id, trans_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ try:
+ doc_rev, content = self._docs[doc_id]
+ except KeyError:
+ return None
+ doc = self._factory(doc_id, doc_rev, content)
+ if check_for_conflicts:
+ doc.has_conflicts = (doc.doc_id in self._conflicts)
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ return doc_id in self._conflicts
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Return all documents in the database."""
+ generation = self._get_generation()
+ results = []
+ for doc_id, (doc_rev, content) in self._docs.items():
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = self._has_conflicts(doc_id)
+ results.append(doc)
+ return (generation, results)
+
+ def get_doc_conflicts(self, doc_id):
+ if doc_id not in self._conflicts:
+ return []
+ result = [self._get_doc(doc_id)]
+ result[0].has_conflicts = True
+ result.extend([self._factory(doc_id, rev, content)
+ for rev, content in self._conflicts[doc_id]])
+ return result
+
+ def _replace_conflicts(self, doc, conflicts):
+ if not conflicts:
+ del self._conflicts[doc.doc_id]
+ else:
+ self._conflicts[doc.doc_id] = conflicts
+ doc.has_conflicts = bool(conflicts)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ c_vcr = vectorclock.VectorClockRev(c_rev)
+ if doc_vcr.is_newer(c_vcr):
+ continue
+ if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ cur_doc = self._get_doc(doc.doc_id)
+ if cur_doc is None:
+ cur_rev = None
+ else:
+ cur_rev = cur_doc.rev
+ new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ if c_rev in superseded_revs:
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ doc.rev = new_rev
+ if cur_rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ remaining_conflicts.append((new_rev, doc.get_json()))
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def delete_doc(self, doc):
+ if doc.doc_id not in self._docs:
+ raise errors.DocumentDoesNotExist
+ if self._docs[doc.doc_id][1] in ('null', None):
+ raise errors.DocumentAlreadyDeleted
+ doc.make_tombstone()
+ self.put_doc(doc)
+
+ def create_index(self, index_name, *index_expressions):
+ if index_name in self._indexes:
+ if self._indexes[index_name]._definition == list(
+ index_expressions):
+ return
+ raise errors.IndexNameTakenError
+ index = InMemoryIndex(index_name, list(index_expressions))
+ for doc_id, (doc_rev, doc) in self._docs.iteritems():
+ if doc is not None:
+ index.add_json(doc_id, doc)
+ self._indexes[index_name] = index
+
+ def delete_index(self, index_name):
+ del self._indexes[index_name]
+
+ def list_indexes(self):
+ definitions = []
+ for idx in self._indexes.itervalues():
+ definitions.append((idx._name, idx._definition))
+ return definitions
+
+ def get_from_index(self, index_name, *key_values):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ doc_ids = index.lookup(key_values)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ doc_ids = index.lookup_range(start_value, end_value)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_index_keys(self, index_name):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ keys = index.keys()
+ # XXX inefficiency warning
+ return list(set([tuple(key.split('\x01')) for key in keys]))
+
+ def whats_changed(self, old_generation=0):
+ changes = []
+ relevant_tail = self._transaction_log[old_generation:]
+ # We don't use len(self._transaction_log) because _transaction_log may
+ # get mutated by a concurrent operation.
+ cur_generation = old_generation + len(relevant_tail)
+ last_trans_id = ''
+ if relevant_tail:
+ last_trans_id = relevant_tail[-1][1]
+ elif self._transaction_log:
+ last_trans_id = self._transaction_log[-1][1]
+ seen = set()
+ generation = cur_generation
+ for doc_id, trans_id in reversed(relevant_tail):
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ generation -= 1
+ changes.reverse()
+ return (cur_generation, last_trans_id, changes)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._conflicts.setdefault(doc.doc_id, []).append(
+ (my_doc.rev, my_doc.get_json()))
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+
+class InMemoryIndex(object):
+ """Interface for managing an Index."""
+
+ def __init__(self, index_name, index_definition):
+ self._name = index_name
+ self._definition = index_definition
+ self._values = {}
+ parser = query_parser.Parser()
+ self._getters = parser.parse_all(self._definition)
+
+ def evaluate_json(self, doc):
+ """Determine the 'key' after applying this index to the doc."""
+ raw = json.loads(doc)
+ return self.evaluate(raw)
+
+ def evaluate(self, obj):
+ """Evaluate a dict object, applying this definition."""
+ all_rows = [[]]
+ for getter in self._getters:
+ new_rows = []
+ keys = getter.get(obj)
+ if not keys:
+ return []
+ for key in keys:
+ new_rows.extend([row + [key] for row in all_rows])
+ all_rows = new_rows
+ all_rows = ['\x01'.join(row) for row in all_rows]
+ return all_rows
+
+ def add_json(self, doc_id, doc):
+ """Add this json doc to the index."""
+ keys = self.evaluate_json(doc)
+ if not keys:
+ return
+ for key in keys:
+ self._values.setdefault(key, []).append(doc_id)
+
+ def remove_json(self, doc_id, doc):
+ """Remove this json doc from the index."""
+ keys = self.evaluate_json(doc)
+ if keys:
+ for key in keys:
+ doc_ids = self._values[key]
+ doc_ids.remove(doc_id)
+ if not doc_ids:
+ del self._values[key]
+
+ def _find_non_wildcards(self, values):
+ """Check if this should be a wildcard match.
+
+ Further, this will raise an exception if the syntax is improperly
+ defined.
+
+ :return: The offset of the last value we need to match against.
+ """
+ if len(values) != len(self._definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ last = 0
+ for idx, val in enumerate(values):
+ if val.endswith('*'):
+ if val != '*':
+ # We have an 'x*' style wildcard
+ if is_wildcard:
+ # We were already in wildcard mode, so this is invalid
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ # We were in wildcard mode, we can't follow that with
+ # non-wildcard
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ if not is_wildcard:
+ return -1
+ return last
+
+ def lookup(self, values):
+ """Find docs that match the values."""
+ last = self._find_non_wildcards(values)
+ if last == -1:
+ return self._lookup_exact(values)
+ else:
+ return self._lookup_prefix(values[:last])
+
+ def lookup_range(self, start_values, end_values):
+ """Find docs within the range."""
+ # TODO: Wildly inefficient, which is unlikely to be a problem for the
+ # inmemory implementation.
+ if start_values:
+ self._find_non_wildcards(start_values)
+ start_values = get_prefix(start_values)
+ if end_values:
+ if self._find_non_wildcards(end_values) == -1:
+ exact = True
+ else:
+ exact = False
+ end_values = get_prefix(end_values)
+ found = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if start_values and start_values > key:
+ continue
+ if end_values and end_values < key:
+ if exact:
+ break
+ else:
+ if not key.startswith(end_values):
+ break
+ found.extend(doc_ids)
+ return found
+
+ def keys(self):
+ """Find the indexed keys."""
+ return self._values.keys()
+
+ def _lookup_prefix(self, value):
+ """Find docs that match the prefix string in values."""
+ # TODO: We need a different data structure to make prefix style fast,
+ # some sort of sorted list would work, but a plain dict doesn't.
+ key_prefix = get_prefix(value)
+ all_doc_ids = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if key.startswith(key_prefix):
+ all_doc_ids.extend(doc_ids)
+ return all_doc_ids
+
+ def _lookup_exact(self, value):
+ """Find docs that match exactly."""
+ key = '\x01'.join(value)
+ if key in self._values:
+ return self._values[key]
+ return ()
+
+
+class InMemorySyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_transaction_id)