summaryrefslogtreecommitdiff
path: root/common/src/leap/soledad/common/l2db/backends/inmemory.py
diff options
context:
space:
mode:
Diffstat (limited to 'common/src/leap/soledad/common/l2db/backends/inmemory.py')
-rw-r--r--common/src/leap/soledad/common/l2db/backends/inmemory.py466
1 files changed, 0 insertions, 466 deletions
diff --git a/common/src/leap/soledad/common/l2db/backends/inmemory.py b/common/src/leap/soledad/common/l2db/backends/inmemory.py
deleted file mode 100644
index 6fd251af..00000000
--- a/common/src/leap/soledad/common/l2db/backends/inmemory.py
+++ /dev/null
@@ -1,466 +0,0 @@
-# Copyright 2011 Canonical Ltd.
-#
-# This file is part of u1db.
-#
-# u1db is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License version 3
-# as published by the Free Software Foundation.
-#
-# u1db is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Lesser General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with u1db. If not, see <http://www.gnu.org/licenses/>.
-
-"""The in-memory Database class for U1DB."""
-
-import json
-
-from leap.soledad.common.l2db import (
- Document, errors,
- query_parser, vectorclock)
-from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget
-
-
-def get_prefix(value):
- key_prefix = '\x01'.join(value)
- return key_prefix.rstrip('*')
-
-
-class InMemoryDatabase(CommonBackend):
- """A database that only stores the data internally."""
-
- def __init__(self, replica_uid, document_factory=None):
- self._transaction_log = []
- self._docs = {}
- # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
- self._conflicts = {}
- self._other_generations = {}
- self._indexes = {}
- self._replica_uid = replica_uid
- self._factory = document_factory or Document
-
- def _set_replica_uid(self, replica_uid):
- """Force the replica_uid to be set."""
- self._replica_uid = replica_uid
-
- def set_document_factory(self, factory):
- self._factory = factory
-
- def close(self):
- # This is a no-op, We don't want to free the data because one client
- # may be closing it, while another wants to inspect the results.
- pass
-
- def _get_replica_gen_and_trans_id(self, other_replica_uid):
- return self._other_generations.get(other_replica_uid, (0, ''))
-
- def _set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation, other_transaction_id):
- self._do_set_replica_gen_and_trans_id(
- other_replica_uid, other_generation, other_transaction_id)
-
- def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation,
- other_transaction_id):
- # TODO: to handle race conditions, we may want to check if the current
- # value is greater than this new value.
- self._other_generations[other_replica_uid] = (other_generation,
- other_transaction_id)
-
- def get_sync_target(self):
- return InMemorySyncTarget(self)
-
- def _get_transaction_log(self):
- # snapshot!
- return self._transaction_log[:]
-
- def _get_generation(self):
- return len(self._transaction_log)
-
- def _get_generation_info(self):
- if not self._transaction_log:
- return 0, ''
- return len(self._transaction_log), self._transaction_log[-1][1]
-
- def _get_trans_id_for_gen(self, generation):
- if generation == 0:
- return ''
- if generation > len(self._transaction_log):
- raise errors.InvalidGeneration
- return self._transaction_log[generation - 1][1]
-
- def put_doc(self, doc):
- if doc.doc_id is None:
- raise errors.InvalidDocId()
- self._check_doc_id(doc.doc_id)
- self._check_doc_size(doc)
- old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
- if old_doc and old_doc.has_conflicts:
- raise errors.ConflictedDoc()
- if old_doc and doc.rev is None and old_doc.is_tombstone():
- new_rev = self._allocate_doc_rev(old_doc.rev)
- else:
- if old_doc is not None:
- if old_doc.rev != doc.rev:
- raise errors.RevisionConflict()
- else:
- if doc.rev is not None:
- raise errors.RevisionConflict()
- new_rev = self._allocate_doc_rev(doc.rev)
- doc.rev = new_rev
- self._put_and_update_indexes(old_doc, doc)
- return new_rev
-
- def _put_and_update_indexes(self, old_doc, doc):
- for index in self._indexes.itervalues():
- if old_doc is not None and not old_doc.is_tombstone():
- index.remove_json(old_doc.doc_id, old_doc.get_json())
- if not doc.is_tombstone():
- index.add_json(doc.doc_id, doc.get_json())
- trans_id = self._allocate_transaction_id()
- self._docs[doc.doc_id] = (doc.rev, doc.get_json())
- self._transaction_log.append((doc.doc_id, trans_id))
-
- def _get_doc(self, doc_id, check_for_conflicts=False):
- try:
- doc_rev, content = self._docs[doc_id]
- except KeyError:
- return None
- doc = self._factory(doc_id, doc_rev, content)
- if check_for_conflicts:
- doc.has_conflicts = (doc.doc_id in self._conflicts)
- return doc
-
- def _has_conflicts(self, doc_id):
- return doc_id in self._conflicts
-
- def get_doc(self, doc_id, include_deleted=False):
- doc = self._get_doc(doc_id, check_for_conflicts=True)
- if doc is None:
- return None
- if doc.is_tombstone() and not include_deleted:
- return None
- return doc
-
- def get_all_docs(self, include_deleted=False):
- """Return all documents in the database."""
- generation = self._get_generation()
- results = []
- for doc_id, (doc_rev, content) in self._docs.items():
- if content is None and not include_deleted:
- continue
- doc = self._factory(doc_id, doc_rev, content)
- doc.has_conflicts = self._has_conflicts(doc_id)
- results.append(doc)
- return (generation, results)
-
- def get_doc_conflicts(self, doc_id):
- if doc_id not in self._conflicts:
- return []
- result = [self._get_doc(doc_id)]
- result[0].has_conflicts = True
- result.extend([self._factory(doc_id, rev, content)
- for rev, content in self._conflicts[doc_id]])
- return result
-
- def _replace_conflicts(self, doc, conflicts):
- if not conflicts:
- del self._conflicts[doc.doc_id]
- else:
- self._conflicts[doc.doc_id] = conflicts
- doc.has_conflicts = bool(conflicts)
-
- def _prune_conflicts(self, doc, doc_vcr):
- if self._has_conflicts(doc.doc_id):
- autoresolved = False
- remaining_conflicts = []
- cur_conflicts = self._conflicts[doc.doc_id]
- for c_rev, c_doc in cur_conflicts:
- c_vcr = vectorclock.VectorClockRev(c_rev)
- if doc_vcr.is_newer(c_vcr):
- continue
- if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
- doc_vcr.maximize(c_vcr)
- autoresolved = True
- continue
- remaining_conflicts.append((c_rev, c_doc))
- if autoresolved:
- doc_vcr.increment(self._replica_uid)
- doc.rev = doc_vcr.as_str()
- self._replace_conflicts(doc, remaining_conflicts)
-
- def resolve_doc(self, doc, conflicted_doc_revs):
- cur_doc = self._get_doc(doc.doc_id)
- if cur_doc is None:
- cur_rev = None
- else:
- cur_rev = cur_doc.rev
- new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
- superseded_revs = set(conflicted_doc_revs)
- remaining_conflicts = []
- cur_conflicts = self._conflicts[doc.doc_id]
- for c_rev, c_doc in cur_conflicts:
- if c_rev in superseded_revs:
- continue
- remaining_conflicts.append((c_rev, c_doc))
- doc.rev = new_rev
- if cur_rev in superseded_revs:
- self._put_and_update_indexes(cur_doc, doc)
- else:
- remaining_conflicts.append((new_rev, doc.get_json()))
- self._replace_conflicts(doc, remaining_conflicts)
-
- def delete_doc(self, doc):
- if doc.doc_id not in self._docs:
- raise errors.DocumentDoesNotExist
- if self._docs[doc.doc_id][1] in ('null', None):
- raise errors.DocumentAlreadyDeleted
- doc.make_tombstone()
- self.put_doc(doc)
-
- def create_index(self, index_name, *index_expressions):
- if index_name in self._indexes:
- if self._indexes[index_name]._definition == list(
- index_expressions):
- return
- raise errors.IndexNameTakenError
- index = InMemoryIndex(index_name, list(index_expressions))
- for doc_id, (doc_rev, doc) in self._docs.iteritems():
- if doc is not None:
- index.add_json(doc_id, doc)
- self._indexes[index_name] = index
-
- def delete_index(self, index_name):
- try:
- del self._indexes[index_name]
- except KeyError:
- pass
-
- def list_indexes(self):
- definitions = []
- for idx in self._indexes.itervalues():
- definitions.append((idx._name, idx._definition))
- return definitions
-
- def get_from_index(self, index_name, *key_values):
- try:
- index = self._indexes[index_name]
- except KeyError:
- raise errors.IndexDoesNotExist
- doc_ids = index.lookup(key_values)
- result = []
- for doc_id in doc_ids:
- result.append(self._get_doc(doc_id, check_for_conflicts=True))
- return result
-
- def get_range_from_index(self, index_name, start_value=None,
- end_value=None):
- """Return all documents with key values in the specified range."""
- try:
- index = self._indexes[index_name]
- except KeyError:
- raise errors.IndexDoesNotExist
- if isinstance(start_value, basestring):
- start_value = (start_value,)
- if isinstance(end_value, basestring):
- end_value = (end_value,)
- doc_ids = index.lookup_range(start_value, end_value)
- result = []
- for doc_id in doc_ids:
- result.append(self._get_doc(doc_id, check_for_conflicts=True))
- return result
-
- def get_index_keys(self, index_name):
- try:
- index = self._indexes[index_name]
- except KeyError:
- raise errors.IndexDoesNotExist
- keys = index.keys()
- # XXX inefficiency warning
- return list(set([tuple(key.split('\x01')) for key in keys]))
-
- def whats_changed(self, old_generation=0):
- changes = []
- relevant_tail = self._transaction_log[old_generation:]
- # We don't use len(self._transaction_log) because _transaction_log may
- # get mutated by a concurrent operation.
- cur_generation = old_generation + len(relevant_tail)
- last_trans_id = ''
- if relevant_tail:
- last_trans_id = relevant_tail[-1][1]
- elif self._transaction_log:
- last_trans_id = self._transaction_log[-1][1]
- seen = set()
- generation = cur_generation
- for doc_id, trans_id in reversed(relevant_tail):
- if doc_id not in seen:
- changes.append((doc_id, generation, trans_id))
- seen.add(doc_id)
- generation -= 1
- changes.reverse()
- return (cur_generation, last_trans_id, changes)
-
- def _force_doc_sync_conflict(self, doc):
- my_doc = self._get_doc(doc.doc_id)
- self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
- self._conflicts.setdefault(doc.doc_id, []).append(
- (my_doc.rev, my_doc.get_json()))
- doc.has_conflicts = True
- self._put_and_update_indexes(my_doc, doc)
-
-
-class InMemoryIndex(object):
- """Interface for managing an Index."""
-
- def __init__(self, index_name, index_definition):
- self._name = index_name
- self._definition = index_definition
- self._values = {}
- parser = query_parser.Parser()
- self._getters = parser.parse_all(self._definition)
-
- def evaluate_json(self, doc):
- """Determine the 'key' after applying this index to the doc."""
- raw = json.loads(doc)
- return self.evaluate(raw)
-
- def evaluate(self, obj):
- """Evaluate a dict object, applying this definition."""
- all_rows = [[]]
- for getter in self._getters:
- new_rows = []
- keys = getter.get(obj)
- if not keys:
- return []
- for key in keys:
- new_rows.extend([row + [key] for row in all_rows])
- all_rows = new_rows
- all_rows = ['\x01'.join(row) for row in all_rows]
- return all_rows
-
- def add_json(self, doc_id, doc):
- """Add this json doc to the index."""
- keys = self.evaluate_json(doc)
- if not keys:
- return
- for key in keys:
- self._values.setdefault(key, []).append(doc_id)
-
- def remove_json(self, doc_id, doc):
- """Remove this json doc from the index."""
- keys = self.evaluate_json(doc)
- if keys:
- for key in keys:
- doc_ids = self._values[key]
- doc_ids.remove(doc_id)
- if not doc_ids:
- del self._values[key]
-
- def _find_non_wildcards(self, values):
- """Check if this should be a wildcard match.
-
- Further, this will raise an exception if the syntax is improperly
- defined.
-
- :return: The offset of the last value we need to match against.
- """
- if len(values) != len(self._definition):
- raise errors.InvalidValueForIndex()
- is_wildcard = False
- last = 0
- for idx, val in enumerate(values):
- if val.endswith('*'):
- if val != '*':
- # We have an 'x*' style wildcard
- if is_wildcard:
- # We were already in wildcard mode, so this is invalid
- raise errors.InvalidGlobbing
- last = idx + 1
- is_wildcard = True
- else:
- if is_wildcard:
- # We were in wildcard mode, we can't follow that with
- # non-wildcard
- raise errors.InvalidGlobbing
- last = idx + 1
- if not is_wildcard:
- return -1
- return last
-
- def lookup(self, values):
- """Find docs that match the values."""
- last = self._find_non_wildcards(values)
- if last == -1:
- return self._lookup_exact(values)
- else:
- return self._lookup_prefix(values[:last])
-
- def lookup_range(self, start_values, end_values):
- """Find docs within the range."""
- # TODO: Wildly inefficient, which is unlikely to be a problem for the
- # inmemory implementation.
- if start_values:
- self._find_non_wildcards(start_values)
- start_values = get_prefix(start_values)
- if end_values:
- if self._find_non_wildcards(end_values) == -1:
- exact = True
- else:
- exact = False
- end_values = get_prefix(end_values)
- found = []
- for key, doc_ids in sorted(self._values.iteritems()):
- if start_values and start_values > key:
- continue
- if end_values and end_values < key:
- if exact:
- break
- else:
- if not key.startswith(end_values):
- break
- found.extend(doc_ids)
- return found
-
- def keys(self):
- """Find the indexed keys."""
- return self._values.keys()
-
- def _lookup_prefix(self, value):
- """Find docs that match the prefix string in values."""
- # TODO: We need a different data structure to make prefix style fast,
- # some sort of sorted list would work, but a plain dict doesn't.
- key_prefix = get_prefix(value)
- all_doc_ids = []
- for key, doc_ids in sorted(self._values.iteritems()):
- if key.startswith(key_prefix):
- all_doc_ids.extend(doc_ids)
- return all_doc_ids
-
- def _lookup_exact(self, value):
- """Find docs that match exactly."""
- key = '\x01'.join(value)
- if key in self._values:
- return self._values[key]
- return ()
-
-
-class InMemorySyncTarget(CommonSyncTarget):
-
- def get_sync_info(self, source_replica_uid):
- source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
- source_replica_uid)
- my_gen, my_trans_id = self._db._get_generation_info()
- return (
- self._db._replica_uid, my_gen, my_trans_id, source_gen,
- source_trans_id)
-
- def record_sync_info(self, source_replica_uid, source_replica_generation,
- source_transaction_id):
- if self._trace_hook:
- self._trace_hook('record_sync_info')
- self._db._set_replica_gen_and_trans_id(
- source_replica_uid, source_replica_generation,
- source_transaction_id)