summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2017-04-29 16:40:55 +0200
committerdrebs <drebs@leap.se>2017-05-01 21:10:18 +0200
commita2a41830f72ae014d3b4dba591c6bfc85ffeb062 (patch)
treeb853c88c3830df52c6f0a2b544ce65b02927816e /common
parenta9580838817c2fbc883253f6df0edc29fcd4c947 (diff)
[refactor] create client _database module
Diffstat (limited to 'common')
-rw-r--r--common/src/leap/soledad/common/l2db/__init__.py4
-rw-r--r--common/src/leap/soledad/common/l2db/backends/dbschema.sql42
-rw-r--r--common/src/leap/soledad/common/l2db/backends/sqlite_backend.py930
-rw-r--r--common/src/leap/soledad/common/l2db/remote/server_state.py14
4 files changed, 8 insertions, 982 deletions
diff --git a/common/src/leap/soledad/common/l2db/__init__.py b/common/src/leap/soledad/common/l2db/__init__.py
index 568897c4..ebbf546a 100644
--- a/common/src/leap/soledad/common/l2db/__init__.py
+++ b/common/src/leap/soledad/common/l2db/__init__.py
@@ -37,8 +37,8 @@ def open(path, create, document_factory=None):
parameters as Document.__init__.
:return: An instance of Database.
"""
- from leap.soledad.common.l2db.backends import sqlite_backend
- return sqlite_backend.SQLiteDatabase.open_database(
+ from leap.soledad.client._database import sqlite
+ return sqlite.SQLiteDatabase.open_database(
path, create=create, document_factory=document_factory)
diff --git a/common/src/leap/soledad/common/l2db/backends/dbschema.sql b/common/src/leap/soledad/common/l2db/backends/dbschema.sql
deleted file mode 100644
index ae027fc5..00000000
--- a/common/src/leap/soledad/common/l2db/backends/dbschema.sql
+++ /dev/null
@@ -1,42 +0,0 @@
--- Database schema
-CREATE TABLE transaction_log (
- generation INTEGER PRIMARY KEY AUTOINCREMENT,
- doc_id TEXT NOT NULL,
- transaction_id TEXT NOT NULL
-);
-CREATE TABLE document (
- doc_id TEXT PRIMARY KEY,
- doc_rev TEXT NOT NULL,
- content TEXT
-);
-CREATE TABLE document_fields (
- doc_id TEXT NOT NULL,
- field_name TEXT NOT NULL,
- value TEXT
-);
-CREATE INDEX document_fields_field_value_doc_idx
- ON document_fields(field_name, value, doc_id);
-
-CREATE TABLE sync_log (
- replica_uid TEXT PRIMARY KEY,
- known_generation INTEGER,
- known_transaction_id TEXT
-);
-CREATE TABLE conflicts (
- doc_id TEXT,
- doc_rev TEXT,
- content TEXT,
- CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
-);
-CREATE TABLE index_definitions (
- name TEXT,
- offset INT,
- field TEXT,
- CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
-);
-create index index_definitions_field on index_definitions(field);
-CREATE TABLE u1db_config (
- name TEXT PRIMARY KEY,
- value TEXT
-);
-INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/common/src/leap/soledad/common/l2db/backends/sqlite_backend.py b/common/src/leap/soledad/common/l2db/backends/sqlite_backend.py
deleted file mode 100644
index 4f7b1259..00000000
--- a/common/src/leap/soledad/common/l2db/backends/sqlite_backend.py
+++ /dev/null
@@ -1,930 +0,0 @@
-# Copyright 2011 Canonical Ltd.
-# Copyright 2016 LEAP Encryption Access Project
-#
-# This file is part of u1db.
-#
-# u1db is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License version 3
-# as published by the Free Software Foundation.
-#
-# u1db is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Lesser General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with u1db. If not, see <http://www.gnu.org/licenses/>.
-
-"""
-A L2DB implementation that uses SQLite as its persistence layer.
-"""
-
-import errno
-import os
-import json
-import sys
-import time
-import uuid
-import pkg_resources
-
-from sqlite3 import dbapi2
-
-from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget
-from leap.soledad.common.l2db import (
- Document, errors,
- query_parser, vectorclock)
-
-
-class SQLiteDatabase(CommonBackend):
- """A U1DB implementation that uses SQLite as its persistence layer."""
-
- _sqlite_registry = {}
-
- def __init__(self, sqlite_file, document_factory=None):
- """Create a new sqlite file."""
- self._db_handle = dbapi2.connect(sqlite_file)
- self._real_replica_uid = None
- self._ensure_schema()
- self._factory = document_factory or Document
-
- def set_document_factory(self, factory):
- self._factory = factory
-
- def get_sync_target(self):
- return SQLiteSyncTarget(self)
-
- @classmethod
- def _which_index_storage(cls, c):
- try:
- c.execute("SELECT value FROM u1db_config"
- " WHERE name = 'index_storage'")
- except dbapi2.OperationalError as e:
- # The table does not exist yet
- return None, e
- else:
- return c.fetchone()[0], None
-
- WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
-
- @classmethod
- def _open_database(cls, sqlite_file, document_factory=None):
- if not os.path.isfile(sqlite_file):
- raise errors.DatabaseDoesNotExist()
- tries = 2
- while True:
- # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
- # where without re-opening the database on Windows, it
- # doesn't see the transaction that was just committed
- db_handle = dbapi2.connect(sqlite_file)
- c = db_handle.cursor()
- v, err = cls._which_index_storage(c)
- db_handle.close()
- if v is not None:
- break
- # possibly another process is initializing it, wait for it to be
- # done
- if tries == 0:
- raise err # go for the richest error?
- tries -= 1
- time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
- return SQLiteDatabase._sqlite_registry[v](
- sqlite_file, document_factory=document_factory)
-
- @classmethod
- def open_database(cls, sqlite_file, create, backend_cls=None,
- document_factory=None):
- try:
- return cls._open_database(
- sqlite_file, document_factory=document_factory)
- except errors.DatabaseDoesNotExist:
- if not create:
- raise
- if backend_cls is None:
- # default is SQLitePartialExpandDatabase
- backend_cls = SQLitePartialExpandDatabase
- return backend_cls(sqlite_file, document_factory=document_factory)
-
- @staticmethod
- def delete_database(sqlite_file):
- try:
- os.unlink(sqlite_file)
- except OSError as ex:
- if ex.errno == errno.ENOENT:
- raise errors.DatabaseDoesNotExist()
- raise
-
- @staticmethod
- def register_implementation(klass):
- """Register that we implement an SQLiteDatabase.
-
- The attribute _index_storage_value will be used as the lookup key.
- """
- SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
-
- def _get_sqlite_handle(self):
- """Get access to the underlying sqlite database.
-
- This should only be used by the test suite, etc, for examining the
- state of the underlying database.
- """
- return self._db_handle
-
- def _close_sqlite_handle(self):
- """Release access to the underlying sqlite database."""
- self._db_handle.close()
-
- def close(self):
- self._close_sqlite_handle()
-
- def _is_initialized(self, c):
- """Check if this database has been initialized."""
- c.execute("PRAGMA case_sensitive_like=ON")
- try:
- c.execute("SELECT value FROM u1db_config"
- " WHERE name = 'sql_schema'")
- except dbapi2.OperationalError:
- # The table does not exist yet
- val = None
- else:
- val = c.fetchone()
- if val is not None:
- return True
- return False
-
- def _initialize(self, c):
- """Create the schema in the database."""
- # read the script with sql commands
- # TODO: Change how we set up the dependency. Most likely use something
- # like lp:dirspec to grab the file from a common resource
- # directory. Doesn't specifically need to be handled until we get
- # to the point of packaging this.
- schema_content = pkg_resources.resource_string(
- __name__, 'dbschema.sql')
- # Note: We'd like to use c.executescript() here, but it seems that
- # executescript always commits, even if you set
- # isolation_level = None, so if we want to properly handle
- # exclusive locking and rollbacks between processes, we need
- # to execute it line-by-line
- for line in schema_content.split(';'):
- if not line:
- continue
- c.execute(line)
- # add extra fields
- self._extra_schema_init(c)
- # A unique identifier should be set for this replica. Implementations
- # don't have to strictly use uuid here, but we do want the uid to be
- # unique amongst all databases that will sync with each other.
- # We might extend this to using something with hostname for easier
- # debugging.
- self._set_replica_uid_in_transaction(uuid.uuid4().hex)
- c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
- (self._index_storage_value,))
-
- def _ensure_schema(self):
- """Ensure that the database schema has been created."""
- old_isolation_level = self._db_handle.isolation_level
- c = self._db_handle.cursor()
- if self._is_initialized(c):
- return
- try:
- # autocommit/own mgmt of transactions
- self._db_handle.isolation_level = None
- with self._db_handle:
- # only one execution path should initialize the db
- c.execute("begin exclusive")
- if self._is_initialized(c):
- return
- self._initialize(c)
- finally:
- self._db_handle.isolation_level = old_isolation_level
-
- def _extra_schema_init(self, c):
- """Add any extra fields, etc to the basic table definitions."""
-
- def _parse_index_definition(self, index_field):
- """Parse a field definition for an index, returning a Getter."""
- # Note: We may want to keep a Parser object around, and cache the
- # Getter objects for a greater length of time. Specifically, if
- # you create a bunch of indexes, and then insert 50k docs, you'll
- # re-parse the indexes between puts. The time to insert the docs
- # is still likely to dominate put_doc time, though.
- parser = query_parser.Parser()
- getter = parser.parse(index_field)
- return getter
-
- def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
- """Update document_fields for a single document.
-
- :param doc_id: Identifier for this document
- :param raw_doc: The python dict representation of the document.
- :param getters: A list of [(field_name, Getter)]. Getter.get will be
- called to evaluate the index definition for this document, and the
- results will be inserted into the db.
- :param db_cursor: An sqlite Cursor.
- :return: None
- """
- values = []
- for field_name, getter in getters:
- for idx_value in getter.get(raw_doc):
- values.append((doc_id, field_name, idx_value))
- if values:
- db_cursor.executemany(
- "INSERT INTO document_fields VALUES (?, ?, ?)", values)
-
- def _set_replica_uid(self, replica_uid):
- """Force the replica_uid to be set."""
- with self._db_handle:
- self._set_replica_uid_in_transaction(replica_uid)
-
- def _set_replica_uid_in_transaction(self, replica_uid):
- """Set the replica_uid. A transaction should already be held."""
- c = self._db_handle.cursor()
- c.execute("INSERT OR REPLACE INTO u1db_config"
- " VALUES ('replica_uid', ?)",
- (replica_uid,))
- self._real_replica_uid = replica_uid
-
- def _get_replica_uid(self):
- if self._real_replica_uid is not None:
- return self._real_replica_uid
- c = self._db_handle.cursor()
- c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
- val = c.fetchone()
- if val is None:
- return None
- self._real_replica_uid = val[0]
- return self._real_replica_uid
-
- _replica_uid = property(_get_replica_uid)
-
- def _get_generation(self):
- c = self._db_handle.cursor()
- c.execute('SELECT max(generation) FROM transaction_log')
- val = c.fetchone()[0]
- if val is None:
- return 0
- return val
-
- def _get_generation_info(self):
- c = self._db_handle.cursor()
- c.execute(
- 'SELECT max(generation), transaction_id FROM transaction_log ')
- val = c.fetchone()
- if val[0] is None:
- return(0, '')
- return val
-
- def _get_trans_id_for_gen(self, generation):
- if generation == 0:
- return ''
- c = self._db_handle.cursor()
- c.execute(
- 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
- (generation,))
- val = c.fetchone()
- if val is None:
- raise errors.InvalidGeneration
- return val[0]
-
- def _get_transaction_log(self):
- c = self._db_handle.cursor()
- c.execute("SELECT doc_id, transaction_id FROM transaction_log"
- " ORDER BY generation")
- return c.fetchall()
-
- def _get_doc(self, doc_id, check_for_conflicts=False):
- """Get just the document content, without fancy handling."""
- c = self._db_handle.cursor()
- if check_for_conflicts:
- c.execute(
- "SELECT document.doc_rev, document.content, "
- "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
- "conflicts ON conflicts.doc_id = document.doc_id WHERE "
- "document.doc_id = ? GROUP BY document.doc_id, "
- "document.doc_rev, document.content;", (doc_id,))
- else:
- c.execute(
- "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
- (doc_id,))
- val = c.fetchone()
- if val is None:
- return None
- doc_rev, content, conflicts = val
- doc = self._factory(doc_id, doc_rev, content)
- doc.has_conflicts = conflicts > 0
- return doc
-
- def _has_conflicts(self, doc_id):
- c = self._db_handle.cursor()
- c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
- (doc_id,))
- val = c.fetchone()
- if val is None:
- return False
- else:
- return True
-
- def get_doc(self, doc_id, include_deleted=False):
- doc = self._get_doc(doc_id, check_for_conflicts=True)
- if doc is None:
- return None
- if doc.is_tombstone() and not include_deleted:
- return None
- return doc
-
- def get_all_docs(self, include_deleted=False):
- """Get all documents from the database."""
- generation = self._get_generation()
- results = []
- c = self._db_handle.cursor()
- c.execute(
- "SELECT document.doc_id, document.doc_rev, document.content, "
- "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
- "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
- "document.doc_rev, document.content;")
- rows = c.fetchall()
- for doc_id, doc_rev, content, conflicts in rows:
- if content is None and not include_deleted:
- continue
- doc = self._factory(doc_id, doc_rev, content)
- doc.has_conflicts = conflicts > 0
- results.append(doc)
- return (generation, results)
-
- def put_doc(self, doc):
- if doc.doc_id is None:
- raise errors.InvalidDocId()
- self._check_doc_id(doc.doc_id)
- self._check_doc_size(doc)
- with self._db_handle:
- old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
- if old_doc and old_doc.has_conflicts:
- raise errors.ConflictedDoc()
- if old_doc and doc.rev is None and old_doc.is_tombstone():
- new_rev = self._allocate_doc_rev(old_doc.rev)
- else:
- if old_doc is not None:
- if old_doc.rev != doc.rev:
- raise errors.RevisionConflict()
- else:
- if doc.rev is not None:
- raise errors.RevisionConflict()
- new_rev = self._allocate_doc_rev(doc.rev)
- doc.rev = new_rev
- self._put_and_update_indexes(old_doc, doc)
- return new_rev
-
- def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
- """Convert a dict representation into named fields.
-
- So something like: {'key1': 'val1', 'key2': 'val2'}
- gets converted into: [(doc_id, 'key1', 'val1', 0)
- (doc_id, 'key2', 'val2', 0)]
- :param doc_id: Just added to every record.
- :param base_field: if set, these are nested keys, so each field should
- be appropriately prefixed.
- :param raw_doc: The python dictionary.
- """
- # TODO: Handle lists
- values = []
- for field_name, value in raw_doc.iteritems():
- if value is None and not save_none:
- continue
- if base_field:
- full_name = base_field + '.' + field_name
- else:
- full_name = field_name
- if value is None or isinstance(value, (int, float, basestring)):
- values.append((doc_id, full_name, value, len(values)))
- else:
- subvalues = self._expand_to_fields(doc_id, full_name, value,
- save_none)
- for _, subfield_name, val, _ in subvalues:
- values.append((doc_id, subfield_name, val, len(values)))
- return values
-
- def _put_and_update_indexes(self, old_doc, doc):
- """Actually insert a document into the database.
-
- This both updates the existing documents content, and any indexes that
- refer to this document.
- """
- raise NotImplementedError(self._put_and_update_indexes)
-
- def whats_changed(self, old_generation=0):
- c = self._db_handle.cursor()
- c.execute("SELECT generation, doc_id, transaction_id"
- " FROM transaction_log"
- " WHERE generation > ? ORDER BY generation DESC",
- (old_generation,))
- results = c.fetchall()
- cur_gen = old_generation
- seen = set()
- changes = []
- newest_trans_id = ''
- for generation, doc_id, trans_id in results:
- if doc_id not in seen:
- changes.append((doc_id, generation, trans_id))
- seen.add(doc_id)
- if changes:
- cur_gen = changes[0][1] # max generation
- newest_trans_id = changes[0][2]
- changes.reverse()
- else:
- c.execute("SELECT generation, transaction_id"
- " FROM transaction_log ORDER BY generation DESC LIMIT 1")
- results = c.fetchone()
- if not results:
- cur_gen = 0
- newest_trans_id = ''
- else:
- cur_gen, newest_trans_id = results
-
- return cur_gen, newest_trans_id, changes
-
- def delete_doc(self, doc):
- with self._db_handle:
- old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
- if old_doc is None:
- raise errors.DocumentDoesNotExist
- if old_doc.rev != doc.rev:
- raise errors.RevisionConflict()
- if old_doc.is_tombstone():
- raise errors.DocumentAlreadyDeleted
- if old_doc.has_conflicts:
- raise errors.ConflictedDoc()
- new_rev = self._allocate_doc_rev(doc.rev)
- doc.rev = new_rev
- doc.make_tombstone()
- self._put_and_update_indexes(old_doc, doc)
- return new_rev
-
- def _get_conflicts(self, doc_id):
- c = self._db_handle.cursor()
- c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
- (doc_id,))
- return [self._factory(doc_id, doc_rev, content)
- for doc_rev, content in c.fetchall()]
-
- def get_doc_conflicts(self, doc_id):
- with self._db_handle:
- conflict_docs = self._get_conflicts(doc_id)
- if not conflict_docs:
- return []
- this_doc = self._get_doc(doc_id)
- this_doc.has_conflicts = True
- return [this_doc] + conflict_docs
-
- def _get_replica_gen_and_trans_id(self, other_replica_uid):
- c = self._db_handle.cursor()
- c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
- " WHERE replica_uid = ?",
- (other_replica_uid,))
- val = c.fetchone()
- if val is None:
- other_gen = 0
- trans_id = ''
- else:
- other_gen = val[0]
- trans_id = val[1]
- return other_gen, trans_id
-
- def _set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation, other_transaction_id):
- with self._db_handle:
- self._do_set_replica_gen_and_trans_id(
- other_replica_uid, other_generation, other_transaction_id)
-
- def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
- other_generation,
- other_transaction_id):
- c = self._db_handle.cursor()
- c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
- (other_replica_uid, other_generation,
- other_transaction_id))
-
- def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
- replica_gen=None, replica_trans_id=None):
- return super(SQLiteDatabase, self)._put_doc_if_newer(
- doc,
- save_conflict=save_conflict,
- replica_uid=replica_uid, replica_gen=replica_gen,
- replica_trans_id=replica_trans_id)
-
- def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
- c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
- (doc_id, my_doc_rev, my_content))
-
- def _delete_conflicts(self, c, doc, conflict_revs):
- deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
- c.executemany("DELETE FROM conflicts"
- " WHERE doc_id=? AND doc_rev=?", deleting)
- doc.has_conflicts = self._has_conflicts(doc.doc_id)
-
- def _prune_conflicts(self, doc, doc_vcr):
- if self._has_conflicts(doc.doc_id):
- autoresolved = False
- c_revs_to_prune = []
- for c_doc in self._get_conflicts(doc.doc_id):
- c_vcr = vectorclock.VectorClockRev(c_doc.rev)
- if doc_vcr.is_newer(c_vcr):
- c_revs_to_prune.append(c_doc.rev)
- elif doc.same_content_as(c_doc):
- c_revs_to_prune.append(c_doc.rev)
- doc_vcr.maximize(c_vcr)
- autoresolved = True
- if autoresolved:
- doc_vcr.increment(self._replica_uid)
- doc.rev = doc_vcr.as_str()
- c = self._db_handle.cursor()
- self._delete_conflicts(c, doc, c_revs_to_prune)
-
- def _force_doc_sync_conflict(self, doc):
- my_doc = self._get_doc(doc.doc_id)
- c = self._db_handle.cursor()
- self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
- self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
- doc.has_conflicts = True
- self._put_and_update_indexes(my_doc, doc)
-
- def resolve_doc(self, doc, conflicted_doc_revs):
- with self._db_handle:
- cur_doc = self._get_doc(doc.doc_id)
- # TODO: https://bugs.launchpad.net/u1db/+bug/928274
- # I think we have a logic bug in resolve_doc
- # Specifically, cur_doc.rev is always in the final vector
- # clock of revisions that we supersede, even if it wasn't in
- # conflicted_doc_revs. We still add it as a conflict, but the
- # fact that _put_doc_if_newer propagates resolutions means I
- # think that conflict could accidentally be resolved. We need
- # to add a test for this case first. (create a rev, create a
- # conflict, create another conflict, resolve the first rev
- # and first conflict, then make sure that the resolved
- # rev doesn't supersede the second conflict rev.) It *might*
- # not matter, because the superseding rev is in as a
- # conflict, but it does seem incorrect
- new_rev = self._ensure_maximal_rev(cur_doc.rev,
- conflicted_doc_revs)
- superseded_revs = set(conflicted_doc_revs)
- c = self._db_handle.cursor()
- doc.rev = new_rev
- if cur_doc.rev in superseded_revs:
- self._put_and_update_indexes(cur_doc, doc)
- else:
- self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
- # TODO: Is there some way that we could construct a rev that would
- # end up in superseded_revs, such that we add a conflict, and
- # then immediately delete it?
- self._delete_conflicts(c, doc, superseded_revs)
-
- def list_indexes(self):
- """Return the list of indexes and their definitions."""
- c = self._db_handle.cursor()
- # TODO: How do we test the ordering?
- c.execute("SELECT name, field FROM index_definitions"
- " ORDER BY name, offset")
- definitions = []
- cur_name = None
- for name, field in c.fetchall():
- if cur_name != name:
- definitions.append((name, []))
- cur_name = name
- definitions[-1][-1].append(field)
- return definitions
-
- def _get_index_definition(self, index_name):
- """Return the stored definition for a given index_name."""
- c = self._db_handle.cursor()
- c.execute("SELECT field FROM index_definitions"
- " WHERE name = ? ORDER BY offset", (index_name,))
- fields = [x[0] for x in c.fetchall()]
- if not fields:
- raise errors.IndexDoesNotExist
- return fields
-
- @staticmethod
- def _strip_glob(value):
- """Remove the trailing * from a value."""
- assert value[-1] == '*'
- return value[:-1]
-
- def _format_query(self, definition, key_values):
- # First, build the definition. We join the document_fields table
- # against itself, as many times as the 'width' of our definition.
- # We then do a query for each key_value, one-at-a-time.
- # Note: All of these strings are static, we could cache them, etc.
- tables = ["document_fields d%d" % i for i in range(len(definition))]
- novalue_where = ["d.doc_id = d%d.doc_id"
- " AND d%d.field_name = ?"
- % (i, i) for i in range(len(definition))]
- wildcard_where = [novalue_where[i] +
- (" AND d%d.value NOT NULL" % (i,))
- for i in range(len(definition))]
- exact_where = [novalue_where[i] +
- (" AND d%d.value = ?" % (i,))
- for i in range(len(definition))]
- like_where = [novalue_where[i] +
- (" AND d%d.value GLOB ?" % (i,))
- for i in range(len(definition))]
- is_wildcard = False
- # Merge the lists together, so that:
- # [field1, field2, field3], [val1, val2, val3]
- # Becomes:
- # (field1, val1, field2, val2, field3, val3)
- args = []
- where = []
- for idx, (field, value) in enumerate(zip(definition, key_values)):
- args.append(field)
- if value.endswith('*'):
- if value == '*':
- where.append(wildcard_where[idx])
- else:
- # This is a glob match
- if is_wildcard:
- # We can't have a partial wildcard following
- # another wildcard
- raise errors.InvalidGlobbing
- where.append(like_where[idx])
- args.append(value)
- is_wildcard = True
- else:
- if is_wildcard:
- raise errors.InvalidGlobbing
- where.append(exact_where[idx])
- args.append(value)
- statement = (
- "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
- "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
- "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
- "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
- ['d%d.value' % i for i in range(len(definition))])))
- return statement, args
-
- def get_from_index(self, index_name, *key_values):
- definition = self._get_index_definition(index_name)
- if len(key_values) != len(definition):
- raise errors.InvalidValueForIndex()
- statement, args = self._format_query(definition, key_values)
- c = self._db_handle.cursor()
- try:
- c.execute(statement, tuple(args))
- except dbapi2.OperationalError as e:
- raise dbapi2.OperationalError(
- str(e) +
- '\nstatement: %s\nargs: %s\n' % (statement, args))
- res = c.fetchall()
- results = []
- for row in res:
- doc = self._factory(row[0], row[1], row[2])
- doc.has_conflicts = row[3] > 0
- results.append(doc)
- return results
-
- def _format_range_query(self, definition, start_value, end_value):
- tables = ["document_fields d%d" % i for i in range(len(definition))]
- novalue_where = [
- "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
- range(len(definition))]
- wildcard_where = [
- novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
- range(len(definition))]
- like_where = [
- novalue_where[i] + (
- " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
- range(len(definition))]
- range_where_lower = [
- novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
- range(len(definition))]
- range_where_upper = [
- novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
- range(len(definition))]
- args = []
- where = []
- if start_value:
- if isinstance(start_value, basestring):
- start_value = (start_value,)
- if len(start_value) != len(definition):
- raise errors.InvalidValueForIndex()
- is_wildcard = False
- for idx, (field, value) in enumerate(zip(definition, start_value)):
- args.append(field)
- if value.endswith('*'):
- if value == '*':
- where.append(wildcard_where[idx])
- else:
- # This is a glob match
- if is_wildcard:
- # We can't have a partial wildcard following
- # another wildcard
- raise errors.InvalidGlobbing
- where.append(range_where_lower[idx])
- args.append(self._strip_glob(value))
- is_wildcard = True
- else:
- if is_wildcard:
- raise errors.InvalidGlobbing
- where.append(range_where_lower[idx])
- args.append(value)
- if end_value:
- if isinstance(end_value, basestring):
- end_value = (end_value,)
- if len(end_value) != len(definition):
- raise errors.InvalidValueForIndex()
- is_wildcard = False
- for idx, (field, value) in enumerate(zip(definition, end_value)):
- args.append(field)
- if value.endswith('*'):
- if value == '*':
- where.append(wildcard_where[idx])
- else:
- # This is a glob match
- if is_wildcard:
- # We can't have a partial wildcard following
- # another wildcard
- raise errors.InvalidGlobbing
- where.append(like_where[idx])
- args.append(self._strip_glob(value))
- args.append(value)
- is_wildcard = True
- else:
- if is_wildcard:
- raise errors.InvalidGlobbing
- where.append(range_where_upper[idx])
- args.append(value)
- statement = (
- "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
- "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
- "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
- "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
- ['d%d.value' % i for i in range(len(definition))])))
- return statement, args
-
- def get_range_from_index(self, index_name, start_value=None,
- end_value=None):
- """Return all documents with key values in the specified range."""
- definition = self._get_index_definition(index_name)
- statement, args = self._format_range_query(
- definition, start_value, end_value)
- c = self._db_handle.cursor()
- try:
- c.execute(statement, tuple(args))
- except dbapi2.OperationalError as e:
- raise dbapi2.OperationalError(
- str(e) +
- '\nstatement: %s\nargs: %s\n' % (statement, args))
- res = c.fetchall()
- results = []
- for row in res:
- doc = self._factory(row[0], row[1], row[2])
- doc.has_conflicts = row[3] > 0
- results.append(doc)
- return results
-
- def get_index_keys(self, index_name):
- c = self._db_handle.cursor()
- definition = self._get_index_definition(index_name)
- value_fields = ', '.join([
- 'd%d.value' % i for i in range(len(definition))])
- tables = ["document_fields d%d" % i for i in range(len(definition))]
- novalue_where = [
- "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
- range(len(definition))]
- where = [
- novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
- range(len(definition))]
- statement = (
- "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
- value_fields, ', '.join(tables), ' AND '.join(where),
- value_fields))
- try:
- c.execute(statement, tuple(definition))
- except dbapi2.OperationalError as e:
- raise dbapi2.OperationalError(
- str(e) +
- '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
- return c.fetchall()
-
- def delete_index(self, index_name):
- with self._db_handle:
- c = self._db_handle.cursor()
- c.execute("DELETE FROM index_definitions WHERE name = ?",
- (index_name,))
- c.execute(
- "DELETE FROM document_fields WHERE document_fields.field_name "
- " NOT IN (SELECT field from index_definitions)")
-
-
-class SQLiteSyncTarget(CommonSyncTarget):
-
- def get_sync_info(self, source_replica_uid):
- source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
- source_replica_uid)
- my_gen, my_trans_id = self._db._get_generation_info()
- return (
- self._db._replica_uid, my_gen, my_trans_id, source_gen,
- source_trans_id)
-
- def record_sync_info(self, source_replica_uid, source_replica_generation,
- source_replica_transaction_id):
- if self._trace_hook:
- self._trace_hook('record_sync_info')
- self._db._set_replica_gen_and_trans_id(
- source_replica_uid, source_replica_generation,
- source_replica_transaction_id)
-
-
-class SQLitePartialExpandDatabase(SQLiteDatabase):
- """An SQLite Backend that expands documents into a document_field table.
-
- It stores the original document text in document.doc. For fields that are
- indexed, the data goes into document_fields.
- """
-
- _index_storage_value = 'expand referenced'
-
- def _get_indexed_fields(self):
- """Determine what fields are indexed."""
- c = self._db_handle.cursor()
- c.execute("SELECT field FROM index_definitions")
- return set([x[0] for x in c.fetchall()])
-
- def _evaluate_index(self, raw_doc, field):
- parser = query_parser.Parser()
- getter = parser.parse(field)
- return getter.get(raw_doc)
-
- def _put_and_update_indexes(self, old_doc, doc):
- c = self._db_handle.cursor()
- if doc and not doc.is_tombstone():
- raw_doc = json.loads(doc.get_json())
- else:
- raw_doc = {}
- if old_doc is not None:
- c.execute("UPDATE document SET doc_rev=?, content=?"
- " WHERE doc_id = ?",
- (doc.rev, doc.get_json(), doc.doc_id))
- c.execute("DELETE FROM document_fields WHERE doc_id = ?",
- (doc.doc_id,))
- else:
- c.execute("INSERT INTO document (doc_id, doc_rev, content)"
- " VALUES (?, ?, ?)",
- (doc.doc_id, doc.rev, doc.get_json()))
- indexed_fields = self._get_indexed_fields()
- if indexed_fields:
- # It is expected that len(indexed_fields) is shorter than
- # len(raw_doc)
- getters = [(field, self._parse_index_definition(field))
- for field in indexed_fields]
- self._update_indexes(doc.doc_id, raw_doc, getters, c)
- trans_id = self._allocate_transaction_id()
- c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
- " VALUES (?, ?)", (doc.doc_id, trans_id))
-
- def create_index(self, index_name, *index_expressions):
- with self._db_handle:
- c = self._db_handle.cursor()
- cur_fields = self._get_indexed_fields()
- definition = [(index_name, idx, field)
- for idx, field in enumerate(index_expressions)]
- try:
- c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
- definition)
- except dbapi2.IntegrityError as e:
- stored_def = self._get_index_definition(index_name)
- if stored_def == [x[-1] for x in definition]:
- return
- raise errors.IndexNameTakenError(
- str(e) +
- str(sys.exc_info()[2])
- )
- new_fields = set(
- [f for f in index_expressions if f not in cur_fields])
- if new_fields:
- self._update_all_indexes(new_fields)
-
- def _iter_all_docs(self):
- c = self._db_handle.cursor()
- c.execute("SELECT doc_id, content FROM document")
- while True:
- next_rows = c.fetchmany()
- if not next_rows:
- break
- for row in next_rows:
- yield row
-
- def _update_all_indexes(self, new_fields):
- """Iterate all the documents, and add content to document_fields.
-
- :param new_fields: The index definitions that need to be added.
- """
- getters = [(field, self._parse_index_definition(field))
- for field in new_fields]
- c = self._db_handle.cursor()
- for doc_id, doc in self._iter_all_docs():
- if doc is None:
- continue
- raw_doc = json.loads(doc)
- self._update_indexes(doc_id, raw_doc, getters, c)
-
-
-SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)
diff --git a/common/src/leap/soledad/common/l2db/remote/server_state.py b/common/src/leap/soledad/common/l2db/remote/server_state.py
index e20b4679..89ac5742 100644
--- a/common/src/leap/soledad/common/l2db/remote/server_state.py
+++ b/common/src/leap/soledad/common/l2db/remote/server_state.py
@@ -42,10 +42,9 @@ class ServerState(object):
def open_database(self, path):
"""Open a database at the given location."""
- from u1db.backends import sqlite_backend
+ from leap.soledad.client._database import sqlite
full_path = self._relpath(path)
- return sqlite_backend.SQLiteDatabase.open_database(full_path,
- create=False)
+ return sqlite.SQLiteDatabase.open_database(full_path, create=False)
def check_database(self, path):
"""Check if the database at the given location exists.
@@ -57,14 +56,13 @@ class ServerState(object):
def ensure_database(self, path):
"""Ensure database at the given location."""
- from u1db.backends import sqlite_backend
+ from leap.soledad.client._database import sqlite
full_path = self._relpath(path)
- db = sqlite_backend.SQLiteDatabase.open_database(full_path,
- create=True)
+ db = sqlite.SQLiteDatabase.open_database(full_path, create=True)
return db, db._replica_uid
def delete_database(self, path):
"""Delete database at the given location."""
- from u1db.backends import sqlite_backend
+ from leap.soledad.client._database import sqlite
full_path = self._relpath(path)
- sqlite_backend.SQLiteDatabase.delete_database(full_path)
+ sqlite.SQLiteDatabase.delete_database(full_path)