summaryrefslogtreecommitdiff
path: root/src/leap/soledad/u1db/backends
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2012-12-06 11:07:53 -0200
committerdrebs <drebs@leap.se>2012-12-06 11:07:53 -0200
commit584696e4dbfc13b793208dc4c5c6cdc224db5a12 (patch)
treea52830511908c8a135acc16095d61acba177873b /src/leap/soledad/u1db/backends
parent2cf00360bce0193d8fa73194a148c28426172043 (diff)
Remove u1db and swiftclient dirs and refactor.
Diffstat (limited to 'src/leap/soledad/u1db/backends')
-rw-r--r--src/leap/soledad/u1db/backends/__init__.py211
-rw-r--r--src/leap/soledad/u1db/backends/dbschema.sql42
-rw-r--r--src/leap/soledad/u1db/backends/inmemory.py469
-rw-r--r--src/leap/soledad/u1db/backends/sqlite_backend.py926
4 files changed, 0 insertions, 1648 deletions
diff --git a/src/leap/soledad/u1db/backends/__init__.py b/src/leap/soledad/u1db/backends/__init__.py
deleted file mode 100644
index c8e5adc6..00000000
--- a/src/leap/soledad/u1db/backends/__init__.py
+++ /dev/null
@@ -1,211 +0,0 @@
-# Copyright 2011 Canonical Ltd.
-#
-# This file is part of u1db.
-#
-# u1db is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License version 3
-# as published by the Free Software Foundation.
-#
-# u1db is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Lesser General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with u1db. If not, see <http://www.gnu.org/licenses/>.
-
-"""Abstract classes and common implementations for the backends."""
-
-import re
-try:
- import simplejson as json
-except ImportError:
- import json # noqa
-import uuid
-
-import u1db
-from u1db import (
- errors,
-)
-import u1db.sync
-from u1db.vectorclock import VectorClockRev
-
-
-check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE)
-
-
-class CommonSyncTarget(u1db.sync.LocalSyncTarget):
- pass
-
-
-class CommonBackend(u1db.Database):
-
- document_size_limit = 0
-
- def _allocate_doc_id(self):
- """Generate a unique identifier for this document."""
- return 'D-' + uuid.uuid4().hex # 'D-' stands for document
-
- def _allocate_transaction_id(self):
- return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction
-
- def _allocate_doc_rev(self, old_doc_rev):
- vcr = VectorClockRev(old_doc_rev)
- vcr.increment(self._replica_uid)
- return vcr.as_str()
-
- def _check_doc_id(self, doc_id):
- if not check_doc_id_re.match(doc_id):
- raise errors.InvalidDocId()
-
- def _check_doc_size(self, doc):
- if not self.document_size_limit:
- return
- if doc.get_size() > self.document_size_limit:
- raise errors.DocumentTooBig
-
- def _get_generation(self):
- """Return the current generation.
-
- """
- raise NotImplementedError(self._get_generation)
-
- def _get_generation_info(self):
- """Return the current generation and transaction id.
-
- """
- raise NotImplementedError(self._get_generation_info)
-
- def _get_doc(self, doc_id, check_for_conflicts=False):
- """Extract the document from storage.
-
- This can return None if the document doesn't exist.
- """
- raise NotImplementedError(self._get_doc)
-
- def _has_conflicts(self, doc_id):
- """Return True if the doc has conflicts, False otherwise."""
- raise NotImplementedError(self._has_conflicts)
-
- def create_doc(self, content, doc_id=None):
- json_string = json.dumps(content)
- if doc_id is None:
- doc_id = self._allocate_doc_id()
- doc = self._factory(doc_id, None, json_string)
- self.put_doc(doc)
- return doc
-
- def create_doc_from_json(self, json, doc_id=None):
- if doc_id is None:
- doc_id = self._allocate_doc_id()
- doc = self._factory(doc_id, None, json)
- self.put_doc(doc)
- return doc
-
- def _get_transaction_log(self):
- """This is only for the test suite, it is not part of the api."""
- raise NotImplementedError(self._get_transaction_log)
-
- def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content):
- raise NotImplementedError(self._put_and_update_indexes)
-
- def get_docs(self, doc_ids, check_for_conflicts=True,
- include_deleted=False):
- for doc_id in doc_ids:
- doc = self._get_doc(
- doc_id, check_for_conflicts=check_for_conflicts)
- if doc.is_tombstone() and not include_deleted:
- continue
- yield doc
-
- def _get_trans_id_for_gen(self, generation):
- """Get the transaction id corresponding to a particular generation.
-
- Raises an InvalidGeneration when the generation does not exist.
-
- """
- raise NotImplementedError(self._get_trans_id_for_gen)
-
- def validate_gen_and_trans_id(self, generation, trans_id):
- """Validate the generation and transaction id.
-
- Raises an InvalidGeneration when the generation does not exist, and an
- InvalidTransactionId when it does but with a different transaction id.
-
- """
- if generation == 0:
- return
- known_trans_id = self._get_trans_id_for_gen(generation)
- if known_trans_id != trans_id:
- raise errors.InvalidTransactionId
-
- def _validate_source(self, other_replica_uid, other_generation,
- other_transaction_id):
- """Validate the new generation and transaction id.
-
- other_generation must be greater than what we have stored for this
- replica, *or* it must be the same and the transaction_id must be the
- same as well.
- """
- (old_generation,
- old_transaction_id) = self._get_replica_gen_and_trans_id(
- other_replica_uid)
- if other_generation < old_generation:
- raise errors.InvalidGeneration
- if other_generation > old_generation:
- return
- if other_transaction_id == old_transaction_id:
- return
- raise errors.InvalidTransactionId
-
- def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen,
- replica_trans_id=''):
- cur_doc = self._get_doc(doc.doc_id)
- doc_vcr = VectorClockRev(doc.rev)
- if cur_doc is None:
- cur_vcr = VectorClockRev(None)
- else:
- cur_vcr = VectorClockRev(cur_doc.rev)
- self._validate_source(replica_uid, replica_gen, replica_trans_id)
- if doc_vcr.is_newer(cur_vcr):
- rev = doc.rev
- self._prune_conflicts(doc, doc_vcr)
- if doc.rev != rev:
- # conflicts have been autoresolved
- state = 'superseded'
- else:
- state = 'inserted'
- self._put_and_update_indexes(cur_doc, doc)
- elif doc.rev == cur_doc.rev:
- # magical convergence
- state = 'converged'
- elif cur_vcr.is_newer(doc_vcr):
- # Don't add this to seen_ids, because we have something newer,
- # so we should send it back, and we should not generate a
- # conflict
- state = 'superseded'
- elif cur_doc.same_content_as(doc):
- # the documents have been edited to the same thing at both ends
- doc_vcr.maximize(cur_vcr)
- doc_vcr.increment(self._replica_uid)
- doc.rev = doc_vcr.as_str()
- self._put_and_update_indexes(cur_doc, doc)
- state = 'superseded'
- else:
- state = 'conflicted'
- if save_conflict:
- self._force_doc_sync_conflict(doc)
- if replica_uid is not None and replica_gen is not None:
- self._do_set_replica_gen_and_trans_id(
- replica_uid, replica_gen, replica_trans_id)
- return state, self._get_generation()
-
- def _ensure_maximal_rev(self, cur_rev, extra_revs):
- vcr = VectorClockRev(cur_rev)
- for rev in extra_revs:
- vcr.maximize(VectorClockRev(rev))
- vcr.increment(self._replica_uid)
- return vcr.as_str()
-
- def set_document_size_limit(self, limit):
- self.document_size_limit = limit
diff --git a/src/leap/soledad/u1db/backends/dbschema.sql b/src/leap/soledad/u1db/backends/dbschema.sql
deleted file mode 100644
index ae027fc5..00000000
--- a/src/leap/soledad/u1db/backends/dbschema.sql
+++ /dev/null
@@ -1,42 +0,0 @@
--- Database schema
-CREATE TABLE transaction_log (
- generation INTEGER PRIMARY KEY AUTOINCREMENT,
- doc_id TEXT NOT NULL,
- transaction_id TEXT NOT NULL
-);
-CREATE TABLE document (
- doc_id TEXT PRIMARY KEY,
- doc_rev TEXT NOT NULL,
- content TEXT
-);
-CREATE TABLE document_fields (
- doc_id TEXT NOT NULL,
- field_name TEXT NOT NULL,
- value TEXT
-);
-CREATE INDEX document_fields_field_value_doc_idx
- ON document_fields(field_name, value, doc_id);
-
-CREATE TABLE sync_log (
- replica_uid TEXT PRIMARY KEY,
- known_generation INTEGER,
- known_transaction_id TEXT
-);
-CREATE TABLE conflicts (
- doc_id TEXT,
- doc_rev TEXT,
- content TEXT,
- CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
-);
-CREATE TABLE index_definitions (
- name TEXT,
- offset INT,
- field TEXT,
- CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
-);
-create index index_definitions_field on index_definitions(field);
-CREATE TABLE u1db_config (
- name TEXT PRIMARY KEY,
- value TEXT
-);
-INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py
deleted file mode 100644
index a271bb37..00000000
--- a/src/leap/soledad/u1db/backends/inmemory.py
+++ /dev/null
@@ -1,469 +0,0 @@
-# Copyright 2011 Canonical Ltd.
-#
-# This file is part of u1db.
-#
-# u1db is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License version 3
-# as published by the Free Software Foundation.
-#
-# u1db is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Lesser General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with u1db. If not, see <http://www.gnu.org/licenses/>.
-
-"""The in-memory Database class for U1DB."""
-
-try:
- import simplejson as json
-except ImportError:
- import json # noqa
-
-from u1db import (
- Document,
- errors,
- query_parser,
- vectorclock,
- )
-from u1db.backends import CommonBackend, CommonSyncTarget
-
-
-def get_prefix(value):
- key_prefix = '\x01'.join(value)
- return key_prefix.rstrip('*')
-
-
-class InMemoryDatabase(CommonBackend):
- """A database that only stores the data internally."""
-
- def __init__(self, replica_uid, document_factory=None):
- self._transaction_log = []
- self._docs = {}
- # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
- self._conflicts = {}
- self._other_generations = {}
- self._indexes = {}
- self._replica_uid = replica_uid
- self._factory = document_factory or Document
-
- def _set_replica_uid(self, replica_uid):
- """Force the replica_uid to be set."""
- self._replica_uid = replica_uid
-
- def set_document_factory(self, factory):
- self._factory = factory
-
- def close(self):
- # This is a no-op, We don't want to free the data because one client
- # may be closing it, while another wants to inspect the results.
- pass
-
- def _get_replica_gen_and_trans_id(self, other_replica_uid):
- return self._other_generations.get(other_replica_uid, (0, ''))
-
- def _set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation, other_transaction_id):
- self._do_set_replica_gen_and_trans_id(
- other_replica_uid, other_generation, other_transaction_id)
-
- def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation,
- other_transaction_id):
- # TODO: to handle race conditions, we may want to check if the current
- # value is greater than this new value.
- self._other_generations[other_replica_uid] = (other_generation,
- other_transaction_id)
-
- def get_sync_target(self):
- return InMemorySyncTarget(self)
-
- def _get_transaction_log(self):
- # snapshot!
- return self._transaction_log[:]
-
- def _get_generation(self):
- return len(self._transaction_log)
-
- def _get_generation_info(self):
- if not self._transaction_log:
- return 0, ''
- return len(self._transaction_log), self._transaction_log[-1][1]
-
- def _get_trans_id_for_gen(self, generation):
- if generation == 0:
- return ''
- if generation > len(self._transaction_log):
- raise errors.InvalidGeneration
- return self._transaction_log[generation - 1][1]
-
- def put_doc(self, doc):
- if doc.doc_id is None:
- raise errors.InvalidDocId()
- self._check_doc_id(doc.doc_id)
- self._check_doc_size(doc)
- old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
- if old_doc and old_doc.has_conflicts:
- raise errors.ConflictedDoc()
- if old_doc and doc.rev is None and old_doc.is_tombstone():
- new_rev = self._allocate_doc_rev(old_doc.rev)
- else:
- if old_doc is not None:
- if old_doc.rev != doc.rev:
- raise errors.RevisionConflict()
- else:
- if doc.rev is not None:
- raise errors.RevisionConflict()
- new_rev = self._allocate_doc_rev(doc.rev)
- doc.rev = new_rev
- self._put_and_update_indexes(old_doc, doc)
- return new_rev
-
- def _put_and_update_indexes(self, old_doc, doc):
- for index in self._indexes.itervalues():
- if old_doc is not None and not old_doc.is_tombstone():
- index.remove_json(old_doc.doc_id, old_doc.get_json())
- if not doc.is_tombstone():
- index.add_json(doc.doc_id, doc.get_json())
- trans_id = self._allocate_transaction_id()
- self._docs[doc.doc_id] = (doc.rev, doc.get_json())
- self._transaction_log.append((doc.doc_id, trans_id))
-
- def _get_doc(self, doc_id, check_for_conflicts=False):
- try:
- doc_rev, content = self._docs[doc_id]
- except KeyError:
- return None
- doc = self._factory(doc_id, doc_rev, content)
- if check_for_conflicts:
- doc.has_conflicts = (doc.doc_id in self._conflicts)
- return doc
-
- def _has_conflicts(self, doc_id):
- return doc_id in self._conflicts
-
- def get_doc(self, doc_id, include_deleted=False):
- doc = self._get_doc(doc_id, check_for_conflicts=True)
- if doc is None:
- return None
- if doc.is_tombstone() and not include_deleted:
- return None
- return doc
-
- def get_all_docs(self, include_deleted=False):
- """Return all documents in the database."""
- generation = self._get_generation()
- results = []
- for doc_id, (doc_rev, content) in self._docs.items():
- if content is None and not include_deleted:
- continue
- doc = self._factory(doc_id, doc_rev, content)
- doc.has_conflicts = self._has_conflicts(doc_id)
- results.append(doc)
- return (generation, results)
-
- def get_doc_conflicts(self, doc_id):
- if doc_id not in self._conflicts:
- return []
- result = [self._get_doc(doc_id)]
- result[0].has_conflicts = True
- result.extend([self._factory(doc_id, rev, content)
- for rev, content in self._conflicts[doc_id]])
- return result
-
- def _replace_conflicts(self, doc, conflicts):
- if not conflicts:
- del self._conflicts[doc.doc_id]
- else:
- self._conflicts[doc.doc_id] = conflicts
- doc.has_conflicts = bool(conflicts)
-
- def _prune_conflicts(self, doc, doc_vcr):
- if self._has_conflicts(doc.doc_id):
- autoresolved = False
- remaining_conflicts = []
- cur_conflicts = self._conflicts[doc.doc_id]
- for c_rev, c_doc in cur_conflicts:
- c_vcr = vectorclock.VectorClockRev(c_rev)
- if doc_vcr.is_newer(c_vcr):
- continue
- if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
- doc_vcr.maximize(c_vcr)
- autoresolved = True
- continue
- remaining_conflicts.append((c_rev, c_doc))
- if autoresolved:
- doc_vcr.increment(self._replica_uid)
- doc.rev = doc_vcr.as_str()
- self._replace_conflicts(doc, remaining_conflicts)
-
- def resolve_doc(self, doc, conflicted_doc_revs):
- cur_doc = self._get_doc(doc.doc_id)
- if cur_doc is None:
- cur_rev = None
- else:
- cur_rev = cur_doc.rev
- new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
- superseded_revs = set(conflicted_doc_revs)
- remaining_conflicts = []
- cur_conflicts = self._conflicts[doc.doc_id]
- for c_rev, c_doc in cur_conflicts:
- if c_rev in superseded_revs:
- continue
- remaining_conflicts.append((c_rev, c_doc))
- doc.rev = new_rev
- if cur_rev in superseded_revs:
- self._put_and_update_indexes(cur_doc, doc)
- else:
- remaining_conflicts.append((new_rev, doc.get_json()))
- self._replace_conflicts(doc, remaining_conflicts)
-
- def delete_doc(self, doc):
- if doc.doc_id not in self._docs:
- raise errors.DocumentDoesNotExist
- if self._docs[doc.doc_id][1] in ('null', None):
- raise errors.DocumentAlreadyDeleted
- doc.make_tombstone()
- self.put_doc(doc)
-
- def create_index(self, index_name, *index_expressions):
- if index_name in self._indexes:
- if self._indexes[index_name]._definition == list(
- index_expressions):
- return
- raise errors.IndexNameTakenError
- index = InMemoryIndex(index_name, list(index_expressions))
- for doc_id, (doc_rev, doc) in self._docs.iteritems():
- if doc is not None:
- index.add_json(doc_id, doc)
- self._indexes[index_name] = index
-
- def delete_index(self, index_name):
- del self._indexes[index_name]
-
- def list_indexes(self):
- definitions = []
- for idx in self._indexes.itervalues():
- definitions.append((idx._name, idx._definition))
- return definitions
-
- def get_from_index(self, index_name, *key_values):
- try:
- index = self._indexes[index_name]
- except KeyError:
- raise errors.IndexDoesNotExist
- doc_ids = index.lookup(key_values)
- result = []
- for doc_id in doc_ids:
- result.append(self._get_doc(doc_id, check_for_conflicts=True))
- return result
-
- def get_range_from_index(self, index_name, start_value=None,
- end_value=None):
- """Return all documents with key values in the specified range."""
- try:
- index = self._indexes[index_name]
- except KeyError:
- raise errors.IndexDoesNotExist
- if isinstance(start_value, basestring):
- start_value = (start_value,)
- if isinstance(end_value, basestring):
- end_value = (end_value,)
- doc_ids = index.lookup_range(start_value, end_value)
- result = []
- for doc_id in doc_ids:
- result.append(self._get_doc(doc_id, check_for_conflicts=True))
- return result
-
- def get_index_keys(self, index_name):
- try:
- index = self._indexes[index_name]
- except KeyError:
- raise errors.IndexDoesNotExist
- keys = index.keys()
- # XXX inefficiency warning
- return list(set([tuple(key.split('\x01')) for key in keys]))
-
- def whats_changed(self, old_generation=0):
- changes = []
- relevant_tail = self._transaction_log[old_generation:]
- # We don't use len(self._transaction_log) because _transaction_log may
- # get mutated by a concurrent operation.
- cur_generation = old_generation + len(relevant_tail)
- last_trans_id = ''
- if relevant_tail:
- last_trans_id = relevant_tail[-1][1]
- elif self._transaction_log:
- last_trans_id = self._transaction_log[-1][1]
- seen = set()
- generation = cur_generation
- for doc_id, trans_id in reversed(relevant_tail):
- if doc_id not in seen:
- changes.append((doc_id, generation, trans_id))
- seen.add(doc_id)
- generation -= 1
- changes.reverse()
- return (cur_generation, last_trans_id, changes)
-
- def _force_doc_sync_conflict(self, doc):
- my_doc = self._get_doc(doc.doc_id)
- self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
- self._conflicts.setdefault(doc.doc_id, []).append(
- (my_doc.rev, my_doc.get_json()))
- doc.has_conflicts = True
- self._put_and_update_indexes(my_doc, doc)
-
-
-class InMemoryIndex(object):
- """Interface for managing an Index."""
-
- def __init__(self, index_name, index_definition):
- self._name = index_name
- self._definition = index_definition
- self._values = {}
- parser = query_parser.Parser()
- self._getters = parser.parse_all(self._definition)
-
- def evaluate_json(self, doc):
- """Determine the 'key' after applying this index to the doc."""
- raw = json.loads(doc)
- return self.evaluate(raw)
-
- def evaluate(self, obj):
- """Evaluate a dict object, applying this definition."""
- all_rows = [[]]
- for getter in self._getters:
- new_rows = []
- keys = getter.get(obj)
- if not keys:
- return []
- for key in keys:
- new_rows.extend([row + [key] for row in all_rows])
- all_rows = new_rows
- all_rows = ['\x01'.join(row) for row in all_rows]
- return all_rows
-
- def add_json(self, doc_id, doc):
- """Add this json doc to the index."""
- keys = self.evaluate_json(doc)
- if not keys:
- return
- for key in keys:
- self._values.setdefault(key, []).append(doc_id)
-
- def remove_json(self, doc_id, doc):
- """Remove this json doc from the index."""
- keys = self.evaluate_json(doc)
- if keys:
- for key in keys:
- doc_ids = self._values[key]
- doc_ids.remove(doc_id)
- if not doc_ids:
- del self._values[key]
-
- def _find_non_wildcards(self, values):
- """Check if this should be a wildcard match.
-
- Further, this will raise an exception if the syntax is improperly
- defined.
-
- :return: The offset of the last value we need to match against.
- """
- if len(values) != len(self._definition):
- raise errors.InvalidValueForIndex()
- is_wildcard = False
- last = 0
- for idx, val in enumerate(values):
- if val.endswith('*'):
- if val != '*':
- # We have an 'x*' style wildcard
- if is_wildcard:
- # We were already in wildcard mode, so this is invalid
- raise errors.InvalidGlobbing
- last = idx + 1
- is_wildcard = True
- else:
- if is_wildcard:
- # We were in wildcard mode, we can't follow that with
- # non-wildcard
- raise errors.InvalidGlobbing
- last = idx + 1
- if not is_wildcard:
- return -1
- return last
-
- def lookup(self, values):
- """Find docs that match the values."""
- last = self._find_non_wildcards(values)
- if last == -1:
- return self._lookup_exact(values)
- else:
- return self._lookup_prefix(values[:last])
-
- def lookup_range(self, start_values, end_values):
- """Find docs within the range."""
- # TODO: Wildly inefficient, which is unlikely to be a problem for the
- # inmemory implementation.
- if start_values:
- self._find_non_wildcards(start_values)
- start_values = get_prefix(start_values)
- if end_values:
- if self._find_non_wildcards(end_values) == -1:
- exact = True
- else:
- exact = False
- end_values = get_prefix(end_values)
- found = []
- for key, doc_ids in sorted(self._values.iteritems()):
- if start_values and start_values > key:
- continue
- if end_values and end_values < key:
- if exact:
- break
- else:
- if not key.startswith(end_values):
- break
- found.extend(doc_ids)
- return found
-
- def keys(self):
- """Find the indexed keys."""
- return self._values.keys()
-
- def _lookup_prefix(self, value):
- """Find docs that match the prefix string in values."""
- # TODO: We need a different data structure to make prefix style fast,
- # some sort of sorted list would work, but a plain dict doesn't.
- key_prefix = get_prefix(value)
- all_doc_ids = []
- for key, doc_ids in sorted(self._values.iteritems()):
- if key.startswith(key_prefix):
- all_doc_ids.extend(doc_ids)
- return all_doc_ids
-
- def _lookup_exact(self, value):
- """Find docs that match exactly."""
- key = '\x01'.join(value)
- if key in self._values:
- return self._values[key]
- return ()
-
-
-class InMemorySyncTarget(CommonSyncTarget):
-
- def get_sync_info(self, source_replica_uid):
- source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
- source_replica_uid)
- my_gen, my_trans_id = self._db._get_generation_info()
- return (
- self._db._replica_uid, my_gen, my_trans_id, source_gen,
- source_trans_id)
-
- def record_sync_info(self, source_replica_uid, source_replica_generation,
- source_transaction_id):
- if self._trace_hook:
- self._trace_hook('record_sync_info')
- self._db._set_replica_gen_and_trans_id(
- source_replica_uid, source_replica_generation,
- source_transaction_id)
diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py
deleted file mode 100644
index 773213b5..00000000
--- a/src/leap/soledad/u1db/backends/sqlite_backend.py
+++ /dev/null
@@ -1,926 +0,0 @@
-# Copyright 2011 Canonical Ltd.
-#
-# This file is part of u1db.
-#
-# u1db is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License version 3
-# as published by the Free Software Foundation.
-#
-# u1db is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Lesser General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with u1db. If not, see <http://www.gnu.org/licenses/>.
-
-"""A U1DB implementation that uses SQLite as its persistence layer."""
-
-import errno
-import os
-try:
- import simplejson as json
-except ImportError:
- import json # noqa
-from sqlite3 import dbapi2
-import sys
-import time
-import uuid
-
-import pkg_resources
-
-from u1db.backends import CommonBackend, CommonSyncTarget
-from u1db import (
- Document,
- errors,
- query_parser,
- vectorclock,
- )
-
-
-class SQLiteDatabase(CommonBackend):
- """A U1DB implementation that uses SQLite as its persistence layer."""
-
- _sqlite_registry = {}
-
- def __init__(self, sqlite_file, document_factory=None):
- """Create a new sqlite file."""
- self._db_handle = dbapi2.connect(sqlite_file)
- self._real_replica_uid = None
- self._ensure_schema()
- self._factory = document_factory or Document
-
- def set_document_factory(self, factory):
- self._factory = factory
-
- def get_sync_target(self):
- return SQLiteSyncTarget(self)
-
- @classmethod
- def _which_index_storage(cls, c):
- try:
- c.execute("SELECT value FROM u1db_config"
- " WHERE name = 'index_storage'")
- except dbapi2.OperationalError, e:
- # The table does not exist yet
- return None, e
- else:
- return c.fetchone()[0], None
-
- WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
-
- @classmethod
- def _open_database(cls, sqlite_file, document_factory=None):
- if not os.path.isfile(sqlite_file):
- raise errors.DatabaseDoesNotExist()
- tries = 2
- while True:
- # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
- # where without re-opening the database on Windows, it
- # doesn't see the transaction that was just committed
- db_handle = dbapi2.connect(sqlite_file)
- c = db_handle.cursor()
- v, err = cls._which_index_storage(c)
- db_handle.close()
- if v is not None:
- break
- # possibly another process is initializing it, wait for it to be
- # done
- if tries == 0:
- raise err # go for the richest error?
- tries -= 1
- time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
- return SQLiteDatabase._sqlite_registry[v](
- sqlite_file, document_factory=document_factory)
-
- @classmethod
- def open_database(cls, sqlite_file, create, backend_cls=None,
- document_factory=None):
- try:
- return cls._open_database(
- sqlite_file, document_factory=document_factory)
- except errors.DatabaseDoesNotExist:
- if not create:
- raise
- if backend_cls is None:
- # default is SQLitePartialExpandDatabase
- backend_cls = SQLitePartialExpandDatabase
- return backend_cls(sqlite_file, document_factory=document_factory)
-
- @staticmethod
- def delete_database(sqlite_file):
- try:
- os.unlink(sqlite_file)
- except OSError as ex:
- if ex.errno == errno.ENOENT:
- raise errors.DatabaseDoesNotExist()
- raise
-
- @staticmethod
- def register_implementation(klass):
- """Register that we implement an SQLiteDatabase.
-
- The attribute _index_storage_value will be used as the lookup key.
- """
- SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
-
- def _get_sqlite_handle(self):
- """Get access to the underlying sqlite database.
-
- This should only be used by the test suite, etc, for examining the
- state of the underlying database.
- """
- return self._db_handle
-
- def _close_sqlite_handle(self):
- """Release access to the underlying sqlite database."""
- self._db_handle.close()
-
- def close(self):
- self._close_sqlite_handle()
-
- def _is_initialized(self, c):
- """Check if this database has been initialized."""
- c.execute("PRAGMA case_sensitive_like=ON")
- try:
- c.execute("SELECT value FROM u1db_config"
- " WHERE name = 'sql_schema'")
- except dbapi2.OperationalError:
- # The table does not exist yet
- val = None
- else:
- val = c.fetchone()
- if val is not None:
- return True
- return False
-
- def _initialize(self, c):
- """Create the schema in the database."""
- #read the script with sql commands
- # TODO: Change how we set up the dependency. Most likely use something
- # like lp:dirspec to grab the file from a common resource
- # directory. Doesn't specifically need to be handled until we get
- # to the point of packaging this.
- schema_content = pkg_resources.resource_string(
- __name__, 'dbschema.sql')
- # Note: We'd like to use c.executescript() here, but it seems that
- # executescript always commits, even if you set
- # isolation_level = None, so if we want to properly handle
- # exclusive locking and rollbacks between processes, we need
- # to execute it line-by-line
- for line in schema_content.split(';'):
- if not line:
- continue
- c.execute(line)
- #add extra fields
- self._extra_schema_init(c)
- # A unique identifier should be set for this replica. Implementations
- # don't have to strictly use uuid here, but we do want the uid to be
- # unique amongst all databases that will sync with each other.
- # We might extend this to using something with hostname for easier
- # debugging.
- self._set_replica_uid_in_transaction(uuid.uuid4().hex)
- c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
- (self._index_storage_value,))
-
- def _ensure_schema(self):
- """Ensure that the database schema has been created."""
- old_isolation_level = self._db_handle.isolation_level
- c = self._db_handle.cursor()
- if self._is_initialized(c):
- return
- try:
- # autocommit/own mgmt of transactions
- self._db_handle.isolation_level = None
- with self._db_handle:
- # only one execution path should initialize the db
- c.execute("begin exclusive")
- if self._is_initialized(c):
- return
- self._initialize(c)
- finally:
- self._db_handle.isolation_level = old_isolation_level
-
- def _extra_schema_init(self, c):
- """Add any extra fields, etc to the basic table definitions."""
-
- def _parse_index_definition(self, index_field):
- """Parse a field definition for an index, returning a Getter."""
- # Note: We may want to keep a Parser object around, and cache the
- # Getter objects for a greater length of time. Specifically, if
- # you create a bunch of indexes, and then insert 50k docs, you'll
- # re-parse the indexes between puts. The time to insert the docs
- # is still likely to dominate put_doc time, though.
- parser = query_parser.Parser()
- getter = parser.parse(index_field)
- return getter
-
- def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
- """Update document_fields for a single document.
-
- :param doc_id: Identifier for this document
- :param raw_doc: The python dict representation of the document.
- :param getters: A list of [(field_name, Getter)]. Getter.get will be
- called to evaluate the index definition for this document, and the
- results will be inserted into the db.
- :param db_cursor: An sqlite Cursor.
- :return: None
- """
- values = []
- for field_name, getter in getters:
- for idx_value in getter.get(raw_doc):
- values.append((doc_id, field_name, idx_value))
- if values:
- db_cursor.executemany(
- "INSERT INTO document_fields VALUES (?, ?, ?)", values)
-
- def _set_replica_uid(self, replica_uid):
- """Force the replica_uid to be set."""
- with self._db_handle:
- self._set_replica_uid_in_transaction(replica_uid)
-
- def _set_replica_uid_in_transaction(self, replica_uid):
- """Set the replica_uid. A transaction should already be held."""
- c = self._db_handle.cursor()
- c.execute("INSERT OR REPLACE INTO u1db_config"
- " VALUES ('replica_uid', ?)",
- (replica_uid,))
- self._real_replica_uid = replica_uid
-
- def _get_replica_uid(self):
- if self._real_replica_uid is not None:
- return self._real_replica_uid
- c = self._db_handle.cursor()
- c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
- val = c.fetchone()
- if val is None:
- return None
- self._real_replica_uid = val[0]
- return self._real_replica_uid
-
- _replica_uid = property(_get_replica_uid)
-
- def _get_generation(self):
- c = self._db_handle.cursor()
- c.execute('SELECT max(generation) FROM transaction_log')
- val = c.fetchone()[0]
- if val is None:
- return 0
- return val
-
- def _get_generation_info(self):
- c = self._db_handle.cursor()
- c.execute(
- 'SELECT max(generation), transaction_id FROM transaction_log ')
- val = c.fetchone()
- if val[0] is None:
- return(0, '')
- return val
-
- def _get_trans_id_for_gen(self, generation):
- if generation == 0:
- return ''
- c = self._db_handle.cursor()
- c.execute(
- 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
- (generation,))
- val = c.fetchone()
- if val is None:
- raise errors.InvalidGeneration
- return val[0]
-
- def _get_transaction_log(self):
- c = self._db_handle.cursor()
- c.execute("SELECT doc_id, transaction_id FROM transaction_log"
- " ORDER BY generation")
- return c.fetchall()
-
- def _get_doc(self, doc_id, check_for_conflicts=False):
- """Get just the document content, without fancy handling."""
- c = self._db_handle.cursor()
- if check_for_conflicts:
- c.execute(
- "SELECT document.doc_rev, document.content, "
- "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
- "conflicts ON conflicts.doc_id = document.doc_id WHERE "
- "document.doc_id = ? GROUP BY document.doc_id, "
- "document.doc_rev, document.content;", (doc_id,))
- else:
- c.execute(
- "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
- (doc_id,))
- val = c.fetchone()
- if val is None:
- return None
- doc_rev, content, conflicts = val
- doc = self._factory(doc_id, doc_rev, content)
- doc.has_conflicts = conflicts > 0
- return doc
-
- def _has_conflicts(self, doc_id):
- c = self._db_handle.cursor()
- c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
- (doc_id,))
- val = c.fetchone()
- if val is None:
- return False
- else:
- return True
-
- def get_doc(self, doc_id, include_deleted=False):
- doc = self._get_doc(doc_id, check_for_conflicts=True)
- if doc is None:
- return None
- if doc.is_tombstone() and not include_deleted:
- return None
- return doc
-
- def get_all_docs(self, include_deleted=False):
- """Get all documents from the database."""
- generation = self._get_generation()
- results = []
- c = self._db_handle.cursor()
- c.execute(
- "SELECT document.doc_id, document.doc_rev, document.content, "
- "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
- "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
- "document.doc_rev, document.content;")
- rows = c.fetchall()
- for doc_id, doc_rev, content, conflicts in rows:
- if content is None and not include_deleted:
- continue
- doc = self._factory(doc_id, doc_rev, content)
- doc.has_conflicts = conflicts > 0
- results.append(doc)
- return (generation, results)
-
- def put_doc(self, doc):
- if doc.doc_id is None:
- raise errors.InvalidDocId()
- self._check_doc_id(doc.doc_id)
- self._check_doc_size(doc)
- with self._db_handle:
- old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
- if old_doc and old_doc.has_conflicts:
- raise errors.ConflictedDoc()
- if old_doc and doc.rev is None and old_doc.is_tombstone():
- new_rev = self._allocate_doc_rev(old_doc.rev)
- else:
- if old_doc is not None:
- if old_doc.rev != doc.rev:
- raise errors.RevisionConflict()
- else:
- if doc.rev is not None:
- raise errors.RevisionConflict()
- new_rev = self._allocate_doc_rev(doc.rev)
- doc.rev = new_rev
- self._put_and_update_indexes(old_doc, doc)
- return new_rev
-
- def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
- """Convert a dict representation into named fields.
-
- So something like: {'key1': 'val1', 'key2': 'val2'}
- gets converted into: [(doc_id, 'key1', 'val1', 0)
- (doc_id, 'key2', 'val2', 0)]
- :param doc_id: Just added to every record.
- :param base_field: if set, these are nested keys, so each field should
- be appropriately prefixed.
- :param raw_doc: The python dictionary.
- """
- # TODO: Handle lists
- values = []
- for field_name, value in raw_doc.iteritems():
- if value is None and not save_none:
- continue
- if base_field:
- full_name = base_field + '.' + field_name
- else:
- full_name = field_name
- if value is None or isinstance(value, (int, float, basestring)):
- values.append((doc_id, full_name, value, len(values)))
- else:
- subvalues = self._expand_to_fields(doc_id, full_name, value,
- save_none)
- for _, subfield_name, val, _ in subvalues:
- values.append((doc_id, subfield_name, val, len(values)))
- return values
-
- def _put_and_update_indexes(self, old_doc, doc):
- """Actually insert a document into the database.
-
- This both updates the existing documents content, and any indexes that
- refer to this document.
- """
- raise NotImplementedError(self._put_and_update_indexes)
-
- def whats_changed(self, old_generation=0):
- c = self._db_handle.cursor()
- c.execute("SELECT generation, doc_id, transaction_id"
- " FROM transaction_log"
- " WHERE generation > ? ORDER BY generation DESC",
- (old_generation,))
- results = c.fetchall()
- cur_gen = old_generation
- seen = set()
- changes = []
- newest_trans_id = ''
- for generation, doc_id, trans_id in results:
- if doc_id not in seen:
- changes.append((doc_id, generation, trans_id))
- seen.add(doc_id)
- if changes:
- cur_gen = changes[0][1] # max generation
- newest_trans_id = changes[0][2]
- changes.reverse()
- else:
- c.execute("SELECT generation, transaction_id"
- " FROM transaction_log ORDER BY generation DESC LIMIT 1")
- results = c.fetchone()
- if not results:
- cur_gen = 0
- newest_trans_id = ''
- else:
- cur_gen, newest_trans_id = results
-
- return cur_gen, newest_trans_id, changes
-
- def delete_doc(self, doc):
- with self._db_handle:
- old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
- if old_doc is None:
- raise errors.DocumentDoesNotExist
- if old_doc.rev != doc.rev:
- raise errors.RevisionConflict()
- if old_doc.is_tombstone():
- raise errors.DocumentAlreadyDeleted
- if old_doc.has_conflicts:
- raise errors.ConflictedDoc()
- new_rev = self._allocate_doc_rev(doc.rev)
- doc.rev = new_rev
- doc.make_tombstone()
- self._put_and_update_indexes(old_doc, doc)
- return new_rev
-
- def _get_conflicts(self, doc_id):
- c = self._db_handle.cursor()
- c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
- (doc_id,))
- return [self._factory(doc_id, doc_rev, content)
- for doc_rev, content in c.fetchall()]
-
- def get_doc_conflicts(self, doc_id):
- with self._db_handle:
- conflict_docs = self._get_conflicts(doc_id)
- if not conflict_docs:
- return []
- this_doc = self._get_doc(doc_id)
- this_doc.has_conflicts = True
- return [this_doc] + conflict_docs
-
- def _get_replica_gen_and_trans_id(self, other_replica_uid):
- c = self._db_handle.cursor()
- c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
- " WHERE replica_uid = ?",
- (other_replica_uid,))
- val = c.fetchone()
- if val is None:
- other_gen = 0
- trans_id = ''
- else:
- other_gen = val[0]
- trans_id = val[1]
- return other_gen, trans_id
-
- def _set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation, other_transaction_id):
- with self._db_handle:
- self._do_set_replica_gen_and_trans_id(
- other_replica_uid, other_generation, other_transaction_id)
-
- def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation,
- other_transaction_id):
- c = self._db_handle.cursor()
- c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
- (other_replica_uid, other_generation,
- other_transaction_id))
-
- def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
- replica_gen=None, replica_trans_id=None):
- with self._db_handle:
- return super(SQLiteDatabase, self)._put_doc_if_newer(doc,
- save_conflict=save_conflict,
- replica_uid=replica_uid, replica_gen=replica_gen,
- replica_trans_id=replica_trans_id)
-
- def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
- c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
- (doc_id, my_doc_rev, my_content))
-
- def _delete_conflicts(self, c, doc, conflict_revs):
- deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
- c.executemany("DELETE FROM conflicts"
- " WHERE doc_id=? AND doc_rev=?", deleting)
- doc.has_conflicts = self._has_conflicts(doc.doc_id)
-
- def _prune_conflicts(self, doc, doc_vcr):
- if self._has_conflicts(doc.doc_id):
- autoresolved = False
- c_revs_to_prune = []
- for c_doc in self._get_conflicts(doc.doc_id):
- c_vcr = vectorclock.VectorClockRev(c_doc.rev)
- if doc_vcr.is_newer(c_vcr):
- c_revs_to_prune.append(c_doc.rev)
- elif doc.same_content_as(c_doc):
- c_revs_to_prune.append(c_doc.rev)
- doc_vcr.maximize(c_vcr)
- autoresolved = True
- if autoresolved:
- doc_vcr.increment(self._replica_uid)
- doc.rev = doc_vcr.as_str()
- c = self._db_handle.cursor()
- self._delete_conflicts(c, doc, c_revs_to_prune)
-
- def _force_doc_sync_conflict(self, doc):
- my_doc = self._get_doc(doc.doc_id)
- c = self._db_handle.cursor()
- self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
- self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
- doc.has_conflicts = True
- self._put_and_update_indexes(my_doc, doc)
-
- def resolve_doc(self, doc, conflicted_doc_revs):
- with self._db_handle:
- cur_doc = self._get_doc(doc.doc_id)
- # TODO: https://bugs.launchpad.net/u1db/+bug/928274
- # I think we have a logic bug in resolve_doc
- # Specifically, cur_doc.rev is always in the final vector
- # clock of revisions that we supersede, even if it wasn't in
- # conflicted_doc_revs. We still add it as a conflict, but the
- # fact that _put_doc_if_newer propagates resolutions means I
- # think that conflict could accidentally be resolved. We need
- # to add a test for this case first. (create a rev, create a
- # conflict, create another conflict, resolve the first rev
- # and first conflict, then make sure that the resolved
- # rev doesn't supersede the second conflict rev.) It *might*
- # not matter, because the superseding rev is in as a
- # conflict, but it does seem incorrect
- new_rev = self._ensure_maximal_rev(cur_doc.rev,
- conflicted_doc_revs)
- superseded_revs = set(conflicted_doc_revs)
- c = self._db_handle.cursor()
- doc.rev = new_rev
- if cur_doc.rev in superseded_revs:
- self._put_and_update_indexes(cur_doc, doc)
- else:
- self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
- # TODO: Is there some way that we could construct a rev that would
- # end up in superseded_revs, such that we add a conflict, and
- # then immediately delete it?
- self._delete_conflicts(c, doc, superseded_revs)
-
- def list_indexes(self):
- """Return the list of indexes and their definitions."""
- c = self._db_handle.cursor()
- # TODO: How do we test the ordering?
- c.execute("SELECT name, field FROM index_definitions"
- " ORDER BY name, offset")
- definitions = []
- cur_name = None
- for name, field in c.fetchall():
- if cur_name != name:
- definitions.append((name, []))
- cur_name = name
- definitions[-1][-1].append(field)
- return definitions
-
- def _get_index_definition(self, index_name):
- """Return the stored definition for a given index_name."""
- c = self._db_handle.cursor()
- c.execute("SELECT field FROM index_definitions"
- " WHERE name = ? ORDER BY offset", (index_name,))
- fields = [x[0] for x in c.fetchall()]
- if not fields:
- raise errors.IndexDoesNotExist
- return fields
-
- @staticmethod
- def _strip_glob(value):
- """Remove the trailing * from a value."""
- assert value[-1] == '*'
- return value[:-1]
-
- def _format_query(self, definition, key_values):
- # First, build the definition. We join the document_fields table
- # against itself, as many times as the 'width' of our definition.
- # We then do a query for each key_value, one-at-a-time.
- # Note: All of these strings are static, we could cache them, etc.
- tables = ["document_fields d%d" % i for i in range(len(definition))]
- novalue_where = ["d.doc_id = d%d.doc_id"
- " AND d%d.field_name = ?"
- % (i, i) for i in range(len(definition))]
- wildcard_where = [novalue_where[i]
- + (" AND d%d.value NOT NULL" % (i,))
- for i in range(len(definition))]
- exact_where = [novalue_where[i]
- + (" AND d%d.value = ?" % (i,))
- for i in range(len(definition))]
- like_where = [novalue_where[i]
- + (" AND d%d.value GLOB ?" % (i,))
- for i in range(len(definition))]
- is_wildcard = False
- # Merge the lists together, so that:
- # [field1, field2, field3], [val1, val2, val3]
- # Becomes:
- # (field1, val1, field2, val2, field3, val3)
- args = []
- where = []
- for idx, (field, value) in enumerate(zip(definition, key_values)):
- args.append(field)
- if value.endswith('*'):
- if value == '*':
- where.append(wildcard_where[idx])
- else:
- # This is a glob match
- if is_wildcard:
- # We can't have a partial wildcard following
- # another wildcard
- raise errors.InvalidGlobbing
- where.append(like_where[idx])
- args.append(value)
- is_wildcard = True
- else:
- if is_wildcard:
- raise errors.InvalidGlobbing
- where.append(exact_where[idx])
- args.append(value)
- statement = (
- "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
- "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
- "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
- "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
- ['d%d.value' % i for i in range(len(definition))])))
- return statement, args
-
- def get_from_index(self, index_name, *key_values):
- definition = self._get_index_definition(index_name)
- if len(key_values) != len(definition):
- raise errors.InvalidValueForIndex()
- statement, args = self._format_query(definition, key_values)
- c = self._db_handle.cursor()
- try:
- c.execute(statement, tuple(args))
- except dbapi2.OperationalError, e:
- raise dbapi2.OperationalError(str(e) +
- '\nstatement: %s\nargs: %s\n' % (statement, args))
- res = c.fetchall()
- results = []
- for row in res:
- doc = self._factory(row[0], row[1], row[2])
- doc.has_conflicts = row[3] > 0
- results.append(doc)
- return results
-
- def _format_range_query(self, definition, start_value, end_value):
- tables = ["document_fields d%d" % i for i in range(len(definition))]
- novalue_where = [
- "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
- range(len(definition))]
- wildcard_where = [
- novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
- range(len(definition))]
- like_where = [
- novalue_where[i] + (
- " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
- range(len(definition))]
- range_where_lower = [
- novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
- range(len(definition))]
- range_where_upper = [
- novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
- range(len(definition))]
- args = []
- where = []
- if start_value:
- if isinstance(start_value, basestring):
- start_value = (start_value,)
- if len(start_value) != len(definition):
- raise errors.InvalidValueForIndex()
- is_wildcard = False
- for idx, (field, value) in enumerate(zip(definition, start_value)):
- args.append(field)
- if value.endswith('*'):
- if value == '*':
- where.append(wildcard_where[idx])
- else:
- # This is a glob match
- if is_wildcard:
- # We can't have a partial wildcard following
- # another wildcard
- raise errors.InvalidGlobbing
- where.append(range_where_lower[idx])
- args.append(self._strip_glob(value))
- is_wildcard = True
- else:
- if is_wildcard:
- raise errors.InvalidGlobbing
- where.append(range_where_lower[idx])
- args.append(value)
- if end_value:
- if isinstance(end_value, basestring):
- end_value = (end_value,)
- if len(end_value) != len(definition):
- raise errors.InvalidValueForIndex()
- is_wildcard = False
- for idx, (field, value) in enumerate(zip(definition, end_value)):
- args.append(field)
- if value.endswith('*'):
- if value == '*':
- where.append(wildcard_where[idx])
- else:
- # This is a glob match
- if is_wildcard:
- # We can't have a partial wildcard following
- # another wildcard
- raise errors.InvalidGlobbing
- where.append(like_where[idx])
- args.append(self._strip_glob(value))
- args.append(value)
- is_wildcard = True
- else:
- if is_wildcard:
- raise errors.InvalidGlobbing
- where.append(range_where_upper[idx])
- args.append(value)
- statement = (
- "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
- "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
- "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
- "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
- ['d%d.value' % i for i in range(len(definition))])))
- return statement, args
-
- def get_range_from_index(self, index_name, start_value=None,
- end_value=None):
- """Return all documents with key values in the specified range."""
- definition = self._get_index_definition(index_name)
- statement, args = self._format_range_query(
- definition, start_value, end_value)
- c = self._db_handle.cursor()
- try:
- c.execute(statement, tuple(args))
- except dbapi2.OperationalError, e:
- raise dbapi2.OperationalError(str(e) +
- '\nstatement: %s\nargs: %s\n' % (statement, args))
- res = c.fetchall()
- results = []
- for row in res:
- doc = self._factory(row[0], row[1], row[2])
- doc.has_conflicts = row[3] > 0
- results.append(doc)
- return results
-
- def get_index_keys(self, index_name):
- c = self._db_handle.cursor()
- definition = self._get_index_definition(index_name)
- value_fields = ', '.join([
- 'd%d.value' % i for i in range(len(definition))])
- tables = ["document_fields d%d" % i for i in range(len(definition))]
- novalue_where = [
- "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
- range(len(definition))]
- where = [
- novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
- range(len(definition))]
- statement = (
- "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
- value_fields, ', '.join(tables), ' AND '.join(where),
- value_fields))
- try:
- c.execute(statement, tuple(definition))
- except dbapi2.OperationalError, e:
- raise dbapi2.OperationalError(str(e) +
- '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
- return c.fetchall()
-
- def delete_index(self, index_name):
- with self._db_handle:
- c = self._db_handle.cursor()
- c.execute("DELETE FROM index_definitions WHERE name = ?",
- (index_name,))
- c.execute(
- "DELETE FROM document_fields WHERE document_fields.field_name "
- " NOT IN (SELECT field from index_definitions)")
-
-
-class SQLiteSyncTarget(CommonSyncTarget):
-
- def get_sync_info(self, source_replica_uid):
- source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
- source_replica_uid)
- my_gen, my_trans_id = self._db._get_generation_info()
- return (
- self._db._replica_uid, my_gen, my_trans_id, source_gen,
- source_trans_id)
-
- def record_sync_info(self, source_replica_uid, source_replica_generation,
- source_replica_transaction_id):
- if self._trace_hook:
- self._trace_hook('record_sync_info')
- self._db._set_replica_gen_and_trans_id(
- source_replica_uid, source_replica_generation,
- source_replica_transaction_id)
-
-
-class SQLitePartialExpandDatabase(SQLiteDatabase):
- """An SQLite Backend that expands documents into a document_field table.
-
- It stores the original document text in document.doc. For fields that are
- indexed, the data goes into document_fields.
- """
-
- _index_storage_value = 'expand referenced'
-
- def _get_indexed_fields(self):
- """Determine what fields are indexed."""
- c = self._db_handle.cursor()
- c.execute("SELECT field FROM index_definitions")
- return set([x[0] for x in c.fetchall()])
-
- def _evaluate_index(self, raw_doc, field):
- parser = query_parser.Parser()
- getter = parser.parse(field)
- return getter.get(raw_doc)
-
- def _put_and_update_indexes(self, old_doc, doc):
- c = self._db_handle.cursor()
- if doc and not doc.is_tombstone():
- raw_doc = json.loads(doc.get_json())
- else:
- raw_doc = {}
- if old_doc is not None:
- c.execute("UPDATE document SET doc_rev=?, content=?"
- " WHERE doc_id = ?",
- (doc.rev, doc.get_json(), doc.doc_id))
- c.execute("DELETE FROM document_fields WHERE doc_id = ?",
- (doc.doc_id,))
- else:
- c.execute("INSERT INTO document (doc_id, doc_rev, content)"
- " VALUES (?, ?, ?)",
- (doc.doc_id, doc.rev, doc.get_json()))
- indexed_fields = self._get_indexed_fields()
- if indexed_fields:
- # It is expected that len(indexed_fields) is shorter than
- # len(raw_doc)
- getters = [(field, self._parse_index_definition(field))
- for field in indexed_fields]
- self._update_indexes(doc.doc_id, raw_doc, getters, c)
- trans_id = self._allocate_transaction_id()
- c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
- " VALUES (?, ?)", (doc.doc_id, trans_id))
-
- def create_index(self, index_name, *index_expressions):
- with self._db_handle:
- c = self._db_handle.cursor()
- cur_fields = self._get_indexed_fields()
- definition = [(index_name, idx, field)
- for idx, field in enumerate(index_expressions)]
- try:
- c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
- definition)
- except dbapi2.IntegrityError as e:
- stored_def = self._get_index_definition(index_name)
- if stored_def == [x[-1] for x in definition]:
- return
- raise errors.IndexNameTakenError, e, sys.exc_info()[2]
- new_fields = set(
- [f for f in index_expressions if f not in cur_fields])
- if new_fields:
- self._update_all_indexes(new_fields)
-
- def _iter_all_docs(self):
- c = self._db_handle.cursor()
- c.execute("SELECT doc_id, content FROM document")
- while True:
- next_rows = c.fetchmany()
- if not next_rows:
- break
- for row in next_rows:
- yield row
-
- def _update_all_indexes(self, new_fields):
- """Iterate all the documents, and add content to document_fields.
-
- :param new_fields: The index definitions that need to be added.
- """
- getters = [(field, self._parse_index_definition(field))
- for field in new_fields]
- c = self._db_handle.cursor()
- for doc_id, doc in self._iter_all_docs():
- if doc is None:
- continue
- raw_doc = json.loads(doc)
- self._update_indexes(doc_id, raw_doc, getters, c)
-
-SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)