summaryrefslogtreecommitdiff
path: root/src/leap/soledad/common/l2db/backends/inmemory.py
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2017-06-18 11:18:10 -0300
committerKali Kaneko <kali@leap.se>2017-06-24 00:49:17 +0200
commit3e94cafa43d464d73815e21810b97a4faf54136d (patch)
tree04f5152e3dfc9f27b7dca2368c0eb8b3f094f5b5 /src/leap/soledad/common/l2db/backends/inmemory.py
parent7d8ee786b086e47264619df3efa73e74440fd068 (diff)
[pkg] unify client and server into a single python package
We have been discussing about this merge for a while. Its main goal is to simplify things: code navigation, but also packaging. The rationale is that the code is more cohesive in this way, and there's only one source package to install. Dependencies that are only for the server or the client will not be installed by default, and they are expected to be provided by the environment. There are setuptools extras defined for the client and the server. Debianization is still expected to split the single source package into 3 binaries. Another avantage is that the documentation can now install a single package with a single step, and therefore include the docstrings into the generated docs. - Resolves: #8896
Diffstat (limited to 'src/leap/soledad/common/l2db/backends/inmemory.py')
-rw-r--r--src/leap/soledad/common/l2db/backends/inmemory.py466
1 files changed, 466 insertions, 0 deletions
diff --git a/src/leap/soledad/common/l2db/backends/inmemory.py b/src/leap/soledad/common/l2db/backends/inmemory.py
new file mode 100644
index 00000000..6fd251af
--- /dev/null
+++ b/src/leap/soledad/common/l2db/backends/inmemory.py
@@ -0,0 +1,466 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The in-memory Database class for U1DB."""
+
+import json
+
+from leap.soledad.common.l2db import (
+ Document, errors,
+ query_parser, vectorclock)
+from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget
+
+
+def get_prefix(value):
+ key_prefix = '\x01'.join(value)
+ return key_prefix.rstrip('*')
+
+
+class InMemoryDatabase(CommonBackend):
+ """A database that only stores the data internally."""
+
+ def __init__(self, replica_uid, document_factory=None):
+ self._transaction_log = []
+ self._docs = {}
+ # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
+ self._conflicts = {}
+ self._other_generations = {}
+ self._indexes = {}
+ self._replica_uid = replica_uid
+ self._factory = document_factory or Document
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ self._replica_uid = replica_uid
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def close(self):
+ # This is a no-op, We don't want to free the data because one client
+ # may be closing it, while another wants to inspect the results.
+ pass
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ return self._other_generations.get(other_replica_uid, (0, ''))
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ # TODO: to handle race conditions, we may want to check if the current
+ # value is greater than this new value.
+ self._other_generations[other_replica_uid] = (other_generation,
+ other_transaction_id)
+
+ def get_sync_target(self):
+ return InMemorySyncTarget(self)
+
+ def _get_transaction_log(self):
+ # snapshot!
+ return self._transaction_log[:]
+
+ def _get_generation(self):
+ return len(self._transaction_log)
+
+ def _get_generation_info(self):
+ if not self._transaction_log:
+ return 0, ''
+ return len(self._transaction_log), self._transaction_log[-1][1]
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ if generation > len(self._transaction_log):
+ raise errors.InvalidGeneration
+ return self._transaction_log[generation - 1][1]
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ for index in self._indexes.itervalues():
+ if old_doc is not None and not old_doc.is_tombstone():
+ index.remove_json(old_doc.doc_id, old_doc.get_json())
+ if not doc.is_tombstone():
+ index.add_json(doc.doc_id, doc.get_json())
+ trans_id = self._allocate_transaction_id()
+ self._docs[doc.doc_id] = (doc.rev, doc.get_json())
+ self._transaction_log.append((doc.doc_id, trans_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ try:
+ doc_rev, content = self._docs[doc_id]
+ except KeyError:
+ return None
+ doc = self._factory(doc_id, doc_rev, content)
+ if check_for_conflicts:
+ doc.has_conflicts = (doc.doc_id in self._conflicts)
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ return doc_id in self._conflicts
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Return all documents in the database."""
+ generation = self._get_generation()
+ results = []
+ for doc_id, (doc_rev, content) in self._docs.items():
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = self._has_conflicts(doc_id)
+ results.append(doc)
+ return (generation, results)
+
+ def get_doc_conflicts(self, doc_id):
+ if doc_id not in self._conflicts:
+ return []
+ result = [self._get_doc(doc_id)]
+ result[0].has_conflicts = True
+ result.extend([self._factory(doc_id, rev, content)
+ for rev, content in self._conflicts[doc_id]])
+ return result
+
+ def _replace_conflicts(self, doc, conflicts):
+ if not conflicts:
+ del self._conflicts[doc.doc_id]
+ else:
+ self._conflicts[doc.doc_id] = conflicts
+ doc.has_conflicts = bool(conflicts)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ c_vcr = vectorclock.VectorClockRev(c_rev)
+ if doc_vcr.is_newer(c_vcr):
+ continue
+ if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ cur_doc = self._get_doc(doc.doc_id)
+ if cur_doc is None:
+ cur_rev = None
+ else:
+ cur_rev = cur_doc.rev
+ new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ if c_rev in superseded_revs:
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ doc.rev = new_rev
+ if cur_rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ remaining_conflicts.append((new_rev, doc.get_json()))
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def delete_doc(self, doc):
+ if doc.doc_id not in self._docs:
+ raise errors.DocumentDoesNotExist
+ if self._docs[doc.doc_id][1] in ('null', None):
+ raise errors.DocumentAlreadyDeleted
+ doc.make_tombstone()
+ self.put_doc(doc)
+
+ def create_index(self, index_name, *index_expressions):
+ if index_name in self._indexes:
+ if self._indexes[index_name]._definition == list(
+ index_expressions):
+ return
+ raise errors.IndexNameTakenError
+ index = InMemoryIndex(index_name, list(index_expressions))
+ for doc_id, (doc_rev, doc) in self._docs.iteritems():
+ if doc is not None:
+ index.add_json(doc_id, doc)
+ self._indexes[index_name] = index
+
+ def delete_index(self, index_name):
+ try:
+ del self._indexes[index_name]
+ except KeyError:
+ pass
+
+ def list_indexes(self):
+ definitions = []
+ for idx in self._indexes.itervalues():
+ definitions.append((idx._name, idx._definition))
+ return definitions
+
+ def get_from_index(self, index_name, *key_values):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ doc_ids = index.lookup(key_values)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ doc_ids = index.lookup_range(start_value, end_value)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_index_keys(self, index_name):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ keys = index.keys()
+ # XXX inefficiency warning
+ return list(set([tuple(key.split('\x01')) for key in keys]))
+
+ def whats_changed(self, old_generation=0):
+ changes = []
+ relevant_tail = self._transaction_log[old_generation:]
+ # We don't use len(self._transaction_log) because _transaction_log may
+ # get mutated by a concurrent operation.
+ cur_generation = old_generation + len(relevant_tail)
+ last_trans_id = ''
+ if relevant_tail:
+ last_trans_id = relevant_tail[-1][1]
+ elif self._transaction_log:
+ last_trans_id = self._transaction_log[-1][1]
+ seen = set()
+ generation = cur_generation
+ for doc_id, trans_id in reversed(relevant_tail):
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ generation -= 1
+ changes.reverse()
+ return (cur_generation, last_trans_id, changes)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._conflicts.setdefault(doc.doc_id, []).append(
+ (my_doc.rev, my_doc.get_json()))
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+
+class InMemoryIndex(object):
+ """Interface for managing an Index."""
+
+ def __init__(self, index_name, index_definition):
+ self._name = index_name
+ self._definition = index_definition
+ self._values = {}
+ parser = query_parser.Parser()
+ self._getters = parser.parse_all(self._definition)
+
+ def evaluate_json(self, doc):
+ """Determine the 'key' after applying this index to the doc."""
+ raw = json.loads(doc)
+ return self.evaluate(raw)
+
+ def evaluate(self, obj):
+ """Evaluate a dict object, applying this definition."""
+ all_rows = [[]]
+ for getter in self._getters:
+ new_rows = []
+ keys = getter.get(obj)
+ if not keys:
+ return []
+ for key in keys:
+ new_rows.extend([row + [key] for row in all_rows])
+ all_rows = new_rows
+ all_rows = ['\x01'.join(row) for row in all_rows]
+ return all_rows
+
+ def add_json(self, doc_id, doc):
+ """Add this json doc to the index."""
+ keys = self.evaluate_json(doc)
+ if not keys:
+ return
+ for key in keys:
+ self._values.setdefault(key, []).append(doc_id)
+
+ def remove_json(self, doc_id, doc):
+ """Remove this json doc from the index."""
+ keys = self.evaluate_json(doc)
+ if keys:
+ for key in keys:
+ doc_ids = self._values[key]
+ doc_ids.remove(doc_id)
+ if not doc_ids:
+ del self._values[key]
+
+ def _find_non_wildcards(self, values):
+ """Check if this should be a wildcard match.
+
+ Further, this will raise an exception if the syntax is improperly
+ defined.
+
+ :return: The offset of the last value we need to match against.
+ """
+ if len(values) != len(self._definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ last = 0
+ for idx, val in enumerate(values):
+ if val.endswith('*'):
+ if val != '*':
+ # We have an 'x*' style wildcard
+ if is_wildcard:
+ # We were already in wildcard mode, so this is invalid
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ # We were in wildcard mode, we can't follow that with
+ # non-wildcard
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ if not is_wildcard:
+ return -1
+ return last
+
+ def lookup(self, values):
+ """Find docs that match the values."""
+ last = self._find_non_wildcards(values)
+ if last == -1:
+ return self._lookup_exact(values)
+ else:
+ return self._lookup_prefix(values[:last])
+
+ def lookup_range(self, start_values, end_values):
+ """Find docs within the range."""
+ # TODO: Wildly inefficient, which is unlikely to be a problem for the
+ # inmemory implementation.
+ if start_values:
+ self._find_non_wildcards(start_values)
+ start_values = get_prefix(start_values)
+ if end_values:
+ if self._find_non_wildcards(end_values) == -1:
+ exact = True
+ else:
+ exact = False
+ end_values = get_prefix(end_values)
+ found = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if start_values and start_values > key:
+ continue
+ if end_values and end_values < key:
+ if exact:
+ break
+ else:
+ if not key.startswith(end_values):
+ break
+ found.extend(doc_ids)
+ return found
+
+ def keys(self):
+ """Find the indexed keys."""
+ return self._values.keys()
+
+ def _lookup_prefix(self, value):
+ """Find docs that match the prefix string in values."""
+ # TODO: We need a different data structure to make prefix style fast,
+ # some sort of sorted list would work, but a plain dict doesn't.
+ key_prefix = get_prefix(value)
+ all_doc_ids = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if key.startswith(key_prefix):
+ all_doc_ids.extend(doc_ids)
+ return all_doc_ids
+
+ def _lookup_exact(self, value):
+ """Find docs that match exactly."""
+ key = '\x01'.join(value)
+ if key in self._values:
+ return self._values[key]
+ return ()
+
+
+class InMemorySyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_transaction_id)