diff options
Diffstat (limited to 'src/leap/soledad/common/l2db')
-rw-r--r-- | src/leap/soledad/common/l2db/__init__.py | 694 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/backends/__init__.py | 204 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/backends/inmemory.py | 466 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/errors.py | 194 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/query_parser.py | 371 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/__init__.py | 15 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_app.py | 660 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_client.py | 178 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_database.py | 158 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_errors.py | 48 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_target.py | 125 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/server_state.py | 68 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/ssl_match_hostname.py | 65 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/utils.py | 23 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/sync.py | 311 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/vectorclock.py | 89 |
16 files changed, 3669 insertions, 0 deletions
diff --git a/src/leap/soledad/common/l2db/__init__.py b/src/leap/soledad/common/l2db/__init__.py new file mode 100644 index 00000000..43d61b1d --- /dev/null +++ b/src/leap/soledad/common/l2db/__init__.py @@ -0,0 +1,694 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""L2DB""" + +import json + +from leap.soledad.common.l2db.errors import InvalidJSON, InvalidContent + +__version_info__ = (13, 9) +__version__ = '.'.join(map(lambda x: '%02d' % x, __version_info__)) + + +def open(path, create, document_factory=None): + """Open a database at the given location. + + Will raise u1db.errors.DatabaseDoesNotExist if create=False and the + database does not already exist. + + :param path: The filesystem path for the database to open. + :param create: True/False, should the database be created if it doesn't + already exist? + :param document_factory: A function that will be called with the same + parameters as Document.__init__. + :return: An instance of Database. + """ + from leap.soledad.client._db import sqlite + return sqlite.SQLiteDatabase.open_database( + path, create=create, document_factory=document_factory) + + +# constraints on database names (relevant for remote access, as regex) +DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*" + +# constraints on doc ids (as regex) +# (no slashes, and no characters outside the ascii range) +DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+" + + +class Database(object): + """A JSON Document data store. + + This data store can be synchronized with other u1db.Database instances. + """ + + def set_document_factory(self, factory): + """Set the document factory that will be used to create objects to be + returned as documents by the database. + + :param factory: A function that returns an object which at minimum must + satisfy the same interface as does the class DocumentBase. + Subclassing that class is the easiest way to create such + a function. + """ + raise NotImplementedError(self.set_document_factory) + + def set_document_size_limit(self, limit): + """Set the maximum allowed document size for this database. + + :param limit: Maximum allowed document size in bytes. + """ + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + """Return a list of documents that have changed since old_generation. + This allows APPS to only store a db generation before going + 'offline', and then when coming back online they can use this + data to update whatever extra data they are storing. + + :param old_generation: The generation of the database in the old + state. + :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) + The current generation of the database, its associated transaction + id, and a list of of changed documents since old_generation, + represented by tuples with for each document its doc_id and the + generation and transaction id corresponding to the last intervening + change and sorted by generation (old changes first) + """ + raise NotImplementedError(self.whats_changed) + + def get_doc(self, doc_id, include_deleted=False): + """Get the JSON string for the given document. + + :param doc_id: The unique document identifier + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise asking for a deleted + document will return None. + :return: a Document object. + """ + raise NotImplementedError(self.get_doc) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + """Get the JSON content for many documents. + + :param doc_ids: A list of document identifiers. + :param check_for_conflicts: If set to False, then the conflict check + will be skipped, and 'None' will be returned instead of True/False. + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: iterable giving the Document object for each document id + in matching doc_ids order. + """ + raise NotImplementedError(self.get_docs) + + def get_all_docs(self, include_deleted=False): + """Get the JSON content for all documents in the database. + + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: (generation, [Document]) + The current generation of the database, followed by a list of all + the documents in the database. + """ + raise NotImplementedError(self.get_all_docs) + + def create_doc(self, content, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param content: A Python dictionary. + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc) + + def create_doc_from_json(self, json, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param json: The JSON document string + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc_from_json) + + def put_doc(self, doc): + """Update a document. + If the document currently has conflicts, put will fail. + If the database specifies a maximum document size and the document + exceeds it, put will fail and raise a DocumentTooBig exception. + + :param doc: A Document with new content. + :return: new_doc_rev - The new revision identifier for the document. + The Document object will also be updated. + """ + raise NotImplementedError(self.put_doc) + + def delete_doc(self, doc): + """Mark a document as deleted. + Will abort if the current revision doesn't match doc.rev. + This will also set doc.content to None. + """ + raise NotImplementedError(self.delete_doc) + + def create_index(self, index_name, *index_expressions): + """Create an named index, which can then be queried for future lookups. + Creating an index which already exists is not an error, and is cheap. + Creating an index which does not match the index_expressions of the + existing index is an error. + Creating an index will block until the expressions have been evaluated + and the index generated. + + :param index_name: A unique name which can be used as a key prefix + :param index_expressions: index expressions defining the index + information. + + Examples: + + "fieldname", or "fieldname.subfieldname" to index alphabetically + sorted on the contents of a field. + + "number(fieldname, width)", "lower(fieldname)" + """ + raise NotImplementedError(self.create_index) + + def delete_index(self, index_name): + """Remove a named index. + + :param index_name: The name of the index we are removing + """ + raise NotImplementedError(self.delete_index) + + def list_indexes(self): + """List the definitions of all known indexes. + + :return: A list of [('index-name', ['field', 'field2'])] definitions. + """ + raise NotImplementedError(self.list_indexes) + + def get_from_index(self, index_name, *key_values): + """Return documents that match the keys supplied. + + You must supply exactly the same number of values as have been defined + in the index. It is possible to do a prefix match by using '*' to + indicate a wildcard match. You can only supply '*' to trailing entries, + (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) + It is also possible to append a '*' to the last supplied value (eg + 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_from_index) + + def get_range_from_index(self, index_name, start_value, end_value): + """Return documents that fall within the specified range. + + Both ends of the range are inclusive. For both start_value and + end_value, one must supply exactly the same number of values as have + been defined in the index, or pass None. In case of a single column + index, a string is accepted as an alternative for a tuple with a single + value. It is possible to do a prefix match by using '*' to indicate + a wildcard match. You can only supply '*' to trailing entries, (eg + 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also + possible to append a '*' to the last supplied value (eg 'val*', '*', + '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param start_values: tuples of values that define the lower bound of + the range. eg, if you have an index with 3 fields then you would + have: (val1, val2, val3) + :param end_values: tuples of values that define the upper bound of the + range. eg, if you have an index with 3 fields then you would have: + (val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_range_from_index) + + def get_index_keys(self, index_name): + """Return all keys under which documents are indexed in this index. + + :param index_name: The index to query + :return: [] A list of tuples of indexed keys. + """ + raise NotImplementedError(self.get_index_keys) + + def get_doc_conflicts(self, doc_id): + """Get the list of conflicts for the given document. + + The order of the conflicts is such that the first entry is the value + that would be returned by "get_doc". + + :return: [doc] A list of the Document entries that are conflicted. + """ + raise NotImplementedError(self.get_doc_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + """Mark a document as no longer conflicted. + + We take the list of revisions that the client knows about that it is + superseding. This may be a different list from the actual current + conflicts, in which case only those are removed as conflicted. This + may fail if the conflict list is significantly different from the + supplied information. (sync could have happened in the background from + the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + + :param doc: A Document with the new content to be inserted. + :param conflicted_doc_revs: A list of revisions that the new content + supersedes. + """ + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + """Return a SyncTarget object, for another u1db to synchronize with. + + :return: An instance of SyncTarget. + """ + raise NotImplementedError(self.get_sync_target) + + def close(self): + """Release any resources associated with this database.""" + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + """Synchronize documents with remote replica exposed at url. + + :param url: the url of the target replica to sync with. + :param creds: optional dictionary giving credentials + to authorize the operation with the server. For using OAuth + the form of creds is: + {'oauth': { + 'consumer_key': ..., + 'consumer_secret': ..., + 'token_key': ..., + 'token_secret': ... + }} + :param autocreate: ask the target to create the db if non-existent. + :return: local_gen_before_sync The local generation before the + synchronisation was performed. This is useful to pass into + whatschanged, if an application wants to know which documents were + affected by a synchronisation. + """ + from u1db.sync import Synchronizer + from u1db.remote.http_target import HTTPSyncTarget + return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + """Return the last known generation and transaction id for the other db + replica. + + When you do a synchronization with another replica, the Database keeps + track of what generation the other database replica was at, and what + the associated transaction id was. This is used to determine what data + needs to be sent, and if two databases are claiming to be the same + replica. + + :param other_replica_uid: The identifier for the other replica. + :return: (gen, trans_id) The generation and transaction id we + encountered during synchronization. If we've never synchronized + with the replica, this is (0, ''). + """ + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """Set the last-known generation and transaction id for the other + database replica. + + We have just performed some synchronization, and we want to track what + generation the other replica was at. See also + _get_replica_gen_and_trans_id. + :param other_replica_uid: The U1DB identifier for the other replica. + :param other_generation: The generation number for the other replica. + :param other_transaction_id: The transaction id associated with the + generation. + """ + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + """Insert/update document into the database with a given revision. + + This api is used during synchronization operations. + + If a document would conflict and save_conflict is set to True, the + content will be selected as the 'current' content for doc.doc_id, + even though doc.rev doesn't supersede the currently stored revision. + The currently stored document will be added to the list of conflict + alternatives for the given doc_id. + + This forces the new content to be 'current' so that we get convergence + after synchronizing, even if people don't resolve conflicts. Users can + then notice that their content is out of date, update it, and + synchronize again. (The alternative is that users could synchronize and + think the data has propagated, but their local copy looks fine, and the + remote copy is never updated again.) + + :param doc: A Document object + :param save_conflict: If this document is a conflict, do you want to + save it as a conflict, or just ignore it. + :param replica_uid: A unique replica identifier. + :param replica_gen: The generation of the replica corresponding to the + this document. The replica arguments are optional, but are used + during synchronization. + :param replica_trans_id: The transaction_id associated with the + generation. + :return: (state, at_gen) - If we don't have doc_id already, + or if doc_rev supersedes the existing document revision, + then the content will be inserted, and state is 'inserted'. + If doc_rev is less than or equal to the existing revision, + then the put is ignored and state is respecitvely 'superseded' + or 'converged'. + If doc_rev is not strictly superseded or supersedes, then + state is 'conflicted'. The document will not be inserted if + save_conflict is False. + For 'inserted' or 'converged', at_gen is the insertion/current + generation. + """ + raise NotImplementedError(self._put_doc_if_newer) + + +class DocumentBase(object): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json_string: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + def __init__(self, doc_id, rev, json_string, has_conflicts=False): + self.doc_id = doc_id + self.rev = rev + if json_string is not None: + try: + value = json.loads(json_string) + except ValueError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + self.has_conflicts = has_conflicts + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = None + if other._json: + c2 = json.loads(other._json) + else: + c2 = None + return c1 == c2 + + def __repr__(self): + if self.has_conflicts: + extra = ', conflicted' + else: + extra = '' + return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, + self.rev, extra, self.get_json()) + + def __hash__(self): + raise NotImplementedError(self.__hash__) + + def __eq__(self, other): + if not isinstance(other, Document): + return NotImplemented + return ( + self.doc_id == other.doc_id and self.rev == other.rev and + self.same_content_as(other) and self.has_conflicts == + other.has_conflicts) + + def __lt__(self, other): + """This is meant for testing, not part of the official api. + + It is implemented so that sorted([Document, Document]) can be used. + It doesn't imply that users would want their documents to be sorted in + this order. + """ + # Since this is just for testing, we don't worry about comparing + # against things that aren't a Document. + return ((self.doc_id, self.rev, self.get_json()) < + (other.doc_id, other.rev, other.get_json())) + + def get_json(self): + """Get the json serialization of this document.""" + if self._json is not None: + return self._json + return None + + def get_size(self): + """Calculate the total size of the document.""" + size = 0 + json = self.get_json() + if json: + size += len(json) + if self.rev: + size += len(self.rev) + if self.doc_id: + size += len(self.doc_id) + return size + + def set_json(self, json_string): + """Set the json serialization of this document.""" + if json_string is not None: + try: + value = json.loads(json_string) + except ValueError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._json = None + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._json is not None: + return False + return True + + +class Document(DocumentBase): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + # The following part of the API is optional: no implementation is forced to + # have it but if the language supports dictionaries/hashtables, it makes + # Documents a lot more user friendly. + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): + # TODO: We convert the json in the superclass to check its validity so + # we might as well set _content here directly since the price is + # already being paid. + super(Document, self).__init__(doc_id, rev, json, has_conflicts) + self._content = None + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = self._content + if other._json: + c2 = json.loads(other._json) + else: + c2 = other._content + return c1 == c2 + + def get_json(self): + """Get the json serialization of this document.""" + json_string = super(Document, self).get_json() + if json_string is not None: + return json_string + if self._content is not None: + return json.dumps(self._content) + return None + + def set_json(self, json): + """Set the json serialization of this document.""" + self._content = None + super(Document, self).set_json(json) + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._content = None + super(Document, self).make_tombstone() + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._content is not None: + return False + return super(Document, self).is_tombstone() + + def _get_content(self): + """Get the dictionary representing this document.""" + if self._json is not None: + self._content = json.loads(self._json) + self._json = None + if self._content is not None: + return self._content + return None + + def _set_content(self, content): + """Set the dictionary representing this document.""" + try: + tmp = json.dumps(content) + except TypeError: + raise InvalidContent( + "Can not be converted to JSON: %r" % (content,)) + if not tmp.startswith('{'): + raise InvalidContent( + "Can not be converted to a JSON object: %r." % (content,)) + # We might as well store the JSON at this point since we did the work + # of encoding it, and it doesn't lose any information. + self._json = tmp + self._content = None + + content = property( + _get_content, _set_content, doc="Content of the Document.") + + # End of optional part. + + +class SyncTarget(object): + """Functionality for using a Database as a synchronization target.""" + + def get_sync_info(self, source_replica_uid): + """Return information about known state. + + Return the replica_uid and the current database generation of this + database, and the last-seen database generation for source_replica_uid + + :param source_replica_uid: Another replica which we might have + synchronized with in the past. + :return: (target_replica_uid, target_replica_generation, + target_trans_id, source_replica_last_known_generation, + source_replica_last_known_transaction_id) + """ + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + """Record tip information for another replica. + + After sync_exchange has been processed, the caller will have + received new content from this replica. This call allows the + source replica instigating the sync to inform us what their + generation became after applying the documents we returned. + + This is used to allow future sync operations to not need to repeat data + that we just talked about. It also means that if this is called at the + wrong time, there can be database records that will never be + synchronized. + + :param source_replica_uid: The identifier for the source replica. + :param source_replica_generation: + The database generation for the source replica. + :param source_replica_transaction_id: The transaction id associated + with the source replica generation. + """ + raise NotImplementedError(self.record_sync_info) + + def sync_exchange(self, docs_by_generation, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + """Incorporate the documents sent from the source replica. + + This is not meant to be called by client code directly, but is used as + part of sync(). + + This adds docs to the local store, and determines documents that need + to be returned to the source replica. + + Documents must be supplied in docs_by_generation paired with + the generation of their latest change in order from the oldest + change to the newest, that means from the oldest generation to + the newest. + + Documents are also returned paired with the generation of + their latest change in order from the oldest change to the + newest. + + :param docs_by_generation: A list of [(Document, generation, + transaction_id)] tuples indicating documents which should be + updated on this replica paired with the generation and transaction + id of their latest change. + :param source_replica_uid: The source replica's identifier + :param last_known_generation: The last generation that the source + replica knows about this target replica + :param last_known_trans_id: The last transaction id that the source + replica knows about this target replica + :param: return_doc_cb(doc, gen): is a callback + used to return documents to the source replica, it will + be invoked in turn with Documents that have changed since + last_known_generation together with the generation of + their last change. + :param: ensure_callback(replica_uid): if set the target may create + the target db if not yet existent, the callback can then + be used to inform of the created db replica uid. + :return: new_generation - After applying docs_by_generation, this is + the current generation for this replica + """ + raise NotImplementedError(self.sync_exchange) + + def _set_trace_hook(self, cb): + """Set a callback that will be invoked to trace database actions. + + The callback will be passed a string indicating the current state, and + the sync target object. Implementations do not have to implement this + api, it is used by the test suite. + + :param cb: A callable that takes cb(state) + """ + raise NotImplementedError(self._set_trace_hook) + + def _set_trace_hook_shallow(self, cb): + """Set a callback that will be invoked to trace database actions. + + Similar to _set_trace_hook, for implementations that don't offer + state changes from the inner working of sync_exchange(). + + :param cb: A callable that takes cb(state) + """ + self._set_trace_hook(cb) diff --git a/src/leap/soledad/common/l2db/backends/__init__.py b/src/leap/soledad/common/l2db/backends/__init__.py new file mode 100644 index 00000000..c731c3d3 --- /dev/null +++ b/src/leap/soledad/common/l2db/backends/__init__.py @@ -0,0 +1,204 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Abstract classes and common implementations for the backends.""" + +import re +import json +import uuid + +from leap.soledad.common import l2db +from leap.soledad.common.l2db import sync as l2db_sync +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.vectorclock import VectorClockRev + + +check_doc_id_re = re.compile("^" + l2db.DOC_ID_CONSTRAINTS + "$", re.UNICODE) + + +class CommonSyncTarget(l2db_sync.LocalSyncTarget): + pass + + +class CommonBackend(l2db.Database): + + document_size_limit = 0 + + def _allocate_doc_id(self): + """Generate a unique identifier for this document.""" + return 'D-' + uuid.uuid4().hex # 'D-' stands for document + + def _allocate_transaction_id(self): + return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction + + def _allocate_doc_rev(self, old_doc_rev): + vcr = VectorClockRev(old_doc_rev) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def _check_doc_id(self, doc_id): + if not check_doc_id_re.match(doc_id): + raise errors.InvalidDocId() + + def _check_doc_size(self, doc): + if not self.document_size_limit: + return + if doc.get_size() > self.document_size_limit: + raise errors.DocumentTooBig + + def _get_generation(self): + """Return the current generation. + + """ + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + """Return the current generation and transaction id. + + """ + raise NotImplementedError(self._get_generation_info) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Extract the document from storage. + + This can return None if the document doesn't exist. + """ + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + """Return True if the doc has conflicts, False otherwise.""" + raise NotImplementedError(self._has_conflicts) + + def create_doc(self, content, doc_id=None): + if not isinstance(content, dict): + raise errors.InvalidContent + json_string = json.dumps(content) + return self.create_doc_from_json(json_string, doc_id) + + def create_doc_from_json(self, json, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + doc = self._factory(doc_id, None, json) + self.put_doc(doc) + return doc + + def _get_transaction_log(self): + """This is only for the test suite, it is not part of the api.""" + raise NotImplementedError(self._get_transaction_log) + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + for doc_id in doc_ids: + doc = self._get_doc( + doc_id, check_for_conflicts=check_for_conflicts) + if doc.is_tombstone() and not include_deleted: + continue + yield doc + + def _get_trans_id_for_gen(self, generation): + """Get the transaction id corresponding to a particular generation. + + Raises an InvalidGeneration when the generation does not exist. + + """ + raise NotImplementedError(self._get_trans_id_for_gen) + + def validate_gen_and_trans_id(self, generation, trans_id): + """Validate the generation and transaction id. + + Raises an InvalidGeneration when the generation does not exist, and an + InvalidTransactionId when it does but with a different transaction id. + + """ + if generation == 0: + return + known_trans_id = self._get_trans_id_for_gen(generation) + if known_trans_id != trans_id: + raise errors.InvalidTransactionId + + def _validate_source(self, other_replica_uid, other_generation, + other_transaction_id): + """Validate the new generation and transaction id. + + other_generation must be greater than what we have stored for this + replica, *or* it must be the same and the transaction_id must be the + same as well. + """ + (old_generation, + old_transaction_id) = self._get_replica_gen_and_trans_id( + other_replica_uid) + if other_generation < old_generation: + raise errors.InvalidGeneration + if other_generation > old_generation: + return + if other_transaction_id == old_transaction_id: + return + raise errors.InvalidTransactionId + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + cur_doc = self._get_doc(doc.doc_id) + doc_vcr = VectorClockRev(doc.rev) + if cur_doc is None: + cur_vcr = VectorClockRev(None) + else: + cur_vcr = VectorClockRev(cur_doc.rev) + self._validate_source(replica_uid, replica_gen, replica_trans_id) + if doc_vcr.is_newer(cur_vcr): + rev = doc.rev + self._prune_conflicts(doc, doc_vcr) + if doc.rev != rev: + # conflicts have been autoresolved + state = 'superseded' + else: + state = 'inserted' + self._put_and_update_indexes(cur_doc, doc) + elif doc.rev == cur_doc.rev: + # magical convergence + state = 'converged' + elif cur_vcr.is_newer(doc_vcr): + # Don't add this to seen_ids, because we have something newer, + # so we should send it back, and we should not generate a + # conflict + state = 'superseded' + elif cur_doc.same_content_as(doc): + # the documents have been edited to the same thing at both ends + doc_vcr.maximize(cur_vcr) + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._put_and_update_indexes(cur_doc, doc) + state = 'superseded' + else: + state = 'conflicted' + if save_conflict: + self._force_doc_sync_conflict(doc) + if replica_uid is not None and replica_gen is not None: + self._do_set_replica_gen_and_trans_id( + replica_uid, replica_gen, replica_trans_id) + return state, self._get_generation() + + def _ensure_maximal_rev(self, cur_rev, extra_revs): + vcr = VectorClockRev(cur_rev) + for rev in extra_revs: + vcr.maximize(VectorClockRev(rev)) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def set_document_size_limit(self, limit): + self.document_size_limit = limit diff --git a/src/leap/soledad/common/l2db/backends/inmemory.py b/src/leap/soledad/common/l2db/backends/inmemory.py new file mode 100644 index 00000000..6fd251af --- /dev/null +++ b/src/leap/soledad/common/l2db/backends/inmemory.py @@ -0,0 +1,466 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The in-memory Database class for U1DB.""" + +import json + +from leap.soledad.common.l2db import ( + Document, errors, + query_parser, vectorclock) +from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget + + +def get_prefix(value): + key_prefix = '\x01'.join(value) + return key_prefix.rstrip('*') + + +class InMemoryDatabase(CommonBackend): + """A database that only stores the data internally.""" + + def __init__(self, replica_uid, document_factory=None): + self._transaction_log = [] + self._docs = {} + # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner' + self._conflicts = {} + self._other_generations = {} + self._indexes = {} + self._replica_uid = replica_uid + self._factory = document_factory or Document + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + self._replica_uid = replica_uid + + def set_document_factory(self, factory): + self._factory = factory + + def close(self): + # This is a no-op, We don't want to free the data because one client + # may be closing it, while another wants to inspect the results. + pass + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + return self._other_generations.get(other_replica_uid, (0, '')) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + # TODO: to handle race conditions, we may want to check if the current + # value is greater than this new value. + self._other_generations[other_replica_uid] = (other_generation, + other_transaction_id) + + def get_sync_target(self): + return InMemorySyncTarget(self) + + def _get_transaction_log(self): + # snapshot! + return self._transaction_log[:] + + def _get_generation(self): + return len(self._transaction_log) + + def _get_generation_info(self): + if not self._transaction_log: + return 0, '' + return len(self._transaction_log), self._transaction_log[-1][1] + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + if generation > len(self._transaction_log): + raise errors.InvalidGeneration + return self._transaction_log[generation - 1][1] + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _put_and_update_indexes(self, old_doc, doc): + for index in self._indexes.itervalues(): + if old_doc is not None and not old_doc.is_tombstone(): + index.remove_json(old_doc.doc_id, old_doc.get_json()) + if not doc.is_tombstone(): + index.add_json(doc.doc_id, doc.get_json()) + trans_id = self._allocate_transaction_id() + self._docs[doc.doc_id] = (doc.rev, doc.get_json()) + self._transaction_log.append((doc.doc_id, trans_id)) + + def _get_doc(self, doc_id, check_for_conflicts=False): + try: + doc_rev, content = self._docs[doc_id] + except KeyError: + return None + doc = self._factory(doc_id, doc_rev, content) + if check_for_conflicts: + doc.has_conflicts = (doc.doc_id in self._conflicts) + return doc + + def _has_conflicts(self, doc_id): + return doc_id in self._conflicts + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Return all documents in the database.""" + generation = self._get_generation() + results = [] + for doc_id, (doc_rev, content) in self._docs.items(): + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = self._has_conflicts(doc_id) + results.append(doc) + return (generation, results) + + def get_doc_conflicts(self, doc_id): + if doc_id not in self._conflicts: + return [] + result = [self._get_doc(doc_id)] + result[0].has_conflicts = True + result.extend([self._factory(doc_id, rev, content) + for rev, content in self._conflicts[doc_id]]) + return result + + def _replace_conflicts(self, doc, conflicts): + if not conflicts: + del self._conflicts[doc.doc_id] + else: + self._conflicts[doc.doc_id] = conflicts + doc.has_conflicts = bool(conflicts) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + c_vcr = vectorclock.VectorClockRev(c_rev) + if doc_vcr.is_newer(c_vcr): + continue + if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)): + doc_vcr.maximize(c_vcr) + autoresolved = True + continue + remaining_conflicts.append((c_rev, c_doc)) + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._replace_conflicts(doc, remaining_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + cur_doc = self._get_doc(doc.doc_id) + if cur_doc is None: + cur_rev = None + else: + cur_rev = cur_doc.rev + new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + if c_rev in superseded_revs: + continue + remaining_conflicts.append((c_rev, c_doc)) + doc.rev = new_rev + if cur_rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + remaining_conflicts.append((new_rev, doc.get_json())) + self._replace_conflicts(doc, remaining_conflicts) + + def delete_doc(self, doc): + if doc.doc_id not in self._docs: + raise errors.DocumentDoesNotExist + if self._docs[doc.doc_id][1] in ('null', None): + raise errors.DocumentAlreadyDeleted + doc.make_tombstone() + self.put_doc(doc) + + def create_index(self, index_name, *index_expressions): + if index_name in self._indexes: + if self._indexes[index_name]._definition == list( + index_expressions): + return + raise errors.IndexNameTakenError + index = InMemoryIndex(index_name, list(index_expressions)) + for doc_id, (doc_rev, doc) in self._docs.iteritems(): + if doc is not None: + index.add_json(doc_id, doc) + self._indexes[index_name] = index + + def delete_index(self, index_name): + try: + del self._indexes[index_name] + except KeyError: + pass + + def list_indexes(self): + definitions = [] + for idx in self._indexes.itervalues(): + definitions.append((idx._name, idx._definition)) + return definitions + + def get_from_index(self, index_name, *key_values): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + doc_ids = index.lookup(key_values) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + if isinstance(start_value, basestring): + start_value = (start_value,) + if isinstance(end_value, basestring): + end_value = (end_value,) + doc_ids = index.lookup_range(start_value, end_value) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_index_keys(self, index_name): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + keys = index.keys() + # XXX inefficiency warning + return list(set([tuple(key.split('\x01')) for key in keys])) + + def whats_changed(self, old_generation=0): + changes = [] + relevant_tail = self._transaction_log[old_generation:] + # We don't use len(self._transaction_log) because _transaction_log may + # get mutated by a concurrent operation. + cur_generation = old_generation + len(relevant_tail) + last_trans_id = '' + if relevant_tail: + last_trans_id = relevant_tail[-1][1] + elif self._transaction_log: + last_trans_id = self._transaction_log[-1][1] + seen = set() + generation = cur_generation + for doc_id, trans_id in reversed(relevant_tail): + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + generation -= 1 + changes.reverse() + return (cur_generation, last_trans_id, changes) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._conflicts.setdefault(doc.doc_id, []).append( + (my_doc.rev, my_doc.get_json())) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + +class InMemoryIndex(object): + """Interface for managing an Index.""" + + def __init__(self, index_name, index_definition): + self._name = index_name + self._definition = index_definition + self._values = {} + parser = query_parser.Parser() + self._getters = parser.parse_all(self._definition) + + def evaluate_json(self, doc): + """Determine the 'key' after applying this index to the doc.""" + raw = json.loads(doc) + return self.evaluate(raw) + + def evaluate(self, obj): + """Evaluate a dict object, applying this definition.""" + all_rows = [[]] + for getter in self._getters: + new_rows = [] + keys = getter.get(obj) + if not keys: + return [] + for key in keys: + new_rows.extend([row + [key] for row in all_rows]) + all_rows = new_rows + all_rows = ['\x01'.join(row) for row in all_rows] + return all_rows + + def add_json(self, doc_id, doc): + """Add this json doc to the index.""" + keys = self.evaluate_json(doc) + if not keys: + return + for key in keys: + self._values.setdefault(key, []).append(doc_id) + + def remove_json(self, doc_id, doc): + """Remove this json doc from the index.""" + keys = self.evaluate_json(doc) + if keys: + for key in keys: + doc_ids = self._values[key] + doc_ids.remove(doc_id) + if not doc_ids: + del self._values[key] + + def _find_non_wildcards(self, values): + """Check if this should be a wildcard match. + + Further, this will raise an exception if the syntax is improperly + defined. + + :return: The offset of the last value we need to match against. + """ + if len(values) != len(self._definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + last = 0 + for idx, val in enumerate(values): + if val.endswith('*'): + if val != '*': + # We have an 'x*' style wildcard + if is_wildcard: + # We were already in wildcard mode, so this is invalid + raise errors.InvalidGlobbing + last = idx + 1 + is_wildcard = True + else: + if is_wildcard: + # We were in wildcard mode, we can't follow that with + # non-wildcard + raise errors.InvalidGlobbing + last = idx + 1 + if not is_wildcard: + return -1 + return last + + def lookup(self, values): + """Find docs that match the values.""" + last = self._find_non_wildcards(values) + if last == -1: + return self._lookup_exact(values) + else: + return self._lookup_prefix(values[:last]) + + def lookup_range(self, start_values, end_values): + """Find docs within the range.""" + # TODO: Wildly inefficient, which is unlikely to be a problem for the + # inmemory implementation. + if start_values: + self._find_non_wildcards(start_values) + start_values = get_prefix(start_values) + if end_values: + if self._find_non_wildcards(end_values) == -1: + exact = True + else: + exact = False + end_values = get_prefix(end_values) + found = [] + for key, doc_ids in sorted(self._values.iteritems()): + if start_values and start_values > key: + continue + if end_values and end_values < key: + if exact: + break + else: + if not key.startswith(end_values): + break + found.extend(doc_ids) + return found + + def keys(self): + """Find the indexed keys.""" + return self._values.keys() + + def _lookup_prefix(self, value): + """Find docs that match the prefix string in values.""" + # TODO: We need a different data structure to make prefix style fast, + # some sort of sorted list would work, but a plain dict doesn't. + key_prefix = get_prefix(value) + all_doc_ids = [] + for key, doc_ids in sorted(self._values.iteritems()): + if key.startswith(key_prefix): + all_doc_ids.extend(doc_ids) + return all_doc_ids + + def _lookup_exact(self, value): + """Find docs that match exactly.""" + key = '\x01'.join(value) + if key in self._values: + return self._values[key] + return () + + +class InMemorySyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_transaction_id) diff --git a/src/leap/soledad/common/l2db/errors.py b/src/leap/soledad/common/l2db/errors.py new file mode 100644 index 00000000..b502fc2d --- /dev/null +++ b/src/leap/soledad/common/l2db/errors.py @@ -0,0 +1,194 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""A list of errors that u1db can raise.""" + + +class U1DBError(Exception): + """Generic base class for U1DB errors.""" + + # description/tag for identifying the error during transmission (http,...) + wire_description = "error" + + def __init__(self, message=None): + self.message = message + + +class RevisionConflict(U1DBError): + """The document revisions supplied does not match the current version.""" + + wire_description = "revision conflict" + + +class InvalidJSON(U1DBError): + """Content was not valid json.""" + + +class InvalidContent(U1DBError): + """Content was not a python dictionary.""" + + +class InvalidDocId(U1DBError): + """A document was requested with an invalid document identifier.""" + + wire_description = "invalid document id" + + +class MissingDocIds(U1DBError): + """Needs document ids.""" + + wire_description = "missing document ids" + + +class DocumentTooBig(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "document too big" + + +class UserQuotaExceeded(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "user quota exceeded" + + +class SubscriptionNeeded(U1DBError): + """User needs a subscription to be able to use this replica..""" + + wire_description = "user needs subscription" + + +class InvalidTransactionId(U1DBError): + """Invalid transaction for generation.""" + + wire_description = "invalid transaction id" + + +class InvalidGeneration(U1DBError): + """Generation was previously synced with a different transaction id.""" + + wire_description = "invalid generation" + + +class InvalidReplicaUID(U1DBError): + """Attempting to sync a database with itself.""" + + wire_description = "invalid replica uid" + + +class ConflictedDoc(U1DBError): + """The document is conflicted, you must call resolve before put()""" + + +class InvalidValueForIndex(U1DBError): + """The values supplied does not match the index definition.""" + + +class InvalidGlobbing(U1DBError): + """Raised if wildcard matches are not strictly at the tail of the request. + """ + + +class DocumentDoesNotExist(U1DBError): + """The document does not exist.""" + + wire_description = "document does not exist" + + +class DocumentAlreadyDeleted(U1DBError): + """The document was already deleted.""" + + wire_description = "document already deleted" + + +class DatabaseDoesNotExist(U1DBError): + """The database does not exist.""" + + wire_description = "database does not exist" + + +class IndexNameTakenError(U1DBError): + """The given index name is already taken.""" + + +class IndexDefinitionParseError(U1DBError): + """The index definition cannot be parsed.""" + + +class IndexDoesNotExist(U1DBError): + """No index of that name exists.""" + + +class Unauthorized(U1DBError): + """Request wasn't authorized properly.""" + + wire_description = "unauthorized" + + +class HTTPError(U1DBError): + """Unspecific HTTP errror.""" + + wire_description = None + + def __init__(self, status, message=None, headers={}): + self.status = status + self.message = message + self.headers = headers + + def __str__(self): + if not self.message: + return "HTTPError(%d)" % self.status + else: + return "HTTPError(%d, %r)" % (self.status, self.message) + + +class Unavailable(HTTPError): + """Server not available not serve request.""" + + wire_description = "unavailable" + + def __init__(self, message=None, headers={}): + super(Unavailable, self).__init__(503, message, headers) + + def __str__(self): + if not self.message: + return "Unavailable()" + else: + return "Unavailable(%r)" % self.message + + +class BrokenSyncStream(U1DBError): + """Unterminated or otherwise broken sync exchange stream.""" + + wire_description = None + + +class UnknownAuthMethod(U1DBError): + """Unknown auhorization method.""" + + wire_description = None + + +# mapping wire (transimission) descriptions/tags for errors to the exceptions +wire_description_to_exc = dict( + (x.wire_description, x) for x in globals().values() + if getattr(x, 'wire_description', None) not in (None, "error")) +wire_description_to_exc["error"] = U1DBError + + +# +# wire error descriptions not corresponding to an exception +DOCUMENT_DELETED = "document deleted" diff --git a/src/leap/soledad/common/l2db/query_parser.py b/src/leap/soledad/common/l2db/query_parser.py new file mode 100644 index 00000000..15a9ac80 --- /dev/null +++ b/src/leap/soledad/common/l2db/query_parser.py @@ -0,0 +1,371 @@ +# Copyright 2011 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. +""" +Code for parsing Index definitions. +""" + +import re + +from leap.soledad.common.l2db import errors + + +class Getter(object): + """Get values from a document based on a specification.""" + + def get(self, raw_doc): + """Get a value from the document. + + :param raw_doc: a python dictionary to get the value from. + :return: A list of values that match the description. + """ + raise NotImplementedError(self.get) + + +class StaticGetter(Getter): + """A getter that returns a defined value (independent of the doc).""" + + def __init__(self, value): + """Create a StaticGetter. + + :param value: the value to return when get is called. + """ + if value is None: + self.value = [] + elif isinstance(value, list): + self.value = value + else: + self.value = [value] + + def get(self, raw_doc): + return self.value + + +def extract_field(raw_doc, subfields, index=0): + if not isinstance(raw_doc, dict): + return [] + val = raw_doc.get(subfields[index]) + if val is None: + return [] + if index < len(subfields) - 1: + if isinstance(val, list): + results = [] + for item in val: + results.extend(extract_field(item, subfields, index + 1)) + return results + if isinstance(val, dict): + return extract_field(val, subfields, index + 1) + return [] + if isinstance(val, dict): + return [] + if isinstance(val, list): + # Strip anything in the list that isn't a simple type + return [v for v in val if not isinstance(v, (dict, list))] + return [val] + + +class ExtractField(Getter): + """Extract a field from the document.""" + + def __init__(self, field): + """Create an ExtractField object. + + When a document is passed to get() this will return a value + from the document based on the field specifier passed to + the constructor. + + None will be returned if the field is nonexistant, or refers to an + object, rather than a simple type or list of simple types. + + :param field: a specifier for the field to return. + This is either a field name, or a dotted field name. + """ + self.field = field.split('.') + + def get(self, raw_doc): + return extract_field(raw_doc, self.field) + + +class Transformation(Getter): + """A transformation on a value from another Getter.""" + + name = None + arity = 1 + args = ['expression'] + + def __init__(self, inner): + """Create a transformation. + + :param inner: the argument(s) to the transformation. + """ + self.inner = inner + + def get(self, raw_doc): + inner_values = self.inner.get(raw_doc) + assert isinstance(inner_values, list),\ + 'get() should always return a list' + return self.transform(inner_values) + + def transform(self, values): + """Transform the values. + + This should be implemented by subclasses to transform the + value when get() is called. + + :param values: the values from the other Getter + :return: the transformed values. + """ + raise NotImplementedError(self.transform) + + +class Lower(Transformation): + """Lowercase a string. + + This transformation will return None for non-string inputs. However, + it will lowercase any strings in a list, dropping any elements + that are not strings. + """ + + name = "lower" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + return [val.lower() for val in values if self._can_transform(val)] + + +class Number(Transformation): + """Convert an integer to a zero padded string. + + This transformation will return None for non-integer inputs. However, it + will transform any integers in a list, dropping any elements that are not + integers. + """ + + name = 'number' + arity = 2 + args = ['expression', int] + + def __init__(self, inner, number): + super(Number, self).__init__(inner) + self.padding = "%%0%sd" % number + + def _can_transform(self, val): + return isinstance(val, int) and not isinstance(val, bool) + + def transform(self, values): + """Transform any integers in values into zero padded strings.""" + if not values: + return [] + return [self.padding % (v,) for v in values if self._can_transform(v)] + + +class Bool(Transformation): + """Convert bool to string.""" + + name = "bool" + args = ['expression'] + + def _can_transform(self, val): + return isinstance(val, bool) + + def transform(self, values): + """Transform any booleans in values into strings.""" + if not values: + return [] + return [('1' if v else '0') for v in values if self._can_transform(v)] + + +class SplitWords(Transformation): + """Split a string on whitespace. + + This Getter will return [] for non-string inputs. It will however + split any strings in an input list, discarding any elements that + are not strings. + """ + + name = "split_words" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + result = set() + for value in values: + if self._can_transform(value): + for word in value.split(): + result.add(word) + return list(result) + + +class Combine(Transformation): + """Combine multiple expressions into a single index.""" + + name = "combine" + # variable number of args + arity = -1 + + def __init__(self, *inner): + super(Combine, self).__init__(inner) + + def get(self, raw_doc): + inner_values = [] + for inner in self.inner: + inner_values.extend(inner.get(raw_doc)) + return self.transform(inner_values) + + def transform(self, values): + return values + + +class IsNull(Transformation): + """Indicate whether the input is None. + + This Getter returns a bool indicating whether the input is nil. + """ + + name = "is_null" + + def transform(self, values): + return [len(values) == 0] + + +def check_fieldname(fieldname): + if fieldname.endswith('.'): + raise errors.IndexDefinitionParseError( + "Fieldname cannot end in '.':%s^" % (fieldname,)) + + +class Parser(object): + """Parse an index expression into a sequence of transformations.""" + + _transformations = {} + _delimiters = re.compile("\(|\)|,") + + def __init__(self): + self._tokens = [] + + def _set_expression(self, expression): + self._open_parens = 0 + self._tokens = [] + expression = expression.strip() + while expression: + delimiter = self._delimiters.search(expression) + if delimiter: + idx = delimiter.start() + if idx == 0: + result, expression = (expression[:1], expression[1:]) + self._tokens.append(result) + else: + result, expression = (expression[:idx], expression[idx:]) + result = result.strip() + if result: + self._tokens.append(result) + else: + expression = expression.strip() + if expression: + self._tokens.append(expression) + expression = None + + def _get_token(self): + if self._tokens: + return self._tokens.pop(0) + + def _peek_token(self): + if self._tokens: + return self._tokens[0] + + @staticmethod + def _to_getter(term): + if isinstance(term, Getter): + return term + check_fieldname(term) + return ExtractField(term) + + def _parse_op(self, op_name): + self._get_token() # '(' + op = self._transformations.get(op_name, None) + if op is None: + raise errors.IndexDefinitionParseError( + "Unknown operation: %s" % op_name) + args = [] + while True: + args.append(self._parse_term()) + sep = self._get_token() + if sep == ')': + break + if sep != ',': + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' in parentheses." % (sep,)) + parsed = [] + for i, arg in enumerate(args): + arg_type = op.args[i % len(op.args)] + if arg_type == 'expression': + inner = self._to_getter(arg) + else: + try: + inner = arg_type(arg) + except ValueError as e: + raise errors.IndexDefinitionParseError( + "Invalid value %r for argument type %r " + "(%r)." % (arg, arg_type, e)) + parsed.append(inner) + return op(*parsed) + + def _parse_term(self): + term = self._get_token() + if term is None: + raise errors.IndexDefinitionParseError( + "Unexpected end of index definition.") + if term in (',', ')', '('): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' at start of expression." % (term,)) + next_token = self._peek_token() + if next_token == '(': + return self._parse_op(term) + return term + + def parse(self, expression): + self._set_expression(expression) + term = self._to_getter(self._parse_term()) + if self._peek_token(): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' after end of expression." + % (self._peek_token(),)) + return term + + def parse_all(self, fields): + return [self.parse(field) for field in fields] + + @classmethod + def register_transormation(cls, transform): + assert transform.name not in cls._transformations, ( + "Transform %s already registered for %s" + % (transform.name, cls._transformations[transform.name])) + cls._transformations[transform.name] = transform + + +Parser.register_transormation(SplitWords) +Parser.register_transormation(Lower) +Parser.register_transormation(Number) +Parser.register_transormation(Bool) +Parser.register_transormation(IsNull) +Parser.register_transormation(Combine) diff --git a/src/leap/soledad/common/l2db/remote/__init__.py b/src/leap/soledad/common/l2db/remote/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. diff --git a/src/leap/soledad/common/l2db/remote/http_app.py b/src/leap/soledad/common/l2db/remote/http_app.py new file mode 100644 index 00000000..a4eddb36 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_app.py @@ -0,0 +1,660 @@ +# Copyright 2011-2012 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +""" +HTTP Application exposing U1DB. +""" +# TODO -- deprecate, use twisted/txaio. + +import functools +import six.moves.http_client as httplib +import inspect +import json +import sys +import six.moves.urllib.parse as urlparse + +import routes.mapper + +from leap.soledad.common.l2db import ( + __version__ as _u1db_version, + DBNAME_CONSTRAINTS, Document, + errors, sync) +from leap.soledad.common.l2db.remote import http_errors, utils + + +def parse_bool(expression): + """Parse boolean querystring parameter.""" + if expression == 'true': + return True + return False + + +def parse_list(expression): + if not expression: + return [] + return [t.strip() for t in expression.split(',')] + + +def none_or_str(expression): + if expression is None: + return None + return str(expression) + + +class BadRequest(Exception): + """Bad request.""" + + +class _FencedReader(object): + """Read and get lines from a file but not past a given length.""" + + MAXCHUNK = 8192 + + def __init__(self, rfile, total, max_entry_size): + self.rfile = rfile + self.remaining = total + self.max_entry_size = max_entry_size + self._kept = None + + def read_chunk(self, atmost): + if self._kept is not None: + # ignore atmost, kept data should be a subchunk anyway + kept, self._kept = self._kept, None + return kept + if self.remaining == 0: + return '' + data = self.rfile.read(min(self.remaining, atmost)) + self.remaining -= len(data) + return data + + def getline(self): + line_parts = [] + size = 0 + while True: + chunk = self.read_chunk(self.MAXCHUNK) + if chunk == '': + break + nl = chunk.find("\n") + if nl != -1: + size += nl + 1 + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk[:nl + 1]) + rest = chunk[nl + 1:] + self._kept = rest or None + break + else: + size += len(chunk) + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk) + return ''.join(line_parts) + + +def http_method(**control): + """Decoration for handling of query arguments and content for a HTTP + method. + + args and content here are the query arguments and body of the incoming + HTTP requests. + + Match query arguments to python method arguments: + w = http_method()(f) + w(self, args, content) => args["content"]=content; + f(self, **args) + + JSON deserialize content to arguments: + w = http_method(content_as_args=True,...)(f) + w(self, args, content) => args.update(json.loads(content)); + f(self, **args) + + Support conversions (e.g int): + w = http_method(Arg=Conv,...)(f) + w(self, args, content) => args["Arg"]=Conv(args["Arg"]); + f(self, **args) + + Enforce no use of query arguments: + w = http_method(no_query=True,...)(f) + w(self, args, content) raises BadRequest if args is not empty + + Argument mismatches, deserialisation failures produce BadRequest. + """ + content_as_args = control.pop('content_as_args', False) + no_query = control.pop('no_query', False) + conversions = control.items() + + def wrap(f): + argspec = inspect.getargspec(f) + assert argspec.args[0] == "self" + nargs = len(argspec.args) + ndefaults = len(argspec.defaults or ()) + required_args = set(argspec.args[1:nargs - ndefaults]) + all_args = set(argspec.args) + + @functools.wraps(f) + def wrapper(self, args, content): + if no_query and args: + raise BadRequest() + if content is not None: + if content_as_args: + try: + args.update(json.loads(content)) + except ValueError: + raise BadRequest() + else: + args["content"] = content + if not (required_args <= set(args) <= all_args): + raise BadRequest("Missing required arguments.") + for name, conv in conversions: + if name not in args: + continue + try: + args[name] = conv(args[name]) + except ValueError: + raise BadRequest() + return f(self, **args) + + return wrapper + + return wrap + + +class URLToResource(object): + """Mappings from URLs to resources.""" + + def __init__(self): + self._map = routes.mapper.Mapper(controller_scan=None) + + def register(self, resource_cls): + # register + self._map.connect(None, resource_cls.url_pattern, + resource_cls=resource_cls, + requirements={"dbname": DBNAME_CONSTRAINTS}) + self._map.create_regs() + return resource_cls + + def match(self, path): + params = self._map.match(path) + if params is None: + return None, None + resource_cls = params.pop('resource_cls') + return resource_cls, params + + +url_to_resource = URLToResource() + + +@url_to_resource.register +class GlobalResource(object): + """Global (root) resource.""" + + url_pattern = "/" + + def __init__(self, state, responder): + self.state = state + self.responder = responder + + @http_method() + def get(self): + info = self.state.global_info() + info['version'] = _u1db_version + self.responder.send_response_json(**info) + + +@url_to_resource.register +class DatabaseResource(object): + """Database resource.""" + + url_pattern = "/{dbname}" + + def __init__(self, dbname, state, responder): + self.dbname = dbname + self.state = state + self.responder = responder + + @http_method() + def get(self): + self.state.check_database(self.dbname) + self.responder.send_response_json(200) + + @http_method(content_as_args=True) + def put(self): + self.state.ensure_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + @http_method() + def delete(self): + self.state.delete_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + +@url_to_resource.register +class DocsResource(object): + """Documents resource.""" + + url_pattern = "/{dbname}/docs" + + def __init__(self, dbname, state, responder): + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool, + include_deleted=parse_bool) + def get(self, doc_ids=None, check_for_conflicts=True, + include_deleted=False): + if doc_ids is None: + raise errors.MissingDocIds + docs = self.db.get_docs(doc_ids, include_deleted=include_deleted) + self.responder.content_type = 'application/json' + self.responder.start_response(200) + self.responder.start_stream(), + for doc in docs: + entry = dict( + doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), + has_conflicts=doc.has_conflicts) + self.responder.stream_entry(entry) + self.responder.end_stream() + self.responder.finish_response() + + +@url_to_resource.register +class AllDocsResource(object): + """All Documents resource.""" + + url_pattern = "/{dbname}/all-docs" + + def __init__(self, dbname, state, responder): + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(include_deleted=parse_bool) + def get(self, include_deleted=False): + gen, docs = self.db.get_all_docs(include_deleted=include_deleted) + self.responder.content_type = 'application/json' + # returning a x-u1db-generation header is optional + # HTTPDatabase will fallback to return -1 if it's missing + self.responder.start_response(200, + headers={'x-u1db-generation': str(gen)}) + self.responder.start_stream(), + for doc in docs: + entry = dict( + doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), + has_conflicts=doc.has_conflicts) + self.responder.stream_entry(entry) + self.responder.end_stream() + self.responder.finish_response() + + +@url_to_resource.register +class DocResource(object): + """Document resource.""" + + url_pattern = "/{dbname}/doc/{id:.*}" + + def __init__(self, dbname, id, state, responder): + self.id = id + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(old_rev=str) + def put(self, content, old_rev=None): + doc = Document(self.id, old_rev, content) + doc_rev = self.db.put_doc(doc) + if old_rev is None: + status = 201 # created + else: + status = 200 + self.responder.send_response_json(status, rev=doc_rev) + + @http_method(old_rev=str) + def delete(self, old_rev=None): + doc = Document(self.id, old_rev, None) + self.db.delete_doc(doc) + self.responder.send_response_json(200, rev=doc.rev) + + @http_method(include_deleted=parse_bool) + def get(self, include_deleted=False): + doc = self.db.get_doc(self.id, include_deleted=include_deleted) + if doc is None: + wire_descr = errors.DocumentDoesNotExist.wire_description + self.responder.send_response_json( + http_errors.wire_description_to_status[wire_descr], + error=wire_descr, + headers={ + 'x-u1db-rev': '', + 'x-u1db-has-conflicts': 'false' + }) + return + headers = { + 'x-u1db-rev': doc.rev, + 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts) + } + if doc.is_tombstone(): + self.responder.send_response_json( + http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED], + error=errors.DOCUMENT_DELETED, + headers=headers) + else: + self.responder.send_response_content( + doc.get_json(), headers=headers) + + +@url_to_resource.register +class SyncResource(object): + """Sync endpoint resource.""" + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + url_pattern = "/{dbname}/sync-from/{source_replica_uid}" + + # pluggable + sync_exchange_class = sync.SyncExchange + + def __init__(self, dbname, source_replica_uid, state, responder): + self.source_replica_uid = source_replica_uid + self.responder = responder + self.state = state + self.dbname = dbname + self.replica_uid = None + + def get_target(self): + return self.state.open_database(self.dbname).get_sync_target() + + @http_method() + def get(self): + result = self.get_target().get_sync_info(self.source_replica_uid) + self.responder.send_response_json( + target_replica_uid=result[0], target_replica_generation=result[1], + target_replica_transaction_id=result[2], + source_replica_uid=self.source_replica_uid, + source_replica_generation=result[3], + source_transaction_id=result[4]) + + @http_method(generation=int, + content_as_args=True, no_query=True) + def put(self, generation, transaction_id): + self.get_target().record_sync_info(self.source_replica_uid, + generation, + transaction_id) + self.responder.send_response_json(ok=True) + + # Implements the same logic as LocalSyncTarget.sync_exchange + + @http_method(last_known_generation=int, last_known_trans_id=none_or_str, + content_as_args=True) + def post_args(self, last_known_generation, last_known_trans_id=None, + ensure=False): + if ensure: + db, self.replica_uid = self.state.ensure_database(self.dbname) + else: + db = self.state.open_database(self.dbname) + db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + self.sync_exch = self.sync_exchange_class( + db, self.source_replica_uid, last_known_generation) + + @http_method(content_as_args=True) + def post_stream_entry(self, id, rev, content, gen, trans_id): + doc = Document(id, rev, content) + self.sync_exch.insert_doc_from_source(doc, gen, trans_id) + + def post_end(self): + + def send_doc(doc, gen, trans_id): + entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + self.responder.stream_entry(entry) + + new_gen = self.sync_exch.find_changes_to_return() + self.responder.content_type = 'application/x-u1db-sync-stream' + self.responder.start_response(200) + self.responder.start_stream(), + header = {"new_generation": new_gen, + "new_transaction_id": self.sync_exch.new_trans_id} + if self.replica_uid is not None: + header['replica_uid'] = self.replica_uid + self.responder.stream_entry(header) + self.sync_exch.return_docs(send_doc) + self.responder.end_stream() + self.responder.finish_response() + + +class HTTPResponder(object): + """Encode responses from the server back to the client.""" + + # a multi document response will put args and documents + # each on one line of the response body + + def __init__(self, start_response): + self._started = False + self._stream_state = -1 + self._no_initial_obj = True + self.sent_response = False + self._start_response = start_response + self._write = None + self.content_type = 'application/json' + self.content = [] + + def start_response(self, status, obj_dic=None, headers={}): + """start sending response with optional first json object.""" + if self._started: + return + self._started = True + status_text = httplib.responses[status] + self._write = self._start_response( + '%d %s' % (status, status_text), + [('content-type', self.content_type), + ('cache-control', 'no-cache')] + + headers.items()) + # xxx version in headers + if obj_dic is not None: + self._no_initial_obj = False + self._write(json.dumps(obj_dic) + "\r\n") + + def finish_response(self): + """finish sending response.""" + self.sent_response = True + + def send_response_json(self, status=200, headers={}, **kwargs): + """send and finish response with json object body from keyword args.""" + content = json.dumps(kwargs) + "\r\n" + self.send_response_content(content, headers=headers, status=status) + + def send_response_content(self, content, status=200, headers={}): + """send and finish response with content""" + headers['content-length'] = str(len(content)) + self.start_response(status, headers=headers) + if self._stream_state == 1: + self.content = [',\r\n', content] + else: + self.content = [content] + self.finish_response() + + def start_stream(self): + "start stream (array) as part of the response." + assert self._started and self._no_initial_obj + self._stream_state = 0 + self._write("[") + + def stream_entry(self, entry): + "send stream entry as part of the response." + assert self._stream_state != -1 + if self._stream_state == 0: + self._stream_state = 1 + self._write('\r\n') + else: + self._write(',\r\n') + if type(entry) == dict: + entry = json.dumps(entry) + self._write(entry) + + def end_stream(self): + "end stream (array)." + assert self._stream_state != -1 + self._write("\r\n]\r\n") + + +class HTTPInvocationByMethodWithBody(object): + """Invoke methods on a resource.""" + + def __init__(self, resource, environ, parameters): + self.resource = resource + self.environ = environ + self.max_request_size = getattr( + resource, 'max_request_size', parameters.max_request_size) + self.max_entry_size = getattr( + resource, 'max_entry_size', parameters.max_entry_size) + + def _lookup(self, method): + try: + return getattr(self.resource, method) + except AttributeError: + raise BadRequest() + + def __call__(self): + args = urlparse.parse_qsl(self.environ['QUERY_STRING'], + strict_parsing=False) + try: + args = dict( + (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) + except ValueError: + raise BadRequest() + method = self.environ['REQUEST_METHOD'].lower() + if method in ('get', 'delete'): + meth = self._lookup(method) + return meth(args, None) + else: + # we expect content-length > 0, reconsider if we move + # to support chunked enconding + try: + content_length = int(self.environ['CONTENT_LENGTH']) + except (ValueError, KeyError): + raise BadRequest + if content_length <= 0: + raise BadRequest + if content_length > self.max_request_size: + raise BadRequest + reader = _FencedReader(self.environ['wsgi.input'], content_length, + self.max_entry_size) + content_type = self.environ.get('CONTENT_TYPE', '') + content_type = content_type.split(';', 1)[0].strip() + if content_type == 'application/json': + meth = self._lookup(method) + body = reader.read_chunk(sys.maxint) + return meth(args, body) + elif content_type == 'application/x-u1db-sync-stream': + meth_args = self._lookup('%s_args' % method) + meth_entry = self._lookup('%s_stream_entry' % method) + meth_end = self._lookup('%s_end' % method) + body_getline = reader.getline + if body_getline().strip() != '[': + raise BadRequest() + line = body_getline() + line, comma = utils.check_and_strip_comma(line.strip()) + meth_args(args, line) + while True: + line = body_getline() + entry = line.strip() + if entry == ']': + break + if not entry or not comma: # empty or no prec comma + raise BadRequest + entry, comma = utils.check_and_strip_comma(entry) + meth_entry({}, entry) + if comma or body_getline(): # extra comma or data + raise BadRequest + return meth_end() + else: + raise BadRequest() + + +class HTTPApp(object): + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + def __init__(self, state): + self.state = state + + def _lookup_resource(self, environ, responder): + resource_cls, params = url_to_resource.match(environ['PATH_INFO']) + if resource_cls is None: + raise BadRequest # 404 instead? + resource = resource_cls( + state=self.state, responder=responder, **params) + return resource + + def __call__(self, environ, start_response): + responder = HTTPResponder(start_response) + self.request_begin(environ) + try: + resource = self._lookup_resource(environ, responder) + HTTPInvocationByMethodWithBody(resource, environ, self)() + except errors.U1DBError as e: + self.request_u1db_error(environ, e) + status = http_errors.wire_description_to_status.get( + e.wire_description, 500) + responder.send_response_json(status, error=e.wire_description) + except BadRequest: + self.request_bad_request(environ) + responder.send_response_json(400, error="bad request") + except KeyboardInterrupt: + raise + except: + self.request_failed(environ) + raise + else: + self.request_done(environ) + return responder.content + + # hooks for tracing requests + + def request_begin(self, environ): + """Hook called at the beginning of processing a request.""" + pass + + def request_done(self, environ): + """Hook called when done processing a request.""" + pass + + def request_u1db_error(self, environ, exc): + """Hook called when processing a request resulted in a U1DBError. + + U1DBError passed as exc. + """ + pass + + def request_bad_request(self, environ): + """Hook called when processing a bad request. + + No actual processing was done. + """ + pass + + def request_failed(self, environ): + """Hook called when processing a request failed unexpectedly. + + Invoked from an except block, so there's interpreter exception + information available. + """ + pass diff --git a/src/leap/soledad/common/l2db/remote/http_client.py b/src/leap/soledad/common/l2db/remote/http_client.py new file mode 100644 index 00000000..1124b038 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_client.py @@ -0,0 +1,178 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Base class to make requests to a remote HTTP server.""" + +import json +import socket +import ssl +import sys +import urllib +import six.moves.urllib.parse as urlparse +import six.moves.http_client as httplib +from time import sleep +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.remote import http_errors + +from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname + +# Ubuntu/debian +# XXX other... +CA_CERTS = "/etc/ssl/certs/ca-certificates.crt" + + +def _encode_query_parameter(value): + """Encode query parameter.""" + if isinstance(value, bool): + if value: + value = 'true' + else: + value = 'false' + return unicode(value).encode('utf-8') + + +class _VerifiedHTTPSConnection(httplib.HTTPSConnection): + """HTTPSConnection verifying server side certificates.""" + # derived from httplib.py + + def connect(self): + "Connect to a host on a given (SSL) port." + + sock = socket.create_connection((self.host, self.port), + self.timeout, self.source_address) + if self._tunnel_host: + self.sock = sock + self._tunnel() + if sys.platform.startswith('linux'): + cert_opts = { + 'cert_reqs': ssl.CERT_REQUIRED, + 'ca_certs': CA_CERTS + } + else: + # XXX no cert verification implemented elsewhere for now + cert_opts = {} + self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, + ssl_version=ssl.PROTOCOL_SSLv3, + **cert_opts + ) + if cert_opts: + match_hostname(self.sock.getpeercert(), self.host) + + +class HTTPClientBase(object): + """Base class to make requests to a remote HTTP server.""" + + # Will use these delays to retry on 503 befor finally giving up. The final + # 0 is there to not wait after the final try fails. + _delays = (1, 1, 2, 4, 0) + + def __init__(self, url, creds=None): + self._url = urlparse.urlsplit(url) + self._conn = None + self._creds = {} + if creds is not None: + if len(creds) != 1: + raise errors.UnknownAuthMethod() + auth_meth, credentials = creds.items()[0] + try: + set_creds = getattr(self, 'set_%s_credentials' % auth_meth) + except AttributeError: + raise errors.UnknownAuthMethod(auth_meth) + set_creds(**credentials) + + def _ensure_connection(self): + if self._conn is not None: + return + if self._url.scheme == 'https': + connClass = _VerifiedHTTPSConnection + else: + connClass = httplib.HTTPConnection + self._conn = connClass(self._url.hostname, self._url.port) + + def close(self): + if self._conn: + self._conn.close() + self._conn = None + + # xxx retry mechanism? + + def _error(self, respdic): + descr = respdic.get("error") + exc_cls = errors.wire_description_to_exc.get(descr) + if exc_cls is not None: + message = respdic.get("message") + raise exc_cls(message) + + def _response(self): + resp = self._conn.getresponse() + body = resp.read() + headers = dict(resp.getheaders()) + if resp.status in (200, 201): + return body, headers + elif resp.status in http_errors.ERROR_STATUSES: + try: + respdic = json.loads(body) + except ValueError: + pass + else: + self._error(respdic) + # special case + if resp.status == 503: + raise errors.Unavailable(body, headers) + raise errors.HTTPError(resp.status, body, headers) + + def _sign_request(self, method, url_query, params): + raise NotImplementedError + + def _request(self, method, url_parts, params=None, body=None, + content_type=None): + self._ensure_connection() + unquoted_url = url_query = self._url.path + if url_parts: + if not url_query.endswith('/'): + url_query += '/' + unquoted_url = url_query + url_query += '/'.join(urllib.quote(part, safe='') + for part in url_parts) + # oauth performs its own quoting + unquoted_url += '/'.join(url_parts) + encoded_params = {} + if params: + for key, value in params.items(): + key = unicode(key).encode('utf-8') + encoded_params[key] = _encode_query_parameter(value) + url_query += ('?' + urllib.urlencode(encoded_params)) + if body is not None and not isinstance(body, basestring): + body = json.dumps(body) + content_type = 'application/json' + headers = {} + if content_type: + headers['content-type'] = content_type + headers.update( + self._sign_request(method, unquoted_url, encoded_params)) + for delay in self._delays: + try: + self._conn.request(method, url_query, body, headers) + return self._response() + except errors.Unavailable as e: + sleep(delay) + raise e + + def _request_json(self, method, url_parts, params=None, body=None, + content_type=None): + res, headers = self._request(method, url_parts, params, body, + content_type) + return json.loads(res), headers diff --git a/src/leap/soledad/common/l2db/remote/http_database.py b/src/leap/soledad/common/l2db/remote/http_database.py new file mode 100644 index 00000000..7e61e5a4 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_database.py @@ -0,0 +1,158 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""HTTPDatabase to access a remote db over the HTTP API.""" + +import json +import uuid + +from leap.soledad.common.l2db import ( + Database, + Document, + errors) +from leap.soledad.common.l2db.remote import ( + http_client, + http_errors, + http_target) + + +DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED] + + +class HTTPDatabase(http_client.HTTPClientBase, Database): + """Implement the Database API to a remote HTTP server.""" + + def __init__(self, url, document_factory=None, creds=None): + super(HTTPDatabase, self).__init__(url, creds=creds) + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + @staticmethod + def open_database(url, create): + db = HTTPDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = HTTPDatabase(url) + db._delete() + db.close() + + def open(self, create): + if create: + self._ensure() + else: + self._check() + + def _check(self): + return self._request_json('GET', [])[0] + + def _ensure(self): + self._request_json('PUT', [], {}, {}) + + def _delete(self): + self._request_json('DELETE', [], {}, {}) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {} + if doc.rev is not None: + params['old_rev'] = doc.rev + res, headers = self._request_json('PUT', ['doc', doc.doc_id], params, + doc.get_json(), 'application/json') + doc.rev = res['rev'] + return res['rev'] + + def get_doc(self, doc_id, include_deleted=False): + try: + res, headers = self._request( + 'GET', ['doc', doc_id], {"include_deleted": include_deleted}) + except errors.DocumentDoesNotExist: + return None + except errors.HTTPError as e: + if (e.status == DOCUMENT_DELETED_STATUS and + 'x-u1db-rev' in e.headers): + res = None + headers = e.headers + else: + raise + doc_rev = headers['x-u1db-rev'] + has_conflicts = json.loads(headers['x-u1db-has-conflicts']) + doc = self._factory(doc_id, doc_rev, res) + doc.has_conflicts = has_conflicts + return doc + + def _build_docs(self, res): + for doc_dict in json.loads(res): + doc = self._factory( + doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content']) + doc.has_conflicts = doc_dict['has_conflicts'] + yield doc + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + if not doc_ids: + return [] + doc_ids = ','.join(doc_ids) + res, headers = self._request( + 'GET', ['docs'], { + "doc_ids": doc_ids, "include_deleted": include_deleted, + "check_for_conflicts": check_for_conflicts}) + return self._build_docs(res) + + def get_all_docs(self, include_deleted=False): + res, headers = self._request( + 'GET', ['all-docs'], {"include_deleted": include_deleted}) + gen = -1 + if 'x-u1db-generation' in headers: + gen = int(headers['x-u1db-generation']) + return gen, list(self._build_docs(res)) + + def _allocate_doc_id(self): + return 'D-%s' % (uuid.uuid4().hex,) + + def create_doc(self, content, doc_id=None): + if not isinstance(content, dict): + raise errors.InvalidContent + json_string = json.dumps(content) + return self.create_doc_from_json(json_string, doc_id) + + def create_doc_from_json(self, content, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + res, headers = self._request_json('PUT', ['doc', doc_id], {}, + content, 'application/json') + new_doc = self._factory(doc_id, res['rev'], content) + return new_doc + + def delete_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {'old_rev': doc.rev} + res, headers = self._request_json( + 'DELETE', ['doc', doc.doc_id], params) + doc.make_tombstone() + doc.rev = res['rev'] + + def get_sync_target(self): + st = http_target.HTTPSyncTarget(self._url.geturl()) + st._creds = self._creds + return st diff --git a/src/leap/soledad/common/l2db/remote/http_errors.py b/src/leap/soledad/common/l2db/remote/http_errors.py new file mode 100644 index 00000000..ee4cfefa --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_errors.py @@ -0,0 +1,48 @@ +# Copyright 2011-2012 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +""" +Information about the encoding of errors over HTTP. +""" + +from leap.soledad.common.l2db import errors + + +# error wire descriptions mapping to HTTP status codes +wire_description_to_status = dict([ + (errors.InvalidDocId.wire_description, 400), + (errors.MissingDocIds.wire_description, 400), + (errors.Unauthorized.wire_description, 401), + (errors.DocumentTooBig.wire_description, 403), + (errors.UserQuotaExceeded.wire_description, 403), + (errors.SubscriptionNeeded.wire_description, 403), + (errors.DatabaseDoesNotExist.wire_description, 404), + (errors.DocumentDoesNotExist.wire_description, 404), + (errors.DocumentAlreadyDeleted.wire_description, 404), + (errors.RevisionConflict.wire_description, 409), + (errors.InvalidGeneration.wire_description, 409), + (errors.InvalidReplicaUID.wire_description, 409), + (errors.InvalidTransactionId.wire_description, 409), + (errors.Unavailable.wire_description, 503), + # without matching exception + (errors.DOCUMENT_DELETED, 404) +]) + + +ERROR_STATUSES = set(wire_description_to_status.values()) +# 400 included explicitly for tests +ERROR_STATUSES.add(400) diff --git a/src/leap/soledad/common/l2db/remote/http_target.py b/src/leap/soledad/common/l2db/remote/http_target.py new file mode 100644 index 00000000..38804f01 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_target.py @@ -0,0 +1,125 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""SyncTarget API implementation to a remote HTTP server.""" + +import json + +from leap.soledad.common.l2db import Document, SyncTarget +from leap.soledad.common.l2db.errors import BrokenSyncStream +from leap.soledad.common.l2db.remote import ( + http_client, utils) + + +class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget): + """Implement the SyncTarget api to a remote HTTP server.""" + + @staticmethod + def connect(url): + return HTTPSyncTarget(url) + + def get_sync_info(self, source_replica_uid): + self._ensure_connection() + res, _ = self._request_json('GET', ['sync-from', source_replica_uid]) + return (res['target_replica_uid'], res['target_replica_generation'], + res['target_replica_transaction_id'], + res['source_replica_generation'], res['source_transaction_id']) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('record_sync_info') + self._request_json('PUT', ['sync-from', source_replica_uid], {}, + {'generation': source_replica_generation, + 'transaction_id': source_transaction_id}) + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = Document(entry['id'], entry['rev'], entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] + + # for tests + _trace_hook = None + + def _set_trace_hook_shallow(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/common/l2db/remote/server_state.py b/src/leap/soledad/common/l2db/remote/server_state.py new file mode 100644 index 00000000..d4c3c45f --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/server_state.py @@ -0,0 +1,68 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""State for servers exposing a set of U1DB databases.""" + + +class ServerState(object): + """Passed to a Request when it is instantiated. + + This is used to track server-side state, such as working-directory, open + databases, etc. + """ + + def __init__(self): + self._workingdir = None + + def set_workingdir(self, path): + self._workingdir = path + + def global_info(self): + """Return global information about the server.""" + return {} + + def _relpath(self, relpath): + # Note: We don't want to allow absolute paths here, because we + # don't want to expose the filesystem. We should also check that + # relpath doesn't have '..' in it, etc. + return self._workingdir + '/' + relpath + + def open_database(self, path): + """Open a database at the given location.""" + from leap.soledad.client._db import sqlite + full_path = self._relpath(path) + return sqlite.SQLiteDatabase.open_database(full_path, create=False) + + def check_database(self, path): + """Check if the database at the given location exists. + + Simply returns if it does or raises DatabaseDoesNotExist. + """ + db = self.open_database(path) + db.close() + + def ensure_database(self, path): + """Ensure database at the given location.""" + from leap.soledad.client._db import sqlite + full_path = self._relpath(path) + db = sqlite.SQLiteDatabase.open_database(full_path, create=True) + return db, db._replica_uid + + def delete_database(self, path): + """Delete database at the given location.""" + from leap.soledad.client._db import sqlite + full_path = self._relpath(path) + sqlite.SQLiteDatabase.delete_database(full_path) diff --git a/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py b/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py new file mode 100644 index 00000000..ce82f1b2 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py @@ -0,0 +1,65 @@ +"""The match_hostname() function from Python 3.2, essential when using SSL.""" +# XXX put it here until it's packaged + +import re + +__version__ = '3.2a3' + + +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn): + pats = [] + for frag in dn.split(r'.'): + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not san: + # The subject is only checked when subjectAltName is empty + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError( + "hostname %r doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError( + "hostname %r doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError( + "no appropriate commonName or " + "subjectAltName fields were found") diff --git a/src/leap/soledad/common/l2db/remote/utils.py b/src/leap/soledad/common/l2db/remote/utils.py new file mode 100644 index 00000000..14cedea9 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/utils.py @@ -0,0 +1,23 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Utilities for details of the procotol.""" + + +def check_and_strip_comma(line): + if line and line[-1] == ',': + return line[:-1], True + return line, False diff --git a/src/leap/soledad/common/l2db/sync.py b/src/leap/soledad/common/l2db/sync.py new file mode 100644 index 00000000..32281f30 --- /dev/null +++ b/src/leap/soledad/common/l2db/sync.py @@ -0,0 +1,311 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The synchronization utilities for U1DB.""" +from six.moves import zip as izip + +from leap.soledad.common import l2db +from leap.soledad.common.l2db import errors + + +class Synchronizer(object): + """Collect the state around synchronizing 2 U1DB replicas. + + Synchronization is bi-directional, in that new items in the source are sent + to the target, and new items in the target are returned to the source. + However, it still recognizes that one side is initiating the request. Also, + at the moment, conflicts are only created in the source. + """ + + def __init__(self, source, sync_target): + """Create a new Synchronization object. + + :param source: A Database + :param sync_target: A SyncTarget + """ + self.source = source + self.sync_target = sync_target + self.target_replica_uid = None + self.num_inserted = 0 + + def _insert_doc_from_target(self, doc, replica_gen, trans_id): + """Try to insert synced document from target. + + Implements TAKE OTHER semantics: any document from the target + that is in conflict will be taken as the new official value, + while the current conflicting value will be stored alongside + as a conflict. In the process indexes will be updated etc. + + :return: None + """ + # Increases self.num_inserted depending whether the document + # was effectively inserted. + state, _ = self.source._put_doc_if_newer( + doc, save_conflict=True, + replica_uid=self.target_replica_uid, replica_gen=replica_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.num_inserted += 1 + elif state == 'converged': + # magical convergence + pass + elif state == 'superseded': + # we have something newer, will be taken care of at the next sync + pass + else: + assert state == 'conflicted' + # The doc was saved as a conflict, so the database was updated + self.num_inserted += 1 + + def _record_sync_info_with_the_target(self, start_generation): + """Record our new after sync generation with the target if gapless. + + Any documents received from the target will cause the local + database to increment its generation. We do not want to send + them back to the target in a future sync. However, there could + also be concurrent updates from another process doing eg + 'put_doc' while the sync was running. And we do want to + synchronize those documents. We can tell if there was a + concurrent update by comparing our new generation number + versus the generation we started, and how many documents we + inserted from the target. If it matches exactly, then we can + record with the target that they are fully up to date with our + new generation. + """ + cur_gen, trans_id = self.source._get_generation_info() + last_gen = start_generation + self.num_inserted + if (cur_gen == last_gen and self.num_inserted > 0): + self.sync_target.record_sync_info( + self.source._replica_uid, cur_gen, trans_id) + + def sync(self, callback=None, autocreate=False): + """Synchronize documents between source and target.""" + sync_target = self.sync_target + # get target identifier, its current generation, + # and its last-seen database generation for this source + try: + (self.target_replica_uid, target_gen, target_trans_id, + target_my_gen, target_my_trans_id) = sync_target.get_sync_info( + self.source._replica_uid) + except errors.DatabaseDoesNotExist: + if not autocreate: + raise + # will try to ask sync_exchange() to create the db + self.target_replica_uid = None + target_gen, target_trans_id = 0, '' + target_my_gen, target_my_trans_id = 0, '' + + def ensure_callback(replica_uid): + self.target_replica_uid = replica_uid + + else: + ensure_callback = None + if self.target_replica_uid == self.source._replica_uid: + raise errors.InvalidReplicaUID + # validate the generation and transaction id the target knows about us + self.source.validate_gen_and_trans_id( + target_my_gen, target_my_trans_id) + # what's changed since that generation and this current gen + my_gen, _, changes = self.source.whats_changed(target_my_gen) + + # this source last-seen database generation for the target + if self.target_replica_uid is None: + target_last_known_gen, target_last_known_trans_id = 0, '' + else: + target_last_known_gen, target_last_known_trans_id = ( + self.source._get_replica_gen_and_trans_id( # nopep8 + self.target_replica_uid)) + if not changes and target_last_known_gen == target_gen: + if target_trans_id != target_last_known_trans_id: + raise errors.InvalidTransactionId + return my_gen + changed_doc_ids = [doc_id for doc_id, _, _ in changes] + # prepare to send all the changed docs + docs_to_send = self.source.get_docs( + changed_doc_ids, + check_for_conflicts=False, include_deleted=True) + # TODO: there must be a way to not iterate twice + docs_by_generation = zip( + docs_to_send, (gen for _, gen, _ in changes), + (trans for _, _, trans in changes)) + + # exchange documents and try to insert the returned ones with + # the target, return target synced-up-to gen + new_gen, new_trans_id = sync_target.sync_exchange( + docs_by_generation, self.source._replica_uid, + target_last_known_gen, target_last_known_trans_id, + self._insert_doc_from_target, ensure_callback=ensure_callback) + # record target synced-up-to generation including applying what we sent + self.source._set_replica_gen_and_trans_id( + self.target_replica_uid, new_gen, new_trans_id) + + # if gapless record current reached generation with target + self._record_sync_info_with_the_target(my_gen) + + return my_gen + + +class SyncExchange(object): + """Steps and state for carrying through a sync exchange on a target.""" + + def __init__(self, db, source_replica_uid, last_known_generation): + self._db = db + self.source_replica_uid = source_replica_uid + self.source_last_known_generation = last_known_generation + self.seen_ids = {} # incoming ids not superseded + self.changes_to_return = None + self.new_gen = None + self.new_trans_id = None + # for tests + self._incoming_trace = [] + self._trace_hook = None + self._db._last_exchange_log = { + 'receive': {'docs': self._incoming_trace}, + 'return': None + } + + def _set_trace_hook(self, cb): + self._trace_hook = cb + + def _trace(self, state): + if not self._trace_hook: + return + self._trace_hook(state) + + def insert_doc_from_source(self, doc, source_gen, trans_id): + """Try to insert synced document from source. + + Conflicting documents are not inserted but will be sent over + to the sync source. + + It keeps track of progress by storing the document source + generation as well. + + The 1st step of a sync exchange is to call this repeatedly to + try insert all incoming documents from the source. + + :param doc: A Document object. + :param source_gen: The source generation of doc. + :return: None + """ + state, at_gen = self._db._put_doc_if_newer( + doc, save_conflict=False, + replica_uid=self.source_replica_uid, replica_gen=source_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.seen_ids[doc.doc_id] = at_gen + elif state == 'converged': + # magical convergence + self.seen_ids[doc.doc_id] = at_gen + elif state == 'superseded': + # we have something newer that we will return + pass + else: + # conflict that we will returne + assert state == 'conflicted' + # for tests + self._incoming_trace.append((doc.doc_id, doc.rev)) + self._db._last_exchange_log['receive'].update({ + 'source_uid': self.source_replica_uid, + 'source_gen': source_gen + }) + + def find_changes_to_return(self): + """Find changes to return. + + Find changes since last_known_generation in db generation + order using whats_changed. It excludes documents ids that have + already been considered (superseded by the sender, etc). + + :return: new_generation - the generation of this database + which the caller can consider themselves to be synchronized after + processing the returned documents. + """ + self._db._last_exchange_log['receive'].update({ # for tests + 'last_known_gen': self.source_last_known_generation + }) + self._trace('before whats_changed') + gen, trans_id, changes = self._db.whats_changed( + self.source_last_known_generation) + self._trace('after whats_changed') + self.new_gen = gen + self.new_trans_id = trans_id + seen_ids = self.seen_ids + # changed docs that weren't superseded by or converged with + self.changes_to_return = [ + (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes if + # there was a subsequent update + doc_id not in seen_ids or seen_ids.get(doc_id) < gen] + return self.new_gen + + def return_docs(self, return_doc_cb): + """Return the changed documents and their last change generation + repeatedly invoking the callback return_doc_cb. + + The final step of a sync exchange. + + :param: return_doc_cb(doc, gen, trans_id): is a callback + used to return the documents with their last change generation + to the target replica. + :return: None + """ + changes_to_return = self.changes_to_return + # return docs, including conflicts + changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] + self._trace('before get_docs') + docs = self._db.get_docs( + changed_doc_ids, check_for_conflicts=False, include_deleted=True) + + docs_by_gen = izip( + docs, (gen for _, gen, _ in changes_to_return), + (trans_id for _, _, trans_id in changes_to_return)) + _outgoing_trace = [] # for tests + for doc, gen, trans_id in docs_by_gen: + return_doc_cb(doc, gen, trans_id) + _outgoing_trace.append((doc.doc_id, doc.rev)) + # for tests + self._db._last_exchange_log['return'] = { + 'docs': _outgoing_trace, + 'last_gen': self.new_gen} + + +class LocalSyncTarget(l2db.SyncTarget): + """Common sync target implementation logic for all local sync targets.""" + + def __init__(self, db): + self._db = db + self._trace_hook = None + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + sync_exch = SyncExchange( + self._db, source_replica_uid, last_known_generation) + if self._trace_hook: + sync_exch._set_trace_hook(self._trace_hook) + # 1st step: try to insert incoming docs and record progress + for doc, doc_gen, trans_id in docs_by_generations: + sync_exch.insert_doc_from_source(doc, doc_gen, trans_id) + # 2nd step: find changed documents (including conflicts) to return + new_gen = sync_exch.find_changes_to_return() + # final step: return docs and record source replica sync point + sync_exch.return_docs(return_doc_cb) + return new_gen, sync_exch.new_trans_id + + def _set_trace_hook(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/common/l2db/vectorclock.py b/src/leap/soledad/common/l2db/vectorclock.py new file mode 100644 index 00000000..42bceaa8 --- /dev/null +++ b/src/leap/soledad/common/l2db/vectorclock.py @@ -0,0 +1,89 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""VectorClockRev helper class.""" + + +class VectorClockRev(object): + """Track vector clocks for multiple replica ids. + + This allows simple comparison to determine if one VectorClockRev is + newer/older/in-conflict-with another VectorClockRev without having to + examine history. Every replica has a strictly increasing revision. When + creating a new revision, they include all revisions for all other replicas + which the new revision dominates, and increment their own revision to + something greater than the current value. + """ + + def __init__(self, value): + self._values = self._expand(value) + + def __repr__(self): + s = self.as_str() + return '%s(%s)' % (self.__class__.__name__, s) + + def as_str(self): + s = '|'.join(['%s:%d' % (m, r) for m, r + in sorted(self._values.items())]) + return s + + def _expand(self, value): + result = {} + if value is None: + return result + for replica_info in value.split('|'): + replica_uid, counter = replica_info.split(':') + counter = int(counter) + result[replica_uid] = counter + return result + + def is_newer(self, other): + """Is this VectorClockRev strictly newer than other. + """ + if not self._values: + return False + if not other._values: + return True + this_is_newer = False + other_expand = dict(other._values) + for key, value in self._values.iteritems(): + if key in other_expand: + other_value = other_expand.pop(key) + if other_value > value: + return False + elif other_value < value: + this_is_newer = True + else: + this_is_newer = True + if other_expand: + return False + return this_is_newer + + def increment(self, replica_uid): + """Increase the 'replica_uid' section of this vector clock. + + :return: A string representing the new vector clock value + """ + self._values[replica_uid] = self._values.get(replica_uid, 0) + 1 + + def maximize(self, other_vcr): + for replica_uid, counter in other_vcr._values.iteritems(): + if replica_uid not in self._values: + self._values[replica_uid] = counter + else: + this_counter = self._values[replica_uid] + if this_counter < counter: + self._values[replica_uid] = counter |