diff options
author | drebs <drebs@leap.se> | 2017-06-18 11:18:10 -0300 |
---|---|---|
committer | Kali Kaneko <kali@leap.se> | 2017-06-24 00:49:17 +0200 |
commit | 3e94cafa43d464d73815e21810b97a4faf54136d (patch) | |
tree | 04f5152e3dfc9f27b7dca2368c0eb8b3f094f5b5 /src/leap/soledad/common/l2db | |
parent | 7d8ee786b086e47264619df3efa73e74440fd068 (diff) |
[pkg] unify client and server into a single python package
We have been discussing about this merge for a while.
Its main goal is to simplify things: code navigation, but also
packaging.
The rationale is that the code is more cohesive in this way, and there's
only one source package to install.
Dependencies that are only for the server or the client will not be
installed by default, and they are expected to be provided by the
environment. There are setuptools extras defined for the client and the
server.
Debianization is still expected to split the single source package into
3 binaries.
Another avantage is that the documentation can now install a single
package with a single step, and therefore include the docstrings into
the generated docs.
- Resolves: #8896
Diffstat (limited to 'src/leap/soledad/common/l2db')
-rw-r--r-- | src/leap/soledad/common/l2db/__init__.py | 694 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/backends/__init__.py | 204 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/backends/inmemory.py | 466 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/errors.py | 194 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/query_parser.py | 371 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/__init__.py | 15 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_app.py | 660 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_client.py | 178 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_database.py | 158 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_errors.py | 48 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/http_target.py | 125 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/server_state.py | 68 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/ssl_match_hostname.py | 65 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/remote/utils.py | 23 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/sync.py | 311 | ||||
-rw-r--r-- | src/leap/soledad/common/l2db/vectorclock.py | 89 |
16 files changed, 3669 insertions, 0 deletions
diff --git a/src/leap/soledad/common/l2db/__init__.py b/src/leap/soledad/common/l2db/__init__.py new file mode 100644 index 00000000..43d61b1d --- /dev/null +++ b/src/leap/soledad/common/l2db/__init__.py @@ -0,0 +1,694 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""L2DB""" + +import json + +from leap.soledad.common.l2db.errors import InvalidJSON, InvalidContent + +__version_info__ = (13, 9) +__version__ = '.'.join(map(lambda x: '%02d' % x, __version_info__)) + + +def open(path, create, document_factory=None): + """Open a database at the given location. + + Will raise u1db.errors.DatabaseDoesNotExist if create=False and the + database does not already exist. + + :param path: The filesystem path for the database to open. + :param create: True/False, should the database be created if it doesn't + already exist? + :param document_factory: A function that will be called with the same + parameters as Document.__init__. + :return: An instance of Database. + """ + from leap.soledad.client._db import sqlite + return sqlite.SQLiteDatabase.open_database( + path, create=create, document_factory=document_factory) + + +# constraints on database names (relevant for remote access, as regex) +DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*" + +# constraints on doc ids (as regex) +# (no slashes, and no characters outside the ascii range) +DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+" + + +class Database(object): + """A JSON Document data store. + + This data store can be synchronized with other u1db.Database instances. + """ + + def set_document_factory(self, factory): + """Set the document factory that will be used to create objects to be + returned as documents by the database. + + :param factory: A function that returns an object which at minimum must + satisfy the same interface as does the class DocumentBase. + Subclassing that class is the easiest way to create such + a function. + """ + raise NotImplementedError(self.set_document_factory) + + def set_document_size_limit(self, limit): + """Set the maximum allowed document size for this database. + + :param limit: Maximum allowed document size in bytes. + """ + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + """Return a list of documents that have changed since old_generation. + This allows APPS to only store a db generation before going + 'offline', and then when coming back online they can use this + data to update whatever extra data they are storing. + + :param old_generation: The generation of the database in the old + state. + :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) + The current generation of the database, its associated transaction + id, and a list of of changed documents since old_generation, + represented by tuples with for each document its doc_id and the + generation and transaction id corresponding to the last intervening + change and sorted by generation (old changes first) + """ + raise NotImplementedError(self.whats_changed) + + def get_doc(self, doc_id, include_deleted=False): + """Get the JSON string for the given document. + + :param doc_id: The unique document identifier + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise asking for a deleted + document will return None. + :return: a Document object. + """ + raise NotImplementedError(self.get_doc) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + """Get the JSON content for many documents. + + :param doc_ids: A list of document identifiers. + :param check_for_conflicts: If set to False, then the conflict check + will be skipped, and 'None' will be returned instead of True/False. + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: iterable giving the Document object for each document id + in matching doc_ids order. + """ + raise NotImplementedError(self.get_docs) + + def get_all_docs(self, include_deleted=False): + """Get the JSON content for all documents in the database. + + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: (generation, [Document]) + The current generation of the database, followed by a list of all + the documents in the database. + """ + raise NotImplementedError(self.get_all_docs) + + def create_doc(self, content, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param content: A Python dictionary. + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc) + + def create_doc_from_json(self, json, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param json: The JSON document string + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc_from_json) + + def put_doc(self, doc): + """Update a document. + If the document currently has conflicts, put will fail. + If the database specifies a maximum document size and the document + exceeds it, put will fail and raise a DocumentTooBig exception. + + :param doc: A Document with new content. + :return: new_doc_rev - The new revision identifier for the document. + The Document object will also be updated. + """ + raise NotImplementedError(self.put_doc) + + def delete_doc(self, doc): + """Mark a document as deleted. + Will abort if the current revision doesn't match doc.rev. + This will also set doc.content to None. + """ + raise NotImplementedError(self.delete_doc) + + def create_index(self, index_name, *index_expressions): + """Create an named index, which can then be queried for future lookups. + Creating an index which already exists is not an error, and is cheap. + Creating an index which does not match the index_expressions of the + existing index is an error. + Creating an index will block until the expressions have been evaluated + and the index generated. + + :param index_name: A unique name which can be used as a key prefix + :param index_expressions: index expressions defining the index + information. + + Examples: + + "fieldname", or "fieldname.subfieldname" to index alphabetically + sorted on the contents of a field. + + "number(fieldname, width)", "lower(fieldname)" + """ + raise NotImplementedError(self.create_index) + + def delete_index(self, index_name): + """Remove a named index. + + :param index_name: The name of the index we are removing + """ + raise NotImplementedError(self.delete_index) + + def list_indexes(self): + """List the definitions of all known indexes. + + :return: A list of [('index-name', ['field', 'field2'])] definitions. + """ + raise NotImplementedError(self.list_indexes) + + def get_from_index(self, index_name, *key_values): + """Return documents that match the keys supplied. + + You must supply exactly the same number of values as have been defined + in the index. It is possible to do a prefix match by using '*' to + indicate a wildcard match. You can only supply '*' to trailing entries, + (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) + It is also possible to append a '*' to the last supplied value (eg + 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_from_index) + + def get_range_from_index(self, index_name, start_value, end_value): + """Return documents that fall within the specified range. + + Both ends of the range are inclusive. For both start_value and + end_value, one must supply exactly the same number of values as have + been defined in the index, or pass None. In case of a single column + index, a string is accepted as an alternative for a tuple with a single + value. It is possible to do a prefix match by using '*' to indicate + a wildcard match. You can only supply '*' to trailing entries, (eg + 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also + possible to append a '*' to the last supplied value (eg 'val*', '*', + '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param start_values: tuples of values that define the lower bound of + the range. eg, if you have an index with 3 fields then you would + have: (val1, val2, val3) + :param end_values: tuples of values that define the upper bound of the + range. eg, if you have an index with 3 fields then you would have: + (val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_range_from_index) + + def get_index_keys(self, index_name): + """Return all keys under which documents are indexed in this index. + + :param index_name: The index to query + :return: [] A list of tuples of indexed keys. + """ + raise NotImplementedError(self.get_index_keys) + + def get_doc_conflicts(self, doc_id): + """Get the list of conflicts for the given document. + + The order of the conflicts is such that the first entry is the value + that would be returned by "get_doc". + + :return: [doc] A list of the Document entries that are conflicted. + """ + raise NotImplementedError(self.get_doc_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + """Mark a document as no longer conflicted. + + We take the list of revisions that the client knows about that it is + superseding. This may be a different list from the actual current + conflicts, in which case only those are removed as conflicted. This + may fail if the conflict list is significantly different from the + supplied information. (sync could have happened in the background from + the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + + :param doc: A Document with the new content to be inserted. + :param conflicted_doc_revs: A list of revisions that the new content + supersedes. + """ + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + """Return a SyncTarget object, for another u1db to synchronize with. + + :return: An instance of SyncTarget. + """ + raise NotImplementedError(self.get_sync_target) + + def close(self): + """Release any resources associated with this database.""" + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + """Synchronize documents with remote replica exposed at url. + + :param url: the url of the target replica to sync with. + :param creds: optional dictionary giving credentials + to authorize the operation with the server. For using OAuth + the form of creds is: + {'oauth': { + 'consumer_key': ..., + 'consumer_secret': ..., + 'token_key': ..., + 'token_secret': ... + }} + :param autocreate: ask the target to create the db if non-existent. + :return: local_gen_before_sync The local generation before the + synchronisation was performed. This is useful to pass into + whatschanged, if an application wants to know which documents were + affected by a synchronisation. + """ + from u1db.sync import Synchronizer + from u1db.remote.http_target import HTTPSyncTarget + return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + """Return the last known generation and transaction id for the other db + replica. + + When you do a synchronization with another replica, the Database keeps + track of what generation the other database replica was at, and what + the associated transaction id was. This is used to determine what data + needs to be sent, and if two databases are claiming to be the same + replica. + + :param other_replica_uid: The identifier for the other replica. + :return: (gen, trans_id) The generation and transaction id we + encountered during synchronization. If we've never synchronized + with the replica, this is (0, ''). + """ + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """Set the last-known generation and transaction id for the other + database replica. + + We have just performed some synchronization, and we want to track what + generation the other replica was at. See also + _get_replica_gen_and_trans_id. + :param other_replica_uid: The U1DB identifier for the other replica. + :param other_generation: The generation number for the other replica. + :param other_transaction_id: The transaction id associated with the + generation. + """ + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + """Insert/update document into the database with a given revision. + + This api is used during synchronization operations. + + If a document would conflict and save_conflict is set to True, the + content will be selected as the 'current' content for doc.doc_id, + even though doc.rev doesn't supersede the currently stored revision. + The currently stored document will be added to the list of conflict + alternatives for the given doc_id. + + This forces the new content to be 'current' so that we get convergence + after synchronizing, even if people don't resolve conflicts. Users can + then notice that their content is out of date, update it, and + synchronize again. (The alternative is that users could synchronize and + think the data has propagated, but their local copy looks fine, and the + remote copy is never updated again.) + + :param doc: A Document object + :param save_conflict: If this document is a conflict, do you want to + save it as a conflict, or just ignore it. + :param replica_uid: A unique replica identifier. + :param replica_gen: The generation of the replica corresponding to the + this document. The replica arguments are optional, but are used + during synchronization. + :param replica_trans_id: The transaction_id associated with the + generation. + :return: (state, at_gen) - If we don't have doc_id already, + or if doc_rev supersedes the existing document revision, + then the content will be inserted, and state is 'inserted'. + If doc_rev is less than or equal to the existing revision, + then the put is ignored and state is respecitvely 'superseded' + or 'converged'. + If doc_rev is not strictly superseded or supersedes, then + state is 'conflicted'. The document will not be inserted if + save_conflict is False. + For 'inserted' or 'converged', at_gen is the insertion/current + generation. + """ + raise NotImplementedError(self._put_doc_if_newer) + + +class DocumentBase(object): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json_string: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + def __init__(self, doc_id, rev, json_string, has_conflicts=False): + self.doc_id = doc_id + self.rev = rev + if json_string is not None: + try: + value = json.loads(json_string) + except ValueError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + self.has_conflicts = has_conflicts + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = None + if other._json: + c2 = json.loads(other._json) + else: + c2 = None + return c1 == c2 + + def __repr__(self): + if self.has_conflicts: + extra = ', conflicted' + else: + extra = '' + return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, + self.rev, extra, self.get_json()) + + def __hash__(self): + raise NotImplementedError(self.__hash__) + + def __eq__(self, other): + if not isinstance(other, Document): + return NotImplemented + return ( + self.doc_id == other.doc_id and self.rev == other.rev and + self.same_content_as(other) and self.has_conflicts == + other.has_conflicts) + + def __lt__(self, other): + """This is meant for testing, not part of the official api. + + It is implemented so that sorted([Document, Document]) can be used. + It doesn't imply that users would want their documents to be sorted in + this order. + """ + # Since this is just for testing, we don't worry about comparing + # against things that aren't a Document. + return ((self.doc_id, self.rev, self.get_json()) < + (other.doc_id, other.rev, other.get_json())) + + def get_json(self): + """Get the json serialization of this document.""" + if self._json is not None: + return self._json + return None + + def get_size(self): + """Calculate the total size of the document.""" + size = 0 + json = self.get_json() + if json: + size += len(json) + if self.rev: + size += len(self.rev) + if self.doc_id: + size += len(self.doc_id) + return size + + def set_json(self, json_string): + """Set the json serialization of this document.""" + if json_string is not None: + try: + value = json.loads(json_string) + except ValueError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._json = None + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._json is not None: + return False + return True + + +class Document(DocumentBase): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + # The following part of the API is optional: no implementation is forced to + # have it but if the language supports dictionaries/hashtables, it makes + # Documents a lot more user friendly. + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): + # TODO: We convert the json in the superclass to check its validity so + # we might as well set _content here directly since the price is + # already being paid. + super(Document, self).__init__(doc_id, rev, json, has_conflicts) + self._content = None + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = self._content + if other._json: + c2 = json.loads(other._json) + else: + c2 = other._content + return c1 == c2 + + def get_json(self): + """Get the json serialization of this document.""" + json_string = super(Document, self).get_json() + if json_string is not None: + return json_string + if self._content is not None: + return json.dumps(self._content) + return None + + def set_json(self, json): + """Set the json serialization of this document.""" + self._content = None + super(Document, self).set_json(json) + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._content = None + super(Document, self).make_tombstone() + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._content is not None: + return False + return super(Document, self).is_tombstone() + + def _get_content(self): + """Get the dictionary representing this document.""" + if self._json is not None: + self._content = json.loads(self._json) + self._json = None + if self._content is not None: + return self._content + return None + + def _set_content(self, content): + """Set the dictionary representing this document.""" + try: + tmp = json.dumps(content) + except TypeError: + raise InvalidContent( + "Can not be converted to JSON: %r" % (content,)) + if not tmp.startswith('{'): + raise InvalidContent( + "Can not be converted to a JSON object: %r." % (content,)) + # We might as well store the JSON at this point since we did the work + # of encoding it, and it doesn't lose any information. + self._json = tmp + self._content = None + + content = property( + _get_content, _set_content, doc="Content of the Document.") + + # End of optional part. + + +class SyncTarget(object): + """Functionality for using a Database as a synchronization target.""" + + def get_sync_info(self, source_replica_uid): + """Return information about known state. + + Return the replica_uid and the current database generation of this + database, and the last-seen database generation for source_replica_uid + + :param source_replica_uid: Another replica which we might have + synchronized with in the past. + :return: (target_replica_uid, target_replica_generation, + target_trans_id, source_replica_last_known_generation, + source_replica_last_known_transaction_id) + """ + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + """Record tip information for another replica. + + After sync_exchange has been processed, the caller will have + received new content from this replica. This call allows the + source replica instigating the sync to inform us what their + generation became after applying the documents we returned. + + This is used to allow future sync operations to not need to repeat data + that we just talked about. It also means that if this is called at the + wrong time, there can be database records that will never be + synchronized. + + :param source_replica_uid: The identifier for the source replica. + :param source_replica_generation: + The database generation for the source replica. + :param source_replica_transaction_id: The transaction id associated + with the source replica generation. + """ + raise NotImplementedError(self.record_sync_info) + + def sync_exchange(self, docs_by_generation, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + """Incorporate the documents sent from the source replica. + + This is not meant to be called by client code directly, but is used as + part of sync(). + + This adds docs to the local store, and determines documents that need + to be returned to the source replica. + + Documents must be supplied in docs_by_generation paired with + the generation of their latest change in order from the oldest + change to the newest, that means from the oldest generation to + the newest. + + Documents are also returned paired with the generation of + their latest change in order from the oldest change to the + newest. + + :param docs_by_generation: A list of [(Document, generation, + transaction_id)] tuples indicating documents which should be + updated on this replica paired with the generation and transaction + id of their latest change. + :param source_replica_uid: The source replica's identifier + :param last_known_generation: The last generation that the source + replica knows about this target replica + :param last_known_trans_id: The last transaction id that the source + replica knows about this target replica + :param: return_doc_cb(doc, gen): is a callback + used to return documents to the source replica, it will + be invoked in turn with Documents that have changed since + last_known_generation together with the generation of + their last change. + :param: ensure_callback(replica_uid): if set the target may create + the target db if not yet existent, the callback can then + be used to inform of the created db replica uid. + :return: new_generation - After applying docs_by_generation, this is + the current generation for this replica + """ + raise NotImplementedError(self.sync_exchange) + + def _set_trace_hook(self, cb): + """Set a callback that will be invoked to trace database actions. + + The callback will be passed a string indicating the current state, and + the sync target object. Implementations do not have to implement this + api, it is used by the test suite. + + :param cb: A callable that takes cb(state) + """ + raise NotImplementedError(self._set_trace_hook) + + def _set_trace_hook_shallow(self, cb): + """Set a callback that will be invoked to trace database actions. + + Similar to _set_trace_hook, for implementations that don't offer + state changes from the inner working of sync_exchange(). + + :param cb: A callable that takes cb(state) + """ + self._set_trace_hook(cb) diff --git a/src/leap/soledad/common/l2db/backends/__init__.py b/src/leap/soledad/common/l2db/backends/__init__.py new file mode 100644 index 00000000..c731c3d3 --- /dev/null +++ b/src/leap/soledad/common/l2db/backends/__init__.py @@ -0,0 +1,204 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Abstract classes and common implementations for the backends.""" + +import re +import json +import uuid + +from leap.soledad.common import l2db +from leap.soledad.common.l2db import sync as l2db_sync +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.vectorclock import VectorClockRev + + +check_doc_id_re = re.compile("^" + l2db.DOC_ID_CONSTRAINTS + "$", re.UNICODE) + + +class CommonSyncTarget(l2db_sync.LocalSyncTarget): + pass + + +class CommonBackend(l2db.Database): + + document_size_limit = 0 + + def _allocate_doc_id(self): + """Generate a unique identifier for this document.""" + return 'D-' + uuid.uuid4().hex # 'D-' stands for document + + def _allocate_transaction_id(self): + return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction + + def _allocate_doc_rev(self, old_doc_rev): + vcr = VectorClockRev(old_doc_rev) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def _check_doc_id(self, doc_id): + if not check_doc_id_re.match(doc_id): + raise errors.InvalidDocId() + + def _check_doc_size(self, doc): + if not self.document_size_limit: + return + if doc.get_size() > self.document_size_limit: + raise errors.DocumentTooBig + + def _get_generation(self): + """Return the current generation. + + """ + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + """Return the current generation and transaction id. + + """ + raise NotImplementedError(self._get_generation_info) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Extract the document from storage. + + This can return None if the document doesn't exist. + """ + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + """Return True if the doc has conflicts, False otherwise.""" + raise NotImplementedError(self._has_conflicts) + + def create_doc(self, content, doc_id=None): + if not isinstance(content, dict): + raise errors.InvalidContent + json_string = json.dumps(content) + return self.create_doc_from_json(json_string, doc_id) + + def create_doc_from_json(self, json, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + doc = self._factory(doc_id, None, json) + self.put_doc(doc) + return doc + + def _get_transaction_log(self): + """This is only for the test suite, it is not part of the api.""" + raise NotImplementedError(self._get_transaction_log) + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + for doc_id in doc_ids: + doc = self._get_doc( + doc_id, check_for_conflicts=check_for_conflicts) + if doc.is_tombstone() and not include_deleted: + continue + yield doc + + def _get_trans_id_for_gen(self, generation): + """Get the transaction id corresponding to a particular generation. + + Raises an InvalidGeneration when the generation does not exist. + + """ + raise NotImplementedError(self._get_trans_id_for_gen) + + def validate_gen_and_trans_id(self, generation, trans_id): + """Validate the generation and transaction id. + + Raises an InvalidGeneration when the generation does not exist, and an + InvalidTransactionId when it does but with a different transaction id. + + """ + if generation == 0: + return + known_trans_id = self._get_trans_id_for_gen(generation) + if known_trans_id != trans_id: + raise errors.InvalidTransactionId + + def _validate_source(self, other_replica_uid, other_generation, + other_transaction_id): + """Validate the new generation and transaction id. + + other_generation must be greater than what we have stored for this + replica, *or* it must be the same and the transaction_id must be the + same as well. + """ + (old_generation, + old_transaction_id) = self._get_replica_gen_and_trans_id( + other_replica_uid) + if other_generation < old_generation: + raise errors.InvalidGeneration + if other_generation > old_generation: + return + if other_transaction_id == old_transaction_id: + return + raise errors.InvalidTransactionId + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + cur_doc = self._get_doc(doc.doc_id) + doc_vcr = VectorClockRev(doc.rev) + if cur_doc is None: + cur_vcr = VectorClockRev(None) + else: + cur_vcr = VectorClockRev(cur_doc.rev) + self._validate_source(replica_uid, replica_gen, replica_trans_id) + if doc_vcr.is_newer(cur_vcr): + rev = doc.rev + self._prune_conflicts(doc, doc_vcr) + if doc.rev != rev: + # conflicts have been autoresolved + state = 'superseded' + else: + state = 'inserted' + self._put_and_update_indexes(cur_doc, doc) + elif doc.rev == cur_doc.rev: + # magical convergence + state = 'converged' + elif cur_vcr.is_newer(doc_vcr): + # Don't add this to seen_ids, because we have something newer, + # so we should send it back, and we should not generate a + # conflict + state = 'superseded' + elif cur_doc.same_content_as(doc): + # the documents have been edited to the same thing at both ends + doc_vcr.maximize(cur_vcr) + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._put_and_update_indexes(cur_doc, doc) + state = 'superseded' + else: + state = 'conflicted' + if save_conflict: + self._force_doc_sync_conflict(doc) + if replica_uid is not None and replica_gen is not None: + self._do_set_replica_gen_and_trans_id( + replica_uid, replica_gen, replica_trans_id) + return state, self._get_generation() + + def _ensure_maximal_rev(self, cur_rev, extra_revs): + vcr = VectorClockRev(cur_rev) + for rev in extra_revs: + vcr.maximize(VectorClockRev(rev)) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def set_document_size_limit(self, limit): + self.document_size_limit = limit diff --git a/src/leap/soledad/common/l2db/backends/inmemory.py b/src/leap/soledad/common/l2db/backends/inmemory.py new file mode 100644 index 00000000..6fd251af --- /dev/null +++ b/src/leap/soledad/common/l2db/backends/inmemory.py @@ -0,0 +1,466 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The in-memory Database class for U1DB.""" + +import json + +from leap.soledad.common.l2db import ( + Document, errors, + query_parser, vectorclock) +from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget + + +def get_prefix(value): + key_prefix = '\x01'.join(value) + return key_prefix.rstrip('*') + + +class InMemoryDatabase(CommonBackend): + """A database that only stores the data internally.""" + + def __init__(self, replica_uid, document_factory=None): + self._transaction_log = [] + self._docs = {} + # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner' + self._conflicts = {} + self._other_generations = {} + self._indexes = {} + self._replica_uid = replica_uid + self._factory = document_factory or Document + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + self._replica_uid = replica_uid + + def set_document_factory(self, factory): + self._factory = factory + + def close(self): + # This is a no-op, We don't want to free the data because one client + # may be closing it, while another wants to inspect the results. + pass + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + return self._other_generations.get(other_replica_uid, (0, '')) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + # TODO: to handle race conditions, we may want to check if the current + # value is greater than this new value. + self._other_generations[other_replica_uid] = (other_generation, + other_transaction_id) + + def get_sync_target(self): + return InMemorySyncTarget(self) + + def _get_transaction_log(self): + # snapshot! + return self._transaction_log[:] + + def _get_generation(self): + return len(self._transaction_log) + + def _get_generation_info(self): + if not self._transaction_log: + return 0, '' + return len(self._transaction_log), self._transaction_log[-1][1] + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + if generation > len(self._transaction_log): + raise errors.InvalidGeneration + return self._transaction_log[generation - 1][1] + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _put_and_update_indexes(self, old_doc, doc): + for index in self._indexes.itervalues(): + if old_doc is not None and not old_doc.is_tombstone(): + index.remove_json(old_doc.doc_id, old_doc.get_json()) + if not doc.is_tombstone(): + index.add_json(doc.doc_id, doc.get_json()) + trans_id = self._allocate_transaction_id() + self._docs[doc.doc_id] = (doc.rev, doc.get_json()) + self._transaction_log.append((doc.doc_id, trans_id)) + + def _get_doc(self, doc_id, check_for_conflicts=False): + try: + doc_rev, content = self._docs[doc_id] + except KeyError: + return None + doc = self._factory(doc_id, doc_rev, content) + if check_for_conflicts: + doc.has_conflicts = (doc.doc_id in self._conflicts) + return doc + + def _has_conflicts(self, doc_id): + return doc_id in self._conflicts + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Return all documents in the database.""" + generation = self._get_generation() + results = [] + for doc_id, (doc_rev, content) in self._docs.items(): + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = self._has_conflicts(doc_id) + results.append(doc) + return (generation, results) + + def get_doc_conflicts(self, doc_id): + if doc_id not in self._conflicts: + return [] + result = [self._get_doc(doc_id)] + result[0].has_conflicts = True + result.extend([self._factory(doc_id, rev, content) + for rev, content in self._conflicts[doc_id]]) + return result + + def _replace_conflicts(self, doc, conflicts): + if not conflicts: + del self._conflicts[doc.doc_id] + else: + self._conflicts[doc.doc_id] = conflicts + doc.has_conflicts = bool(conflicts) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + c_vcr = vectorclock.VectorClockRev(c_rev) + if doc_vcr.is_newer(c_vcr): + continue + if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)): + doc_vcr.maximize(c_vcr) + autoresolved = True + continue + remaining_conflicts.append((c_rev, c_doc)) + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._replace_conflicts(doc, remaining_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + cur_doc = self._get_doc(doc.doc_id) + if cur_doc is None: + cur_rev = None + else: + cur_rev = cur_doc.rev + new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + if c_rev in superseded_revs: + continue + remaining_conflicts.append((c_rev, c_doc)) + doc.rev = new_rev + if cur_rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + remaining_conflicts.append((new_rev, doc.get_json())) + self._replace_conflicts(doc, remaining_conflicts) + + def delete_doc(self, doc): + if doc.doc_id not in self._docs: + raise errors.DocumentDoesNotExist + if self._docs[doc.doc_id][1] in ('null', None): + raise errors.DocumentAlreadyDeleted + doc.make_tombstone() + self.put_doc(doc) + + def create_index(self, index_name, *index_expressions): + if index_name in self._indexes: + if self._indexes[index_name]._definition == list( + index_expressions): + return + raise errors.IndexNameTakenError + index = InMemoryIndex(index_name, list(index_expressions)) + for doc_id, (doc_rev, doc) in self._docs.iteritems(): + if doc is not None: + index.add_json(doc_id, doc) + self._indexes[index_name] = index + + def delete_index(self, index_name): + try: + del self._indexes[index_name] + except KeyError: + pass + + def list_indexes(self): + definitions = [] + for idx in self._indexes.itervalues(): + definitions.append((idx._name, idx._definition)) + return definitions + + def get_from_index(self, index_name, *key_values): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + doc_ids = index.lookup(key_values) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + if isinstance(start_value, basestring): + start_value = (start_value,) + if isinstance(end_value, basestring): + end_value = (end_value,) + doc_ids = index.lookup_range(start_value, end_value) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_index_keys(self, index_name): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + keys = index.keys() + # XXX inefficiency warning + return list(set([tuple(key.split('\x01')) for key in keys])) + + def whats_changed(self, old_generation=0): + changes = [] + relevant_tail = self._transaction_log[old_generation:] + # We don't use len(self._transaction_log) because _transaction_log may + # get mutated by a concurrent operation. + cur_generation = old_generation + len(relevant_tail) + last_trans_id = '' + if relevant_tail: + last_trans_id = relevant_tail[-1][1] + elif self._transaction_log: + last_trans_id = self._transaction_log[-1][1] + seen = set() + generation = cur_generation + for doc_id, trans_id in reversed(relevant_tail): + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + generation -= 1 + changes.reverse() + return (cur_generation, last_trans_id, changes) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._conflicts.setdefault(doc.doc_id, []).append( + (my_doc.rev, my_doc.get_json())) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + +class InMemoryIndex(object): + """Interface for managing an Index.""" + + def __init__(self, index_name, index_definition): + self._name = index_name + self._definition = index_definition + self._values = {} + parser = query_parser.Parser() + self._getters = parser.parse_all(self._definition) + + def evaluate_json(self, doc): + """Determine the 'key' after applying this index to the doc.""" + raw = json.loads(doc) + return self.evaluate(raw) + + def evaluate(self, obj): + """Evaluate a dict object, applying this definition.""" + all_rows = [[]] + for getter in self._getters: + new_rows = [] + keys = getter.get(obj) + if not keys: + return [] + for key in keys: + new_rows.extend([row + [key] for row in all_rows]) + all_rows = new_rows + all_rows = ['\x01'.join(row) for row in all_rows] + return all_rows + + def add_json(self, doc_id, doc): + """Add this json doc to the index.""" + keys = self.evaluate_json(doc) + if not keys: + return + for key in keys: + self._values.setdefault(key, []).append(doc_id) + + def remove_json(self, doc_id, doc): + """Remove this json doc from the index.""" + keys = self.evaluate_json(doc) + if keys: + for key in keys: + doc_ids = self._values[key] + doc_ids.remove(doc_id) + if not doc_ids: + del self._values[key] + + def _find_non_wildcards(self, values): + """Check if this should be a wildcard match. + + Further, this will raise an exception if the syntax is improperly + defined. + + :return: The offset of the last value we need to match against. + """ + if len(values) != len(self._definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + last = 0 + for idx, val in enumerate(values): + if val.endswith('*'): + if val != '*': + # We have an 'x*' style wildcard + if is_wildcard: + # We were already in wildcard mode, so this is invalid + raise errors.InvalidGlobbing + last = idx + 1 + is_wildcard = True + else: + if is_wildcard: + # We were in wildcard mode, we can't follow that with + # non-wildcard + raise errors.InvalidGlobbing + last = idx + 1 + if not is_wildcard: + return -1 + return last + + def lookup(self, values): + """Find docs that match the values.""" + last = self._find_non_wildcards(values) + if last == -1: + return self._lookup_exact(values) + else: + return self._lookup_prefix(values[:last]) + + def lookup_range(self, start_values, end_values): + """Find docs within the range.""" + # TODO: Wildly inefficient, which is unlikely to be a problem for the + # inmemory implementation. + if start_values: + self._find_non_wildcards(start_values) + start_values = get_prefix(start_values) + if end_values: + if self._find_non_wildcards(end_values) == -1: + exact = True + else: + exact = False + end_values = get_prefix(end_values) + found = [] + for key, doc_ids in sorted(self._values.iteritems()): + if start_values and start_values > key: + continue + if end_values and end_values < key: + if exact: + break + else: + if not key.startswith(end_values): + break + found.extend(doc_ids) + return found + + def keys(self): + """Find the indexed keys.""" + return self._values.keys() + + def _lookup_prefix(self, value): + """Find docs that match the prefix string in values.""" + # TODO: We need a different data structure to make prefix style fast, + # some sort of sorted list would work, but a plain dict doesn't. + key_prefix = get_prefix(value) + all_doc_ids = [] + for key, doc_ids in sorted(self._values.iteritems()): + if key.startswith(key_prefix): + all_doc_ids.extend(doc_ids) + return all_doc_ids + + def _lookup_exact(self, value): + """Find docs that match exactly.""" + key = '\x01'.join(value) + if key in self._values: + return self._values[key] + return () + + +class InMemorySyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_transaction_id) diff --git a/src/leap/soledad/common/l2db/errors.py b/src/leap/soledad/common/l2db/errors.py new file mode 100644 index 00000000..b502fc2d --- /dev/null +++ b/src/leap/soledad/common/l2db/errors.py @@ -0,0 +1,194 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""A list of errors that u1db can raise.""" + + +class U1DBError(Exception): + """Generic base class for U1DB errors.""" + + # description/tag for identifying the error during transmission (http,...) + wire_description = "error" + + def __init__(self, message=None): + self.message = message + + +class RevisionConflict(U1DBError): + """The document revisions supplied does not match the current version.""" + + wire_description = "revision conflict" + + +class InvalidJSON(U1DBError): + """Content was not valid json.""" + + +class InvalidContent(U1DBError): + """Content was not a python dictionary.""" + + +class InvalidDocId(U1DBError): + """A document was requested with an invalid document identifier.""" + + wire_description = "invalid document id" + + +class MissingDocIds(U1DBError): + """Needs document ids.""" + + wire_description = "missing document ids" + + +class DocumentTooBig(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "document too big" + + +class UserQuotaExceeded(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "user quota exceeded" + + +class SubscriptionNeeded(U1DBError): + """User needs a subscription to be able to use this replica..""" + + wire_description = "user needs subscription" + + +class InvalidTransactionId(U1DBError): + """Invalid transaction for generation.""" + + wire_description = "invalid transaction id" + + +class InvalidGeneration(U1DBError): + """Generation was previously synced with a different transaction id.""" + + wire_description = "invalid generation" + + +class InvalidReplicaUID(U1DBError): + """Attempting to sync a database with itself.""" + + wire_description = "invalid replica uid" + + +class ConflictedDoc(U1DBError): + """The document is conflicted, you must call resolve before put()""" + + +class InvalidValueForIndex(U1DBError): + """The values supplied does not match the index definition.""" + + +class InvalidGlobbing(U1DBError): + """Raised if wildcard matches are not strictly at the tail of the request. + """ + + +class DocumentDoesNotExist(U1DBError): + """The document does not exist.""" + + wire_description = "document does not exist" + + +class DocumentAlreadyDeleted(U1DBError): + """The document was already deleted.""" + + wire_description = "document already deleted" + + +class DatabaseDoesNotExist(U1DBError): + """The database does not exist.""" + + wire_description = "database does not exist" + + +class IndexNameTakenError(U1DBError): + """The given index name is already taken.""" + + +class IndexDefinitionParseError(U1DBError): + """The index definition cannot be parsed.""" + + +class IndexDoesNotExist(U1DBError): + """No index of that name exists.""" + + +class Unauthorized(U1DBError): + """Request wasn't authorized properly.""" + + wire_description = "unauthorized" + + +class HTTPError(U1DBError): + """Unspecific HTTP errror.""" + + wire_description = None + + def __init__(self, status, message=None, headers={}): + self.status = status + self.message = message + self.headers = headers + + def __str__(self): + if not self.message: + return "HTTPError(%d)" % self.status + else: + return "HTTPError(%d, %r)" % (self.status, self.message) + + +class Unavailable(HTTPError): + """Server not available not serve request.""" + + wire_description = "unavailable" + + def __init__(self, message=None, headers={}): + super(Unavailable, self).__init__(503, message, headers) + + def __str__(self): + if not self.message: + return "Unavailable()" + else: + return "Unavailable(%r)" % self.message + + +class BrokenSyncStream(U1DBError): + """Unterminated or otherwise broken sync exchange stream.""" + + wire_description = None + + +class UnknownAuthMethod(U1DBError): + """Unknown auhorization method.""" + + wire_description = None + + +# mapping wire (transimission) descriptions/tags for errors to the exceptions +wire_description_to_exc = dict( + (x.wire_description, x) for x in globals().values() + if getattr(x, 'wire_description', None) not in (None, "error")) +wire_description_to_exc["error"] = U1DBError + + +# +# wire error descriptions not corresponding to an exception +DOCUMENT_DELETED = "document deleted" diff --git a/src/leap/soledad/common/l2db/query_parser.py b/src/leap/soledad/common/l2db/query_parser.py new file mode 100644 index 00000000..15a9ac80 --- /dev/null +++ b/src/leap/soledad/common/l2db/query_parser.py @@ -0,0 +1,371 @@ +# Copyright 2011 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. +""" +Code for parsing Index definitions. +""" + +import re + +from leap.soledad.common.l2db import errors + + +class Getter(object): + """Get values from a document based on a specification.""" + + def get(self, raw_doc): + """Get a value from the document. + + :param raw_doc: a python dictionary to get the value from. + :return: A list of values that match the description. + """ + raise NotImplementedError(self.get) + + +class StaticGetter(Getter): + """A getter that returns a defined value (independent of the doc).""" + + def __init__(self, value): + """Create a StaticGetter. + + :param value: the value to return when get is called. + """ + if value is None: + self.value = [] + elif isinstance(value, list): + self.value = value + else: + self.value = [value] + + def get(self, raw_doc): + return self.value + + +def extract_field(raw_doc, subfields, index=0): + if not isinstance(raw_doc, dict): + return [] + val = raw_doc.get(subfields[index]) + if val is None: + return [] + if index < len(subfields) - 1: + if isinstance(val, list): + results = [] + for item in val: + results.extend(extract_field(item, subfields, index + 1)) + return results + if isinstance(val, dict): + return extract_field(val, subfields, index + 1) + return [] + if isinstance(val, dict): + return [] + if isinstance(val, list): + # Strip anything in the list that isn't a simple type + return [v for v in val if not isinstance(v, (dict, list))] + return [val] + + +class ExtractField(Getter): + """Extract a field from the document.""" + + def __init__(self, field): + """Create an ExtractField object. + + When a document is passed to get() this will return a value + from the document based on the field specifier passed to + the constructor. + + None will be returned if the field is nonexistant, or refers to an + object, rather than a simple type or list of simple types. + + :param field: a specifier for the field to return. + This is either a field name, or a dotted field name. + """ + self.field = field.split('.') + + def get(self, raw_doc): + return extract_field(raw_doc, self.field) + + +class Transformation(Getter): + """A transformation on a value from another Getter.""" + + name = None + arity = 1 + args = ['expression'] + + def __init__(self, inner): + """Create a transformation. + + :param inner: the argument(s) to the transformation. + """ + self.inner = inner + + def get(self, raw_doc): + inner_values = self.inner.get(raw_doc) + assert isinstance(inner_values, list),\ + 'get() should always return a list' + return self.transform(inner_values) + + def transform(self, values): + """Transform the values. + + This should be implemented by subclasses to transform the + value when get() is called. + + :param values: the values from the other Getter + :return: the transformed values. + """ + raise NotImplementedError(self.transform) + + +class Lower(Transformation): + """Lowercase a string. + + This transformation will return None for non-string inputs. However, + it will lowercase any strings in a list, dropping any elements + that are not strings. + """ + + name = "lower" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + return [val.lower() for val in values if self._can_transform(val)] + + +class Number(Transformation): + """Convert an integer to a zero padded string. + + This transformation will return None for non-integer inputs. However, it + will transform any integers in a list, dropping any elements that are not + integers. + """ + + name = 'number' + arity = 2 + args = ['expression', int] + + def __init__(self, inner, number): + super(Number, self).__init__(inner) + self.padding = "%%0%sd" % number + + def _can_transform(self, val): + return isinstance(val, int) and not isinstance(val, bool) + + def transform(self, values): + """Transform any integers in values into zero padded strings.""" + if not values: + return [] + return [self.padding % (v,) for v in values if self._can_transform(v)] + + +class Bool(Transformation): + """Convert bool to string.""" + + name = "bool" + args = ['expression'] + + def _can_transform(self, val): + return isinstance(val, bool) + + def transform(self, values): + """Transform any booleans in values into strings.""" + if not values: + return [] + return [('1' if v else '0') for v in values if self._can_transform(v)] + + +class SplitWords(Transformation): + """Split a string on whitespace. + + This Getter will return [] for non-string inputs. It will however + split any strings in an input list, discarding any elements that + are not strings. + """ + + name = "split_words" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + result = set() + for value in values: + if self._can_transform(value): + for word in value.split(): + result.add(word) + return list(result) + + +class Combine(Transformation): + """Combine multiple expressions into a single index.""" + + name = "combine" + # variable number of args + arity = -1 + + def __init__(self, *inner): + super(Combine, self).__init__(inner) + + def get(self, raw_doc): + inner_values = [] + for inner in self.inner: + inner_values.extend(inner.get(raw_doc)) + return self.transform(inner_values) + + def transform(self, values): + return values + + +class IsNull(Transformation): + """Indicate whether the input is None. + + This Getter returns a bool indicating whether the input is nil. + """ + + name = "is_null" + + def transform(self, values): + return [len(values) == 0] + + +def check_fieldname(fieldname): + if fieldname.endswith('.'): + raise errors.IndexDefinitionParseError( + "Fieldname cannot end in '.':%s^" % (fieldname,)) + + +class Parser(object): + """Parse an index expression into a sequence of transformations.""" + + _transformations = {} + _delimiters = re.compile("\(|\)|,") + + def __init__(self): + self._tokens = [] + + def _set_expression(self, expression): + self._open_parens = 0 + self._tokens = [] + expression = expression.strip() + while expression: + delimiter = self._delimiters.search(expression) + if delimiter: + idx = delimiter.start() + if idx == 0: + result, expression = (expression[:1], expression[1:]) + self._tokens.append(result) + else: + result, expression = (expression[:idx], expression[idx:]) + result = result.strip() + if result: + self._tokens.append(result) + else: + expression = expression.strip() + if expression: + self._tokens.append(expression) + expression = None + + def _get_token(self): + if self._tokens: + return self._tokens.pop(0) + + def _peek_token(self): + if self._tokens: + return self._tokens[0] + + @staticmethod + def _to_getter(term): + if isinstance(term, Getter): + return term + check_fieldname(term) + return ExtractField(term) + + def _parse_op(self, op_name): + self._get_token() # '(' + op = self._transformations.get(op_name, None) + if op is None: + raise errors.IndexDefinitionParseError( + "Unknown operation: %s" % op_name) + args = [] + while True: + args.append(self._parse_term()) + sep = self._get_token() + if sep == ')': + break + if sep != ',': + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' in parentheses." % (sep,)) + parsed = [] + for i, arg in enumerate(args): + arg_type = op.args[i % len(op.args)] + if arg_type == 'expression': + inner = self._to_getter(arg) + else: + try: + inner = arg_type(arg) + except ValueError as e: + raise errors.IndexDefinitionParseError( + "Invalid value %r for argument type %r " + "(%r)." % (arg, arg_type, e)) + parsed.append(inner) + return op(*parsed) + + def _parse_term(self): + term = self._get_token() + if term is None: + raise errors.IndexDefinitionParseError( + "Unexpected end of index definition.") + if term in (',', ')', '('): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' at start of expression." % (term,)) + next_token = self._peek_token() + if next_token == '(': + return self._parse_op(term) + return term + + def parse(self, expression): + self._set_expression(expression) + term = self._to_getter(self._parse_term()) + if self._peek_token(): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' after end of expression." + % (self._peek_token(),)) + return term + + def parse_all(self, fields): + return [self.parse(field) for field in fields] + + @classmethod + def register_transormation(cls, transform): + assert transform.name not in cls._transformations, ( + "Transform %s already registered for %s" + % (transform.name, cls._transformations[transform.name])) + cls._transformations[transform.name] = transform + + +Parser.register_transormation(SplitWords) +Parser.register_transormation(Lower) +Parser.register_transormation(Number) +Parser.register_transormation(Bool) +Parser.register_transormation(IsNull) +Parser.register_transormation(Combine) diff --git a/src/leap/soledad/common/l2db/remote/__init__.py b/src/leap/soledad/common/l2db/remote/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. diff --git a/src/leap/soledad/common/l2db/remote/http_app.py b/src/leap/soledad/common/l2db/remote/http_app.py new file mode 100644 index 00000000..a4eddb36 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_app.py @@ -0,0 +1,660 @@ +# Copyright 2011-2012 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +""" +HTTP Application exposing U1DB. +""" +# TODO -- deprecate, use twisted/txaio. + +import functools +import six.moves.http_client as httplib +import inspect +import json +import sys +import six.moves.urllib.parse as urlparse + +import routes.mapper + +from leap.soledad.common.l2db import ( + __version__ as _u1db_version, + DBNAME_CONSTRAINTS, Document, + errors, sync) +from leap.soledad.common.l2db.remote import http_errors, utils + + +def parse_bool(expression): + """Parse boolean querystring parameter.""" + if expression == 'true': + return True + return False + + +def parse_list(expression): + if not expression: + return [] + return [t.strip() for t in expression.split(',')] + + +def none_or_str(expression): + if expression is None: + return None + return str(expression) + + +class BadRequest(Exception): + """Bad request.""" + + +class _FencedReader(object): + """Read and get lines from a file but not past a given length.""" + + MAXCHUNK = 8192 + + def __init__(self, rfile, total, max_entry_size): + self.rfile = rfile + self.remaining = total + self.max_entry_size = max_entry_size + self._kept = None + + def read_chunk(self, atmost): + if self._kept is not None: + # ignore atmost, kept data should be a subchunk anyway + kept, self._kept = self._kept, None + return kept + if self.remaining == 0: + return '' + data = self.rfile.read(min(self.remaining, atmost)) + self.remaining -= len(data) + return data + + def getline(self): + line_parts = [] + size = 0 + while True: + chunk = self.read_chunk(self.MAXCHUNK) + if chunk == '': + break + nl = chunk.find("\n") + if nl != -1: + size += nl + 1 + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk[:nl + 1]) + rest = chunk[nl + 1:] + self._kept = rest or None + break + else: + size += len(chunk) + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk) + return ''.join(line_parts) + + +def http_method(**control): + """Decoration for handling of query arguments and content for a HTTP + method. + + args and content here are the query arguments and body of the incoming + HTTP requests. + + Match query arguments to python method arguments: + w = http_method()(f) + w(self, args, content) => args["content"]=content; + f(self, **args) + + JSON deserialize content to arguments: + w = http_method(content_as_args=True,...)(f) + w(self, args, content) => args.update(json.loads(content)); + f(self, **args) + + Support conversions (e.g int): + w = http_method(Arg=Conv,...)(f) + w(self, args, content) => args["Arg"]=Conv(args["Arg"]); + f(self, **args) + + Enforce no use of query arguments: + w = http_method(no_query=True,...)(f) + w(self, args, content) raises BadRequest if args is not empty + + Argument mismatches, deserialisation failures produce BadRequest. + """ + content_as_args = control.pop('content_as_args', False) + no_query = control.pop('no_query', False) + conversions = control.items() + + def wrap(f): + argspec = inspect.getargspec(f) + assert argspec.args[0] == "self" + nargs = len(argspec.args) + ndefaults = len(argspec.defaults or ()) + required_args = set(argspec.args[1:nargs - ndefaults]) + all_args = set(argspec.args) + + @functools.wraps(f) + def wrapper(self, args, content): + if no_query and args: + raise BadRequest() + if content is not None: + if content_as_args: + try: + args.update(json.loads(content)) + except ValueError: + raise BadRequest() + else: + args["content"] = content + if not (required_args <= set(args) <= all_args): + raise BadRequest("Missing required arguments.") + for name, conv in conversions: + if name not in args: + continue + try: + args[name] = conv(args[name]) + except ValueError: + raise BadRequest() + return f(self, **args) + + return wrapper + + return wrap + + +class URLToResource(object): + """Mappings from URLs to resources.""" + + def __init__(self): + self._map = routes.mapper.Mapper(controller_scan=None) + + def register(self, resource_cls): + # register + self._map.connect(None, resource_cls.url_pattern, + resource_cls=resource_cls, + requirements={"dbname": DBNAME_CONSTRAINTS}) + self._map.create_regs() + return resource_cls + + def match(self, path): + params = self._map.match(path) + if params is None: + return None, None + resource_cls = params.pop('resource_cls') + return resource_cls, params + + +url_to_resource = URLToResource() + + +@url_to_resource.register +class GlobalResource(object): + """Global (root) resource.""" + + url_pattern = "/" + + def __init__(self, state, responder): + self.state = state + self.responder = responder + + @http_method() + def get(self): + info = self.state.global_info() + info['version'] = _u1db_version + self.responder.send_response_json(**info) + + +@url_to_resource.register +class DatabaseResource(object): + """Database resource.""" + + url_pattern = "/{dbname}" + + def __init__(self, dbname, state, responder): + self.dbname = dbname + self.state = state + self.responder = responder + + @http_method() + def get(self): + self.state.check_database(self.dbname) + self.responder.send_response_json(200) + + @http_method(content_as_args=True) + def put(self): + self.state.ensure_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + @http_method() + def delete(self): + self.state.delete_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + +@url_to_resource.register +class DocsResource(object): + """Documents resource.""" + + url_pattern = "/{dbname}/docs" + + def __init__(self, dbname, state, responder): + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool, + include_deleted=parse_bool) + def get(self, doc_ids=None, check_for_conflicts=True, + include_deleted=False): + if doc_ids is None: + raise errors.MissingDocIds + docs = self.db.get_docs(doc_ids, include_deleted=include_deleted) + self.responder.content_type = 'application/json' + self.responder.start_response(200) + self.responder.start_stream(), + for doc in docs: + entry = dict( + doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), + has_conflicts=doc.has_conflicts) + self.responder.stream_entry(entry) + self.responder.end_stream() + self.responder.finish_response() + + +@url_to_resource.register +class AllDocsResource(object): + """All Documents resource.""" + + url_pattern = "/{dbname}/all-docs" + + def __init__(self, dbname, state, responder): + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(include_deleted=parse_bool) + def get(self, include_deleted=False): + gen, docs = self.db.get_all_docs(include_deleted=include_deleted) + self.responder.content_type = 'application/json' + # returning a x-u1db-generation header is optional + # HTTPDatabase will fallback to return -1 if it's missing + self.responder.start_response(200, + headers={'x-u1db-generation': str(gen)}) + self.responder.start_stream(), + for doc in docs: + entry = dict( + doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), + has_conflicts=doc.has_conflicts) + self.responder.stream_entry(entry) + self.responder.end_stream() + self.responder.finish_response() + + +@url_to_resource.register +class DocResource(object): + """Document resource.""" + + url_pattern = "/{dbname}/doc/{id:.*}" + + def __init__(self, dbname, id, state, responder): + self.id = id + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(old_rev=str) + def put(self, content, old_rev=None): + doc = Document(self.id, old_rev, content) + doc_rev = self.db.put_doc(doc) + if old_rev is None: + status = 201 # created + else: + status = 200 + self.responder.send_response_json(status, rev=doc_rev) + + @http_method(old_rev=str) + def delete(self, old_rev=None): + doc = Document(self.id, old_rev, None) + self.db.delete_doc(doc) + self.responder.send_response_json(200, rev=doc.rev) + + @http_method(include_deleted=parse_bool) + def get(self, include_deleted=False): + doc = self.db.get_doc(self.id, include_deleted=include_deleted) + if doc is None: + wire_descr = errors.DocumentDoesNotExist.wire_description + self.responder.send_response_json( + http_errors.wire_description_to_status[wire_descr], + error=wire_descr, + headers={ + 'x-u1db-rev': '', + 'x-u1db-has-conflicts': 'false' + }) + return + headers = { + 'x-u1db-rev': doc.rev, + 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts) + } + if doc.is_tombstone(): + self.responder.send_response_json( + http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED], + error=errors.DOCUMENT_DELETED, + headers=headers) + else: + self.responder.send_response_content( + doc.get_json(), headers=headers) + + +@url_to_resource.register +class SyncResource(object): + """Sync endpoint resource.""" + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + url_pattern = "/{dbname}/sync-from/{source_replica_uid}" + + # pluggable + sync_exchange_class = sync.SyncExchange + + def __init__(self, dbname, source_replica_uid, state, responder): + self.source_replica_uid = source_replica_uid + self.responder = responder + self.state = state + self.dbname = dbname + self.replica_uid = None + + def get_target(self): + return self.state.open_database(self.dbname).get_sync_target() + + @http_method() + def get(self): + result = self.get_target().get_sync_info(self.source_replica_uid) + self.responder.send_response_json( + target_replica_uid=result[0], target_replica_generation=result[1], + target_replica_transaction_id=result[2], + source_replica_uid=self.source_replica_uid, + source_replica_generation=result[3], + source_transaction_id=result[4]) + + @http_method(generation=int, + content_as_args=True, no_query=True) + def put(self, generation, transaction_id): + self.get_target().record_sync_info(self.source_replica_uid, + generation, + transaction_id) + self.responder.send_response_json(ok=True) + + # Implements the same logic as LocalSyncTarget.sync_exchange + + @http_method(last_known_generation=int, last_known_trans_id=none_or_str, + content_as_args=True) + def post_args(self, last_known_generation, last_known_trans_id=None, + ensure=False): + if ensure: + db, self.replica_uid = self.state.ensure_database(self.dbname) + else: + db = self.state.open_database(self.dbname) + db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + self.sync_exch = self.sync_exchange_class( + db, self.source_replica_uid, last_known_generation) + + @http_method(content_as_args=True) + def post_stream_entry(self, id, rev, content, gen, trans_id): + doc = Document(id, rev, content) + self.sync_exch.insert_doc_from_source(doc, gen, trans_id) + + def post_end(self): + + def send_doc(doc, gen, trans_id): + entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + self.responder.stream_entry(entry) + + new_gen = self.sync_exch.find_changes_to_return() + self.responder.content_type = 'application/x-u1db-sync-stream' + self.responder.start_response(200) + self.responder.start_stream(), + header = {"new_generation": new_gen, + "new_transaction_id": self.sync_exch.new_trans_id} + if self.replica_uid is not None: + header['replica_uid'] = self.replica_uid + self.responder.stream_entry(header) + self.sync_exch.return_docs(send_doc) + self.responder.end_stream() + self.responder.finish_response() + + +class HTTPResponder(object): + """Encode responses from the server back to the client.""" + + # a multi document response will put args and documents + # each on one line of the response body + + def __init__(self, start_response): + self._started = False + self._stream_state = -1 + self._no_initial_obj = True + self.sent_response = False + self._start_response = start_response + self._write = None + self.content_type = 'application/json' + self.content = [] + + def start_response(self, status, obj_dic=None, headers={}): + """start sending response with optional first json object.""" + if self._started: + return + self._started = True + status_text = httplib.responses[status] + self._write = self._start_response( + '%d %s' % (status, status_text), + [('content-type', self.content_type), + ('cache-control', 'no-cache')] + + headers.items()) + # xxx version in headers + if obj_dic is not None: + self._no_initial_obj = False + self._write(json.dumps(obj_dic) + "\r\n") + + def finish_response(self): + """finish sending response.""" + self.sent_response = True + + def send_response_json(self, status=200, headers={}, **kwargs): + """send and finish response with json object body from keyword args.""" + content = json.dumps(kwargs) + "\r\n" + self.send_response_content(content, headers=headers, status=status) + + def send_response_content(self, content, status=200, headers={}): + """send and finish response with content""" + headers['content-length'] = str(len(content)) + self.start_response(status, headers=headers) + if self._stream_state == 1: + self.content = [',\r\n', content] + else: + self.content = [content] + self.finish_response() + + def start_stream(self): + "start stream (array) as part of the response." + assert self._started and self._no_initial_obj + self._stream_state = 0 + self._write("[") + + def stream_entry(self, entry): + "send stream entry as part of the response." + assert self._stream_state != -1 + if self._stream_state == 0: + self._stream_state = 1 + self._write('\r\n') + else: + self._write(',\r\n') + if type(entry) == dict: + entry = json.dumps(entry) + self._write(entry) + + def end_stream(self): + "end stream (array)." + assert self._stream_state != -1 + self._write("\r\n]\r\n") + + +class HTTPInvocationByMethodWithBody(object): + """Invoke methods on a resource.""" + + def __init__(self, resource, environ, parameters): + self.resource = resource + self.environ = environ + self.max_request_size = getattr( + resource, 'max_request_size', parameters.max_request_size) + self.max_entry_size = getattr( + resource, 'max_entry_size', parameters.max_entry_size) + + def _lookup(self, method): + try: + return getattr(self.resource, method) + except AttributeError: + raise BadRequest() + + def __call__(self): + args = urlparse.parse_qsl(self.environ['QUERY_STRING'], + strict_parsing=False) + try: + args = dict( + (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) + except ValueError: + raise BadRequest() + method = self.environ['REQUEST_METHOD'].lower() + if method in ('get', 'delete'): + meth = self._lookup(method) + return meth(args, None) + else: + # we expect content-length > 0, reconsider if we move + # to support chunked enconding + try: + content_length = int(self.environ['CONTENT_LENGTH']) + except (ValueError, KeyError): + raise BadRequest + if content_length <= 0: + raise BadRequest + if content_length > self.max_request_size: + raise BadRequest + reader = _FencedReader(self.environ['wsgi.input'], content_length, + self.max_entry_size) + content_type = self.environ.get('CONTENT_TYPE', '') + content_type = content_type.split(';', 1)[0].strip() + if content_type == 'application/json': + meth = self._lookup(method) + body = reader.read_chunk(sys.maxint) + return meth(args, body) + elif content_type == 'application/x-u1db-sync-stream': + meth_args = self._lookup('%s_args' % method) + meth_entry = self._lookup('%s_stream_entry' % method) + meth_end = self._lookup('%s_end' % method) + body_getline = reader.getline + if body_getline().strip() != '[': + raise BadRequest() + line = body_getline() + line, comma = utils.check_and_strip_comma(line.strip()) + meth_args(args, line) + while True: + line = body_getline() + entry = line.strip() + if entry == ']': + break + if not entry or not comma: # empty or no prec comma + raise BadRequest + entry, comma = utils.check_and_strip_comma(entry) + meth_entry({}, entry) + if comma or body_getline(): # extra comma or data + raise BadRequest + return meth_end() + else: + raise BadRequest() + + +class HTTPApp(object): + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + def __init__(self, state): + self.state = state + + def _lookup_resource(self, environ, responder): + resource_cls, params = url_to_resource.match(environ['PATH_INFO']) + if resource_cls is None: + raise BadRequest # 404 instead? + resource = resource_cls( + state=self.state, responder=responder, **params) + return resource + + def __call__(self, environ, start_response): + responder = HTTPResponder(start_response) + self.request_begin(environ) + try: + resource = self._lookup_resource(environ, responder) + HTTPInvocationByMethodWithBody(resource, environ, self)() + except errors.U1DBError as e: + self.request_u1db_error(environ, e) + status = http_errors.wire_description_to_status.get( + e.wire_description, 500) + responder.send_response_json(status, error=e.wire_description) + except BadRequest: + self.request_bad_request(environ) + responder.send_response_json(400, error="bad request") + except KeyboardInterrupt: + raise + except: + self.request_failed(environ) + raise + else: + self.request_done(environ) + return responder.content + + # hooks for tracing requests + + def request_begin(self, environ): + """Hook called at the beginning of processing a request.""" + pass + + def request_done(self, environ): + """Hook called when done processing a request.""" + pass + + def request_u1db_error(self, environ, exc): + """Hook called when processing a request resulted in a U1DBError. + + U1DBError passed as exc. + """ + pass + + def request_bad_request(self, environ): + """Hook called when processing a bad request. + + No actual processing was done. + """ + pass + + def request_failed(self, environ): + """Hook called when processing a request failed unexpectedly. + + Invoked from an except block, so there's interpreter exception + information available. + """ + pass diff --git a/src/leap/soledad/common/l2db/remote/http_client.py b/src/leap/soledad/common/l2db/remote/http_client.py new file mode 100644 index 00000000..1124b038 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_client.py @@ -0,0 +1,178 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Base class to make requests to a remote HTTP server.""" + +import json +import socket +import ssl +import sys +import urllib +import six.moves.urllib.parse as urlparse +import six.moves.http_client as httplib +from time import sleep +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.remote import http_errors + +from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname + +# Ubuntu/debian +# XXX other... +CA_CERTS = "/etc/ssl/certs/ca-certificates.crt" + + +def _encode_query_parameter(value): + """Encode query parameter.""" + if isinstance(value, bool): + if value: + value = 'true' + else: + value = 'false' + return unicode(value).encode('utf-8') + + +class _VerifiedHTTPSConnection(httplib.HTTPSConnection): + """HTTPSConnection verifying server side certificates.""" + # derived from httplib.py + + def connect(self): + "Connect to a host on a given (SSL) port." + + sock = socket.create_connection((self.host, self.port), + self.timeout, self.source_address) + if self._tunnel_host: + self.sock = sock + self._tunnel() + if sys.platform.startswith('linux'): + cert_opts = { + 'cert_reqs': ssl.CERT_REQUIRED, + 'ca_certs': CA_CERTS + } + else: + # XXX no cert verification implemented elsewhere for now + cert_opts = {} + self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, + ssl_version=ssl.PROTOCOL_SSLv3, + **cert_opts + ) + if cert_opts: + match_hostname(self.sock.getpeercert(), self.host) + + +class HTTPClientBase(object): + """Base class to make requests to a remote HTTP server.""" + + # Will use these delays to retry on 503 befor finally giving up. The final + # 0 is there to not wait after the final try fails. + _delays = (1, 1, 2, 4, 0) + + def __init__(self, url, creds=None): + self._url = urlparse.urlsplit(url) + self._conn = None + self._creds = {} + if creds is not None: + if len(creds) != 1: + raise errors.UnknownAuthMethod() + auth_meth, credentials = creds.items()[0] + try: + set_creds = getattr(self, 'set_%s_credentials' % auth_meth) + except AttributeError: + raise errors.UnknownAuthMethod(auth_meth) + set_creds(**credentials) + + def _ensure_connection(self): + if self._conn is not None: + return + if self._url.scheme == 'https': + connClass = _VerifiedHTTPSConnection + else: + connClass = httplib.HTTPConnection + self._conn = connClass(self._url.hostname, self._url.port) + + def close(self): + if self._conn: + self._conn.close() + self._conn = None + + # xxx retry mechanism? + + def _error(self, respdic): + descr = respdic.get("error") + exc_cls = errors.wire_description_to_exc.get(descr) + if exc_cls is not None: + message = respdic.get("message") + raise exc_cls(message) + + def _response(self): + resp = self._conn.getresponse() + body = resp.read() + headers = dict(resp.getheaders()) + if resp.status in (200, 201): + return body, headers + elif resp.status in http_errors.ERROR_STATUSES: + try: + respdic = json.loads(body) + except ValueError: + pass + else: + self._error(respdic) + # special case + if resp.status == 503: + raise errors.Unavailable(body, headers) + raise errors.HTTPError(resp.status, body, headers) + + def _sign_request(self, method, url_query, params): + raise NotImplementedError + + def _request(self, method, url_parts, params=None, body=None, + content_type=None): + self._ensure_connection() + unquoted_url = url_query = self._url.path + if url_parts: + if not url_query.endswith('/'): + url_query += '/' + unquoted_url = url_query + url_query += '/'.join(urllib.quote(part, safe='') + for part in url_parts) + # oauth performs its own quoting + unquoted_url += '/'.join(url_parts) + encoded_params = {} + if params: + for key, value in params.items(): + key = unicode(key).encode('utf-8') + encoded_params[key] = _encode_query_parameter(value) + url_query += ('?' + urllib.urlencode(encoded_params)) + if body is not None and not isinstance(body, basestring): + body = json.dumps(body) + content_type = 'application/json' + headers = {} + if content_type: + headers['content-type'] = content_type + headers.update( + self._sign_request(method, unquoted_url, encoded_params)) + for delay in self._delays: + try: + self._conn.request(method, url_query, body, headers) + return self._response() + except errors.Unavailable as e: + sleep(delay) + raise e + + def _request_json(self, method, url_parts, params=None, body=None, + content_type=None): + res, headers = self._request(method, url_parts, params, body, + content_type) + return json.loads(res), headers diff --git a/src/leap/soledad/common/l2db/remote/http_database.py b/src/leap/soledad/common/l2db/remote/http_database.py new file mode 100644 index 00000000..7e61e5a4 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_database.py @@ -0,0 +1,158 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""HTTPDatabase to access a remote db over the HTTP API.""" + +import json +import uuid + +from leap.soledad.common.l2db import ( + Database, + Document, + errors) +from leap.soledad.common.l2db.remote import ( + http_client, + http_errors, + http_target) + + +DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED] + + +class HTTPDatabase(http_client.HTTPClientBase, Database): + """Implement the Database API to a remote HTTP server.""" + + def __init__(self, url, document_factory=None, creds=None): + super(HTTPDatabase, self).__init__(url, creds=creds) + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + @staticmethod + def open_database(url, create): + db = HTTPDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = HTTPDatabase(url) + db._delete() + db.close() + + def open(self, create): + if create: + self._ensure() + else: + self._check() + + def _check(self): + return self._request_json('GET', [])[0] + + def _ensure(self): + self._request_json('PUT', [], {}, {}) + + def _delete(self): + self._request_json('DELETE', [], {}, {}) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {} + if doc.rev is not None: + params['old_rev'] = doc.rev + res, headers = self._request_json('PUT', ['doc', doc.doc_id], params, + doc.get_json(), 'application/json') + doc.rev = res['rev'] + return res['rev'] + + def get_doc(self, doc_id, include_deleted=False): + try: + res, headers = self._request( + 'GET', ['doc', doc_id], {"include_deleted": include_deleted}) + except errors.DocumentDoesNotExist: + return None + except errors.HTTPError as e: + if (e.status == DOCUMENT_DELETED_STATUS and + 'x-u1db-rev' in e.headers): + res = None + headers = e.headers + else: + raise + doc_rev = headers['x-u1db-rev'] + has_conflicts = json.loads(headers['x-u1db-has-conflicts']) + doc = self._factory(doc_id, doc_rev, res) + doc.has_conflicts = has_conflicts + return doc + + def _build_docs(self, res): + for doc_dict in json.loads(res): + doc = self._factory( + doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content']) + doc.has_conflicts = doc_dict['has_conflicts'] + yield doc + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + if not doc_ids: + return [] + doc_ids = ','.join(doc_ids) + res, headers = self._request( + 'GET', ['docs'], { + "doc_ids": doc_ids, "include_deleted": include_deleted, + "check_for_conflicts": check_for_conflicts}) + return self._build_docs(res) + + def get_all_docs(self, include_deleted=False): + res, headers = self._request( + 'GET', ['all-docs'], {"include_deleted": include_deleted}) + gen = -1 + if 'x-u1db-generation' in headers: + gen = int(headers['x-u1db-generation']) + return gen, list(self._build_docs(res)) + + def _allocate_doc_id(self): + return 'D-%s' % (uuid.uuid4().hex,) + + def create_doc(self, content, doc_id=None): + if not isinstance(content, dict): + raise errors.InvalidContent + json_string = json.dumps(content) + return self.create_doc_from_json(json_string, doc_id) + + def create_doc_from_json(self, content, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + res, headers = self._request_json('PUT', ['doc', doc_id], {}, + content, 'application/json') + new_doc = self._factory(doc_id, res['rev'], content) + return new_doc + + def delete_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {'old_rev': doc.rev} + res, headers = self._request_json( + 'DELETE', ['doc', doc.doc_id], params) + doc.make_tombstone() + doc.rev = res['rev'] + + def get_sync_target(self): + st = http_target.HTTPSyncTarget(self._url.geturl()) + st._creds = self._creds + return st diff --git a/src/leap/soledad/common/l2db/remote/http_errors.py b/src/leap/soledad/common/l2db/remote/http_errors.py new file mode 100644 index 00000000..ee4cfefa --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_errors.py @@ -0,0 +1,48 @@ +# Copyright 2011-2012 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +""" +Information about the encoding of errors over HTTP. +""" + +from leap.soledad.common.l2db import errors + + +# error wire descriptions mapping to HTTP status codes +wire_description_to_status = dict([ + (errors.InvalidDocId.wire_description, 400), + (errors.MissingDocIds.wire_description, 400), + (errors.Unauthorized.wire_description, 401), + (errors.DocumentTooBig.wire_description, 403), + (errors.UserQuotaExceeded.wire_description, 403), + (errors.SubscriptionNeeded.wire_description, 403), + (errors.DatabaseDoesNotExist.wire_description, 404), + (errors.DocumentDoesNotExist.wire_description, 404), + (errors.DocumentAlreadyDeleted.wire_description, 404), + (errors.RevisionConflict.wire_description, 409), + (errors.InvalidGeneration.wire_description, 409), + (errors.InvalidReplicaUID.wire_description, 409), + (errors.InvalidTransactionId.wire_description, 409), + (errors.Unavailable.wire_description, 503), + # without matching exception + (errors.DOCUMENT_DELETED, 404) +]) + + +ERROR_STATUSES = set(wire_description_to_status.values()) +# 400 included explicitly for tests +ERROR_STATUSES.add(400) diff --git a/src/leap/soledad/common/l2db/remote/http_target.py b/src/leap/soledad/common/l2db/remote/http_target.py new file mode 100644 index 00000000..38804f01 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_target.py @@ -0,0 +1,125 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""SyncTarget API implementation to a remote HTTP server.""" + +import json + +from leap.soledad.common.l2db import Document, SyncTarget +from leap.soledad.common.l2db.errors import BrokenSyncStream +from leap.soledad.common.l2db.remote import ( + http_client, utils) + + +class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget): + """Implement the SyncTarget api to a remote HTTP server.""" + + @staticmethod + def connect(url): + return HTTPSyncTarget(url) + + def get_sync_info(self, source_replica_uid): + self._ensure_connection() + res, _ = self._request_json('GET', ['sync-from', source_replica_uid]) + return (res['target_replica_uid'], res['target_replica_generation'], + res['target_replica_transaction_id'], + res['source_replica_generation'], res['source_transaction_id']) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('record_sync_info') + self._request_json('PUT', ['sync-from', source_replica_uid], {}, + {'generation': source_replica_generation, + 'transaction_id': source_transaction_id}) + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = Document(entry['id'], entry['rev'], entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] + + # for tests + _trace_hook = None + + def _set_trace_hook_shallow(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/common/l2db/remote/server_state.py b/src/leap/soledad/common/l2db/remote/server_state.py new file mode 100644 index 00000000..d4c3c45f --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/server_state.py @@ -0,0 +1,68 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""State for servers exposing a set of U1DB databases.""" + + +class ServerState(object): + """Passed to a Request when it is instantiated. + + This is used to track server-side state, such as working-directory, open + databases, etc. + """ + + def __init__(self): + self._workingdir = None + + def set_workingdir(self, path): + self._workingdir = path + + def global_info(self): + """Return global information about the server.""" + return {} + + def _relpath(self, relpath): + # Note: We don't want to allow absolute paths here, because we + # don't want to expose the filesystem. We should also check that + # relpath doesn't have '..' in it, etc. + return self._workingdir + '/' + relpath + + def open_database(self, path): + """Open a database at the given location.""" + from leap.soledad.client._db import sqlite + full_path = self._relpath(path) + return sqlite.SQLiteDatabase.open_database(full_path, create=False) + + def check_database(self, path): + """Check if the database at the given location exists. + + Simply returns if it does or raises DatabaseDoesNotExist. + """ + db = self.open_database(path) + db.close() + + def ensure_database(self, path): + """Ensure database at the given location.""" + from leap.soledad.client._db import sqlite + full_path = self._relpath(path) + db = sqlite.SQLiteDatabase.open_database(full_path, create=True) + return db, db._replica_uid + + def delete_database(self, path): + """Delete database at the given location.""" + from leap.soledad.client._db import sqlite + full_path = self._relpath(path) + sqlite.SQLiteDatabase.delete_database(full_path) diff --git a/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py b/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py new file mode 100644 index 00000000..ce82f1b2 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py @@ -0,0 +1,65 @@ +"""The match_hostname() function from Python 3.2, essential when using SSL.""" +# XXX put it here until it's packaged + +import re + +__version__ = '3.2a3' + + +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn): + pats = [] + for frag in dn.split(r'.'): + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not san: + # The subject is only checked when subjectAltName is empty + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError( + "hostname %r doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError( + "hostname %r doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError( + "no appropriate commonName or " + "subjectAltName fields were found") diff --git a/src/leap/soledad/common/l2db/remote/utils.py b/src/leap/soledad/common/l2db/remote/utils.py new file mode 100644 index 00000000..14cedea9 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/utils.py @@ -0,0 +1,23 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Utilities for details of the procotol.""" + + +def check_and_strip_comma(line): + if line and line[-1] == ',': + return line[:-1], True + return line, False diff --git a/src/leap/soledad/common/l2db/sync.py b/src/leap/soledad/common/l2db/sync.py new file mode 100644 index 00000000..32281f30 --- /dev/null +++ b/src/leap/soledad/common/l2db/sync.py @@ -0,0 +1,311 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The synchronization utilities for U1DB.""" +from six.moves import zip as izip + +from leap.soledad.common import l2db +from leap.soledad.common.l2db import errors + + +class Synchronizer(object): + """Collect the state around synchronizing 2 U1DB replicas. + + Synchronization is bi-directional, in that new items in the source are sent + to the target, and new items in the target are returned to the source. + However, it still recognizes that one side is initiating the request. Also, + at the moment, conflicts are only created in the source. + """ + + def __init__(self, source, sync_target): + """Create a new Synchronization object. + + :param source: A Database + :param sync_target: A SyncTarget + """ + self.source = source + self.sync_target = sync_target + self.target_replica_uid = None + self.num_inserted = 0 + + def _insert_doc_from_target(self, doc, replica_gen, trans_id): + """Try to insert synced document from target. + + Implements TAKE OTHER semantics: any document from the target + that is in conflict will be taken as the new official value, + while the current conflicting value will be stored alongside + as a conflict. In the process indexes will be updated etc. + + :return: None + """ + # Increases self.num_inserted depending whether the document + # was effectively inserted. + state, _ = self.source._put_doc_if_newer( + doc, save_conflict=True, + replica_uid=self.target_replica_uid, replica_gen=replica_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.num_inserted += 1 + elif state == 'converged': + # magical convergence + pass + elif state == 'superseded': + # we have something newer, will be taken care of at the next sync + pass + else: + assert state == 'conflicted' + # The doc was saved as a conflict, so the database was updated + self.num_inserted += 1 + + def _record_sync_info_with_the_target(self, start_generation): + """Record our new after sync generation with the target if gapless. + + Any documents received from the target will cause the local + database to increment its generation. We do not want to send + them back to the target in a future sync. However, there could + also be concurrent updates from another process doing eg + 'put_doc' while the sync was running. And we do want to + synchronize those documents. We can tell if there was a + concurrent update by comparing our new generation number + versus the generation we started, and how many documents we + inserted from the target. If it matches exactly, then we can + record with the target that they are fully up to date with our + new generation. + """ + cur_gen, trans_id = self.source._get_generation_info() + last_gen = start_generation + self.num_inserted + if (cur_gen == last_gen and self.num_inserted > 0): + self.sync_target.record_sync_info( + self.source._replica_uid, cur_gen, trans_id) + + def sync(self, callback=None, autocreate=False): + """Synchronize documents between source and target.""" + sync_target = self.sync_target + # get target identifier, its current generation, + # and its last-seen database generation for this source + try: + (self.target_replica_uid, target_gen, target_trans_id, + target_my_gen, target_my_trans_id) = sync_target.get_sync_info( + self.source._replica_uid) + except errors.DatabaseDoesNotExist: + if not autocreate: + raise + # will try to ask sync_exchange() to create the db + self.target_replica_uid = None + target_gen, target_trans_id = 0, '' + target_my_gen, target_my_trans_id = 0, '' + + def ensure_callback(replica_uid): + self.target_replica_uid = replica_uid + + else: + ensure_callback = None + if self.target_replica_uid == self.source._replica_uid: + raise errors.InvalidReplicaUID + # validate the generation and transaction id the target knows about us + self.source.validate_gen_and_trans_id( + target_my_gen, target_my_trans_id) + # what's changed since that generation and this current gen + my_gen, _, changes = self.source.whats_changed(target_my_gen) + + # this source last-seen database generation for the target + if self.target_replica_uid is None: + target_last_known_gen, target_last_known_trans_id = 0, '' + else: + target_last_known_gen, target_last_known_trans_id = ( + self.source._get_replica_gen_and_trans_id( # nopep8 + self.target_replica_uid)) + if not changes and target_last_known_gen == target_gen: + if target_trans_id != target_last_known_trans_id: + raise errors.InvalidTransactionId + return my_gen + changed_doc_ids = [doc_id for doc_id, _, _ in changes] + # prepare to send all the changed docs + docs_to_send = self.source.get_docs( + changed_doc_ids, + check_for_conflicts=False, include_deleted=True) + # TODO: there must be a way to not iterate twice + docs_by_generation = zip( + docs_to_send, (gen for _, gen, _ in changes), + (trans for _, _, trans in changes)) + + # exchange documents and try to insert the returned ones with + # the target, return target synced-up-to gen + new_gen, new_trans_id = sync_target.sync_exchange( + docs_by_generation, self.source._replica_uid, + target_last_known_gen, target_last_known_trans_id, + self._insert_doc_from_target, ensure_callback=ensure_callback) + # record target synced-up-to generation including applying what we sent + self.source._set_replica_gen_and_trans_id( + self.target_replica_uid, new_gen, new_trans_id) + + # if gapless record current reached generation with target + self._record_sync_info_with_the_target(my_gen) + + return my_gen + + +class SyncExchange(object): + """Steps and state for carrying through a sync exchange on a target.""" + + def __init__(self, db, source_replica_uid, last_known_generation): + self._db = db + self.source_replica_uid = source_replica_uid + self.source_last_known_generation = last_known_generation + self.seen_ids = {} # incoming ids not superseded + self.changes_to_return = None + self.new_gen = None + self.new_trans_id = None + # for tests + self._incoming_trace = [] + self._trace_hook = None + self._db._last_exchange_log = { + 'receive': {'docs': self._incoming_trace}, + 'return': None + } + + def _set_trace_hook(self, cb): + self._trace_hook = cb + + def _trace(self, state): + if not self._trace_hook: + return + self._trace_hook(state) + + def insert_doc_from_source(self, doc, source_gen, trans_id): + """Try to insert synced document from source. + + Conflicting documents are not inserted but will be sent over + to the sync source. + + It keeps track of progress by storing the document source + generation as well. + + The 1st step of a sync exchange is to call this repeatedly to + try insert all incoming documents from the source. + + :param doc: A Document object. + :param source_gen: The source generation of doc. + :return: None + """ + state, at_gen = self._db._put_doc_if_newer( + doc, save_conflict=False, + replica_uid=self.source_replica_uid, replica_gen=source_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.seen_ids[doc.doc_id] = at_gen + elif state == 'converged': + # magical convergence + self.seen_ids[doc.doc_id] = at_gen + elif state == 'superseded': + # we have something newer that we will return + pass + else: + # conflict that we will returne + assert state == 'conflicted' + # for tests + self._incoming_trace.append((doc.doc_id, doc.rev)) + self._db._last_exchange_log['receive'].update({ + 'source_uid': self.source_replica_uid, + 'source_gen': source_gen + }) + + def find_changes_to_return(self): + """Find changes to return. + + Find changes since last_known_generation in db generation + order using whats_changed. It excludes documents ids that have + already been considered (superseded by the sender, etc). + + :return: new_generation - the generation of this database + which the caller can consider themselves to be synchronized after + processing the returned documents. + """ + self._db._last_exchange_log['receive'].update({ # for tests + 'last_known_gen': self.source_last_known_generation + }) + self._trace('before whats_changed') + gen, trans_id, changes = self._db.whats_changed( + self.source_last_known_generation) + self._trace('after whats_changed') + self.new_gen = gen + self.new_trans_id = trans_id + seen_ids = self.seen_ids + # changed docs that weren't superseded by or converged with + self.changes_to_return = [ + (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes if + # there was a subsequent update + doc_id not in seen_ids or seen_ids.get(doc_id) < gen] + return self.new_gen + + def return_docs(self, return_doc_cb): + """Return the changed documents and their last change generation + repeatedly invoking the callback return_doc_cb. + + The final step of a sync exchange. + + :param: return_doc_cb(doc, gen, trans_id): is a callback + used to return the documents with their last change generation + to the target replica. + :return: None + """ + changes_to_return = self.changes_to_return + # return docs, including conflicts + changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] + self._trace('before get_docs') + docs = self._db.get_docs( + changed_doc_ids, check_for_conflicts=False, include_deleted=True) + + docs_by_gen = izip( + docs, (gen for _, gen, _ in changes_to_return), + (trans_id for _, _, trans_id in changes_to_return)) + _outgoing_trace = [] # for tests + for doc, gen, trans_id in docs_by_gen: + return_doc_cb(doc, gen, trans_id) + _outgoing_trace.append((doc.doc_id, doc.rev)) + # for tests + self._db._last_exchange_log['return'] = { + 'docs': _outgoing_trace, + 'last_gen': self.new_gen} + + +class LocalSyncTarget(l2db.SyncTarget): + """Common sync target implementation logic for all local sync targets.""" + + def __init__(self, db): + self._db = db + self._trace_hook = None + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + sync_exch = SyncExchange( + self._db, source_replica_uid, last_known_generation) + if self._trace_hook: + sync_exch._set_trace_hook(self._trace_hook) + # 1st step: try to insert incoming docs and record progress + for doc, doc_gen, trans_id in docs_by_generations: + sync_exch.insert_doc_from_source(doc, doc_gen, trans_id) + # 2nd step: find changed documents (including conflicts) to return + new_gen = sync_exch.find_changes_to_return() + # final step: return docs and record source replica sync point + sync_exch.return_docs(return_doc_cb) + return new_gen, sync_exch.new_trans_id + + def _set_trace_hook(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/common/l2db/vectorclock.py b/src/leap/soledad/common/l2db/vectorclock.py new file mode 100644 index 00000000..42bceaa8 --- /dev/null +++ b/src/leap/soledad/common/l2db/vectorclock.py @@ -0,0 +1,89 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""VectorClockRev helper class.""" + + +class VectorClockRev(object): + """Track vector clocks for multiple replica ids. + + This allows simple comparison to determine if one VectorClockRev is + newer/older/in-conflict-with another VectorClockRev without having to + examine history. Every replica has a strictly increasing revision. When + creating a new revision, they include all revisions for all other replicas + which the new revision dominates, and increment their own revision to + something greater than the current value. + """ + + def __init__(self, value): + self._values = self._expand(value) + + def __repr__(self): + s = self.as_str() + return '%s(%s)' % (self.__class__.__name__, s) + + def as_str(self): + s = '|'.join(['%s:%d' % (m, r) for m, r + in sorted(self._values.items())]) + return s + + def _expand(self, value): + result = {} + if value is None: + return result + for replica_info in value.split('|'): + replica_uid, counter = replica_info.split(':') + counter = int(counter) + result[replica_uid] = counter + return result + + def is_newer(self, other): + """Is this VectorClockRev strictly newer than other. + """ + if not self._values: + return False + if not other._values: + return True + this_is_newer = False + other_expand = dict(other._values) + for key, value in self._values.iteritems(): + if key in other_expand: + other_value = other_expand.pop(key) + if other_value > value: + return False + elif other_value < value: + this_is_newer = True + else: + this_is_newer = True + if other_expand: + return False + return this_is_newer + + def increment(self, replica_uid): + """Increase the 'replica_uid' section of this vector clock. + + :return: A string representing the new vector clock value + """ + self._values[replica_uid] = self._values.get(replica_uid, 0) + 1 + + def maximize(self, other_vcr): + for replica_uid, counter in other_vcr._values.iteritems(): + if replica_uid not in self._values: + self._values[replica_uid] = counter + else: + this_counter = self._values[replica_uid] + if this_counter < counter: + self._values[replica_uid] = counter |