From ea2f5e31c5754b71b2cb5aea9d9b36f4d2b2ac31 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 28 Nov 2012 20:05:27 -0200 Subject: add u1db openstack backend base files --- src/leap/soledad/README | 6 ++ src/leap/soledad/__init__.py | 164 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 src/leap/soledad/README create mode 100644 src/leap/soledad/__init__.py (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README new file mode 100644 index 00000000..91263d50 --- /dev/null +++ b/src/leap/soledad/README @@ -0,0 +1,6 @@ +Soledad -- Synchronization Of Locally Encrypted Data Among Devices +================================================================== + +This code is based on: + +* u1db 0.14 diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py new file mode 100644 index 00000000..b4284c84 --- /dev/null +++ b/src/leap/soledad/__init__.py @@ -0,0 +1,164 @@ +# License? + +"""A U1DB implementation that uses OpenStack Swift as its persistence layer.""" + +import errno +import os +try: + import simplejson as json +except ImportError: + import json # noqa +import sys +import time +import uuid + +from u1db.backends import CommonBackend, CommonSyncTarget +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) + + +class OpenStackDatabase(CommonBackend): + """A U1DB implementation that uses OpenStack as its persistence layer.""" + + def __init__(self, sqlite_file, document_factory=None): + """Create a new OpenStack data container.""" + raise NotImplementedError(self.__init__) + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return OpenStackSyncTarget(self) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None): + raise NotImplementedError(open_database) + + @staticmethod + def delete_database(sqlite_file): + raise NotImplementedError(delete_database) + + + def close(self): + raise NotImplementedError(self.close) + + def _is_initialized(self, c): + raise NotImplementedError(self._is_initialized) + + def _initialize(self, c): + raise NotImplementedError(self._initialize) + + def _ensure_schema(self): + raise NotImplementedError(self._ensure_schema) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + raise NotImplementedError(self._set_replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + raise NotImplementedError(self._set_replica_uid_in_transaction) + + def _get_replica_uid(self): + raise NotImplementedError(self._get_replica_uid) + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + raise NotImplementedError(self._get_generation_info) + + def _get_trans_id_for_gen(self, generation): + raise NotImplementedError(self._get_trans_id_for_gen) + + def _get_transaction_log(self): + raise NotImplementedError(self._get_transaction_log) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + raise NotImplementedError(self._has_conflicts) + + def get_doc(self, doc_id, include_deleted=False): + raise NotImplementedError(self.get_doc) + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + raise NotImplementedError(self.get_all_docs) + + def put_doc(self, doc): + raise NotImplementedError(self.put_doc) + + def whats_changed(self, old_generation=0): + raise NotImplementedError(self.whats_changed) + + def delete_doc(self, doc): + raise NotImplementedError(self.delete_doc) + + def _get_conflicts(self, doc_id): + return [] + + def get_doc_conflicts(self, doc_id): + return [] + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + raise NotImplementedError(self._do_set_replica_gen_and_trans_id) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + raise NotImplementedError(self._put_doc_if_newer) + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + def delete_index(self, index_name): + return False + +class LeapDocument(Document): + + def get_content_encrypted(self): + raise NotImplementedError(self.get_content_encrypted) + + def set_content_encrypted(self): + raise NotImplementedError(self.set_content_encrypted) + + +class OpenStackSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + raise NotImplementedError(self.record_sync_info) -- cgit v1.2.3 From ea8d5c9d587d7089637ff8cd4076029505f3aca0 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 29 Nov 2012 10:55:46 -0200 Subject: add swiftclient version to readme --- src/leap/soledad/README | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README index 91263d50..dc448374 100644 --- a/src/leap/soledad/README +++ b/src/leap/soledad/README @@ -1,6 +1,13 @@ Soledad -- Synchronization Of Locally Encrypted Data Among Devices ================================================================== -This code is based on: +Dependencies +------------ -* u1db 0.14 +Soledad uses the following python libraries: + + * u1db 0.1.4 [1] + * python-swiftclient 1.1.1 [2] + +[1] http://pypi.python.org/pypi/u1db/0.1.4 +[2] https://launchpad.net/python-swiftclient -- cgit v1.2.3 From 0f1f9474e7ea6b52dc3ae18444cfaaca56ff3070 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 29 Nov 2012 10:56:06 -0200 Subject: organize methods for openstack backend --- src/leap/soledad/__init__.py | 172 +++++++++++++++++++++---------------------- 1 file changed, 86 insertions(+), 86 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index b4284c84..3d685635 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -2,15 +2,10 @@ """A U1DB implementation that uses OpenStack Swift as its persistence layer.""" -import errno -import os try: import simplejson as json except ImportError: import json # noqa -import sys -import time -import uuid from u1db.backends import CommonBackend, CommonSyncTarget from u1db import ( @@ -20,73 +15,33 @@ from u1db import ( vectorclock, ) +from swiftclient import client + class OpenStackDatabase(CommonBackend): """A U1DB implementation that uses OpenStack as its persistence layer.""" - def __init__(self, sqlite_file, document_factory=None): + def __init__(self, auth_url, user, auth_key): """Create a new OpenStack data container.""" - raise NotImplementedError(self.__init__) + self._auth_url = auth_url + self._user = user + self._auth_key = auth_key + self.set_document_factory(LeapDocument) + self._connection = swiftclient.Connection(self._auth_url, self._user, + self._auth_key) + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- def set_document_factory(self, factory): self._factory = factory - def get_sync_target(self): - return OpenStackSyncTarget(self) - - @classmethod - def open_database(cls, sqlite_file, create, backend_cls=None, - document_factory=None): - raise NotImplementedError(open_database) - - @staticmethod - def delete_database(sqlite_file): - raise NotImplementedError(delete_database) - - - def close(self): - raise NotImplementedError(self.close) - - def _is_initialized(self, c): - raise NotImplementedError(self._is_initialized) - - def _initialize(self, c): - raise NotImplementedError(self._initialize) - - def _ensure_schema(self): - raise NotImplementedError(self._ensure_schema) - - def _set_replica_uid(self, replica_uid): - """Force the replica_uid to be set.""" - raise NotImplementedError(self._set_replica_uid) - - def _set_replica_uid_in_transaction(self, replica_uid): - """Set the replica_uid. A transaction should already be held.""" - raise NotImplementedError(self._set_replica_uid_in_transaction) - - def _get_replica_uid(self): - raise NotImplementedError(self._get_replica_uid) - - _replica_uid = property(_get_replica_uid) + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) - def _get_generation(self): - raise NotImplementedError(self._get_generation) - - def _get_generation_info(self): - raise NotImplementedError(self._get_generation_info) - - def _get_trans_id_for_gen(self, generation): - raise NotImplementedError(self._get_trans_id_for_gen) - - def _get_transaction_log(self): - raise NotImplementedError(self._get_transaction_log) - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - raise NotImplementedError(self._get_doc) - - def _has_conflicts(self, doc_id): - raise NotImplementedError(self._has_conflicts) + def whats_changed(self, old_generation=0): + raise NotImplementedError(self.whats_changed) def get_doc(self, doc_id, include_deleted=False): raise NotImplementedError(self.get_doc) @@ -98,18 +53,47 @@ class OpenStackDatabase(CommonBackend): def put_doc(self, doc): raise NotImplementedError(self.put_doc) - def whats_changed(self, old_generation=0): - raise NotImplementedError(self.whats_changed) - def delete_doc(self, doc): raise NotImplementedError(self.delete_doc) - def _get_conflicts(self, doc_id): + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): return [] + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + def get_doc_conflicts(self, doc_id): return [] + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + return OpenStackSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + raise NotImplementedError(self.close) + def _get_replica_gen_and_trans_id(self, other_replica_uid): raise NotImplementedError(self._get_replica_gen_and_trans_id) @@ -117,33 +101,49 @@ class OpenStackDatabase(CommonBackend): other_generation, other_transaction_id): raise NotImplementedError(self._set_replica_gen_and_trans_id) - def _do_set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, - other_transaction_id): - raise NotImplementedError(self._do_set_replica_gen_and_trans_id) + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- - def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, - replica_gen=None, replica_trans_id=None): - raise NotImplementedError(self._put_doc_if_newer) + def _get_generation(self): + raise NotImplementedError(self._get_generation) - def resolve_doc(self, doc, conflicted_doc_revs): - raise NotImplementedError(self.resolve_doc) + def _get_generation_info(self): + raise NotImplementedError(self._get_generation_info) - def list_indexes(self): - return [] + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + raise NotImplementedError(self._get_doc) - def get_from_index(self, index_name, *key_values): - return [] + def _has_conflicts(self, doc_id): + raise NotImplementedError(self._has_conflicts) - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - return [] + def _get_transaction_log(self): + raise NotImplementedError(self._get_transaction_log) - def get_index_keys(self, index_name): - return [] + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + raise NotImplementedError(self._get_trans_id_for_gen) + + #------------------------------------------------------------------------- + # OpenStack specific methods + #------------------------------------------------------------------------- + + def _is_initialized(self, c): + raise NotImplementedError(self._is_initialized) + + def _initialize(self, c): + raise NotImplementedError(self._initialize) + + def _get_auth(self): + self._url, self._auth_token = self._connection.get_auth(self._auth_url, + self._user, + self._auth_key) + return self._url, self.auth_token - def delete_index(self, index_name): - return False class LeapDocument(Document): -- cgit v1.2.3 From 17ccbcb831044c29f521b529f5aa96dc2a3cd18f Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 29 Nov 2012 10:56:49 -0200 Subject: add u1db code (not as submodule) --- src/leap/soledad/u1db/__init__.py | 697 +++++++ src/leap/soledad/u1db/backends/__init__.py | 211 +++ src/leap/soledad/u1db/backends/dbschema.sql | 42 + src/leap/soledad/u1db/backends/inmemory.py | 469 +++++ src/leap/soledad/u1db/backends/sqlite_backend.py | 926 ++++++++++ src/leap/soledad/u1db/commandline/__init__.py | 15 + src/leap/soledad/u1db/commandline/client.py | 497 +++++ src/leap/soledad/u1db/commandline/command.py | 80 + src/leap/soledad/u1db/commandline/serve.py | 34 + src/leap/soledad/u1db/errors.py | 189 ++ src/leap/soledad/u1db/query_parser.py | 370 ++++ src/leap/soledad/u1db/remote/__init__.py | 15 + .../soledad/u1db/remote/basic_auth_middleware.py | 68 + src/leap/soledad/u1db/remote/http_app.py | 629 +++++++ src/leap/soledad/u1db/remote/http_client.py | 218 +++ src/leap/soledad/u1db/remote/http_database.py | 143 ++ src/leap/soledad/u1db/remote/http_errors.py | 46 + src/leap/soledad/u1db/remote/http_target.py | 135 ++ src/leap/soledad/u1db/remote/oauth_middleware.py | 89 + src/leap/soledad/u1db/remote/server_state.py | 67 + src/leap/soledad/u1db/remote/ssl_match_hostname.py | 64 + src/leap/soledad/u1db/remote/utils.py | 23 + src/leap/soledad/u1db/sync.py | 304 ++++ src/leap/soledad/u1db/tests/__init__.py | 463 +++++ src/leap/soledad/u1db/tests/c_backend_wrapper.pyx | 1541 ++++++++++++++++ .../soledad/u1db/tests/commandline/__init__.py | 47 + .../soledad/u1db/tests/commandline/test_client.py | 916 ++++++++++ .../soledad/u1db/tests/commandline/test_command.py | 105 ++ .../soledad/u1db/tests/commandline/test_serve.py | 101 ++ .../soledad/u1db/tests/test_auth_middleware.py | 309 ++++ src/leap/soledad/u1db/tests/test_backends.py | 1895 ++++++++++++++++++++ src/leap/soledad/u1db/tests/test_c_backend.py | 634 +++++++ src/leap/soledad/u1db/tests/test_common_backend.py | 33 + src/leap/soledad/u1db/tests/test_document.py | 148 ++ src/leap/soledad/u1db/tests/test_errors.py | 61 + src/leap/soledad/u1db/tests/test_http_app.py | 1133 ++++++++++++ src/leap/soledad/u1db/tests/test_http_client.py | 361 ++++ src/leap/soledad/u1db/tests/test_http_database.py | 256 +++ src/leap/soledad/u1db/tests/test_https.py | 117 ++ src/leap/soledad/u1db/tests/test_inmemory.py | 128 ++ src/leap/soledad/u1db/tests/test_open.py | 69 + src/leap/soledad/u1db/tests/test_query_parser.py | 443 +++++ .../soledad/u1db/tests/test_remote_sync_target.py | 314 ++++ src/leap/soledad/u1db/tests/test_remote_utils.py | 36 + src/leap/soledad/u1db/tests/test_server_state.py | 93 + src/leap/soledad/u1db/tests/test_sqlite_backend.py | 493 +++++ src/leap/soledad/u1db/tests/test_sync.py | 1285 +++++++++++++ .../soledad/u1db/tests/test_test_infrastructure.py | 41 + src/leap/soledad/u1db/tests/test_vectorclock.py | 121 ++ src/leap/soledad/u1db/tests/testing-certs/Makefile | 35 + .../soledad/u1db/tests/testing-certs/cacert.pem | 58 + .../soledad/u1db/tests/testing-certs/testing.cert | 61 + .../soledad/u1db/tests/testing-certs/testing.key | 16 + src/leap/soledad/u1db/vectorclock.py | 89 + 54 files changed, 16733 insertions(+) create mode 100644 src/leap/soledad/u1db/__init__.py create mode 100644 src/leap/soledad/u1db/backends/__init__.py create mode 100644 src/leap/soledad/u1db/backends/dbschema.sql create mode 100644 src/leap/soledad/u1db/backends/inmemory.py create mode 100644 src/leap/soledad/u1db/backends/sqlite_backend.py create mode 100644 src/leap/soledad/u1db/commandline/__init__.py create mode 100644 src/leap/soledad/u1db/commandline/client.py create mode 100644 src/leap/soledad/u1db/commandline/command.py create mode 100644 src/leap/soledad/u1db/commandline/serve.py create mode 100644 src/leap/soledad/u1db/errors.py create mode 100644 src/leap/soledad/u1db/query_parser.py create mode 100644 src/leap/soledad/u1db/remote/__init__.py create mode 100644 src/leap/soledad/u1db/remote/basic_auth_middleware.py create mode 100644 src/leap/soledad/u1db/remote/http_app.py create mode 100644 src/leap/soledad/u1db/remote/http_client.py create mode 100644 src/leap/soledad/u1db/remote/http_database.py create mode 100644 src/leap/soledad/u1db/remote/http_errors.py create mode 100644 src/leap/soledad/u1db/remote/http_target.py create mode 100644 src/leap/soledad/u1db/remote/oauth_middleware.py create mode 100644 src/leap/soledad/u1db/remote/server_state.py create mode 100644 src/leap/soledad/u1db/remote/ssl_match_hostname.py create mode 100644 src/leap/soledad/u1db/remote/utils.py create mode 100644 src/leap/soledad/u1db/sync.py create mode 100644 src/leap/soledad/u1db/tests/__init__.py create mode 100644 src/leap/soledad/u1db/tests/c_backend_wrapper.pyx create mode 100644 src/leap/soledad/u1db/tests/commandline/__init__.py create mode 100644 src/leap/soledad/u1db/tests/commandline/test_client.py create mode 100644 src/leap/soledad/u1db/tests/commandline/test_command.py create mode 100644 src/leap/soledad/u1db/tests/commandline/test_serve.py create mode 100644 src/leap/soledad/u1db/tests/test_auth_middleware.py create mode 100644 src/leap/soledad/u1db/tests/test_backends.py create mode 100644 src/leap/soledad/u1db/tests/test_c_backend.py create mode 100644 src/leap/soledad/u1db/tests/test_common_backend.py create mode 100644 src/leap/soledad/u1db/tests/test_document.py create mode 100644 src/leap/soledad/u1db/tests/test_errors.py create mode 100644 src/leap/soledad/u1db/tests/test_http_app.py create mode 100644 src/leap/soledad/u1db/tests/test_http_client.py create mode 100644 src/leap/soledad/u1db/tests/test_http_database.py create mode 100644 src/leap/soledad/u1db/tests/test_https.py create mode 100644 src/leap/soledad/u1db/tests/test_inmemory.py create mode 100644 src/leap/soledad/u1db/tests/test_open.py create mode 100644 src/leap/soledad/u1db/tests/test_query_parser.py create mode 100644 src/leap/soledad/u1db/tests/test_remote_sync_target.py create mode 100644 src/leap/soledad/u1db/tests/test_remote_utils.py create mode 100644 src/leap/soledad/u1db/tests/test_server_state.py create mode 100644 src/leap/soledad/u1db/tests/test_sqlite_backend.py create mode 100644 src/leap/soledad/u1db/tests/test_sync.py create mode 100644 src/leap/soledad/u1db/tests/test_test_infrastructure.py create mode 100644 src/leap/soledad/u1db/tests/test_vectorclock.py create mode 100644 src/leap/soledad/u1db/tests/testing-certs/Makefile create mode 100644 src/leap/soledad/u1db/tests/testing-certs/cacert.pem create mode 100644 src/leap/soledad/u1db/tests/testing-certs/testing.cert create mode 100644 src/leap/soledad/u1db/tests/testing-certs/testing.key create mode 100644 src/leap/soledad/u1db/vectorclock.py (limited to 'src/leap') diff --git a/src/leap/soledad/u1db/__init__.py b/src/leap/soledad/u1db/__init__.py new file mode 100644 index 00000000..ed41bb03 --- /dev/null +++ b/src/leap/soledad/u1db/__init__.py @@ -0,0 +1,697 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""U1DB""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db.errors import InvalidJSON, InvalidContent + +__version_info__ = (0, 1, 4) +__version__ = '.'.join(map(str, __version_info__)) + + +def open(path, create, document_factory=None): + """Open a database at the given location. + + Will raise u1db.errors.DatabaseDoesNotExist if create=False and the + database does not already exist. + + :param path: The filesystem path for the database to open. + :param create: True/False, should the database be created if it doesn't + already exist? + :param document_factory: A function that will be called with the same + parameters as Document.__init__. + :return: An instance of Database. + """ + from u1db.backends import sqlite_backend + return sqlite_backend.SQLiteDatabase.open_database( + path, create=create, document_factory=document_factory) + + +# constraints on database names (relevant for remote access, as regex) +DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*" + +# constraints on doc ids (as regex) +# (no slashes, and no characters outside the ascii range) +DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+" + + +class Database(object): + """A JSON Document data store. + + This data store can be synchronized with other u1db.Database instances. + """ + + def set_document_factory(self, factory): + """Set the document factory that will be used to create objects to be + returned as documents by the database. + + :param factory: A function that returns an object which at minimum must + satisfy the same interface as does the class DocumentBase. + Subclassing that class is the easiest way to create such + a function. + """ + raise NotImplementedError(self.set_document_factory) + + def set_document_size_limit(self, limit): + """Set the maximum allowed document size for this database. + + :param limit: Maximum allowed document size in bytes. + """ + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + """Return a list of documents that have changed since old_generation. + This allows APPS to only store a db generation before going + 'offline', and then when coming back online they can use this + data to update whatever extra data they are storing. + + :param old_generation: The generation of the database in the old + state. + :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) + The current generation of the database, its associated transaction + id, and a list of of changed documents since old_generation, + represented by tuples with for each document its doc_id and the + generation and transaction id corresponding to the last intervening + change and sorted by generation (old changes first) + """ + raise NotImplementedError(self.whats_changed) + + def get_doc(self, doc_id, include_deleted=False): + """Get the JSON string for the given document. + + :param doc_id: The unique document identifier + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise asking for a deleted + document will return None. + :return: a Document object. + """ + raise NotImplementedError(self.get_doc) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + """Get the JSON content for many documents. + + :param doc_ids: A list of document identifiers. + :param check_for_conflicts: If set to False, then the conflict check + will be skipped, and 'None' will be returned instead of True/False. + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: iterable giving the Document object for each document id + in matching doc_ids order. + """ + raise NotImplementedError(self.get_docs) + + def get_all_docs(self, include_deleted=False): + """Get the JSON content for all documents in the database. + + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: (generation, [Document]) + The current generation of the database, followed by a list of all + the documents in the database. + """ + raise NotImplementedError(self.get_all_docs) + + def create_doc(self, content, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param content: A Python dictionary. + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc) + + def create_doc_from_json(self, json, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param json: The JSON document string + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc_from_json) + + def put_doc(self, doc): + """Update a document. + If the document currently has conflicts, put will fail. + If the database specifies a maximum document size and the document + exceeds it, put will fail and raise a DocumentTooBig exception. + + :param doc: A Document with new content. + :return: new_doc_rev - The new revision identifier for the document. + The Document object will also be updated. + """ + raise NotImplementedError(self.put_doc) + + def delete_doc(self, doc): + """Mark a document as deleted. + Will abort if the current revision doesn't match doc.rev. + This will also set doc.content to None. + """ + raise NotImplementedError(self.delete_doc) + + def create_index(self, index_name, *index_expressions): + """Create an named index, which can then be queried for future lookups. + Creating an index which already exists is not an error, and is cheap. + Creating an index which does not match the index_expressions of the + existing index is an error. + Creating an index will block until the expressions have been evaluated + and the index generated. + + :param index_name: A unique name which can be used as a key prefix + :param index_expressions: index expressions defining the index + information. + + Examples: + + "fieldname", or "fieldname.subfieldname" to index alphabetically + sorted on the contents of a field. + + "number(fieldname, width)", "lower(fieldname)" + """ + raise NotImplementedError(self.create_index) + + def delete_index(self, index_name): + """Remove a named index. + + :param index_name: The name of the index we are removing + """ + raise NotImplementedError(self.delete_index) + + def list_indexes(self): + """List the definitions of all known indexes. + + :return: A list of [('index-name', ['field', 'field2'])] definitions. + """ + raise NotImplementedError(self.list_indexes) + + def get_from_index(self, index_name, *key_values): + """Return documents that match the keys supplied. + + You must supply exactly the same number of values as have been defined + in the index. It is possible to do a prefix match by using '*' to + indicate a wildcard match. You can only supply '*' to trailing entries, + (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) + It is also possible to append a '*' to the last supplied value (eg + 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_from_index) + + def get_range_from_index(self, index_name, start_value, end_value): + """Return documents that fall within the specified range. + + Both ends of the range are inclusive. For both start_value and + end_value, one must supply exactly the same number of values as have + been defined in the index, or pass None. In case of a single column + index, a string is accepted as an alternative for a tuple with a single + value. It is possible to do a prefix match by using '*' to indicate + a wildcard match. You can only supply '*' to trailing entries, (eg + 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also + possible to append a '*' to the last supplied value (eg 'val*', '*', + '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param start_values: tuples of values that define the lower bound of + the range. eg, if you have an index with 3 fields then you would + have: (val1, val2, val3) + :param end_values: tuples of values that define the upper bound of the + range. eg, if you have an index with 3 fields then you would have: + (val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_range_from_index) + + def get_index_keys(self, index_name): + """Return all keys under which documents are indexed in this index. + + :param index_name: The index to query + :return: [] A list of tuples of indexed keys. + """ + raise NotImplementedError(self.get_index_keys) + + def get_doc_conflicts(self, doc_id): + """Get the list of conflicts for the given document. + + The order of the conflicts is such that the first entry is the value + that would be returned by "get_doc". + + :return: [doc] A list of the Document entries that are conflicted. + """ + raise NotImplementedError(self.get_doc_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + """Mark a document as no longer conflicted. + + We take the list of revisions that the client knows about that it is + superseding. This may be a different list from the actual current + conflicts, in which case only those are removed as conflicted. This + may fail if the conflict list is significantly different from the + supplied information. (sync could have happened in the background from + the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + + :param doc: A Document with the new content to be inserted. + :param conflicted_doc_revs: A list of revisions that the new content + supersedes. + """ + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + """Return a SyncTarget object, for another u1db to synchronize with. + + :return: An instance of SyncTarget. + """ + raise NotImplementedError(self.get_sync_target) + + def close(self): + """Release any resources associated with this database.""" + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + """Synchronize documents with remote replica exposed at url. + + :param url: the url of the target replica to sync with. + :param creds: optional dictionary giving credentials + to authorize the operation with the server. For using OAuth + the form of creds is: + {'oauth': { + 'consumer_key': ..., + 'consumer_secret': ..., + 'token_key': ..., + 'token_secret': ... + }} + :param autocreate: ask the target to create the db if non-existent. + :return: local_gen_before_sync The local generation before the + synchronisation was performed. This is useful to pass into + whatschanged, if an application wants to know which documents were + affected by a synchronisation. + """ + from u1db.sync import Synchronizer + from u1db.remote.http_target import HTTPSyncTarget + return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + """Return the last known generation and transaction id for the other db + replica. + + When you do a synchronization with another replica, the Database keeps + track of what generation the other database replica was at, and what + the associated transaction id was. This is used to determine what data + needs to be sent, and if two databases are claiming to be the same + replica. + + :param other_replica_uid: The identifier for the other replica. + :return: (gen, trans_id) The generation and transaction id we + encountered during synchronization. If we've never synchronized + with the replica, this is (0, ''). + """ + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """Set the last-known generation and transaction id for the other + database replica. + + We have just performed some synchronization, and we want to track what + generation the other replica was at. See also + _get_replica_gen_and_trans_id. + :param other_replica_uid: The U1DB identifier for the other replica. + :param other_generation: The generation number for the other replica. + :param other_transaction_id: The transaction id associated with the + generation. + """ + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + """Insert/update document into the database with a given revision. + + This api is used during synchronization operations. + + If a document would conflict and save_conflict is set to True, the + content will be selected as the 'current' content for doc.doc_id, + even though doc.rev doesn't supersede the currently stored revision. + The currently stored document will be added to the list of conflict + alternatives for the given doc_id. + + This forces the new content to be 'current' so that we get convergence + after synchronizing, even if people don't resolve conflicts. Users can + then notice that their content is out of date, update it, and + synchronize again. (The alternative is that users could synchronize and + think the data has propagated, but their local copy looks fine, and the + remote copy is never updated again.) + + :param doc: A Document object + :param save_conflict: If this document is a conflict, do you want to + save it as a conflict, or just ignore it. + :param replica_uid: A unique replica identifier. + :param replica_gen: The generation of the replica corresponding to the + this document. The replica arguments are optional, but are used + during synchronization. + :param replica_trans_id: The transaction_id associated with the + generation. + :return: (state, at_gen) - If we don't have doc_id already, + or if doc_rev supersedes the existing document revision, + then the content will be inserted, and state is 'inserted'. + If doc_rev is less than or equal to the existing revision, + then the put is ignored and state is respecitvely 'superseded' + or 'converged'. + If doc_rev is not strictly superseded or supersedes, then + state is 'conflicted'. The document will not be inserted if + save_conflict is False. + For 'inserted' or 'converged', at_gen is the insertion/current + generation. + """ + raise NotImplementedError(self._put_doc_if_newer) + + +class DocumentBase(object): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json_string: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + def __init__(self, doc_id, rev, json_string, has_conflicts=False): + self.doc_id = doc_id + self.rev = rev + if json_string is not None: + try: + value = json.loads(json_string) + except json.JSONDecodeError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + self.has_conflicts = has_conflicts + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = None + if other._json: + c2 = json.loads(other._json) + else: + c2 = None + return c1 == c2 + + def __repr__(self): + if self.has_conflicts: + extra = ', conflicted' + else: + extra = '' + return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, + self.rev, extra, self.get_json()) + + def __hash__(self): + raise NotImplementedError(self.__hash__) + + def __eq__(self, other): + if not isinstance(other, Document): + return NotImplemented + return ( + self.doc_id == other.doc_id and self.rev == other.rev and + self.same_content_as(other) and self.has_conflicts == + other.has_conflicts) + + def __lt__(self, other): + """This is meant for testing, not part of the official api. + + It is implemented so that sorted([Document, Document]) can be used. + It doesn't imply that users would want their documents to be sorted in + this order. + """ + # Since this is just for testing, we don't worry about comparing + # against things that aren't a Document. + return ((self.doc_id, self.rev, self.get_json()) + < (other.doc_id, other.rev, other.get_json())) + + def get_json(self): + """Get the json serialization of this document.""" + if self._json is not None: + return self._json + return None + + def get_size(self): + """Calculate the total size of the document.""" + size = 0 + json = self.get_json() + if json: + size += len(json) + if self.rev: + size += len(self.rev) + if self.doc_id: + size += len(self.doc_id) + return size + + def set_json(self, json_string): + """Set the json serialization of this document.""" + if json_string is not None: + try: + value = json.loads(json_string) + except json.JSONDecodeError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._json = None + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._json is not None: + return False + return True + + +class Document(DocumentBase): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + # The following part of the API is optional: no implementation is forced to + # have it but if the language supports dictionaries/hashtables, it makes + # Documents a lot more user friendly. + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): + # TODO: We convert the json in the superclass to check its validity so + # we might as well set _content here directly since the price is + # already being paid. + super(Document, self).__init__(doc_id, rev, json, has_conflicts) + self._content = None + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = self._content + if other._json: + c2 = json.loads(other._json) + else: + c2 = other._content + return c1 == c2 + + def get_json(self): + """Get the json serialization of this document.""" + json_string = super(Document, self).get_json() + if json_string is not None: + return json_string + if self._content is not None: + return json.dumps(self._content) + return None + + def set_json(self, json): + """Set the json serialization of this document.""" + self._content = None + super(Document, self).set_json(json) + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._content = None + super(Document, self).make_tombstone() + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._content is not None: + return False + return super(Document, self).is_tombstone() + + def _get_content(self): + """Get the dictionary representing this document.""" + if self._json is not None: + self._content = json.loads(self._json) + self._json = None + if self._content is not None: + return self._content + return None + + def _set_content(self, content): + """Set the dictionary representing this document.""" + try: + tmp = json.dumps(content) + except TypeError: + raise InvalidContent( + "Can not be converted to JSON: %r" % (content,)) + if not tmp.startswith('{'): + raise InvalidContent( + "Can not be converted to a JSON object: %r." % (content,)) + # We might as well store the JSON at this point since we did the work + # of encoding it, and it doesn't lose any information. + self._json = tmp + self._content = None + + content = property( + _get_content, _set_content, doc="Content of the Document.") + + # End of optional part. + + +class SyncTarget(object): + """Functionality for using a Database as a synchronization target.""" + + def get_sync_info(self, source_replica_uid): + """Return information about known state. + + Return the replica_uid and the current database generation of this + database, and the last-seen database generation for source_replica_uid + + :param source_replica_uid: Another replica which we might have + synchronized with in the past. + :return: (target_replica_uid, target_replica_generation, + target_trans_id, source_replica_last_known_generation, + source_replica_last_known_transaction_id) + """ + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + """Record tip information for another replica. + + After sync_exchange has been processed, the caller will have + received new content from this replica. This call allows the + source replica instigating the sync to inform us what their + generation became after applying the documents we returned. + + This is used to allow future sync operations to not need to repeat data + that we just talked about. It also means that if this is called at the + wrong time, there can be database records that will never be + synchronized. + + :param source_replica_uid: The identifier for the source replica. + :param source_replica_generation: + The database generation for the source replica. + :param source_replica_transaction_id: The transaction id associated + with the source replica generation. + """ + raise NotImplementedError(self.record_sync_info) + + def sync_exchange(self, docs_by_generation, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + """Incorporate the documents sent from the source replica. + + This is not meant to be called by client code directly, but is used as + part of sync(). + + This adds docs to the local store, and determines documents that need + to be returned to the source replica. + + Documents must be supplied in docs_by_generation paired with + the generation of their latest change in order from the oldest + change to the newest, that means from the oldest generation to + the newest. + + Documents are also returned paired with the generation of + their latest change in order from the oldest change to the + newest. + + :param docs_by_generation: A list of [(Document, generation, + transaction_id)] tuples indicating documents which should be + updated on this replica paired with the generation and transaction + id of their latest change. + :param source_replica_uid: The source replica's identifier + :param last_known_generation: The last generation that the source + replica knows about this target replica + :param last_known_trans_id: The last transaction id that the source + replica knows about this target replica + :param: return_doc_cb(doc, gen): is a callback + used to return documents to the source replica, it will + be invoked in turn with Documents that have changed since + last_known_generation together with the generation of + their last change. + :param: ensure_callback(replica_uid): if set the target may create + the target db if not yet existent, the callback can then + be used to inform of the created db replica uid. + :return: new_generation - After applying docs_by_generation, this is + the current generation for this replica + """ + raise NotImplementedError(self.sync_exchange) + + def _set_trace_hook(self, cb): + """Set a callback that will be invoked to trace database actions. + + The callback will be passed a string indicating the current state, and + the sync target object. Implementations do not have to implement this + api, it is used by the test suite. + + :param cb: A callable that takes cb(state) + """ + raise NotImplementedError(self._set_trace_hook) + + def _set_trace_hook_shallow(self, cb): + """Set a callback that will be invoked to trace database actions. + + Similar to _set_trace_hook, for implementations that don't offer + state changes from the inner working of sync_exchange(). + + :param cb: A callable that takes cb(state) + """ + self._set_trace_hook(cb) diff --git a/src/leap/soledad/u1db/backends/__init__.py b/src/leap/soledad/u1db/backends/__init__.py new file mode 100644 index 00000000..c8e5adc6 --- /dev/null +++ b/src/leap/soledad/u1db/backends/__init__.py @@ -0,0 +1,211 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Abstract classes and common implementations for the backends.""" + +import re +try: + import simplejson as json +except ImportError: + import json # noqa +import uuid + +import u1db +from u1db import ( + errors, +) +import u1db.sync +from u1db.vectorclock import VectorClockRev + + +check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE) + + +class CommonSyncTarget(u1db.sync.LocalSyncTarget): + pass + + +class CommonBackend(u1db.Database): + + document_size_limit = 0 + + def _allocate_doc_id(self): + """Generate a unique identifier for this document.""" + return 'D-' + uuid.uuid4().hex # 'D-' stands for document + + def _allocate_transaction_id(self): + return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction + + def _allocate_doc_rev(self, old_doc_rev): + vcr = VectorClockRev(old_doc_rev) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def _check_doc_id(self, doc_id): + if not check_doc_id_re.match(doc_id): + raise errors.InvalidDocId() + + def _check_doc_size(self, doc): + if not self.document_size_limit: + return + if doc.get_size() > self.document_size_limit: + raise errors.DocumentTooBig + + def _get_generation(self): + """Return the current generation. + + """ + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + """Return the current generation and transaction id. + + """ + raise NotImplementedError(self._get_generation_info) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Extract the document from storage. + + This can return None if the document doesn't exist. + """ + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + """Return True if the doc has conflicts, False otherwise.""" + raise NotImplementedError(self._has_conflicts) + + def create_doc(self, content, doc_id=None): + json_string = json.dumps(content) + if doc_id is None: + doc_id = self._allocate_doc_id() + doc = self._factory(doc_id, None, json_string) + self.put_doc(doc) + return doc + + def create_doc_from_json(self, json, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + doc = self._factory(doc_id, None, json) + self.put_doc(doc) + return doc + + def _get_transaction_log(self): + """This is only for the test suite, it is not part of the api.""" + raise NotImplementedError(self._get_transaction_log) + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + for doc_id in doc_ids: + doc = self._get_doc( + doc_id, check_for_conflicts=check_for_conflicts) + if doc.is_tombstone() and not include_deleted: + continue + yield doc + + def _get_trans_id_for_gen(self, generation): + """Get the transaction id corresponding to a particular generation. + + Raises an InvalidGeneration when the generation does not exist. + + """ + raise NotImplementedError(self._get_trans_id_for_gen) + + def validate_gen_and_trans_id(self, generation, trans_id): + """Validate the generation and transaction id. + + Raises an InvalidGeneration when the generation does not exist, and an + InvalidTransactionId when it does but with a different transaction id. + + """ + if generation == 0: + return + known_trans_id = self._get_trans_id_for_gen(generation) + if known_trans_id != trans_id: + raise errors.InvalidTransactionId + + def _validate_source(self, other_replica_uid, other_generation, + other_transaction_id): + """Validate the new generation and transaction id. + + other_generation must be greater than what we have stored for this + replica, *or* it must be the same and the transaction_id must be the + same as well. + """ + (old_generation, + old_transaction_id) = self._get_replica_gen_and_trans_id( + other_replica_uid) + if other_generation < old_generation: + raise errors.InvalidGeneration + if other_generation > old_generation: + return + if other_transaction_id == old_transaction_id: + return + raise errors.InvalidTransactionId + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + cur_doc = self._get_doc(doc.doc_id) + doc_vcr = VectorClockRev(doc.rev) + if cur_doc is None: + cur_vcr = VectorClockRev(None) + else: + cur_vcr = VectorClockRev(cur_doc.rev) + self._validate_source(replica_uid, replica_gen, replica_trans_id) + if doc_vcr.is_newer(cur_vcr): + rev = doc.rev + self._prune_conflicts(doc, doc_vcr) + if doc.rev != rev: + # conflicts have been autoresolved + state = 'superseded' + else: + state = 'inserted' + self._put_and_update_indexes(cur_doc, doc) + elif doc.rev == cur_doc.rev: + # magical convergence + state = 'converged' + elif cur_vcr.is_newer(doc_vcr): + # Don't add this to seen_ids, because we have something newer, + # so we should send it back, and we should not generate a + # conflict + state = 'superseded' + elif cur_doc.same_content_as(doc): + # the documents have been edited to the same thing at both ends + doc_vcr.maximize(cur_vcr) + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._put_and_update_indexes(cur_doc, doc) + state = 'superseded' + else: + state = 'conflicted' + if save_conflict: + self._force_doc_sync_conflict(doc) + if replica_uid is not None and replica_gen is not None: + self._do_set_replica_gen_and_trans_id( + replica_uid, replica_gen, replica_trans_id) + return state, self._get_generation() + + def _ensure_maximal_rev(self, cur_rev, extra_revs): + vcr = VectorClockRev(cur_rev) + for rev in extra_revs: + vcr.maximize(VectorClockRev(rev)) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def set_document_size_limit(self, limit): + self.document_size_limit = limit diff --git a/src/leap/soledad/u1db/backends/dbschema.sql b/src/leap/soledad/u1db/backends/dbschema.sql new file mode 100644 index 00000000..ae027fc5 --- /dev/null +++ b/src/leap/soledad/u1db/backends/dbschema.sql @@ -0,0 +1,42 @@ +-- Database schema +CREATE TABLE transaction_log ( + generation INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT NOT NULL, + transaction_id TEXT NOT NULL +); +CREATE TABLE document ( + doc_id TEXT PRIMARY KEY, + doc_rev TEXT NOT NULL, + content TEXT +); +CREATE TABLE document_fields ( + doc_id TEXT NOT NULL, + field_name TEXT NOT NULL, + value TEXT +); +CREATE INDEX document_fields_field_value_doc_idx + ON document_fields(field_name, value, doc_id); + +CREATE TABLE sync_log ( + replica_uid TEXT PRIMARY KEY, + known_generation INTEGER, + known_transaction_id TEXT +); +CREATE TABLE conflicts ( + doc_id TEXT, + doc_rev TEXT, + content TEXT, + CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) +); +CREATE TABLE index_definitions ( + name TEXT, + offset INT, + field TEXT, + CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) +); +create index index_definitions_field on index_definitions(field); +CREATE TABLE u1db_config ( + name TEXT PRIMARY KEY, + value TEXT +); +INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py new file mode 100644 index 00000000..a271bb37 --- /dev/null +++ b/src/leap/soledad/u1db/backends/inmemory.py @@ -0,0 +1,469 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""The in-memory Database class for U1DB.""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) +from u1db.backends import CommonBackend, CommonSyncTarget + + +def get_prefix(value): + key_prefix = '\x01'.join(value) + return key_prefix.rstrip('*') + + +class InMemoryDatabase(CommonBackend): + """A database that only stores the data internally.""" + + def __init__(self, replica_uid, document_factory=None): + self._transaction_log = [] + self._docs = {} + # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner' + self._conflicts = {} + self._other_generations = {} + self._indexes = {} + self._replica_uid = replica_uid + self._factory = document_factory or Document + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + self._replica_uid = replica_uid + + def set_document_factory(self, factory): + self._factory = factory + + def close(self): + # This is a no-op, We don't want to free the data because one client + # may be closing it, while another wants to inspect the results. + pass + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + return self._other_generations.get(other_replica_uid, (0, '')) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + # TODO: to handle race conditions, we may want to check if the current + # value is greater than this new value. + self._other_generations[other_replica_uid] = (other_generation, + other_transaction_id) + + def get_sync_target(self): + return InMemorySyncTarget(self) + + def _get_transaction_log(self): + # snapshot! + return self._transaction_log[:] + + def _get_generation(self): + return len(self._transaction_log) + + def _get_generation_info(self): + if not self._transaction_log: + return 0, '' + return len(self._transaction_log), self._transaction_log[-1][1] + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + if generation > len(self._transaction_log): + raise errors.InvalidGeneration + return self._transaction_log[generation - 1][1] + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _put_and_update_indexes(self, old_doc, doc): + for index in self._indexes.itervalues(): + if old_doc is not None and not old_doc.is_tombstone(): + index.remove_json(old_doc.doc_id, old_doc.get_json()) + if not doc.is_tombstone(): + index.add_json(doc.doc_id, doc.get_json()) + trans_id = self._allocate_transaction_id() + self._docs[doc.doc_id] = (doc.rev, doc.get_json()) + self._transaction_log.append((doc.doc_id, trans_id)) + + def _get_doc(self, doc_id, check_for_conflicts=False): + try: + doc_rev, content = self._docs[doc_id] + except KeyError: + return None + doc = self._factory(doc_id, doc_rev, content) + if check_for_conflicts: + doc.has_conflicts = (doc.doc_id in self._conflicts) + return doc + + def _has_conflicts(self, doc_id): + return doc_id in self._conflicts + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Return all documents in the database.""" + generation = self._get_generation() + results = [] + for doc_id, (doc_rev, content) in self._docs.items(): + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = self._has_conflicts(doc_id) + results.append(doc) + return (generation, results) + + def get_doc_conflicts(self, doc_id): + if doc_id not in self._conflicts: + return [] + result = [self._get_doc(doc_id)] + result[0].has_conflicts = True + result.extend([self._factory(doc_id, rev, content) + for rev, content in self._conflicts[doc_id]]) + return result + + def _replace_conflicts(self, doc, conflicts): + if not conflicts: + del self._conflicts[doc.doc_id] + else: + self._conflicts[doc.doc_id] = conflicts + doc.has_conflicts = bool(conflicts) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + c_vcr = vectorclock.VectorClockRev(c_rev) + if doc_vcr.is_newer(c_vcr): + continue + if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)): + doc_vcr.maximize(c_vcr) + autoresolved = True + continue + remaining_conflicts.append((c_rev, c_doc)) + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._replace_conflicts(doc, remaining_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + cur_doc = self._get_doc(doc.doc_id) + if cur_doc is None: + cur_rev = None + else: + cur_rev = cur_doc.rev + new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + if c_rev in superseded_revs: + continue + remaining_conflicts.append((c_rev, c_doc)) + doc.rev = new_rev + if cur_rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + remaining_conflicts.append((new_rev, doc.get_json())) + self._replace_conflicts(doc, remaining_conflicts) + + def delete_doc(self, doc): + if doc.doc_id not in self._docs: + raise errors.DocumentDoesNotExist + if self._docs[doc.doc_id][1] in ('null', None): + raise errors.DocumentAlreadyDeleted + doc.make_tombstone() + self.put_doc(doc) + + def create_index(self, index_name, *index_expressions): + if index_name in self._indexes: + if self._indexes[index_name]._definition == list( + index_expressions): + return + raise errors.IndexNameTakenError + index = InMemoryIndex(index_name, list(index_expressions)) + for doc_id, (doc_rev, doc) in self._docs.iteritems(): + if doc is not None: + index.add_json(doc_id, doc) + self._indexes[index_name] = index + + def delete_index(self, index_name): + del self._indexes[index_name] + + def list_indexes(self): + definitions = [] + for idx in self._indexes.itervalues(): + definitions.append((idx._name, idx._definition)) + return definitions + + def get_from_index(self, index_name, *key_values): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + doc_ids = index.lookup(key_values) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + if isinstance(start_value, basestring): + start_value = (start_value,) + if isinstance(end_value, basestring): + end_value = (end_value,) + doc_ids = index.lookup_range(start_value, end_value) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_index_keys(self, index_name): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + keys = index.keys() + # XXX inefficiency warning + return list(set([tuple(key.split('\x01')) for key in keys])) + + def whats_changed(self, old_generation=0): + changes = [] + relevant_tail = self._transaction_log[old_generation:] + # We don't use len(self._transaction_log) because _transaction_log may + # get mutated by a concurrent operation. + cur_generation = old_generation + len(relevant_tail) + last_trans_id = '' + if relevant_tail: + last_trans_id = relevant_tail[-1][1] + elif self._transaction_log: + last_trans_id = self._transaction_log[-1][1] + seen = set() + generation = cur_generation + for doc_id, trans_id in reversed(relevant_tail): + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + generation -= 1 + changes.reverse() + return (cur_generation, last_trans_id, changes) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._conflicts.setdefault(doc.doc_id, []).append( + (my_doc.rev, my_doc.get_json())) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + +class InMemoryIndex(object): + """Interface for managing an Index.""" + + def __init__(self, index_name, index_definition): + self._name = index_name + self._definition = index_definition + self._values = {} + parser = query_parser.Parser() + self._getters = parser.parse_all(self._definition) + + def evaluate_json(self, doc): + """Determine the 'key' after applying this index to the doc.""" + raw = json.loads(doc) + return self.evaluate(raw) + + def evaluate(self, obj): + """Evaluate a dict object, applying this definition.""" + all_rows = [[]] + for getter in self._getters: + new_rows = [] + keys = getter.get(obj) + if not keys: + return [] + for key in keys: + new_rows.extend([row + [key] for row in all_rows]) + all_rows = new_rows + all_rows = ['\x01'.join(row) for row in all_rows] + return all_rows + + def add_json(self, doc_id, doc): + """Add this json doc to the index.""" + keys = self.evaluate_json(doc) + if not keys: + return + for key in keys: + self._values.setdefault(key, []).append(doc_id) + + def remove_json(self, doc_id, doc): + """Remove this json doc from the index.""" + keys = self.evaluate_json(doc) + if keys: + for key in keys: + doc_ids = self._values[key] + doc_ids.remove(doc_id) + if not doc_ids: + del self._values[key] + + def _find_non_wildcards(self, values): + """Check if this should be a wildcard match. + + Further, this will raise an exception if the syntax is improperly + defined. + + :return: The offset of the last value we need to match against. + """ + if len(values) != len(self._definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + last = 0 + for idx, val in enumerate(values): + if val.endswith('*'): + if val != '*': + # We have an 'x*' style wildcard + if is_wildcard: + # We were already in wildcard mode, so this is invalid + raise errors.InvalidGlobbing + last = idx + 1 + is_wildcard = True + else: + if is_wildcard: + # We were in wildcard mode, we can't follow that with + # non-wildcard + raise errors.InvalidGlobbing + last = idx + 1 + if not is_wildcard: + return -1 + return last + + def lookup(self, values): + """Find docs that match the values.""" + last = self._find_non_wildcards(values) + if last == -1: + return self._lookup_exact(values) + else: + return self._lookup_prefix(values[:last]) + + def lookup_range(self, start_values, end_values): + """Find docs within the range.""" + # TODO: Wildly inefficient, which is unlikely to be a problem for the + # inmemory implementation. + if start_values: + self._find_non_wildcards(start_values) + start_values = get_prefix(start_values) + if end_values: + if self._find_non_wildcards(end_values) == -1: + exact = True + else: + exact = False + end_values = get_prefix(end_values) + found = [] + for key, doc_ids in sorted(self._values.iteritems()): + if start_values and start_values > key: + continue + if end_values and end_values < key: + if exact: + break + else: + if not key.startswith(end_values): + break + found.extend(doc_ids) + return found + + def keys(self): + """Find the indexed keys.""" + return self._values.keys() + + def _lookup_prefix(self, value): + """Find docs that match the prefix string in values.""" + # TODO: We need a different data structure to make prefix style fast, + # some sort of sorted list would work, but a plain dict doesn't. + key_prefix = get_prefix(value) + all_doc_ids = [] + for key, doc_ids in sorted(self._values.iteritems()): + if key.startswith(key_prefix): + all_doc_ids.extend(doc_ids) + return all_doc_ids + + def _lookup_exact(self, value): + """Find docs that match exactly.""" + key = '\x01'.join(value) + if key in self._values: + return self._values[key] + return () + + +class InMemorySyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_transaction_id) diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py new file mode 100644 index 00000000..773213b5 --- /dev/null +++ b/src/leap/soledad/u1db/backends/sqlite_backend.py @@ -0,0 +1,926 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""A U1DB implementation that uses SQLite as its persistence layer.""" + +import errno +import os +try: + import simplejson as json +except ImportError: + import json # noqa +from sqlite3 import dbapi2 +import sys +import time +import uuid + +import pkg_resources + +from u1db.backends import CommonBackend, CommonSyncTarget +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) + + +class SQLiteDatabase(CommonBackend): + """A U1DB implementation that uses SQLite as its persistence layer.""" + + _sqlite_registry = {} + + def __init__(self, sqlite_file, document_factory=None): + """Create a new sqlite file.""" + self._db_handle = dbapi2.connect(sqlite_file) + self._real_replica_uid = None + self._ensure_schema() + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return SQLiteSyncTarget(self) + + @classmethod + def _which_index_storage(cls, c): + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'index_storage'") + except dbapi2.OperationalError, e: + # The table does not exist yet + return None, e + else: + return c.fetchone()[0], None + + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + + @classmethod + def _open_database(cls, sqlite_file, document_factory=None): + if not os.path.isfile(sqlite_file): + raise errors.DatabaseDoesNotExist() + tries = 2 + while True: + # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + db_handle = dbapi2.connect(sqlite_file) + c = db_handle.cursor() + v, err = cls._which_index_storage(c) + db_handle.close() + if v is not None: + break + # possibly another process is initializing it, wait for it to be + # done + if tries == 0: + raise err # go for the richest error? + tries -= 1 + time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) + return SQLiteDatabase._sqlite_registry[v]( + sqlite_file, document_factory=document_factory) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None): + try: + return cls._open_database( + sqlite_file, document_factory=document_factory) + except errors.DatabaseDoesNotExist: + if not create: + raise + if backend_cls is None: + # default is SQLitePartialExpandDatabase + backend_cls = SQLitePartialExpandDatabase + return backend_cls(sqlite_file, document_factory=document_factory) + + @staticmethod + def delete_database(sqlite_file): + try: + os.unlink(sqlite_file) + except OSError as ex: + if ex.errno == errno.ENOENT: + raise errors.DatabaseDoesNotExist() + raise + + @staticmethod + def register_implementation(klass): + """Register that we implement an SQLiteDatabase. + + The attribute _index_storage_value will be used as the lookup key. + """ + SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass + + def _get_sqlite_handle(self): + """Get access to the underlying sqlite database. + + This should only be used by the test suite, etc, for examining the + state of the underlying database. + """ + return self._db_handle + + def _close_sqlite_handle(self): + """Release access to the underlying sqlite database.""" + self._db_handle.close() + + def close(self): + self._close_sqlite_handle() + + def _is_initialized(self, c): + """Check if this database has been initialized.""" + c.execute("PRAGMA case_sensitive_like=ON") + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'sql_schema'") + except dbapi2.OperationalError: + # The table does not exist yet + val = None + else: + val = c.fetchone() + if val is not None: + return True + return False + + def _initialize(self, c): + """Create the schema in the database.""" + #read the script with sql commands + # TODO: Change how we set up the dependency. Most likely use something + # like lp:dirspec to grab the file from a common resource + # directory. Doesn't specifically need to be handled until we get + # to the point of packaging this. + schema_content = pkg_resources.resource_string( + __name__, 'dbschema.sql') + # Note: We'd like to use c.executescript() here, but it seems that + # executescript always commits, even if you set + # isolation_level = None, so if we want to properly handle + # exclusive locking and rollbacks between processes, we need + # to execute it line-by-line + for line in schema_content.split(';'): + if not line: + continue + c.execute(line) + #add extra fields + self._extra_schema_init(c) + # A unique identifier should be set for this replica. Implementations + # don't have to strictly use uuid here, but we do want the uid to be + # unique amongst all databases that will sync with each other. + # We might extend this to using something with hostname for easier + # debugging. + self._set_replica_uid_in_transaction(uuid.uuid4().hex) + c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", + (self._index_storage_value,)) + + def _ensure_schema(self): + """Ensure that the database schema has been created.""" + old_isolation_level = self._db_handle.isolation_level + c = self._db_handle.cursor() + if self._is_initialized(c): + return + try: + # autocommit/own mgmt of transactions + self._db_handle.isolation_level = None + with self._db_handle: + # only one execution path should initialize the db + c.execute("begin exclusive") + if self._is_initialized(c): + return + self._initialize(c) + finally: + self._db_handle.isolation_level = old_isolation_level + + def _extra_schema_init(self, c): + """Add any extra fields, etc to the basic table definitions.""" + + def _parse_index_definition(self, index_field): + """Parse a field definition for an index, returning a Getter.""" + # Note: We may want to keep a Parser object around, and cache the + # Getter objects for a greater length of time. Specifically, if + # you create a bunch of indexes, and then insert 50k docs, you'll + # re-parse the indexes between puts. The time to insert the docs + # is still likely to dominate put_doc time, though. + parser = query_parser.Parser() + getter = parser.parse(index_field) + return getter + + def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): + """Update document_fields for a single document. + + :param doc_id: Identifier for this document + :param raw_doc: The python dict representation of the document. + :param getters: A list of [(field_name, Getter)]. Getter.get will be + called to evaluate the index definition for this document, and the + results will be inserted into the db. + :param db_cursor: An sqlite Cursor. + :return: None + """ + values = [] + for field_name, getter in getters: + for idx_value in getter.get(raw_doc): + values.append((doc_id, field_name, idx_value)) + if values: + db_cursor.executemany( + "INSERT INTO document_fields VALUES (?, ?, ?)", values) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + with self._db_handle: + self._set_replica_uid_in_transaction(replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO u1db_config" + " VALUES ('replica_uid', ?)", + (replica_uid,)) + self._real_replica_uid = replica_uid + + def _get_replica_uid(self): + if self._real_replica_uid is not None: + return self._real_replica_uid + c = self._db_handle.cursor() + c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") + val = c.fetchone() + if val is None: + return None + self._real_replica_uid = val[0] + return self._real_replica_uid + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + c = self._db_handle.cursor() + c.execute('SELECT max(generation) FROM transaction_log') + val = c.fetchone()[0] + if val is None: + return 0 + return val + + def _get_generation_info(self): + c = self._db_handle.cursor() + c.execute( + 'SELECT max(generation), transaction_id FROM transaction_log ') + val = c.fetchone() + if val[0] is None: + return(0, '') + return val + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + c = self._db_handle.cursor() + c.execute( + 'SELECT transaction_id FROM transaction_log WHERE generation = ?', + (generation,)) + val = c.fetchone() + if val is None: + raise errors.InvalidGeneration + return val[0] + + def _get_transaction_log(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, transaction_id FROM transaction_log" + " ORDER BY generation") + return c.fetchall() + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + c = self._db_handle.cursor() + if check_for_conflicts: + c.execute( + "SELECT document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " + "conflicts ON conflicts.doc_id = document.doc_id WHERE " + "document.doc_id = ? GROUP BY document.doc_id, " + "document.doc_rev, document.content;", (doc_id,)) + else: + c.execute( + "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", + (doc_id,)) + val = c.fetchone() + if val is None: + return None + doc_rev, content, conflicts = val + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + return doc + + def _has_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", + (doc_id,)) + val = c.fetchone() + if val is None: + return False + else: + return True + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + c = self._db_handle.cursor() + c.execute( + "SELECT document.doc_id, document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " + "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " + "document.doc_rev, document.content;") + rows = c.fetchall() + for doc_id, doc_rev, content, conflicts in rows: + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): + """Convert a dict representation into named fields. + + So something like: {'key1': 'val1', 'key2': 'val2'} + gets converted into: [(doc_id, 'key1', 'val1', 0) + (doc_id, 'key2', 'val2', 0)] + :param doc_id: Just added to every record. + :param base_field: if set, these are nested keys, so each field should + be appropriately prefixed. + :param raw_doc: The python dictionary. + """ + # TODO: Handle lists + values = [] + for field_name, value in raw_doc.iteritems(): + if value is None and not save_none: + continue + if base_field: + full_name = base_field + '.' + field_name + else: + full_name = field_name + if value is None or isinstance(value, (int, float, basestring)): + values.append((doc_id, full_name, value, len(values))) + else: + subvalues = self._expand_to_fields(doc_id, full_name, value, + save_none) + for _, subfield_name, val, _ in subvalues: + values.append((doc_id, subfield_name, val, len(values))) + return values + + def _put_and_update_indexes(self, old_doc, doc): + """Actually insert a document into the database. + + This both updates the existing documents content, and any indexes that + refer to this document. + """ + raise NotImplementedError(self._put_and_update_indexes) + + def whats_changed(self, old_generation=0): + c = self._db_handle.cursor() + c.execute("SELECT generation, doc_id, transaction_id" + " FROM transaction_log" + " WHERE generation > ? ORDER BY generation DESC", + (old_generation,)) + results = c.fetchall() + cur_gen = old_generation + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + c.execute("SELECT generation, transaction_id" + " FROM transaction_log ORDER BY generation DESC LIMIT 1") + results = c.fetchone() + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, newest_trans_id = results + + return cur_gen, newest_trans_id, changes + + def delete_doc(self, doc): + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _get_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", + (doc_id,)) + return [self._factory(doc_id, doc_rev, content) + for doc_rev, content in c.fetchall()] + + def get_doc_conflicts(self, doc_id): + with self._db_handle: + conflict_docs = self._get_conflicts(doc_id) + if not conflict_docs: + return [] + this_doc = self._get_doc(doc_id) + this_doc.has_conflicts = True + return [this_doc] + conflict_docs + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + c = self._db_handle.cursor() + c.execute("SELECT known_generation, known_transaction_id FROM sync_log" + " WHERE replica_uid = ?", + (other_replica_uid,)) + val = c.fetchone() + if val is None: + other_gen = 0 + trans_id = '' + else: + other_gen = val[0] + trans_id = val[1] + return other_gen, trans_id + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + with self._db_handle: + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", + (other_replica_uid, other_generation, + other_transaction_id)) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + with self._db_handle: + return super(SQLiteDatabase, self)._put_doc_if_newer(doc, + save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + + def _add_conflict(self, c, doc_id, my_doc_rev, my_content): + c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", + (doc_id, my_doc_rev, my_content)) + + def _delete_conflicts(self, c, doc, conflict_revs): + deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] + c.executemany("DELETE FROM conflicts" + " WHERE doc_id=? AND doc_rev=?", deleting) + doc.has_conflicts = self._has_conflicts(doc.doc_id) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + c_revs_to_prune = [] + for c_doc in self._get_conflicts(doc.doc_id): + c_vcr = vectorclock.VectorClockRev(c_doc.rev) + if doc_vcr.is_newer(c_vcr): + c_revs_to_prune.append(c_doc.rev) + elif doc.same_content_as(c_doc): + c_revs_to_prune.append(c_doc.rev) + doc_vcr.maximize(c_vcr) + autoresolved = True + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + c = self._db_handle.cursor() + self._delete_conflicts(c, doc, c_revs_to_prune) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + c = self._db_handle.cursor() + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + def resolve_doc(self, doc, conflicted_doc_revs): + with self._db_handle: + cur_doc = self._get_doc(doc.doc_id) + # TODO: https://bugs.launchpad.net/u1db/+bug/928274 + # I think we have a logic bug in resolve_doc + # Specifically, cur_doc.rev is always in the final vector + # clock of revisions that we supersede, even if it wasn't in + # conflicted_doc_revs. We still add it as a conflict, but the + # fact that _put_doc_if_newer propagates resolutions means I + # think that conflict could accidentally be resolved. We need + # to add a test for this case first. (create a rev, create a + # conflict, create another conflict, resolve the first rev + # and first conflict, then make sure that the resolved + # rev doesn't supersede the second conflict rev.) It *might* + # not matter, because the superseding rev is in as a + # conflict, but it does seem incorrect + new_rev = self._ensure_maximal_rev(cur_doc.rev, + conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + c = self._db_handle.cursor() + doc.rev = new_rev + if cur_doc.rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) + # TODO: Is there some way that we could construct a rev that would + # end up in superseded_revs, such that we add a conflict, and + # then immediately delete it? + self._delete_conflicts(c, doc, superseded_revs) + + def list_indexes(self): + """Return the list of indexes and their definitions.""" + c = self._db_handle.cursor() + # TODO: How do we test the ordering? + c.execute("SELECT name, field FROM index_definitions" + " ORDER BY name, offset") + definitions = [] + cur_name = None + for name, field in c.fetchall(): + if cur_name != name: + definitions.append((name, [])) + cur_name = name + definitions[-1][-1].append(field) + return definitions + + def _get_index_definition(self, index_name): + """Return the stored definition for a given index_name.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions" + " WHERE name = ? ORDER BY offset", (index_name,)) + fields = [x[0] for x in c.fetchall()] + if not fields: + raise errors.IndexDoesNotExist + return fields + + @staticmethod + def _strip_glob(value): + """Remove the trailing * from a value.""" + assert value[-1] == '*' + return value[:-1] + + def _format_query(self, definition, key_values): + # First, build the definition. We join the document_fields table + # against itself, as many times as the 'width' of our definition. + # We then do a query for each key_value, one-at-a-time. + # Note: All of these strings are static, we could cache them, etc. + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + wildcard_where = [novalue_where[i] + + (" AND d%d.value NOT NULL" % (i,)) + for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + like_where = [novalue_where[i] + + (" AND d%d.value GLOB ?" % (i,)) + for i in range(len(definition))] + is_wildcard = False + # Merge the lists together, so that: + # [field1, field2, field3], [val1, val2, val3] + # Becomes: + # (field1, val1, field2, val2, field3, val3) + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(exact_where[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_from_index(self, index_name, *key_values): + definition = self._get_index_definition(index_name) + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + statement, args = self._format_query(definition, key_values) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def _format_range_query(self, definition, start_value, end_value): + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + wildcard_where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + like_where = [ + novalue_where[i] + ( + " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in + range(len(definition))] + range_where_lower = [ + novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in + range(len(definition))] + range_where_upper = [ + novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in + range(len(definition))] + args = [] + where = [] + if start_value: + if isinstance(start_value, basestring): + start_value = (start_value,) + if len(start_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, start_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(self._strip_glob(value)) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(value) + if end_value: + if isinstance(end_value, basestring): + end_value = (end_value,) + if len(end_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, end_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(self._strip_glob(value)) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_upper[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + definition = self._get_index_definition(index_name) + statement, args = self._format_range_query( + definition, start_value, end_value) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def get_index_keys(self, index_name): + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + value_fields = ', '.join([ + 'd%d.value' % i for i in range(len(definition))]) + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + statement = ( + "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( + value_fields, ', '.join(tables), ' AND '.join(where), + value_fields)) + try: + c.execute(statement, tuple(definition)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) + return c.fetchall() + + def delete_index(self, index_name): + with self._db_handle: + c = self._db_handle.cursor() + c.execute("DELETE FROM index_definitions WHERE name = ?", + (index_name,)) + c.execute( + "DELETE FROM document_fields WHERE document_fields.field_name " + " NOT IN (SELECT field from index_definitions)") + + +class SQLiteSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SQLitePartialExpandDatabase(SQLiteDatabase): + """An SQLite Backend that expands documents into a document_field table. + + It stores the original document text in document.doc. For fields that are + indexed, the data goes into document_fields. + """ + + _index_storage_value = 'expand referenced' + + def _get_indexed_fields(self): + """Determine what fields are indexed.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions") + return set([x[0] for x in c.fetchall()]) + + def _evaluate_index(self, raw_doc, field): + parser = query_parser.Parser() + getter = parser.parse(field) + return getter.get(raw_doc) + + def _put_and_update_indexes(self, old_doc, doc): + c = self._db_handle.cursor() + if doc and not doc.is_tombstone(): + raw_doc = json.loads(doc.get_json()) + else: + raw_doc = {} + if old_doc is not None: + c.execute("UPDATE document SET doc_rev=?, content=?" + " WHERE doc_id = ?", + (doc.rev, doc.get_json(), doc.doc_id)) + c.execute("DELETE FROM document_fields WHERE doc_id = ?", + (doc.doc_id,)) + else: + c.execute("INSERT INTO document (doc_id, doc_rev, content)" + " VALUES (?, ?, ?)", + (doc.doc_id, doc.rev, doc.get_json())) + indexed_fields = self._get_indexed_fields() + if indexed_fields: + # It is expected that len(indexed_fields) is shorter than + # len(raw_doc) + getters = [(field, self._parse_index_definition(field)) + for field in indexed_fields] + self._update_indexes(doc.doc_id, raw_doc, getters, c) + trans_id = self._allocate_transaction_id() + c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" + " VALUES (?, ?)", (doc.doc_id, trans_id)) + + def create_index(self, index_name, *index_expressions): + with self._db_handle: + c = self._db_handle.cursor() + cur_fields = self._get_indexed_fields() + definition = [(index_name, idx, field) + for idx, field in enumerate(index_expressions)] + try: + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + definition) + except dbapi2.IntegrityError as e: + stored_def = self._get_index_definition(index_name) + if stored_def == [x[-1] for x in definition]: + return + raise errors.IndexNameTakenError, e, sys.exc_info()[2] + new_fields = set( + [f for f in index_expressions if f not in cur_fields]) + if new_fields: + self._update_all_indexes(new_fields) + + def _iter_all_docs(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, content FROM document") + while True: + next_rows = c.fetchmany() + if not next_rows: + break + for row in next_rows: + yield row + + def _update_all_indexes(self, new_fields): + """Iterate all the documents, and add content to document_fields. + + :param new_fields: The index definitions that need to be added. + """ + getters = [(field, self._parse_index_definition(field)) + for field in new_fields] + c = self._db_handle.cursor() + for doc_id, doc in self._iter_all_docs(): + if doc is None: + continue + raw_doc = json.loads(doc) + self._update_indexes(doc_id, raw_doc, getters, c) + +SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/u1db/commandline/__init__.py b/src/leap/soledad/u1db/commandline/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . diff --git a/src/leap/soledad/u1db/commandline/client.py b/src/leap/soledad/u1db/commandline/client.py new file mode 100644 index 00000000..15bf8561 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/client.py @@ -0,0 +1,497 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Commandline bindings for the u1db-client program.""" + +import argparse +import os +try: + import simplejson as json +except ImportError: + import json # noqa +import sys + +from u1db import ( + Document, + open as u1db_open, + sync, + errors, + ) +from u1db.commandline import command +from u1db.remote import ( + http_database, + http_target, + ) + + +client_commands = command.CommandGroup() + + +def set_oauth_credentials(client): + keys = os.environ.get('OAUTH_CREDENTIALS', None) + if keys is not None: + consumer_key, consumer_secret, \ + token_key, token_secret = keys.split(":") + client.set_oauth_credentials(consumer_key, consumer_secret, + token_key, token_secret) + + +class OneDbCmd(command.Command): + """Base class for commands operating on one local or remote database.""" + + def _open(self, database, create): + if database.startswith(('http://', 'https://')): + db = http_database.HTTPDatabase(database) + set_oauth_credentials(db) + db.open(create) + return db + else: + return u1db_open(database, create) + + +class CmdCreate(OneDbCmd): + """Create a new document from scratch""" + + name = 'create' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url') + parser.add_argument('infile', nargs='?', default=None, + help='The file to read content from.') + parser.add_argument('--id', dest='doc_id', default=None, + help='Set the document identifier') + + def run(self, database, infile, doc_id): + if infile is None: + infile = self.stdin + db = self._open(database, create=False) + doc = db.create_doc_from_json(infile.read(), doc_id=doc_id) + self.stderr.write('id: %s\nrev: %s\n' % (doc.doc_id, doc.rev)) + +client_commands.register(CmdCreate) + + +class CmdDelete(OneDbCmd): + """Delete a document from the database""" + + name = 'delete' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url') + parser.add_argument('doc_id', help='The document id to retrieve') + parser.add_argument('doc_rev', + help='The revision of the document (which is being superseded.)') + + def run(self, database, doc_id, doc_rev): + db = self._open(database, create=False) + doc = Document(doc_id, doc_rev, None) + db.delete_doc(doc) + self.stderr.write('rev: %s\n' % (doc.rev,)) + +client_commands.register(CmdDelete) + + +class CmdGet(OneDbCmd): + """Extract a document from the database""" + + name = 'get' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to query', + metavar='database-path-or-url') + parser.add_argument('doc_id', help='The document id to retrieve.') + parser.add_argument('outfile', nargs='?', default=None, + help='The file to write the document to', + type=argparse.FileType('wb')) + + def run(self, database, doc_id, outfile): + if outfile is None: + outfile = self.stdout + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + doc = db.get_doc(doc_id) + if doc is None: + self.stderr.write('Document not found (id: %s)\n' % (doc_id,)) + return 1 # failed + if doc.is_tombstone(): + outfile.write('[document deleted]\n') + else: + outfile.write(doc.get_json() + '\n') + self.stderr.write('rev: %s\n' % (doc.rev,)) + if doc.has_conflicts: + self.stderr.write("Document has conflicts.\n") + +client_commands.register(CmdGet) + + +class CmdGetDocConflicts(OneDbCmd): + """Get the conflicts from a document""" + + name = 'get-doc-conflicts' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local database to query', + metavar='database-path') + parser.add_argument('doc_id', help='The document id to retrieve.') + + def run(self, database, doc_id): + try: + db = self._open(database, False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + conflicts = db.get_doc_conflicts(doc_id) + if not conflicts: + if db.get_doc(doc_id) is None: + self.stderr.write("Document does not exist.\n") + return 1 + self.stdout.write("[") + for i, doc in enumerate(conflicts): + if i: + self.stdout.write(",") + self.stdout.write( + json.dumps(dict(rev=doc.rev, content=doc.content), indent=4)) + self.stdout.write("]\n") + +client_commands.register(CmdGetDocConflicts) + + +class CmdInitDB(OneDbCmd): + """Create a new database""" + + name = 'init-db' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to create', + metavar='database-path-or-url') + parser.add_argument('--replica-uid', default=None, + help='The unique identifier for this database (not for remote)') + + def run(self, database, replica_uid): + db = self._open(database, create=True) + if replica_uid is not None: + db._set_replica_uid(replica_uid) + +client_commands.register(CmdInitDB) + + +class CmdPut(OneDbCmd): + """Add a document to the database""" + + name = 'put' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url'), + parser.add_argument('doc_id', help='The document id to retrieve') + parser.add_argument('doc_rev', + help='The revision of the document (which is being superseded.)') + parser.add_argument('infile', nargs='?', default=None, + help='The filename of the document that will be used for content', + type=argparse.FileType('rb')) + + def run(self, database, doc_id, doc_rev, infile): + if infile is None: + infile = self.stdin + try: + db = self._open(database, create=False) + doc = Document(doc_id, doc_rev, infile.read()) + doc_rev = db.put_doc(doc) + self.stderr.write('rev: %s\n' % (doc_rev,)) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + except errors.RevisionConflict: + if db.get_doc(doc_id) is None: + self.stderr.write("Document does not exist.\n") + else: + self.stderr.write("Given revision is not current.\n") + except errors.ConflictedDoc: + self.stderr.write( + "Document has conflicts.\n" + "Inspect with get-doc-conflicts, then resolve.\n") + else: + return + return 1 + +client_commands.register(CmdPut) + + +class CmdResolve(OneDbCmd): + """Resolve a conflicted document""" + + name = 'resolve-doc' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url'), + parser.add_argument('doc_id', help='The conflicted document id') + parser.add_argument('doc_revs', metavar="doc-rev", nargs="+", + help='The revisions that the new content supersedes') + parser.add_argument('--infile', nargs='?', default=None, + help='The filename of the document that will be used for content', + type=argparse.FileType('rb')) + + def run(self, database, doc_id, doc_revs, infile): + if infile is None: + infile = self.stdin + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + doc = db.get_doc(doc_id) + if doc is None: + self.stderr.write("Document does not exist.\n") + return 1 + doc.set_json(infile.read()) + db.resolve_doc(doc, doc_revs) + self.stderr.write("rev: %s\n" % db.get_doc(doc_id).rev) + if doc.has_conflicts: + self.stderr.write("Document still has conflicts.\n") + +client_commands.register(CmdResolve) + + +class CmdSync(command.Command): + """Synchronize two databases""" + + name = 'sync' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('source', help='database to sync from') + parser.add_argument('target', help='database to sync to') + + def _open_target(self, target): + if target.startswith(('http://', 'https://')): + st = http_target.HTTPSyncTarget.connect(target) + set_oauth_credentials(st) + else: + db = u1db_open(target, create=True) + st = db.get_sync_target() + return st + + def run(self, source, target): + """Start a Sync request.""" + source_db = u1db_open(source, create=False) + st = self._open_target(target) + syncer = sync.Synchronizer(source_db, st) + syncer.sync() + source_db.close() + +client_commands.register(CmdSync) + + +class CmdCreateIndex(OneDbCmd): + """Create an index""" + + name = "create-index" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to update', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + parser.add_argument('expression', help='an index expression', + nargs='+') + + def run(self, database, index, expression): + try: + db = self._open(database, create=False) + db.create_index(index, *expression) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + except errors.IndexNameTakenError: + self.stderr.write("There is already a different index named %r.\n" + % (index,)) + return 1 + except errors.IndexDefinitionParseError: + self.stderr.write("Bad index expression.\n") + return 1 + +client_commands.register(CmdCreateIndex) + + +class CmdListIndexes(OneDbCmd): + """List existing indexes""" + + name = "list-indexes" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to query', + metavar='database-path') + + def run(self, database): + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + for (index, expression) in db.list_indexes(): + self.stdout.write("%s: %s\n" % (index, ", ".join(expression))) + +client_commands.register(CmdListIndexes) + + +class CmdDeleteIndex(OneDbCmd): + """Delete an index""" + + name = "delete-index" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to update', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + + def run(self, database, index): + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + db.delete_index(index) + +client_commands.register(CmdDeleteIndex) + + +class CmdGetIndexKeys(OneDbCmd): + """Get the index's keys""" + + name = "get-index-keys" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to query', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + + def run(self, database, index): + try: + db = self._open(database, create=False) + for key in db.get_index_keys(index): + self.stdout.write("%s\n" % (", ".join( + [i.encode('utf-8') for i in key],))) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + except errors.IndexDoesNotExist: + self.stderr.write("Index does not exist.\n") + else: + return + return 1 + +client_commands.register(CmdGetIndexKeys) + + +class CmdGetFromIndex(OneDbCmd): + """Find documents by searching an index""" + + name = "get-from-index" + argv = None + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to query', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + parser.add_argument('values', metavar="value", + help='the value to look up (one per index column)', + nargs="+") + + def run(self, database, index, values): + try: + db = self._open(database, create=False) + docs = db.get_from_index(index, *values) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + except errors.IndexDoesNotExist: + self.stderr.write("Index does not exist.\n") + except errors.InvalidValueForIndex: + index_def = db._get_index_definition(index) + len_diff = len(index_def) - len(values) + if len_diff == 0: + # can't happen (HAH) + raise + argv = self.argv if self.argv is not None else sys.argv + self.stderr.write( + "Invalid query: " + "index %r requires %d query expression%s%s.\n" + "For example, the following would be valid:\n" + " %s %s %r %r %s\n" + % (index, + len(index_def), + "s" if len(index_def) > 1 else "", + ", not %d" % len(values) if len(values) else "", + argv[0], argv[1], database, index, + " ".join(map(repr, + values[:len(index_def)] + + ["*" for i in range(len_diff)])), + )) + except errors.InvalidGlobbing: + argv = self.argv if self.argv is not None else sys.argv + fixed = [] + for (i, v) in enumerate(values): + fixed.append(v) + if v.endswith('*'): + break + # values has at least one element, so i is defined + fixed.extend('*' * (len(values) - i - 1)) + self.stderr.write( + "Invalid query: a star can only be followed by stars.\n" + "For example, the following would be valid:\n" + " %s %s %r %r %s\n" + % (argv[0], argv[1], database, index, + " ".join(map(repr, fixed)))) + + else: + self.stdout.write("[") + for i, doc in enumerate(docs): + if i: + self.stdout.write(",") + self.stdout.write( + json.dumps( + dict(id=doc.doc_id, rev=doc.rev, content=doc.content), + indent=4)) + self.stdout.write("]\n") + return + return 1 + +client_commands.register(CmdGetFromIndex) + + +def main(args): + return client_commands.run_argv(args, sys.stdin, sys.stdout, sys.stderr) diff --git a/src/leap/soledad/u1db/commandline/command.py b/src/leap/soledad/u1db/commandline/command.py new file mode 100644 index 00000000..eace0560 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/command.py @@ -0,0 +1,80 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Command infrastructure for u1db""" + +import argparse +import inspect + + +class CommandGroup(object): + """A collection of commands.""" + + def __init__(self, description=None): + self.commands = {} + self.description = description + + def register(self, cmd): + """Register a new command to be incorporated with this group.""" + self.commands[cmd.name] = cmd + + def make_argparser(self): + """Create an argparse.ArgumentParser""" + parser = argparse.ArgumentParser(description=self.description) + subs = parser.add_subparsers(title='commands') + for name, cmd in sorted(self.commands.iteritems()): + sub = subs.add_parser(name, help=cmd.__doc__) + sub.set_defaults(subcommand=cmd) + cmd._populate_subparser(sub) + return parser + + def run_argv(self, argv, stdin, stdout, stderr): + """Run a command, from a sys.argv[1:] style input.""" + parser = self.make_argparser() + args = parser.parse_args(argv) + cmd = args.subcommand(stdin, stdout, stderr) + params, _, _, _ = inspect.getargspec(cmd.run) + vals = [] + for param in params[1:]: + vals.append(getattr(args, param)) + return cmd.run(*vals) + + +class Command(object): + """Definition of a Command that can be run. + + :cvar name: The name of the command, so that you can run + 'u1db-client '. + """ + + name = None + + def __init__(self, stdin, stdout, stderr): + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + + @classmethod + def _populate_subparser(cls, parser): + """Child classes should override this to provide their arguments.""" + raise NotImplementedError(cls._populate_subparser) + + def run(self, *args): + """This is where the magic happens. + + Subclasses should implement this, requesting their specific arguments. + """ + raise NotImplementedError(self.run) diff --git a/src/leap/soledad/u1db/commandline/serve.py b/src/leap/soledad/u1db/commandline/serve.py new file mode 100644 index 00000000..0bb0e641 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/serve.py @@ -0,0 +1,34 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Build server for u1db-serve.""" + +from paste import httpserver + +from u1db.remote import ( + http_app, + server_state, + ) + + +def make_server(host, port, working_dir): + """Make a server on host and port exposing dbs living in working_dir.""" + state = server_state.ServerState() + state.set_workingdir(working_dir) + application = http_app.HTTPApp(state) + server = httpserver.WSGIServer(application, (host, port), + httpserver.WSGIHandler) + return server diff --git a/src/leap/soledad/u1db/errors.py b/src/leap/soledad/u1db/errors.py new file mode 100644 index 00000000..967c7c38 --- /dev/null +++ b/src/leap/soledad/u1db/errors.py @@ -0,0 +1,189 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""A list of errors that u1db can raise.""" + + +class U1DBError(Exception): + """Generic base class for U1DB errors.""" + + # description/tag for identifying the error during transmission (http,...) + wire_description = "error" + + def __init__(self, message=None): + self.message = message + + +class RevisionConflict(U1DBError): + """The document revisions supplied does not match the current version.""" + + wire_description = "revision conflict" + + +class InvalidJSON(U1DBError): + """Content was not valid json.""" + + +class InvalidContent(U1DBError): + """Content was not a python dictionary.""" + + +class InvalidDocId(U1DBError): + """A document was requested with an invalid document identifier.""" + + wire_description = "invalid document id" + + +class MissingDocIds(U1DBError): + """Needs document ids.""" + + wire_description = "missing document ids" + + +class DocumentTooBig(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "document too big" + + +class UserQuotaExceeded(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "user quota exceeded" + + +class SubscriptionNeeded(U1DBError): + """User needs a subscription to be able to use this replica..""" + + wire_description = "user needs subscription" + + +class InvalidTransactionId(U1DBError): + """Invalid transaction for generation.""" + + wire_description = "invalid transaction id" + + +class InvalidGeneration(U1DBError): + """Generation was previously synced with a different transaction id.""" + + wire_description = "invalid generation" + + +class ConflictedDoc(U1DBError): + """The document is conflicted, you must call resolve before put()""" + + +class InvalidValueForIndex(U1DBError): + """The values supplied does not match the index definition.""" + + +class InvalidGlobbing(U1DBError): + """Raised if wildcard matches are not strictly at the tail of the request. + """ + + +class DocumentDoesNotExist(U1DBError): + """The document does not exist.""" + + wire_description = "document does not exist" + + +class DocumentAlreadyDeleted(U1DBError): + """The document was already deleted.""" + + wire_description = "document already deleted" + + +class DatabaseDoesNotExist(U1DBError): + """The database does not exist.""" + + wire_description = "database does not exist" + + +class IndexNameTakenError(U1DBError): + """The given index name is already taken.""" + + +class IndexDefinitionParseError(U1DBError): + """The index definition cannot be parsed.""" + + +class IndexDoesNotExist(U1DBError): + """No index of that name exists.""" + + +class Unauthorized(U1DBError): + """Request wasn't authorized properly.""" + + wire_description = "unauthorized" + + +class HTTPError(U1DBError): + """Unspecific HTTP errror.""" + + wire_description = None + + def __init__(self, status, message=None, headers={}): + self.status = status + self.message = message + self.headers = headers + + def __str__(self): + if not self.message: + return "HTTPError(%d)" % self.status + else: + return "HTTPError(%d, %r)" % (self.status, self.message) + + +class Unavailable(HTTPError): + """Server not available not serve request.""" + + wire_description = "unavailable" + + def __init__(self, message=None, headers={}): + super(Unavailable, self).__init__(503, message, headers) + + def __str__(self): + if not self.message: + return "Unavailable()" + else: + return "Unavailable(%r)" % self.message + + +class BrokenSyncStream(U1DBError): + """Unterminated or otherwise broken sync exchange stream.""" + + wire_description = None + + +class UnknownAuthMethod(U1DBError): + """Unknown auhorization method.""" + + wire_description = None + + +# mapping wire (transimission) descriptions/tags for errors to the exceptions +wire_description_to_exc = dict( + (x.wire_description, x) for x in globals().values() + if getattr(x, 'wire_description', None) not in (None, "error") +) +wire_description_to_exc["error"] = U1DBError + + +# +# wire error descriptions not corresponding to an exception +DOCUMENT_DELETED = "document deleted" diff --git a/src/leap/soledad/u1db/query_parser.py b/src/leap/soledad/u1db/query_parser.py new file mode 100644 index 00000000..f564821f --- /dev/null +++ b/src/leap/soledad/u1db/query_parser.py @@ -0,0 +1,370 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Code for parsing Index definitions.""" + +import re +from u1db import ( + errors, + ) + + +class Getter(object): + """Get values from a document based on a specification.""" + + def get(self, raw_doc): + """Get a value from the document. + + :param raw_doc: a python dictionary to get the value from. + :return: A list of values that match the description. + """ + raise NotImplementedError(self.get) + + +class StaticGetter(Getter): + """A getter that returns a defined value (independent of the doc).""" + + def __init__(self, value): + """Create a StaticGetter. + + :param value: the value to return when get is called. + """ + if value is None: + self.value = [] + elif isinstance(value, list): + self.value = value + else: + self.value = [value] + + def get(self, raw_doc): + return self.value + + +def extract_field(raw_doc, subfields, index=0): + if not isinstance(raw_doc, dict): + return [] + val = raw_doc.get(subfields[index]) + if val is None: + return [] + if index < len(subfields) - 1: + if isinstance(val, list): + results = [] + for item in val: + results.extend(extract_field(item, subfields, index + 1)) + return results + if isinstance(val, dict): + return extract_field(val, subfields, index + 1) + return [] + if isinstance(val, dict): + return [] + if isinstance(val, list): + # Strip anything in the list that isn't a simple type + return [v for v in val if not isinstance(v, (dict, list))] + return [val] + + +class ExtractField(Getter): + """Extract a field from the document.""" + + def __init__(self, field): + """Create an ExtractField object. + + When a document is passed to get() this will return a value + from the document based on the field specifier passed to + the constructor. + + None will be returned if the field is nonexistant, or refers to an + object, rather than a simple type or list of simple types. + + :param field: a specifier for the field to return. + This is either a field name, or a dotted field name. + """ + self.field = field.split('.') + + def get(self, raw_doc): + return extract_field(raw_doc, self.field) + + +class Transformation(Getter): + """A transformation on a value from another Getter.""" + + name = None + arity = 1 + args = ['expression'] + + def __init__(self, inner): + """Create a transformation. + + :param inner: the argument(s) to the transformation. + """ + self.inner = inner + + def get(self, raw_doc): + inner_values = self.inner.get(raw_doc) + assert isinstance(inner_values, list),\ + 'get() should always return a list' + return self.transform(inner_values) + + def transform(self, values): + """Transform the values. + + This should be implemented by subclasses to transform the + value when get() is called. + + :param values: the values from the other Getter + :return: the transformed values. + """ + raise NotImplementedError(self.transform) + + +class Lower(Transformation): + """Lowercase a string. + + This transformation will return None for non-string inputs. However, + it will lowercase any strings in a list, dropping any elements + that are not strings. + """ + + name = "lower" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + return [val.lower() for val in values if self._can_transform(val)] + + +class Number(Transformation): + """Convert an integer to a zero padded string. + + This transformation will return None for non-integer inputs. However, it + will transform any integers in a list, dropping any elements that are not + integers. + """ + + name = 'number' + arity = 2 + args = ['expression', int] + + def __init__(self, inner, number): + super(Number, self).__init__(inner) + self.padding = "%%0%sd" % number + + def _can_transform(self, val): + return isinstance(val, int) and not isinstance(val, bool) + + def transform(self, values): + """Transform any integers in values into zero padded strings.""" + if not values: + return [] + return [self.padding % (v,) for v in values if self._can_transform(v)] + + +class Bool(Transformation): + """Convert bool to string.""" + + name = "bool" + args = ['expression'] + + def _can_transform(self, val): + return isinstance(val, bool) + + def transform(self, values): + """Transform any booleans in values into strings.""" + if not values: + return [] + return [('1' if v else '0') for v in values if self._can_transform(v)] + + +class SplitWords(Transformation): + """Split a string on whitespace. + + This Getter will return [] for non-string inputs. It will however + split any strings in an input list, discarding any elements that + are not strings. + """ + + name = "split_words" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + result = set() + for value in values: + if self._can_transform(value): + for word in value.split(): + result.add(word) + return list(result) + + +class Combine(Transformation): + """Combine multiple expressions into a single index.""" + + name = "combine" + # variable number of args + arity = -1 + + def __init__(self, *inner): + super(Combine, self).__init__(inner) + + def get(self, raw_doc): + inner_values = [] + for inner in self.inner: + inner_values.extend(inner.get(raw_doc)) + return self.transform(inner_values) + + def transform(self, values): + return values + + +class IsNull(Transformation): + """Indicate whether the input is None. + + This Getter returns a bool indicating whether the input is nil. + """ + + name = "is_null" + + def transform(self, values): + return [len(values) == 0] + + +def check_fieldname(fieldname): + if fieldname.endswith('.'): + raise errors.IndexDefinitionParseError( + "Fieldname cannot end in '.':%s^" % (fieldname,)) + + +class Parser(object): + """Parse an index expression into a sequence of transformations.""" + + _transformations = {} + _delimiters = re.compile("\(|\)|,") + + def __init__(self): + self._tokens = [] + + def _set_expression(self, expression): + self._open_parens = 0 + self._tokens = [] + expression = expression.strip() + while expression: + delimiter = self._delimiters.search(expression) + if delimiter: + idx = delimiter.start() + if idx == 0: + result, expression = (expression[:1], expression[1:]) + self._tokens.append(result) + else: + result, expression = (expression[:idx], expression[idx:]) + result = result.strip() + if result: + self._tokens.append(result) + else: + expression = expression.strip() + if expression: + self._tokens.append(expression) + expression = None + + def _get_token(self): + if self._tokens: + return self._tokens.pop(0) + + def _peek_token(self): + if self._tokens: + return self._tokens[0] + + @staticmethod + def _to_getter(term): + if isinstance(term, Getter): + return term + check_fieldname(term) + return ExtractField(term) + + def _parse_op(self, op_name): + self._get_token() # '(' + op = self._transformations.get(op_name, None) + if op is None: + raise errors.IndexDefinitionParseError( + "Unknown operation: %s" % op_name) + args = [] + while True: + args.append(self._parse_term()) + sep = self._get_token() + if sep == ')': + break + if sep != ',': + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' in parentheses." % (sep,)) + parsed = [] + for i, arg in enumerate(args): + arg_type = op.args[i % len(op.args)] + if arg_type == 'expression': + inner = self._to_getter(arg) + else: + try: + inner = arg_type(arg) + except ValueError, e: + raise errors.IndexDefinitionParseError( + "Invalid value %r for argument type %r " + "(%r)." % (arg, arg_type, e)) + parsed.append(inner) + return op(*parsed) + + def _parse_term(self): + term = self._get_token() + if term is None: + raise errors.IndexDefinitionParseError( + "Unexpected end of index definition.") + if term in (',', ')', '('): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' at start of expression." % (term,)) + next_token = self._peek_token() + if next_token == '(': + return self._parse_op(term) + return term + + def parse(self, expression): + self._set_expression(expression) + term = self._to_getter(self._parse_term()) + if self._peek_token(): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' after end of expression." + % (self._peek_token(),)) + return term + + def parse_all(self, fields): + return [self.parse(field) for field in fields] + + @classmethod + def register_transormation(cls, transform): + assert transform.name not in cls._transformations, ( + "Transform %s already registered for %s" + % (transform.name, cls._transformations[transform.name])) + cls._transformations[transform.name] = transform + + +Parser.register_transormation(SplitWords) +Parser.register_transormation(Lower) +Parser.register_transormation(Number) +Parser.register_transormation(Bool) +Parser.register_transormation(IsNull) +Parser.register_transormation(Combine) diff --git a/src/leap/soledad/u1db/remote/__init__.py b/src/leap/soledad/u1db/remote/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/u1db/remote/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . diff --git a/src/leap/soledad/u1db/remote/basic_auth_middleware.py b/src/leap/soledad/u1db/remote/basic_auth_middleware.py new file mode 100644 index 00000000..a2cbff62 --- /dev/null +++ b/src/leap/soledad/u1db/remote/basic_auth_middleware.py @@ -0,0 +1,68 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . +"""U1DB Basic Auth authorisation WSGI middleware.""" +import httplib +try: + import simplejson as json +except ImportError: + import json # noqa +from wsgiref.util import shift_path_info + + +class Unauthorized(Exception): + """User authorization failed.""" + + +class BasicAuthMiddleware(object): + """U1DB Basic Auth Authorisation WSGI middleware.""" + + def __init__(self, app, prefix): + self.app = app + self.prefix = prefix + + def _error(self, start_response, status, description, message=None): + start_response("%d %s" % (status, httplib.responses[status]), + [('content-type', 'application/json')]) + err = {"error": description} + if message: + err['message'] = message + return [json.dumps(err)] + + def __call__(self, environ, start_response): + if self.prefix and not environ['PATH_INFO'].startswith(self.prefix): + return self._error(start_response, 400, "bad request") + auth = environ.get('HTTP_AUTHORIZATION') + if not auth: + return self._error(start_response, 401, "unauthorized", + "Missing Basic Authentication.") + scheme, encoded = auth.split(None, 1) + if scheme.lower() != 'basic': + return self._error( + start_response, 401, "unauthorized", + "Missing Basic Authentication") + user, password = encoded.decode('base64').split(':', 1) + try: + self.verify_user(environ, user, password) + except Unauthorized: + return self._error( + start_response, 401, "unauthorized", + "Incorrect password or login.") + del environ['HTTP_AUTHORIZATION'] + shift_path_info(environ) + return self.app(environ, start_response) + + def verify_user(self, environ, username, password): + raise NotImplementedError(self.verify_user) diff --git a/src/leap/soledad/u1db/remote/http_app.py b/src/leap/soledad/u1db/remote/http_app.py new file mode 100644 index 00000000..3d7d4248 --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_app.py @@ -0,0 +1,629 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""HTTP Application exposing U1DB.""" + +import functools +import httplib +import inspect +try: + import simplejson as json +except ImportError: + import json # noqa +import sys +import urlparse + +import routes.mapper + +from u1db import ( + __version__ as _u1db_version, + DBNAME_CONSTRAINTS, + Document, + errors, + sync, + ) +from u1db.remote import ( + http_errors, + utils, + ) + + +def parse_bool(expression): + """Parse boolean querystring parameter.""" + if expression == 'true': + return True + return False + + +def parse_list(expression): + if expression is None: + return [] + return [t.strip() for t in expression.split(',')] + + +def none_or_str(expression): + if expression is None: + return None + return str(expression) + + +class BadRequest(Exception): + """Bad request.""" + + +class _FencedReader(object): + """Read and get lines from a file but not past a given length.""" + + MAXCHUNK = 8192 + + def __init__(self, rfile, total, max_entry_size): + self.rfile = rfile + self.remaining = total + self.max_entry_size = max_entry_size + self._kept = None + + def read_chunk(self, atmost): + if self._kept is not None: + # ignore atmost, kept data should be a subchunk anyway + kept, self._kept = self._kept, None + return kept + if self.remaining == 0: + return '' + data = self.rfile.read(min(self.remaining, atmost)) + self.remaining -= len(data) + return data + + def getline(self): + line_parts = [] + size = 0 + while True: + chunk = self.read_chunk(self.MAXCHUNK) + if chunk == '': + break + nl = chunk.find("\n") + if nl != -1: + size += nl + 1 + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk[:nl + 1]) + rest = chunk[nl + 1:] + self._kept = rest or None + break + else: + size += len(chunk) + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk) + return ''.join(line_parts) + + +def http_method(**control): + """Decoration for handling of query arguments and content for a HTTP + method. + + args and content here are the query arguments and body of the incoming + HTTP requests. + + Match query arguments to python method arguments: + w = http_method()(f) + w(self, args, content) => args["content"]=content; + f(self, **args) + + JSON deserialize content to arguments: + w = http_method(content_as_args=True,...)(f) + w(self, args, content) => args.update(json.loads(content)); + f(self, **args) + + Support conversions (e.g int): + w = http_method(Arg=Conv,...)(f) + w(self, args, content) => args["Arg"]=Conv(args["Arg"]); + f(self, **args) + + Enforce no use of query arguments: + w = http_method(no_query=True,...)(f) + w(self, args, content) raises BadRequest if args is not empty + + Argument mismatches, deserialisation failures produce BadRequest. + """ + content_as_args = control.pop('content_as_args', False) + no_query = control.pop('no_query', False) + conversions = control.items() + + def wrap(f): + argspec = inspect.getargspec(f) + assert argspec.args[0] == "self" + nargs = len(argspec.args) + ndefaults = len(argspec.defaults or ()) + required_args = set(argspec.args[1:nargs - ndefaults]) + all_args = set(argspec.args) + + @functools.wraps(f) + def wrapper(self, args, content): + if no_query and args: + raise BadRequest() + if content is not None: + if content_as_args: + try: + args.update(json.loads(content)) + except ValueError: + raise BadRequest() + else: + args["content"] = content + if not (required_args <= set(args) <= all_args): + raise BadRequest("Missing required arguments.") + for name, conv in conversions: + if name not in args: + continue + try: + args[name] = conv(args[name]) + except ValueError: + raise BadRequest() + return f(self, **args) + + return wrapper + + return wrap + + +class URLToResource(object): + """Mappings from URLs to resources.""" + + def __init__(self): + self._map = routes.mapper.Mapper(controller_scan=None) + + def register(self, resource_cls): + # register + self._map.connect(None, resource_cls.url_pattern, + resource_cls=resource_cls, + requirements={"dbname": DBNAME_CONSTRAINTS}) + self._map.create_regs() + return resource_cls + + def match(self, path): + params = self._map.match(path) + if params is None: + return None, None + resource_cls = params.pop('resource_cls') + return resource_cls, params + +url_to_resource = URLToResource() + + +@url_to_resource.register +class GlobalResource(object): + """Global (root) resource.""" + + url_pattern = "/" + + def __init__(self, state, responder): + self.responder = responder + + @http_method() + def get(self): + self.responder.send_response_json(version=_u1db_version) + + +@url_to_resource.register +class DatabaseResource(object): + """Database resource.""" + + url_pattern = "/{dbname}" + + def __init__(self, dbname, state, responder): + self.dbname = dbname + self.state = state + self.responder = responder + + @http_method() + def get(self): + self.state.check_database(self.dbname) + self.responder.send_response_json(200) + + @http_method(content_as_args=True) + def put(self): + self.state.ensure_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + @http_method() + def delete(self): + self.state.delete_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + +@url_to_resource.register +class DocsResource(object): + """Documents resource.""" + + url_pattern = "/{dbname}/docs" + + def __init__(self, dbname, state, responder): + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool, + include_deleted=parse_bool) + def get(self, doc_ids=None, check_for_conflicts=True, + include_deleted=False): + if doc_ids is None: + raise errors.MissingDocIds + docs = self.db.get_docs(doc_ids, include_deleted=include_deleted) + self.responder.content_type = 'application/json' + self.responder.start_response(200) + self.responder.start_stream(), + for doc in docs: + entry = dict( + doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), + has_conflicts=doc.has_conflicts) + self.responder.stream_entry(entry) + self.responder.end_stream() + self.responder.finish_response() + + +@url_to_resource.register +class DocResource(object): + """Document resource.""" + + url_pattern = "/{dbname}/doc/{id:.*}" + + def __init__(self, dbname, id, state, responder): + self.id = id + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(old_rev=str) + def put(self, content, old_rev=None): + doc = Document(self.id, old_rev, content) + doc_rev = self.db.put_doc(doc) + if old_rev is None: + status = 201 # created + else: + status = 200 + self.responder.send_response_json(status, rev=doc_rev) + + @http_method(old_rev=str) + def delete(self, old_rev=None): + doc = Document(self.id, old_rev, None) + self.db.delete_doc(doc) + self.responder.send_response_json(200, rev=doc.rev) + + @http_method(include_deleted=parse_bool) + def get(self, include_deleted=False): + doc = self.db.get_doc(self.id, include_deleted=include_deleted) + if doc is None: + wire_descr = errors.DocumentDoesNotExist.wire_description + self.responder.send_response_json( + http_errors.wire_description_to_status[wire_descr], + error=wire_descr, + headers={ + 'x-u1db-rev': '', + 'x-u1db-has-conflicts': 'false' + }) + return + headers = { + 'x-u1db-rev': doc.rev, + 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts) + } + if doc.is_tombstone(): + self.responder.send_response_json( + http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED], + error=errors.DOCUMENT_DELETED, + headers=headers) + else: + self.responder.send_response_content( + doc.get_json(), headers=headers) + + +@url_to_resource.register +class SyncResource(object): + """Sync endpoint resource.""" + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + url_pattern = "/{dbname}/sync-from/{source_replica_uid}" + + # pluggable + sync_exchange_class = sync.SyncExchange + + def __init__(self, dbname, source_replica_uid, state, responder): + self.source_replica_uid = source_replica_uid + self.responder = responder + self.state = state + self.dbname = dbname + self.replica_uid = None + + def get_target(self): + return self.state.open_database(self.dbname).get_sync_target() + + @http_method() + def get(self): + result = self.get_target().get_sync_info(self.source_replica_uid) + self.responder.send_response_json( + target_replica_uid=result[0], target_replica_generation=result[1], + target_replica_transaction_id=result[2], + source_replica_uid=self.source_replica_uid, + source_replica_generation=result[3], + source_transaction_id=result[4]) + + @http_method(generation=int, + content_as_args=True, no_query=True) + def put(self, generation, transaction_id): + self.get_target().record_sync_info(self.source_replica_uid, + generation, + transaction_id) + self.responder.send_response_json(ok=True) + + # Implements the same logic as LocalSyncTarget.sync_exchange + + @http_method(last_known_generation=int, last_known_trans_id=none_or_str, + content_as_args=True) + def post_args(self, last_known_generation, last_known_trans_id=None, + ensure=False): + if ensure: + db, self.replica_uid = self.state.ensure_database(self.dbname) + else: + db = self.state.open_database(self.dbname) + db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + self.sync_exch = self.sync_exchange_class( + db, self.source_replica_uid, last_known_generation) + + @http_method(content_as_args=True) + def post_stream_entry(self, id, rev, content, gen, trans_id): + doc = Document(id, rev, content) + self.sync_exch.insert_doc_from_source(doc, gen, trans_id) + + def post_end(self): + + def send_doc(doc, gen, trans_id): + entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + self.responder.stream_entry(entry) + + new_gen = self.sync_exch.find_changes_to_return() + self.responder.content_type = 'application/x-u1db-sync-stream' + self.responder.start_response(200) + self.responder.start_stream(), + header = {"new_generation": new_gen, + "new_transaction_id": self.sync_exch.new_trans_id} + if self.replica_uid is not None: + header['replica_uid'] = self.replica_uid + self.responder.stream_entry(header) + self.sync_exch.return_docs(send_doc) + self.responder.end_stream() + self.responder.finish_response() + + +class HTTPResponder(object): + """Encode responses from the server back to the client.""" + + # a multi document response will put args and documents + # each on one line of the response body + + def __init__(self, start_response): + self._started = False + self._stream_state = -1 + self._no_initial_obj = True + self.sent_response = False + self._start_response = start_response + self._write = None + self.content_type = 'application/json' + self.content = [] + + def start_response(self, status, obj_dic=None, headers={}): + """start sending response with optional first json object.""" + if self._started: + return + self._started = True + status_text = httplib.responses[status] + self._write = self._start_response('%d %s' % (status, status_text), + [('content-type', self.content_type), + ('cache-control', 'no-cache')] + + headers.items()) + # xxx version in headers + if obj_dic is not None: + self._no_initial_obj = False + self._write(json.dumps(obj_dic) + "\r\n") + + def finish_response(self): + """finish sending response.""" + self.sent_response = True + + def send_response_json(self, status=200, headers={}, **kwargs): + """send and finish response with json object body from keyword args.""" + content = json.dumps(kwargs) + "\r\n" + self.send_response_content(content, headers=headers, status=status) + + def send_response_content(self, content, status=200, headers={}): + """send and finish response with content""" + headers['content-length'] = str(len(content)) + self.start_response(status, headers=headers) + if self._stream_state == 1: + self.content = [',\r\n', content] + else: + self.content = [content] + self.finish_response() + + def start_stream(self): + "start stream (array) as part of the response." + assert self._started and self._no_initial_obj + self._stream_state = 0 + self._write("[") + + def stream_entry(self, entry): + "send stream entry as part of the response." + assert self._stream_state != -1 + if self._stream_state == 0: + self._stream_state = 1 + self._write('\r\n') + else: + self._write(',\r\n') + self._write(json.dumps(entry)) + + def end_stream(self): + "end stream (array)." + assert self._stream_state != -1 + self._write("\r\n]\r\n") + + +class HTTPInvocationByMethodWithBody(object): + """Invoke methods on a resource.""" + + def __init__(self, resource, environ, parameters): + self.resource = resource + self.environ = environ + self.max_request_size = getattr( + resource, 'max_request_size', parameters.max_request_size) + self.max_entry_size = getattr( + resource, 'max_entry_size', parameters.max_entry_size) + + def _lookup(self, method): + try: + return getattr(self.resource, method) + except AttributeError: + raise BadRequest() + + def __call__(self): + args = urlparse.parse_qsl(self.environ['QUERY_STRING'], + strict_parsing=False) + try: + args = dict( + (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) + except ValueError: + raise BadRequest() + method = self.environ['REQUEST_METHOD'].lower() + if method in ('get', 'delete'): + meth = self._lookup(method) + return meth(args, None) + else: + # we expect content-length > 0, reconsider if we move + # to support chunked enconding + try: + content_length = int(self.environ['CONTENT_LENGTH']) + except (ValueError, KeyError): + raise BadRequest + if content_length <= 0: + raise BadRequest + if content_length > self.max_request_size: + raise BadRequest + reader = _FencedReader(self.environ['wsgi.input'], content_length, + self.max_entry_size) + content_type = self.environ.get('CONTENT_TYPE') + if content_type == 'application/json': + meth = self._lookup(method) + body = reader.read_chunk(sys.maxint) + return meth(args, body) + elif content_type == 'application/x-u1db-sync-stream': + meth_args = self._lookup('%s_args' % method) + meth_entry = self._lookup('%s_stream_entry' % method) + meth_end = self._lookup('%s_end' % method) + body_getline = reader.getline + if body_getline().strip() != '[': + raise BadRequest() + line = body_getline() + line, comma = utils.check_and_strip_comma(line.strip()) + meth_args(args, line) + while True: + line = body_getline() + entry = line.strip() + if entry == ']': + break + if not entry or not comma: # empty or no prec comma + raise BadRequest + entry, comma = utils.check_and_strip_comma(entry) + meth_entry({}, entry) + if comma or body_getline(): # extra comma or data + raise BadRequest + return meth_end() + else: + raise BadRequest() + + +class HTTPApp(object): + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + def __init__(self, state): + self.state = state + + def _lookup_resource(self, environ, responder): + resource_cls, params = url_to_resource.match(environ['PATH_INFO']) + if resource_cls is None: + raise BadRequest # 404 instead? + resource = resource_cls( + state=self.state, responder=responder, **params) + return resource + + def __call__(self, environ, start_response): + responder = HTTPResponder(start_response) + self.request_begin(environ) + try: + resource = self._lookup_resource(environ, responder) + HTTPInvocationByMethodWithBody(resource, environ, self)() + except errors.U1DBError, e: + self.request_u1db_error(environ, e) + status = http_errors.wire_description_to_status.get( + e.wire_description, 500) + responder.send_response_json(status, error=e.wire_description) + except BadRequest: + self.request_bad_request(environ) + responder.send_response_json(400, error="bad request") + except KeyboardInterrupt: + raise + except: + self.request_failed(environ) + raise + else: + self.request_done(environ) + return responder.content + + # hooks for tracing requests + + def request_begin(self, environ): + """Hook called at the beginning of processing a request.""" + pass + + def request_done(self, environ): + """Hook called when done processing a request.""" + pass + + def request_u1db_error(self, environ, exc): + """Hook called when processing a request resulted in a U1DBError. + + U1DBError passed as exc. + """ + pass + + def request_bad_request(self, environ): + """Hook called when processing a bad request. + + No actual processing was done. + """ + pass + + def request_failed(self, environ): + """Hook called when processing a request failed unexpectedly. + + Invoked from an except block, so there's interpreter exception + information available. + """ + pass diff --git a/src/leap/soledad/u1db/remote/http_client.py b/src/leap/soledad/u1db/remote/http_client.py new file mode 100644 index 00000000..decddda3 --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_client.py @@ -0,0 +1,218 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Base class to make requests to a remote HTTP server.""" + +import httplib +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa +import socket +import ssl +import sys +import urlparse +import urllib + +from time import sleep +from u1db import ( + errors, + ) +from u1db.remote import ( + http_errors, + ) + +from u1db.remote.ssl_match_hostname import ( # noqa + CertificateError, + match_hostname, + ) + +# Ubuntu/debian +# XXX other... +CA_CERTS = "/etc/ssl/certs/ca-certificates.crt" + + +def _encode_query_parameter(value): + """Encode query parameter.""" + if isinstance(value, bool): + if value: + value = 'true' + else: + value = 'false' + return unicode(value).encode('utf-8') + + +class _VerifiedHTTPSConnection(httplib.HTTPSConnection): + """HTTPSConnection verifying server side certificates.""" + # derived from httplib.py + + def connect(self): + "Connect to a host on a given (SSL) port." + + sock = socket.create_connection((self.host, self.port), + self.timeout, self.source_address) + if self._tunnel_host: + self.sock = sock + self._tunnel() + if sys.platform.startswith('linux'): + cert_opts = { + 'cert_reqs': ssl.CERT_REQUIRED, + 'ca_certs': CA_CERTS + } + else: + # XXX no cert verification implemented elsewhere for now + cert_opts = {} + self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, + ssl_version=ssl.PROTOCOL_SSLv3, + **cert_opts + ) + if cert_opts: + match_hostname(self.sock.getpeercert(), self.host) + + +class HTTPClientBase(object): + """Base class to make requests to a remote HTTP server.""" + + # by default use HMAC-SHA1 OAuth signature method to not disclose + # tokens + # NB: given that the content bodies are not covered by the + # signatures though, to achieve security (against man-in-the-middle + # attacks for example) one would need HTTPS + oauth_signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1() + + # Will use these delays to retry on 503 befor finally giving up. The final + # 0 is there to not wait after the final try fails. + _delays = (1, 1, 2, 4, 0) + + def __init__(self, url, creds=None): + self._url = urlparse.urlsplit(url) + self._conn = None + self._creds = {} + if creds is not None: + if len(creds) != 1: + raise errors.UnknownAuthMethod() + auth_meth, credentials = creds.items()[0] + try: + set_creds = getattr(self, 'set_%s_credentials' % auth_meth) + except AttributeError: + raise errors.UnknownAuthMethod(auth_meth) + set_creds(**credentials) + + def set_oauth_credentials(self, consumer_key, consumer_secret, + token_key, token_secret): + self._creds = {'oauth': ( + oauth.OAuthConsumer(consumer_key, consumer_secret), + oauth.OAuthToken(token_key, token_secret))} + + def _ensure_connection(self): + if self._conn is not None: + return + if self._url.scheme == 'https': + connClass = _VerifiedHTTPSConnection + else: + connClass = httplib.HTTPConnection + self._conn = connClass(self._url.hostname, self._url.port) + + def close(self): + if self._conn: + self._conn.close() + self._conn = None + + # xxx retry mechanism? + + def _error(self, respdic): + descr = respdic.get("error") + exc_cls = errors.wire_description_to_exc.get(descr) + if exc_cls is not None: + message = respdic.get("message") + raise exc_cls(message) + + def _response(self): + resp = self._conn.getresponse() + body = resp.read() + headers = dict(resp.getheaders()) + if resp.status in (200, 201): + return body, headers + elif resp.status in http_errors.ERROR_STATUSES: + try: + respdic = json.loads(body) + except ValueError: + pass + else: + self._error(respdic) + # special case + if resp.status == 503: + raise errors.Unavailable(body, headers) + raise errors.HTTPError(resp.status, body, headers) + + def _sign_request(self, method, url_query, params): + if 'oauth' in self._creds: + consumer, token = self._creds['oauth'] + full_url = "%s://%s%s" % (self._url.scheme, self._url.netloc, + url_query) + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + consumer, token, + http_method=method, + parameters=params, + http_url=full_url + ) + oauth_req.sign_request( + self.oauth_signature_method, consumer, token) + # Authorization: OAuth ... + return oauth_req.to_header().items() + else: + return [] + + def _request(self, method, url_parts, params=None, body=None, + content_type=None): + self._ensure_connection() + unquoted_url = url_query = self._url.path + if url_parts: + if not url_query.endswith('/'): + url_query += '/' + unquoted_url = url_query + url_query += '/'.join(urllib.quote(part, safe='') + for part in url_parts) + # oauth performs its own quoting + unquoted_url += '/'.join(url_parts) + encoded_params = {} + if params: + for key, value in params.items(): + key = unicode(key).encode('utf-8') + encoded_params[key] = _encode_query_parameter(value) + url_query += ('?' + urllib.urlencode(encoded_params)) + if body is not None and not isinstance(body, basestring): + body = json.dumps(body) + content_type = 'application/json' + headers = {} + if content_type: + headers['content-type'] = content_type + headers.update( + self._sign_request(method, unquoted_url, encoded_params)) + for delay in self._delays: + try: + self._conn.request(method, url_query, body, headers) + return self._response() + except errors.Unavailable, e: + sleep(delay) + raise e + + def _request_json(self, method, url_parts, params=None, body=None, + content_type=None): + res, headers = self._request(method, url_parts, params, body, + content_type) + return json.loads(res), headers diff --git a/src/leap/soledad/u1db/remote/http_database.py b/src/leap/soledad/u1db/remote/http_database.py new file mode 100644 index 00000000..6901baad --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_database.py @@ -0,0 +1,143 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""HTTPDatabase to access a remote db over the HTTP API.""" + +try: + import simplejson as json +except ImportError: + import json # noqa +import uuid + +from u1db import ( + Database, + Document, + errors, + ) +from u1db.remote import ( + http_client, + http_errors, + http_target, + ) + + +DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED] + + +class HTTPDatabase(http_client.HTTPClientBase, Database): + """Implement the Database API to a remote HTTP server.""" + + def __init__(self, url, document_factory=None, creds=None): + super(HTTPDatabase, self).__init__(url, creds=creds) + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + @staticmethod + def open_database(url, create): + db = HTTPDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = HTTPDatabase(url) + db._delete() + db.close() + + def open(self, create): + if create: + self._ensure() + else: + self._check() + + def _check(self): + return self._request_json('GET', [])[0] + + def _ensure(self): + self._request_json('PUT', [], {}, {}) + + def _delete(self): + self._request_json('DELETE', [], {}, {}) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {} + if doc.rev is not None: + params['old_rev'] = doc.rev + res, headers = self._request_json('PUT', ['doc', doc.doc_id], params, + doc.get_json(), 'application/json') + doc.rev = res['rev'] + return res['rev'] + + def get_doc(self, doc_id, include_deleted=False): + try: + res, headers = self._request( + 'GET', ['doc', doc_id], {"include_deleted": include_deleted}) + except errors.DocumentDoesNotExist: + return None + except errors.HTTPError, e: + if (e.status == DOCUMENT_DELETED_STATUS and + 'x-u1db-rev' in e.headers): + res = None + headers = e.headers + else: + raise + doc_rev = headers['x-u1db-rev'] + has_conflicts = json.loads(headers['x-u1db-has-conflicts']) + doc = self._factory(doc_id, doc_rev, res) + doc.has_conflicts = has_conflicts + return doc + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + if not doc_ids: + return + doc_ids = ','.join(doc_ids) + res, headers = self._request( + 'GET', ['docs'], { + "doc_ids": doc_ids, "include_deleted": include_deleted, + "check_for_conflicts": check_for_conflicts}) + for doc_dict in json.loads(res): + doc = self._factory( + doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content']) + doc.has_conflicts = doc_dict['has_conflicts'] + yield doc + + def create_doc_from_json(self, content, doc_id=None): + if doc_id is None: + doc_id = 'D-%s' % (uuid.uuid4().hex,) + res, headers = self._request_json('PUT', ['doc', doc_id], {}, + content, 'application/json') + new_doc = self._factory(doc_id, res['rev'], content) + return new_doc + + def delete_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {'old_rev': doc.rev} + res, headers = self._request_json('DELETE', + ['doc', doc.doc_id], params) + doc.make_tombstone() + doc.rev = res['rev'] + + def get_sync_target(self): + st = http_target.HTTPSyncTarget(self._url.geturl()) + st._creds = self._creds + return st diff --git a/src/leap/soledad/u1db/remote/http_errors.py b/src/leap/soledad/u1db/remote/http_errors.py new file mode 100644 index 00000000..2039c5b2 --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_errors.py @@ -0,0 +1,46 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Information about the encoding of errors over HTTP.""" + +from u1db import ( + errors, + ) + + +# error wire descriptions mapping to HTTP status codes +wire_description_to_status = dict([ + (errors.InvalidDocId.wire_description, 400), + (errors.MissingDocIds.wire_description, 400), + (errors.Unauthorized.wire_description, 401), + (errors.DocumentTooBig.wire_description, 403), + (errors.UserQuotaExceeded.wire_description, 403), + (errors.SubscriptionNeeded.wire_description, 403), + (errors.DatabaseDoesNotExist.wire_description, 404), + (errors.DocumentDoesNotExist.wire_description, 404), + (errors.DocumentAlreadyDeleted.wire_description, 404), + (errors.RevisionConflict.wire_description, 409), + (errors.InvalidGeneration.wire_description, 409), + (errors.InvalidTransactionId.wire_description, 409), + (errors.Unavailable.wire_description, 503), +# without matching exception + (errors.DOCUMENT_DELETED, 404) +]) + + +ERROR_STATUSES = set(wire_description_to_status.values()) +# 400 included explicitly for tests +ERROR_STATUSES.add(400) diff --git a/src/leap/soledad/u1db/remote/http_target.py b/src/leap/soledad/u1db/remote/http_target.py new file mode 100644 index 00000000..1028963e --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_target.py @@ -0,0 +1,135 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""SyncTarget API implementation to a remote HTTP server.""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + Document, + SyncTarget, + ) +from u1db.errors import ( + BrokenSyncStream, + ) +from u1db.remote import ( + http_client, + utils, + ) + + +class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget): + """Implement the SyncTarget api to a remote HTTP server.""" + + @staticmethod + def connect(url): + return HTTPSyncTarget(url) + + def get_sync_info(self, source_replica_uid): + self._ensure_connection() + res, _ = self._request_json('GET', ['sync-from', source_replica_uid]) + return (res['target_replica_uid'], res['target_replica_generation'], + res['target_replica_transaction_id'], + res['source_replica_generation'], res['source_transaction_id']) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('record_sync_info') + self._request_json('PUT', ['sync-from', source_replica_uid], {}, + {'generation': source_replica_generation, + 'transaction_id': source_transaction_id}) + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = Document(entry['id'], entry['rev'], entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] + + # for tests + _trace_hook = None + + def _set_trace_hook_shallow(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/u1db/remote/oauth_middleware.py b/src/leap/soledad/u1db/remote/oauth_middleware.py new file mode 100644 index 00000000..5772580a --- /dev/null +++ b/src/leap/soledad/u1db/remote/oauth_middleware.py @@ -0,0 +1,89 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . +"""U1DB OAuth authorisation WSGI middleware.""" +import httplib +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa +from urllib import quote +from wsgiref.util import shift_path_info + + +sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() +sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() + + +class OAuthMiddleware(object): + """U1DB OAuth Authorisation WSGI middleware.""" + + # max seconds the request timestamp is allowed to be shifted + # from arrival time + timestamp_threshold = 300 + + def __init__(self, app, base_url, prefix='/~/'): + self.app = app + self.base_url = base_url + self.prefix = prefix + + def get_oauth_data_store(self): + """Provide a oauth.OAuthDataStore.""" + raise NotImplementedError(self.get_oauth_data_store) + + def _error(self, start_response, status, description, message=None): + start_response("%d %s" % (status, httplib.responses[status]), + [('content-type', 'application/json')]) + err = {"error": description} + if message: + err['message'] = message + return [json.dumps(err)] + + def __call__(self, environ, start_response): + if self.prefix and not environ['PATH_INFO'].startswith(self.prefix): + return self._error(start_response, 400, "bad request") + headers = {} + if 'HTTP_AUTHORIZATION' in environ: + headers['Authorization'] = environ['HTTP_AUTHORIZATION'] + oauth_req = oauth.OAuthRequest.from_request( + http_method=environ['REQUEST_METHOD'], + http_url=self.base_url + environ['PATH_INFO'], + headers=headers, + query_string=environ['QUERY_STRING'] + ) + if oauth_req is None: + return self._error(start_response, 401, "unauthorized", + "Missing OAuth.") + try: + self.verify(environ, oauth_req) + except oauth.OAuthError, e: + return self._error(start_response, 401, "unauthorized", + e.message) + shift_path_info(environ) + return self.app(environ, start_response) + + def verify(self, environ, oauth_req): + """Verify OAuth request, put user_id in the environ.""" + oauth_server = oauth.OAuthServer(self.get_oauth_data_store()) + oauth_server.timestamp_threshold = self.timestamp_threshold + oauth_server.add_signature_method(sign_meth_HMAC_SHA1) + oauth_server.add_signature_method(sign_meth_PLAINTEXT) + consumer, token, parameters = oauth_server.verify_request(oauth_req) + # filter out oauth bits + environ['QUERY_STRING'] = '&'.join("%s=%s" % (quote(k, safe=''), + quote(v, safe='')) + for k, v in parameters.iteritems()) + return consumer, token diff --git a/src/leap/soledad/u1db/remote/server_state.py b/src/leap/soledad/u1db/remote/server_state.py new file mode 100644 index 00000000..96581359 --- /dev/null +++ b/src/leap/soledad/u1db/remote/server_state.py @@ -0,0 +1,67 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""State for servers exposing a set of U1DB databases.""" +import os +import errno + +class ServerState(object): + """Passed to a Request when it is instantiated. + + This is used to track server-side state, such as working-directory, open + databases, etc. + """ + + def __init__(self): + self._workingdir = None + + def set_workingdir(self, path): + self._workingdir = path + + def _relpath(self, relpath): + # Note: We don't want to allow absolute paths here, because we + # don't want to expose the filesystem. We should also check that + # relpath doesn't have '..' in it, etc. + return self._workingdir + '/' + relpath + + def open_database(self, path): + """Open a database at the given location.""" + from u1db.backends import sqlite_backend + full_path = self._relpath(path) + return sqlite_backend.SQLiteDatabase.open_database(full_path, + create=False) + + def check_database(self, path): + """Check if the database at the given location exists. + + Simply returns if it does or raises DatabaseDoesNotExist. + """ + db = self.open_database(path) + db.close() + + def ensure_database(self, path): + """Ensure database at the given location.""" + from u1db.backends import sqlite_backend + full_path = self._relpath(path) + db = sqlite_backend.SQLiteDatabase.open_database(full_path, + create=True) + return db, db._replica_uid + + def delete_database(self, path): + """Delete database at the given location.""" + from u1db.backends import sqlite_backend + full_path = self._relpath(path) + sqlite_backend.SQLiteDatabase.delete_database(full_path) diff --git a/src/leap/soledad/u1db/remote/ssl_match_hostname.py b/src/leap/soledad/u1db/remote/ssl_match_hostname.py new file mode 100644 index 00000000..fbabc177 --- /dev/null +++ b/src/leap/soledad/u1db/remote/ssl_match_hostname.py @@ -0,0 +1,64 @@ +"""The match_hostname() function from Python 3.2, essential when using SSL.""" +# XXX put it here until it's packaged + +import re + +__version__ = '3.2a3' + + +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn): + pats = [] + for frag in dn.split(r'.'): + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not san: + # The subject is only checked when subjectAltName is empty + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") diff --git a/src/leap/soledad/u1db/remote/utils.py b/src/leap/soledad/u1db/remote/utils.py new file mode 100644 index 00000000..14cedea9 --- /dev/null +++ b/src/leap/soledad/u1db/remote/utils.py @@ -0,0 +1,23 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Utilities for details of the procotol.""" + + +def check_and_strip_comma(line): + if line and line[-1] == ',': + return line[:-1], True + return line, False diff --git a/src/leap/soledad/u1db/sync.py b/src/leap/soledad/u1db/sync.py new file mode 100644 index 00000000..3375d097 --- /dev/null +++ b/src/leap/soledad/u1db/sync.py @@ -0,0 +1,304 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""The synchronization utilities for U1DB.""" +from itertools import izip + +import u1db +from u1db import errors + + +class Synchronizer(object): + """Collect the state around synchronizing 2 U1DB replicas. + + Synchronization is bi-directional, in that new items in the source are sent + to the target, and new items in the target are returned to the source. + However, it still recognizes that one side is initiating the request. Also, + at the moment, conflicts are only created in the source. + """ + + def __init__(self, source, sync_target): + """Create a new Synchronization object. + + :param source: A Database + :param sync_target: A SyncTarget + """ + self.source = source + self.sync_target = sync_target + self.target_replica_uid = None + self.num_inserted = 0 + + def _insert_doc_from_target(self, doc, replica_gen, trans_id): + """Try to insert synced document from target. + + Implements TAKE OTHER semantics: any document from the target + that is in conflict will be taken as the new official value, + while the current conflicting value will be stored alongside + as a conflict. In the process indexes will be updated etc. + + :return: None + """ + # Increases self.num_inserted depending whether the document + # was effectively inserted. + state, _ = self.source._put_doc_if_newer(doc, save_conflict=True, + replica_uid=self.target_replica_uid, replica_gen=replica_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.num_inserted += 1 + elif state == 'converged': + # magical convergence + pass + elif state == 'superseded': + # we have something newer, will be taken care of at the next sync + pass + else: + assert state == 'conflicted' + # The doc was saved as a conflict, so the database was updated + self.num_inserted += 1 + + def _record_sync_info_with_the_target(self, start_generation): + """Record our new after sync generation with the target if gapless. + + Any documents received from the target will cause the local + database to increment its generation. We do not want to send + them back to the target in a future sync. However, there could + also be concurrent updates from another process doing eg + 'put_doc' while the sync was running. And we do want to + synchronize those documents. We can tell if there was a + concurrent update by comparing our new generation number + versus the generation we started, and how many documents we + inserted from the target. If it matches exactly, then we can + record with the target that they are fully up to date with our + new generation. + """ + cur_gen, trans_id = self.source._get_generation_info() + if (cur_gen == start_generation + self.num_inserted + and self.num_inserted > 0): + self.sync_target.record_sync_info( + self.source._replica_uid, cur_gen, trans_id) + + def sync(self, callback=None, autocreate=False): + """Synchronize documents between source and target.""" + sync_target = self.sync_target + # get target identifier, its current generation, + # and its last-seen database generation for this source + try: + (self.target_replica_uid, target_gen, target_trans_id, + target_my_gen, target_my_trans_id) = sync_target.get_sync_info( + self.source._replica_uid) + except errors.DatabaseDoesNotExist: + if not autocreate: + raise + # will try to ask sync_exchange() to create the db + self.target_replica_uid = None + target_gen, target_trans_id = 0, '' + target_my_gen, target_my_trans_id = 0, '' + def ensure_callback(replica_uid): + self.target_replica_uid = replica_uid + else: + ensure_callback = None + # validate the generation and transaction id the target knows about us + self.source.validate_gen_and_trans_id( + target_my_gen, target_my_trans_id) + # what's changed since that generation and this current gen + my_gen, _, changes = self.source.whats_changed(target_my_gen) + + # this source last-seen database generation for the target + if self.target_replica_uid is None: + target_last_known_gen, target_last_known_trans_id = 0, '' + else: + target_last_known_gen, target_last_known_trans_id = \ + self.source._get_replica_gen_and_trans_id(self.target_replica_uid) + if not changes and target_last_known_gen == target_gen: + if target_trans_id != target_last_known_trans_id: + raise errors.InvalidTransactionId + return my_gen + changed_doc_ids = [doc_id for doc_id, _, _ in changes] + # prepare to send all the changed docs + docs_to_send = self.source.get_docs(changed_doc_ids, + check_for_conflicts=False, include_deleted=True) + # TODO: there must be a way to not iterate twice + docs_by_generation = zip( + docs_to_send, (gen for _, gen, _ in changes), + (trans for _, _, trans in changes)) + + # exchange documents and try to insert the returned ones with + # the target, return target synced-up-to gen + new_gen, new_trans_id = sync_target.sync_exchange( + docs_by_generation, self.source._replica_uid, + target_last_known_gen, target_last_known_trans_id, + self._insert_doc_from_target, ensure_callback=ensure_callback) + # record target synced-up-to generation including applying what we sent + self.source._set_replica_gen_and_trans_id( + self.target_replica_uid, new_gen, new_trans_id) + + # if gapless record current reached generation with target + self._record_sync_info_with_the_target(my_gen) + + return my_gen + + +class SyncExchange(object): + """Steps and state for carrying through a sync exchange on a target.""" + + def __init__(self, db, source_replica_uid, last_known_generation): + self._db = db + self.source_replica_uid = source_replica_uid + self.source_last_known_generation = last_known_generation + self.seen_ids = {} # incoming ids not superseded + self.changes_to_return = None + self.new_gen = None + self.new_trans_id = None + # for tests + self._incoming_trace = [] + self._trace_hook = None + self._db._last_exchange_log = { + 'receive': {'docs': self._incoming_trace}, + 'return': None + } + + def _set_trace_hook(self, cb): + self._trace_hook = cb + + def _trace(self, state): + if not self._trace_hook: + return + self._trace_hook(state) + + def insert_doc_from_source(self, doc, source_gen, trans_id): + """Try to insert synced document from source. + + Conflicting documents are not inserted but will be sent over + to the sync source. + + It keeps track of progress by storing the document source + generation as well. + + The 1st step of a sync exchange is to call this repeatedly to + try insert all incoming documents from the source. + + :param doc: A Document object. + :param source_gen: The source generation of doc. + :return: None + """ + state, at_gen = self._db._put_doc_if_newer(doc, save_conflict=False, + replica_uid=self.source_replica_uid, replica_gen=source_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.seen_ids[doc.doc_id] = at_gen + elif state == 'converged': + # magical convergence + self.seen_ids[doc.doc_id] = at_gen + elif state == 'superseded': + # we have something newer that we will return + pass + else: + # conflict that we will returne + assert state == 'conflicted' + # for tests + self._incoming_trace.append((doc.doc_id, doc.rev)) + self._db._last_exchange_log['receive'].update({ + 'source_uid': self.source_replica_uid, + 'source_gen': source_gen + }) + + def find_changes_to_return(self): + """Find changes to return. + + Find changes since last_known_generation in db generation + order using whats_changed. It excludes documents ids that have + already been considered (superseded by the sender, etc). + + :return: new_generation - the generation of this database + which the caller can consider themselves to be synchronized after + processing the returned documents. + """ + self._db._last_exchange_log['receive'].update({ # for tests + 'last_known_gen': self.source_last_known_generation + }) + self._trace('before whats_changed') + gen, trans_id, changes = self._db.whats_changed( + self.source_last_known_generation) + self._trace('after whats_changed') + self.new_gen = gen + self.new_trans_id = trans_id + seen_ids = self.seen_ids + # changed docs that weren't superseded by or converged with + self.changes_to_return = [ + (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes + # there was a subsequent update + if doc_id not in seen_ids or seen_ids.get(doc_id) < gen] + return self.new_gen + + def return_docs(self, return_doc_cb): + """Return the changed documents and their last change generation + repeatedly invoking the callback return_doc_cb. + + The final step of a sync exchange. + + :param: return_doc_cb(doc, gen, trans_id): is a callback + used to return the documents with their last change generation + to the target replica. + :return: None + """ + changes_to_return = self.changes_to_return + # return docs, including conflicts + changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] + self._trace('before get_docs') + docs = self._db.get_docs( + changed_doc_ids, check_for_conflicts=False, include_deleted=True) + + docs_by_gen = izip( + docs, (gen for _, gen, _ in changes_to_return), + (trans_id for _, _, trans_id in changes_to_return)) + _outgoing_trace = [] # for tests + for doc, gen, trans_id in docs_by_gen: + return_doc_cb(doc, gen, trans_id) + _outgoing_trace.append((doc.doc_id, doc.rev)) + # for tests + self._db._last_exchange_log['return'] = { + 'docs': _outgoing_trace, + 'last_gen': self.new_gen + } + + +class LocalSyncTarget(u1db.SyncTarget): + """Common sync target implementation logic for all local sync targets.""" + + def __init__(self, db): + self._db = db + self._trace_hook = None + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + sync_exch = SyncExchange( + self._db, source_replica_uid, last_known_generation) + if self._trace_hook: + sync_exch._set_trace_hook(self._trace_hook) + # 1st step: try to insert incoming docs and record progress + for doc, doc_gen, trans_id in docs_by_generations: + sync_exch.insert_doc_from_source(doc, doc_gen, trans_id) + # 2nd step: find changed documents (including conflicts) to return + new_gen = sync_exch.find_changes_to_return() + # final step: return docs and record source replica sync point + sync_exch.return_docs(return_doc_cb) + return new_gen, sync_exch.new_trans_id + + def _set_trace_hook(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/u1db/tests/__init__.py b/src/leap/soledad/u1db/tests/__init__.py new file mode 100644 index 00000000..b8e16b15 --- /dev/null +++ b/src/leap/soledad/u1db/tests/__init__.py @@ -0,0 +1,463 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test infrastructure for U1DB""" + +import copy +import shutil +import socket +import tempfile +import threading + +try: + import simplejson as json +except ImportError: + import json # noqa + +from wsgiref import simple_server + +from oauth import oauth +from sqlite3 import dbapi2 +from StringIO import StringIO + +import testscenarios +import testtools + +from u1db import ( + errors, + Document, + ) +from u1db.backends import ( + inmemory, + sqlite_backend, + ) +from u1db.remote import ( + server_state, + ) + +try: + from u1db.tests import c_backend_wrapper + c_backend_error = None +except ImportError, e: + c_backend_wrapper = None # noqa + c_backend_error = e + +# Setting this means that failing assertions will not include this module in +# their traceback. However testtools doesn't seem to set it, and we don't want +# this level to be omitted, but the lower levels to be shown. +# __unittest = 1 + + +class TestCase(testtools.TestCase): + + def createTempDir(self, prefix='u1db-tmp-'): + """Create a temporary directory to do some work in. + + This directory will be scheduled for cleanup when the test ends. + """ + tempdir = tempfile.mkdtemp(prefix=prefix) + self.addCleanup(shutil.rmtree, tempdir) + return tempdir + + def make_document(self, doc_id, doc_rev, content, has_conflicts=False): + return self.make_document_for_test( + self, doc_id, doc_rev, content, has_conflicts) + + def make_document_for_test(self, test, doc_id, doc_rev, content, + has_conflicts): + return make_document_for_test( + test, doc_id, doc_rev, content, has_conflicts) + + def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id)) + + def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, + has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + + def assertGetDocConflicts(self, db, doc_id, conflicts): + """Assert what conflicts are stored for a given doc_id. + + :param conflicts: A list of (doc_rev, content) pairs. + The first item must match the first item returned from the + database, however the rest can be returned in any order. + """ + if conflicts: + conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) + else cont)) for (rev, cont) in conflicts] + conflicts = conflicts[:1] + sorted(conflicts[1:]) + actual = db.get_doc_conflicts(doc_id) + if actual: + actual = [(doc.rev, (json.loads(doc.get_json()) + if doc.get_json() is not None else None)) for doc in actual] + actual = actual[:1] + sorted(actual[1:]) + self.assertEqual(conflicts, actual) + + +def multiply_scenarios(a_scenarios, b_scenarios): + """Create the cross-product of scenarios.""" + + all_scenarios = [] + for a_name, a_attrs in a_scenarios: + for b_name, b_attrs in b_scenarios: + name = '%s,%s' % (a_name, b_name) + attrs = dict(a_attrs) + attrs.update(b_attrs) + all_scenarios.append((name, attrs)) + return all_scenarios + + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + + +def make_memory_database_for_test(test, replica_uid): + return inmemory.InMemoryDatabase(replica_uid) + + +def copy_memory_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + new_db = inmemory.InMemoryDatabase(db._replica_uid) + new_db._transaction_log = db._transaction_log[:] + new_db._docs = copy.deepcopy(db._docs) + new_db._conflicts = copy.deepcopy(db._conflicts) + new_db._indexes = copy.deepcopy(db._indexes) + new_db._factory = db._factory + return new_db + + +def make_sqlite_partial_expanded_for_test(test, replica_uid): + db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + db._set_replica_uid(replica_uid) + return db + + +def copy_sqlite_partial_expanded_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + tmpfile = StringIO() + for line in db._db_handle.iterdump(): + if not 'sqlite_sequence' in line: # work around bug in iterdump + tmpfile.write('%s\n' % line) + tmpfile.seek(0) + new_db._db_handle = dbapi2.connect(':memory:') + new_db._db_handle.cursor().executescript(tmpfile.read()) + new_db._db_handle.commit() + new_db._set_replica_uid(db._replica_uid) + new_db._factory = db._factory + return new_db + + +def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): + return Document(doc_id, rev, content, has_conflicts=has_conflicts) + + +def make_c_database_for_test(test, replica_uid): + if c_backend_wrapper is None: + test.skipTest('c_backend_wrapper is not available') + db = c_backend_wrapper.CDatabase(':memory:') + db._set_replica_uid(replica_uid) + return db + + +def copy_c_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + if c_backend_wrapper is None: + test.skipTest('c_backend_wrapper is not available') + new_db = db._copy(db) + return new_db + + +def make_c_document_for_test(test, doc_id, rev, content, has_conflicts=False): + if c_backend_wrapper is None: + test.skipTest('c_backend_wrapper is not available') + return c_backend_wrapper.make_document( + doc_id, rev, content, has_conflicts=has_conflicts) + + +LOCAL_DATABASES_SCENARIOS = [ + ('mem', {'make_database_for_test': make_memory_database_for_test, + 'copy_database_for_test': copy_memory_database_for_test, + 'make_document_for_test': make_document_for_test}), + ('sql', {'make_database_for_test': + make_sqlite_partial_expanded_for_test, + 'copy_database_for_test': + copy_sqlite_partial_expanded_for_test, + 'make_document_for_test': make_document_for_test}), + ] + + +C_DATABASE_SCENARIOS = [ + ('c', {'make_database_for_test': make_c_database_for_test, + 'copy_database_for_test': copy_c_database_for_test, + 'make_document_for_test': make_c_document_for_test})] + + +class DatabaseBaseTests(TestCase): + + accept_fixed_trans_id = False # set to True assertTransactionLog + # is happy with all trans ids = '' + + scenarios = LOCAL_DATABASES_SCENARIOS + + def create_database(self, replica_uid): + return self.make_database_for_test(self, replica_uid) + + def copy_database(self, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES + # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST + # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS + # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND + # NINJA TO YOUR HOUSE. + return self.copy_database_for_test(self, db) + + def setUp(self): + super(DatabaseBaseTests, self).setUp() + self.db = self.create_database('test') + + def tearDown(self): + # TODO: Add close_database parameterization + # self.close_database(self.db) + super(DatabaseBaseTests, self).tearDown() + + def assertTransactionLog(self, doc_ids, db): + """Assert that the given docs are in the transaction log.""" + log = db._get_transaction_log() + just_ids = [] + seen_transactions = set() + for doc_id, transaction_id in log: + just_ids.append(doc_id) + self.assertIsNot(None, transaction_id, + "Transaction id should not be None") + if transaction_id == '' and self.accept_fixed_trans_id: + continue + self.assertNotEqual('', transaction_id, + "Transaction id should be a unique string") + self.assertTrue(transaction_id.startswith('T-')) + self.assertNotIn(transaction_id, seen_transactions) + seen_transactions.add(transaction_id) + self.assertEqual(doc_ids, just_ids) + + def getLastTransId(self, db): + """Return the transaction id for the last database update.""" + return self.db._get_transaction_log()[-1][-1] + + +class ServerStateForTests(server_state.ServerState): + """Used in the test suite, so we don't have to touch disk, etc.""" + + def __init__(self): + super(ServerStateForTests, self).__init__() + self._dbs = {} + + def open_database(self, path): + try: + return self._dbs[path] + except KeyError: + raise errors.DatabaseDoesNotExist + + def check_database(self, path): + # cares only about the possible exception + self.open_database(path) + + def ensure_database(self, path): + try: + db = self.open_database(path) + except errors.DatabaseDoesNotExist: + db = self._create_database(path) + return db, db._replica_uid + + def _copy_database(self, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES + # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST + # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS + # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND + # NINJA TO YOUR HOUSE. + new_db = copy_memory_database_for_test(None, db) + path = db._replica_uid + while path in self._dbs: + path += 'copy' + self._dbs[path] = new_db + return new_db + + def _create_database(self, path): + db = inmemory.InMemoryDatabase(path) + self._dbs[path] = db + return db + + def delete_database(self, path): + del self._dbs[path] + + +class ResponderForTests(object): + """Responder for tests.""" + _started = False + sent_response = False + status = None + + def start_response(self, status='success', **kwargs): + self._started = True + self.status = status + self.kwargs = kwargs + + def send_response(self, status='success', **kwargs): + self.start_response(status, **kwargs) + self.finish_response() + + def finish_response(self): + self.sent_response = True + + +class TestCaseWithServer(TestCase): + + @staticmethod + def server_def(): + # hook point + # should return (ServerClass, "shutdown method name", "url_scheme") + class _RequestHandler(simple_server.WSGIRequestHandler): + def log_request(*args): + pass # suppress + + def make_server(host_port, application): + assert application, "forgot to override make_app(_with_state)?" + srv = simple_server.WSGIServer(host_port, _RequestHandler) + # patch the value in if it's None + if getattr(application, 'base_url', 1) is None: + application.base_url = "http://%s:%s" % srv.server_address + srv.set_app(application) + return srv + + return make_server, "shutdown", "http" + + @staticmethod + def make_app_with_state(state): + # hook point + return None + + def make_app(self): + # potential hook point + self.request_state = ServerStateForTests() + return self.make_app_with_state(self.request_state) + + def setUp(self): + super(TestCaseWithServer, self).setUp() + self.server = self.server_thread = None + + @property + def url_scheme(self): + return self.server_def()[-1] + + def startServer(self): + server_def = self.server_def() + server_class, shutdown_meth, _ = server_def + application = self.make_app() + self.server = server_class(('127.0.0.1', 0), application) + self.server_thread = threading.Thread(target=self.server.serve_forever, + kwargs=dict(poll_interval=0.01)) + self.server_thread.start() + self.addCleanup(self.server_thread.join) + self.addCleanup(getattr(self.server, shutdown_meth)) + + def getURL(self, path=None): + host, port = self.server.server_address + if path is None: + path = '' + return '%s://%s:%s/%s' % (self.url_scheme, host, port, path) + + +def socket_pair(): + """Return a pair of TCP sockets connected to each other. + + Unlike socket.socketpair, this should work on Windows. + """ + sock_pair = getattr(socket, 'socket_pair', None) + if sock_pair: + return sock_pair(socket.AF_INET, socket.SOCK_STREAM) + listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(1) + client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_sock.connect(listen_sock.getsockname()) + server_sock, addr = listen_sock.accept() + listen_sock.close() + return server_sock, client_sock + + +# OAuth related testing + +consumer1 = oauth.OAuthConsumer('K1', 'S1') +token1 = oauth.OAuthToken('kkkk1', 'XYZ') +consumer2 = oauth.OAuthConsumer('K2', 'S2') +token2 = oauth.OAuthToken('kkkk2', 'ZYX') +token3 = oauth.OAuthToken('kkkk3', 'ZYX') + + +class TestingOAuthDataStore(oauth.OAuthDataStore): + """In memory predefined OAuthDataStore for testing.""" + + consumers = { + consumer1.key: consumer1, + consumer2.key: consumer2, + } + + tokens = { + token1.key: token1, + token2.key: token2 + } + + def lookup_consumer(self, key): + return self.consumers.get(key) + + def lookup_token(self, token_type, token_token): + return self.tokens.get(token_token) + + def lookup_nonce(self, oauth_consumer, oauth_token, nonce): + return None + +testingOAuthStore = TestingOAuthDataStore() + +sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() +sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() + + +def load_with_scenarios(loader, standard_tests, pattern): + """Load the tests in a given module. + + This just applies testscenarios.generate_scenarios to all the tests that + are present. We do it at load time rather than at run time, because it + plays nicer with various tools. + """ + suite = loader.suiteClass() + suite.addTests(testscenarios.generate_scenarios(standard_tests)) + return suite diff --git a/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx new file mode 100644 index 00000000..8a4b600d --- /dev/null +++ b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx @@ -0,0 +1,1541 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . +# +"""A Cython wrapper around the C implementation of U1DB Database backend.""" + +cdef extern from "Python.h": + object PyString_FromStringAndSize(char *s, Py_ssize_t n) + int PyString_AsStringAndSize(object o, char **buf, Py_ssize_t *length + ) except -1 + char *PyString_AsString(object) except NULL + char *PyString_AS_STRING(object) + char *strdup(char *) + void *calloc(size_t, size_t) + void free(void *) + ctypedef struct FILE: + pass + fprintf(FILE *, char *, ...) + FILE *stderr + size_t strlen(char *) + +cdef extern from "stdarg.h": + ctypedef struct va_list: + pass + void va_start(va_list, void*) + void va_start_int "va_start" (va_list, int) + void va_end(va_list) + +cdef extern from "u1db/u1db.h": + ctypedef struct u1database: + pass + ctypedef struct u1db_document: + char *doc_id + size_t doc_id_len + char *doc_rev + size_t doc_rev_len + char *json + size_t json_len + int has_conflicts + # Note: u1query is actually defined in u1db_internal.h, and in u1db.h it is + # just an opaque pointer. However, older versions of Cython don't let + # you have a forward declaration and a full declaration, so we just + # expose the whole thing here. + ctypedef struct u1query: + char *index_name + int num_fields + char **fields + cdef struct u1db_oauth_creds: + int auth_kind + char *consumer_key + char *consumer_secret + char *token_key + char *token_secret + ctypedef union u1db_creds + ctypedef u1db_creds* const_u1db_creds_ptr "const u1db_creds *" + + ctypedef char* const_char_ptr "const char*" + ctypedef int (*u1db_doc_callback)(void *context, u1db_document *doc) + ctypedef int (*u1db_key_callback)(void *context, int num_fields, + const_char_ptr *key) + ctypedef int (*u1db_doc_gen_callback)(void *context, + u1db_document *doc, int gen, const_char_ptr trans_id) + ctypedef int (*u1db_trans_info_callback)(void *context, + const_char_ptr doc_id, int gen, const_char_ptr trans_id) + + u1database * u1db_open(char *fname) + void u1db_free(u1database **) + int u1db_set_replica_uid(u1database *, char *replica_uid) + int u1db_set_document_size_limit(u1database *, int limit) + int u1db_get_replica_uid(u1database *, const_char_ptr *replica_uid) + int u1db_create_doc_from_json(u1database *db, char *json, char *doc_id, + u1db_document **doc) + int u1db_delete_doc(u1database *db, u1db_document *doc) + int u1db_get_doc(u1database *db, char *doc_id, int include_deleted, + u1db_document **doc) + int u1db_get_docs(u1database *db, int n_doc_ids, const_char_ptr *doc_ids, + int check_for_conflicts, int include_deleted, + void *context, u1db_doc_callback cb) + int u1db_get_all_docs(u1database *db, int include_deleted, int *generation, + void *context, u1db_doc_callback cb) + int u1db_put_doc(u1database *db, u1db_document *doc) + int u1db__validate_source(u1database *db, const_char_ptr replica_uid, + int replica_gen, const_char_ptr replica_trans_id) + int u1db__put_doc_if_newer(u1database *db, u1db_document *doc, + int save_conflict, char *replica_uid, + int replica_gen, char *replica_trans_id, + int *state, int *at_gen) + int u1db_resolve_doc(u1database *db, u1db_document *doc, + int n_revs, const_char_ptr *revs) + int u1db_delete_doc(u1database *db, u1db_document *doc) + int u1db_whats_changed(u1database *db, int *gen, char **trans_id, + void *context, u1db_trans_info_callback cb) + int u1db__get_transaction_log(u1database *db, void *context, + u1db_trans_info_callback cb) + int u1db_get_doc_conflicts(u1database *db, char *doc_id, void *context, + u1db_doc_callback cb) + int u1db_sync(u1database *db, const_char_ptr url, + const_u1db_creds_ptr creds, int *local_gen) nogil + int u1db_create_index_list(u1database *db, char *index_name, + int n_expressions, const_char_ptr *expressions) + int u1db_create_index(u1database *db, char *index_name, int n_expressions, + ...) + int u1db_get_from_index_list(u1database *db, u1query *query, void *context, + u1db_doc_callback cb, int n_values, + const_char_ptr *values) + int u1db_get_from_index(u1database *db, u1query *query, void *context, + u1db_doc_callback cb, int n_values, char *val0, + ...) + int u1db_get_range_from_index(u1database *db, u1query *query, + void *context, u1db_doc_callback cb, + int n_values, const_char_ptr *start_values, + const_char_ptr *end_values) + int u1db_delete_index(u1database *db, char *index_name) + int u1db_list_indexes(u1database *db, void *context, + int (*cb)(void *context, const_char_ptr index_name, + int n_expressions, const_char_ptr *expressions)) + int u1db_get_index_keys(u1database *db, char *index_name, void *context, + u1db_key_callback cb) + int u1db_simple_lookup1(u1database *db, char *index_name, char *val1, + void *context, u1db_doc_callback cb) + int u1db_query_init(u1database *db, char *index_name, u1query **query) + void u1db_free_query(u1query **query) + + int U1DB_OK + int U1DB_INVALID_PARAMETER + int U1DB_REVISION_CONFLICT + int U1DB_INVALID_DOC_ID + int U1DB_DOCUMENT_ALREADY_DELETED + int U1DB_DOCUMENT_DOES_NOT_EXIST + int U1DB_NOT_IMPLEMENTED + int U1DB_INVALID_JSON + int U1DB_DOCUMENT_TOO_BIG + int U1DB_USER_QUOTA_EXCEEDED + int U1DB_INVALID_VALUE_FOR_INDEX + int U1DB_INVALID_FIELD_SPECIFIER + int U1DB_INVALID_GLOBBING + int U1DB_BROKEN_SYNC_STREAM + int U1DB_DUPLICATE_INDEX_NAME + int U1DB_INDEX_DOES_NOT_EXIST + int U1DB_INVALID_GENERATION + int U1DB_INVALID_TRANSACTION_ID + int U1DB_INVALID_TRANSFORMATION_FUNCTION + int U1DB_UNKNOWN_OPERATION + int U1DB_INTERNAL_ERROR + int U1DB_TARGET_UNAVAILABLE + + int U1DB_INSERTED + int U1DB_SUPERSEDED + int U1DB_CONVERGED + int U1DB_CONFLICTED + + int U1DB_OAUTH_AUTH + + void u1db_free_doc(u1db_document **doc) + int u1db_doc_set_json(u1db_document *doc, char *json) + int u1db_doc_get_size(u1db_document *doc) + + +cdef extern from "u1db/u1db_internal.h": + ctypedef struct u1db_row: + u1db_row *next + int num_columns + int *column_sizes + unsigned char **columns + + ctypedef struct u1db_table: + int status + u1db_row *first_row + + ctypedef struct u1db_record: + u1db_record *next + char *doc_id + char *doc_rev + char *doc + + ctypedef struct u1db_sync_exchange: + int target_gen + int num_doc_ids + char **doc_ids_to_return + int *gen_for_doc_ids + const_char_ptr *trans_ids_for_doc_ids + + ctypedef int (*u1db__trace_callback)(void *context, const_char_ptr state) + ctypedef struct u1db_sync_target: + int (*get_sync_info)(u1db_sync_target *st, char *source_replica_uid, + const_char_ptr *st_replica_uid, int *st_gen, + char **st_trans_id, int *source_gen, + char **source_trans_id) nogil + int (*record_sync_info)(u1db_sync_target *st, + char *source_replica_uid, int source_gen, char *trans_id) nogil + int (*sync_exchange)(u1db_sync_target *st, + char *source_replica_uid, int n_docs, + u1db_document **docs, int *generations, + const_char_ptr *trans_ids, + int *target_gen, char **target_trans_id, + void *context, u1db_doc_gen_callback cb, + void *ensure_callback) nogil + int (*sync_exchange_doc_ids)(u1db_sync_target *st, + u1database *source_db, int n_doc_ids, + const_char_ptr *doc_ids, int *generations, + const_char_ptr *trans_ids, + int *target_gen, char **target_trans_id, + void *context, + u1db_doc_gen_callback cb, + void *ensure_callback) nogil + int (*get_sync_exchange)(u1db_sync_target *st, + char *source_replica_uid, + int last_known_source_gen, + u1db_sync_exchange **exchange) nogil + void (*finalize_sync_exchange)(u1db_sync_target *st, + u1db_sync_exchange **exchange) nogil + int (*_set_trace_hook)(u1db_sync_target *st, + void *context, u1db__trace_callback cb) nogil + + + void u1db__set_zero_delays() + int u1db__get_generation(u1database *, int *db_rev) + int u1db__get_document_size_limit(u1database *, int *limit) + int u1db__get_generation_info(u1database *, int *db_rev, char **trans_id) + int u1db__get_trans_id_for_gen(u1database *, int db_rev, char **trans_id) + int u1db_validate_gen_and_trans_id(u1database *, int db_rev, + const_char_ptr trans_id) + char *u1db__allocate_doc_id(u1database *) + int u1db__sql_close(u1database *) + u1database *u1db__copy(u1database *) + int u1db__sql_is_open(u1database *) + u1db_table *u1db__sql_run(u1database *, char *sql, size_t n) + void u1db__free_table(u1db_table **table) + u1db_record *u1db__create_record(char *doc_id, char *doc_rev, char *doc) + void u1db__free_records(u1db_record **) + + int u1db__allocate_document(char *doc_id, char *revision, char *content, + int has_conflicts, u1db_document **result) + int u1db__generate_hex_uuid(char *) + + int u1db__get_replica_gen_and_trans_id(u1database *db, char *replica_uid, + int *generation, char **trans_id) + int u1db__set_replica_gen_and_trans_id(u1database *db, char *replica_uid, + int generation, char *trans_id) + int u1db__sync_get_machine_info(u1database *db, char *other_replica_uid, + int *other_db_rev, char **my_replica_uid, + int *my_db_rev) + int u1db__sync_record_machine_info(u1database *db, char *replica_uid, + int db_rev) + int u1db__sync_exchange_seen_ids(u1db_sync_exchange *se, int *n_ids, + const_char_ptr **doc_ids) + int u1db__format_query(int n_fields, const_char_ptr *values, char **buf, + int *wildcard) + int u1db__get_sync_target(u1database *db, u1db_sync_target **sync_target) + int u1db__free_sync_target(u1db_sync_target **sync_target) + int u1db__sync_db_to_target(u1database *db, u1db_sync_target *target, + int *local_gen_before_sync) nogil + + int u1db__sync_exchange_insert_doc_from_source(u1db_sync_exchange *se, + u1db_document *doc, int source_gen, const_char_ptr trans_id) + int u1db__sync_exchange_find_doc_ids_to_return(u1db_sync_exchange *se) + int u1db__sync_exchange_return_docs(u1db_sync_exchange *se, void *context, + int (*cb)(void *context, + u1db_document *doc, int gen, + const_char_ptr trans_id)) + int u1db__create_http_sync_target(char *url, u1db_sync_target **target) + int u1db__create_oauth_http_sync_target(char *url, + char *consumer_key, char *consumer_secret, + char *token_key, char *token_secret, + u1db_sync_target **target) + +cdef extern from "u1db/u1db_http_internal.h": + int u1db__format_sync_url(u1db_sync_target *st, + const_char_ptr source_replica_uid, char **sync_url) + int u1db__get_oauth_authorization(u1db_sync_target *st, + char *http_method, char *url, + char **oauth_authorization) + + +cdef extern from "u1db/u1db_vectorclock.h": + ctypedef struct u1db_vectorclock_item: + char *replica_uid + int generation + + ctypedef struct u1db_vectorclock: + int num_items + u1db_vectorclock_item *items + + u1db_vectorclock *u1db__vectorclock_from_str(char *s) + void u1db__free_vectorclock(u1db_vectorclock **clock) + int u1db__vectorclock_increment(u1db_vectorclock *clock, char *replica_uid) + int u1db__vectorclock_maximize(u1db_vectorclock *clock, + u1db_vectorclock *other) + int u1db__vectorclock_as_str(u1db_vectorclock *clock, char **result) + int u1db__vectorclock_is_newer(u1db_vectorclock *maybe_newer, + u1db_vectorclock *older) + +from u1db import errors +from sqlite3 import dbapi2 + + +cdef int _append_trans_info_to_list(void *context, const_char_ptr doc_id, + int generation, + const_char_ptr trans_id) with gil: + a_list = (context) + doc = doc_id + a_list.append((doc, generation, trans_id)) + return 0 + + +cdef int _append_doc_to_list(void *context, u1db_document *doc) with gil: + a_list = context + pydoc = CDocument() + pydoc._doc = doc + a_list.append(pydoc) + return 0 + +cdef int _append_key_to_list(void *context, int num_fields, + const_char_ptr *key) with gil: + a_list = (context) + field_list = [] + for i from 0 <= i < num_fields: + field = key[i] + field_list.append(field.decode('utf-8')) + a_list.append(tuple(field_list)) + return 0 + +cdef _list_to_array(lst, const_char_ptr **res, int *count): + cdef const_char_ptr *tmp + count[0] = len(lst) + tmp = calloc(sizeof(char*), count[0]) + for idx, x in enumerate(lst): + tmp[idx] = x + res[0] = tmp + +cdef _list_to_str_array(lst, const_char_ptr **res, int *count): + cdef const_char_ptr *tmp + count[0] = len(lst) + tmp = calloc(sizeof(char*), count[0]) + new_objs = [] + for idx, x in enumerate(lst): + if isinstance(x, unicode): + x = x.encode('utf-8') + new_objs.append(x) + tmp[idx] = x + res[0] = tmp + return new_objs + + +cdef int _append_index_definition_to_list(void *context, + const_char_ptr index_name, int n_expressions, + const_char_ptr *expressions) with gil: + cdef int i + + a_list = (context) + exp_list = [] + for i from 0 <= i < n_expressions: + s = expressions[i] + exp_list.append(s.decode('utf-8')) + a_list.append((index_name, exp_list)) + return 0 + + +cdef int return_doc_cb_wrapper(void *context, u1db_document *doc, + int gen, const_char_ptr trans_id) with gil: + cdef CDocument pydoc + user_cb = context + pydoc = CDocument() + pydoc._doc = doc + try: + user_cb(pydoc, gen, trans_id) + except Exception, e: + # We suppress the exception here, because intermediating through the C + # layer gets a bit crazy + return U1DB_INVALID_PARAMETER + return U1DB_OK + + +cdef int _trace_hook(void *context, const_char_ptr state) with gil: + if context == NULL: + return U1DB_INVALID_PARAMETER + ctx = context + try: + ctx(state) + except: + # Note: It would be nice if we could map the Python exception into + # something in C + return U1DB_INTERNAL_ERROR + return U1DB_OK + + +cdef char *_ensure_str(object obj, object extra_objs) except NULL: + """Ensure that we have the UTF-8 representation of a parameter. + + :param obj: A Unicode or String object. + :param extra_objs: This should be a Python list. If we have to convert obj + from being a Unicode object, this will hold the PyString object so that + we know the char* lifetime will be correct. + :return: A C pointer to the UTF-8 representation. + """ + if isinstance(obj, unicode): + obj = obj.encode('utf-8') + extra_objs.append(obj) + return PyString_AsString(obj) + + +def _format_query(fields): + """Wrapper around u1db__format_query for testing.""" + cdef int status + cdef char *buf + cdef int wildcard[10] + cdef const_char_ptr *values + cdef int n_values + + # keep a reference to new_objs so that the pointers in expressions + # remain valid. + new_objs = _list_to_str_array(fields, &values, &n_values) + try: + status = u1db__format_query(n_values, values, &buf, wildcard) + finally: + free(values) + handle_status("format_query", status) + if buf == NULL: + res = None + else: + res = buf + free(buf) + w = [] + for i in range(len(fields)): + w.append(wildcard[i]) + return res, w + + +def make_document(doc_id, rev, content, has_conflicts=False): + cdef u1db_document *doc + cdef char *c_content = NULL, *c_rev = NULL, *c_doc_id = NULL + cdef int conflict + + if has_conflicts: + conflict = 1 + else: + conflict = 0 + if doc_id is None: + c_doc_id = NULL + else: + c_doc_id = doc_id + if content is None: + c_content = NULL + else: + c_content = content + if rev is None: + c_rev = NULL + else: + c_rev = rev + handle_status( + "make_document", + u1db__allocate_document(c_doc_id, c_rev, c_content, conflict, &doc)) + pydoc = CDocument() + pydoc._doc = doc + return pydoc + + +def generate_hex_uuid(): + uuid = PyString_FromStringAndSize(NULL, 32) + handle_status( + "Failed to generate uuid", + u1db__generate_hex_uuid(PyString_AS_STRING(uuid))) + return uuid + + +cdef class CDocument(object): + """A thin wrapper around the C Document struct.""" + + cdef u1db_document *_doc + + def __init__(self): + self._doc = NULL + + def __dealloc__(self): + u1db_free_doc(&self._doc) + + property doc_id: + def __get__(self): + if self._doc.doc_id == NULL: + return None + return PyString_FromStringAndSize( + self._doc.doc_id, self._doc.doc_id_len) + + property rev: + def __get__(self): + if self._doc.doc_rev == NULL: + return None + return PyString_FromStringAndSize( + self._doc.doc_rev, self._doc.doc_rev_len) + + def get_json(self): + if self._doc.json == NULL: + return None + return PyString_FromStringAndSize( + self._doc.json, self._doc.json_len) + + def set_json(self, val): + u1db_doc_set_json(self._doc, val) + + def get_size(self): + return u1db_doc_get_size(self._doc) + + property has_conflicts: + def __get__(self): + if self._doc.has_conflicts: + return True + return False + + def __repr__(self): + if self._doc.has_conflicts: + extra = ', conflicted' + else: + extra = '' + return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, + self.rev, extra, self.get_json()) + + def __hash__(self): + raise NotImplementedError(self.__hash__) + + def __richcmp__(self, other, int t): + try: + if t == 0: # Py_LT < + return ((self.doc_id, self.rev, self.get_json()) + < (other.doc_id, other.rev, other.get_json())) + elif t == 2: # Py_EQ == + return (self.doc_id == other.doc_id + and self.rev == other.rev + and self.get_json() == other.get_json() + and self.has_conflicts == other.has_conflicts) + except AttributeError: + # Fall through to NotImplemented + pass + + return NotImplemented + + +cdef object safe_str(const_char_ptr s): + if s == NULL: + return None + return s + + +cdef class CQuery: + + cdef u1query *_query + + def __init__(self): + self._query = NULL + + def __dealloc__(self): + u1db_free_query(&self._query) + + def _check(self): + if self._query == NULL: + raise RuntimeError("No valid _query.") + + property index_name: + def __get__(self): + self._check() + return safe_str(self._query.index_name) + + property num_fields: + def __get__(self): + self._check() + return self._query.num_fields + + property fields: + def __get__(self): + cdef int i + self._check() + fields = [] + for i from 0 <= i < self._query.num_fields: + fields.append(safe_str(self._query.fields[i])) + return fields + + +cdef handle_status(context, int status): + if status == U1DB_OK: + return + if status == U1DB_REVISION_CONFLICT: + raise errors.RevisionConflict() + if status == U1DB_INVALID_DOC_ID: + raise errors.InvalidDocId() + if status == U1DB_DOCUMENT_ALREADY_DELETED: + raise errors.DocumentAlreadyDeleted() + if status == U1DB_DOCUMENT_DOES_NOT_EXIST: + raise errors.DocumentDoesNotExist() + if status == U1DB_INVALID_PARAMETER: + raise RuntimeError('Bad parameters supplied') + if status == U1DB_NOT_IMPLEMENTED: + raise NotImplementedError("Functionality not implemented yet: %s" + % (context,)) + if status == U1DB_INVALID_VALUE_FOR_INDEX: + raise errors.InvalidValueForIndex() + if status == U1DB_INVALID_GLOBBING: + raise errors.InvalidGlobbing() + if status == U1DB_INTERNAL_ERROR: + raise errors.U1DBError("internal error") + if status == U1DB_BROKEN_SYNC_STREAM: + raise errors.BrokenSyncStream() + if status == U1DB_CONFLICTED: + raise errors.ConflictedDoc() + if status == U1DB_DUPLICATE_INDEX_NAME: + raise errors.IndexNameTakenError() + if status == U1DB_INDEX_DOES_NOT_EXIST: + raise errors.IndexDoesNotExist + if status == U1DB_INVALID_GENERATION: + raise errors.InvalidGeneration + if status == U1DB_INVALID_TRANSACTION_ID: + raise errors.InvalidTransactionId + if status == U1DB_TARGET_UNAVAILABLE: + raise errors.Unavailable + if status == U1DB_INVALID_JSON: + raise errors.InvalidJSON + if status == U1DB_DOCUMENT_TOO_BIG: + raise errors.DocumentTooBig + if status == U1DB_USER_QUOTA_EXCEEDED: + raise errors.UserQuotaExceeded + if status == U1DB_INVALID_TRANSFORMATION_FUNCTION: + raise errors.IndexDefinitionParseError + if status == U1DB_UNKNOWN_OPERATION: + raise errors.IndexDefinitionParseError + if status == U1DB_INVALID_FIELD_SPECIFIER: + raise errors.IndexDefinitionParseError() + raise RuntimeError('%s (status: %s)' % (context, status)) + + +cdef class CDatabase +cdef class CSyncTarget + +cdef class CSyncExchange(object): + + cdef u1db_sync_exchange *_exchange + cdef CSyncTarget _target + + def __init__(self, CSyncTarget target, source_replica_uid, source_gen): + self._target = target + assert self._target._st.get_sync_exchange != NULL, \ + "get_sync_exchange is NULL?" + handle_status("get_sync_exchange", + self._target._st.get_sync_exchange(self._target._st, + source_replica_uid, source_gen, &self._exchange)) + + def __dealloc__(self): + if self._target is not None and self._target._st != NULL: + self._target._st.finalize_sync_exchange(self._target._st, + &self._exchange) + + def _check(self): + if self._exchange == NULL: + raise RuntimeError("self._exchange is NULL") + + property target_gen: + def __get__(self): + self._check() + return self._exchange.target_gen + + def insert_doc_from_source(self, CDocument doc, source_gen, + source_trans_id): + self._check() + handle_status("insert_doc_from_source", + u1db__sync_exchange_insert_doc_from_source(self._exchange, + doc._doc, source_gen, source_trans_id)) + + def find_doc_ids_to_return(self): + self._check() + handle_status("find_doc_ids_to_return", + u1db__sync_exchange_find_doc_ids_to_return(self._exchange)) + + def return_docs(self, return_doc_cb): + self._check() + handle_status("return_docs", + u1db__sync_exchange_return_docs(self._exchange, + return_doc_cb, &return_doc_cb_wrapper)) + + def get_seen_ids(self): + cdef const_char_ptr *seen_ids + cdef int i, n_ids + self._check() + handle_status("sync_exchange_seen_ids", + u1db__sync_exchange_seen_ids(self._exchange, &n_ids, &seen_ids)) + res = [] + for i from 0 <= i < n_ids: + res.append(seen_ids[i]) + if (seen_ids != NULL): + free(seen_ids) + return res + + def get_doc_ids_to_return(self): + self._check() + res = [] + if (self._exchange.num_doc_ids > 0 + and self._exchange.doc_ids_to_return != NULL): + for i from 0 <= i < self._exchange.num_doc_ids: + res.append( + (self._exchange.doc_ids_to_return[i], + self._exchange.gen_for_doc_ids[i], + self._exchange.trans_ids_for_doc_ids[i])) + return res + + +cdef class CSyncTarget(object): + + cdef u1db_sync_target *_st + cdef CDatabase _db + + def __init__(self): + self._db = None + self._st = NULL + u1db__set_zero_delays() + + def __dealloc__(self): + u1db__free_sync_target(&self._st) + + def _check(self): + if self._st == NULL: + raise RuntimeError("self._st is NULL") + + def get_sync_info(self, source_replica_uid): + cdef const_char_ptr st_replica_uid = NULL + cdef int st_gen = 0, source_gen = 0, status + cdef char *trans_id = NULL + cdef char *st_trans_id = NULL + cdef char *c_source_replica_uid = NULL + + self._check() + assert self._st.get_sync_info != NULL, "get_sync_info is NULL?" + c_source_replica_uid = source_replica_uid + with nogil: + status = self._st.get_sync_info(self._st, c_source_replica_uid, + &st_replica_uid, &st_gen, &st_trans_id, &source_gen, &trans_id) + handle_status("get_sync_info", status) + res_trans_id = None + res_st_trans_id = None + if trans_id != NULL: + res_trans_id = trans_id + free(trans_id) + if st_trans_id != NULL: + res_st_trans_id = st_trans_id + free(st_trans_id) + return ( + safe_str(st_replica_uid), st_gen, res_st_trans_id, source_gen, + res_trans_id) + + def record_sync_info(self, source_replica_uid, source_gen, source_trans_id): + cdef int status + cdef int c_source_gen + cdef char *c_source_replica_uid = NULL + cdef char *c_source_trans_id = NULL + + self._check() + assert self._st.record_sync_info != NULL, "record_sync_info is NULL?" + c_source_replica_uid = source_replica_uid + c_source_gen = source_gen + c_source_trans_id = source_trans_id + with nogil: + status = self._st.record_sync_info( + self._st, c_source_replica_uid, c_source_gen, + c_source_trans_id) + handle_status("record_sync_info", status) + + def _get_sync_exchange(self, source_replica_uid, source_gen): + self._check() + return CSyncExchange(self, source_replica_uid, source_gen) + + def sync_exchange_doc_ids(self, source_db, doc_id_generations, + last_known_generation, last_known_trans_id, + return_doc_cb): + cdef const_char_ptr *doc_ids + cdef int *generations + cdef int num_doc_ids + cdef int target_gen + cdef char *target_trans_id = NULL + cdef int status + cdef CDatabase sdb + + self._check() + assert self._st.sync_exchange_doc_ids != NULL, "sync_exchange_doc_ids is NULL?" + sdb = source_db + num_doc_ids = len(doc_id_generations) + doc_ids = calloc(num_doc_ids, sizeof(char *)) + if doc_ids == NULL: + raise MemoryError + generations = calloc(num_doc_ids, sizeof(int)) + if generations == NULL: + free(doc_ids) + raise MemoryError + trans_ids = calloc(num_doc_ids, sizeof(char *)) + if trans_ids == NULL: + raise MemoryError + res_trans_id = '' + try: + for i, (doc_id, gen, trans_id) in enumerate(doc_id_generations): + doc_ids[i] = PyString_AsString(doc_id) + generations[i] = gen + trans_ids[i] = trans_id + target_gen = last_known_generation + if last_known_trans_id is not None: + target_trans_id = last_known_trans_id + with nogil: + status = self._st.sync_exchange_doc_ids(self._st, sdb._db, + num_doc_ids, doc_ids, generations, trans_ids, + &target_gen, &target_trans_id, + return_doc_cb, return_doc_cb_wrapper, NULL) + handle_status("sync_exchange_doc_ids", status) + if target_trans_id != NULL: + res_trans_id = target_trans_id + finally: + if target_trans_id != NULL: + free(target_trans_id) + if doc_ids != NULL: + free(doc_ids) + if generations != NULL: + free(generations) + if trans_ids != NULL: + free(trans_ids) + return target_gen, res_trans_id + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + cdef CDocument cur_doc + cdef u1db_document **docs = NULL + cdef int *generations = NULL + cdef const_char_ptr *trans_ids = NULL + cdef char *target_trans_id = NULL + cdef char *c_source_replica_uid = NULL + cdef int i, count, status, target_gen + assert ensure_callback is None # interface difference + + self._check() + assert self._st.sync_exchange != NULL, "sync_exchange is NULL?" + count = len(docs_by_generations) + res_trans_id = '' + try: + docs = calloc(count, sizeof(u1db_document*)) + if docs == NULL: + raise MemoryError + generations = calloc(count, sizeof(int)) + if generations == NULL: + raise MemoryError + trans_ids = calloc(count, sizeof(char*)) + if trans_ids == NULL: + raise MemoryError + for i from 0 <= i < count: + cur_doc = docs_by_generations[i][0] + generations[i] = docs_by_generations[i][1] + trans_ids[i] = docs_by_generations[i][2] + docs[i] = cur_doc._doc + target_gen = last_known_generation + if last_known_trans_id is not None: + target_trans_id = last_known_trans_id + c_source_replica_uid = source_replica_uid + with nogil: + status = self._st.sync_exchange( + self._st, c_source_replica_uid, count, docs, generations, + trans_ids, &target_gen, &target_trans_id, + return_doc_cb, return_doc_cb_wrapper, NULL) + handle_status("sync_exchange", status) + finally: + if docs != NULL: + free(docs) + if generations != NULL: + free(generations) + if trans_ids != NULL: + free(trans_ids) + if target_trans_id != NULL: + res_trans_id = target_trans_id + free(target_trans_id) + return target_gen, res_trans_id + + def _set_trace_hook(self, cb): + self._check() + assert self._st._set_trace_hook != NULL, "_set_trace_hook is NULL?" + handle_status("_set_trace_hook", + self._st._set_trace_hook(self._st, cb, _trace_hook)) + + _set_trace_hook_shallow = _set_trace_hook + + +cdef class CDatabase(object): + """A thin wrapper/shim to interact with the C implementation. + + Functionality should not be written here. It is only provided as a way to + expose the C API to the python test suite. + """ + + cdef public object _filename + cdef u1database *_db + cdef public object _supports_indexes + + def __init__(self, filename): + self._supports_indexes = False + self._filename = filename + self._db = u1db_open(self._filename) + + def __dealloc__(self): + u1db_free(&self._db) + + def close(self): + return u1db__sql_close(self._db) + + def _copy(self, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + new_db = CDatabase(':memory:') + u1db_free(&new_db._db) + new_db._db = u1db__copy(self._db) + return new_db + + def _sql_is_open(self): + if self._db == NULL: + return True + return u1db__sql_is_open(self._db) + + property _replica_uid: + def __get__(self): + cdef const_char_ptr val + cdef int status + status = u1db_get_replica_uid(self._db, &val) + if status != 0: + if val != NULL: + err = str(val) + else: + err = "" + raise RuntimeError("Failed to get_replica_uid: %d %s" + % (status, err)) + if val == NULL: + return None + return str(val) + + def _set_replica_uid(self, replica_uid): + cdef int status + status = u1db_set_replica_uid(self._db, replica_uid) + if status != 0: + raise RuntimeError('replica_uid could not be set to %s, error: %d' + % (replica_uid, status)) + + property document_size_limit: + def __get__(self): + cdef int limit + handle_status("document_size_limit", + u1db__get_document_size_limit(self._db, &limit)) + return limit + + def set_document_size_limit(self, limit): + cdef int status + status = u1db_set_document_size_limit(self._db, limit) + if status != 0: + raise RuntimeError( + "document_size_limit could not be set to %d, error: %d", + (limit, status)) + + def _allocate_doc_id(self): + cdef char *val + val = u1db__allocate_doc_id(self._db) + if val == NULL: + raise RuntimeError("Failed to allocate document id") + s = str(val) + free(val) + return s + + def _run_sql(self, sql): + cdef u1db_table *tbl + cdef u1db_row *cur_row + cdef size_t n + cdef int i + + if self._db == NULL: + raise RuntimeError("called _run_sql with a NULL pointer.") + tbl = u1db__sql_run(self._db, sql, len(sql)) + if tbl == NULL: + raise MemoryError("Failed to allocate table memory.") + try: + if tbl.status != 0: + raise RuntimeError("Status was not 0: %d" % (tbl.status,)) + # Now convert the table into python + res = [] + cur_row = tbl.first_row + while cur_row != NULL: + row = [] + for i from 0 <= i < cur_row.num_columns: + row.append(PyString_FromStringAndSize( + (cur_row.columns[i]), cur_row.column_sizes[i])) + res.append(tuple(row)) + cur_row = cur_row.next + return res + finally: + u1db__free_table(&tbl) + + def create_doc_from_json(self, json, doc_id=None): + cdef u1db_document *doc = NULL + cdef char *c_doc_id + + if doc_id is None: + c_doc_id = NULL + else: + c_doc_id = doc_id + handle_status('Failed to create_doc', + u1db_create_doc_from_json(self._db, json, c_doc_id, &doc)) + pydoc = CDocument() + pydoc._doc = doc + return pydoc + + def put_doc(self, CDocument doc): + handle_status("Failed to put_doc", + u1db_put_doc(self._db, doc._doc)) + return doc.rev + + def _validate_source(self, replica_uid, replica_gen, replica_trans_id): + cdef const_char_ptr c_uid, c_trans_id + cdef int c_gen = 0 + + c_uid = replica_uid + c_trans_id = replica_trans_id + c_gen = replica_gen + handle_status( + "invalid generation or transaction id", + u1db__validate_source(self._db, c_uid, c_gen, c_trans_id)) + + def _put_doc_if_newer(self, CDocument doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + cdef char *c_uid, *c_trans_id + cdef int gen, state = 0, at_gen = -1 + + if replica_uid is None: + c_uid = NULL + else: + c_uid = replica_uid + if replica_trans_id is None: + c_trans_id = NULL + else: + c_trans_id = replica_trans_id + if replica_gen is None: + gen = 0 + else: + gen = replica_gen + handle_status("Failed to _put_doc_if_newer", + u1db__put_doc_if_newer(self._db, doc._doc, save_conflict, + c_uid, gen, c_trans_id, &state, &at_gen)) + if state == U1DB_INSERTED: + return 'inserted', at_gen + elif state == U1DB_SUPERSEDED: + return 'superseded', at_gen + elif state == U1DB_CONVERGED: + return 'converged', at_gen + elif state == U1DB_CONFLICTED: + return 'conflicted', at_gen + else: + raise RuntimeError("Unknown _put_doc_if_newer state: %d" % (state,)) + + def get_doc(self, doc_id, include_deleted=False): + cdef u1db_document *doc = NULL + deleted = 1 if include_deleted else 0 + handle_status("get_doc failed", + u1db_get_doc(self._db, doc_id, deleted, &doc)) + if doc == NULL: + return None + pydoc = CDocument() + pydoc._doc = doc + return pydoc + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + cdef int n_doc_ids, conflicts + cdef const_char_ptr *c_doc_ids + + _list_to_array(doc_ids, &c_doc_ids, &n_doc_ids) + deleted = 1 if include_deleted else 0 + conflicts = 1 if check_for_conflicts else 0 + a_list = [] + handle_status("get_docs", + u1db_get_docs(self._db, n_doc_ids, c_doc_ids, + conflicts, deleted, a_list, _append_doc_to_list)) + free(c_doc_ids) + return a_list + + def get_all_docs(self, include_deleted=False): + cdef int c_generation + + a_list = [] + deleted = 1 if include_deleted else 0 + generation = 0 + c_generation = generation + handle_status( + "get_all_docs", u1db_get_all_docs( + self._db, deleted, &c_generation, a_list, + _append_doc_to_list)) + return (c_generation, a_list) + + def resolve_doc(self, CDocument doc, conflicted_doc_revs): + cdef const_char_ptr *revs + cdef int n_revs + + _list_to_array(conflicted_doc_revs, &revs, &n_revs) + handle_status("resolve_doc", + u1db_resolve_doc(self._db, doc._doc, n_revs, revs)) + free(revs) + + def get_doc_conflicts(self, doc_id): + conflict_docs = [] + handle_status("get_doc_conflicts", + u1db_get_doc_conflicts(self._db, doc_id, conflict_docs, + _append_doc_to_list)) + return conflict_docs + + def delete_doc(self, CDocument doc): + handle_status( + "Failed to delete %s" % (doc,), + u1db_delete_doc(self._db, doc._doc)) + + def whats_changed(self, generation=0): + cdef int c_generation + cdef int status + cdef char *trans_id = NULL + + a_list = [] + c_generation = generation + res_trans_id = '' + status = u1db_whats_changed(self._db, &c_generation, &trans_id, + a_list, _append_trans_info_to_list) + try: + handle_status("whats_changed", status) + finally: + if trans_id != NULL: + res_trans_id = trans_id + free(trans_id) + return c_generation, res_trans_id, a_list + + def _get_transaction_log(self): + a_list = [] + handle_status("_get_transaction_log", + u1db__get_transaction_log(self._db, a_list, + _append_trans_info_to_list)) + return [(doc_id, trans_id) for doc_id, gen, trans_id in a_list] + + def _get_generation(self): + cdef int generation + handle_status("get_generation", + u1db__get_generation(self._db, &generation)) + return generation + + def _get_generation_info(self): + cdef int generation + cdef char *trans_id + handle_status("get_generation_info", + u1db__get_generation_info(self._db, &generation, &trans_id)) + raw_trans_id = None + if trans_id != NULL: + raw_trans_id = trans_id + free(trans_id) + return generation, raw_trans_id + + def validate_gen_and_trans_id(self, generation, trans_id): + handle_status( + "validate_gen_and_trans_id", + u1db_validate_gen_and_trans_id(self._db, generation, trans_id)) + + def _get_trans_id_for_gen(self, generation): + cdef char *trans_id = NULL + + handle_status( + "_get_trans_id_for_gen", + u1db__get_trans_id_for_gen(self._db, generation, &trans_id)) + raw_trans_id = None + if trans_id != NULL: + raw_trans_id = trans_id + free(trans_id) + return raw_trans_id + + def _get_replica_gen_and_trans_id(self, replica_uid): + cdef int generation, status + cdef char *trans_id = NULL + + status = u1db__get_replica_gen_and_trans_id( + self._db, replica_uid, &generation, &trans_id) + handle_status("_get_replica_gen_and_trans_id", status) + raw_trans_id = None + if trans_id != NULL: + raw_trans_id = trans_id + free(trans_id) + return generation, raw_trans_id + + def _set_replica_gen_and_trans_id(self, replica_uid, generation, trans_id): + handle_status("_set_replica_gen_and_trans_id", + u1db__set_replica_gen_and_trans_id( + self._db, replica_uid, generation, trans_id)) + + def create_index_list(self, index_name, index_expressions): + cdef const_char_ptr *expressions + cdef int n_expressions + + # keep a reference to new_objs so that the pointers in expressions + # remain valid. + new_objs = _list_to_str_array( + index_expressions, &expressions, &n_expressions) + try: + status = u1db_create_index_list( + self._db, index_name, n_expressions, expressions) + finally: + free(expressions) + handle_status("create_index", status) + + def create_index(self, index_name, *index_expressions): + extra = [] + if len(index_expressions) == 0: + status = u1db_create_index(self._db, index_name, 0, NULL) + elif len(index_expressions) == 1: + status = u1db_create_index( + self._db, index_name, 1, + _ensure_str(index_expressions[0], extra)) + elif len(index_expressions) == 2: + status = u1db_create_index( + self._db, index_name, 2, + _ensure_str(index_expressions[0], extra), + _ensure_str(index_expressions[1], extra)) + elif len(index_expressions) == 3: + status = u1db_create_index( + self._db, index_name, 3, + _ensure_str(index_expressions[0], extra), + _ensure_str(index_expressions[1], extra), + _ensure_str(index_expressions[2], extra)) + elif len(index_expressions) == 4: + status = u1db_create_index( + self._db, index_name, 4, + _ensure_str(index_expressions[0], extra), + _ensure_str(index_expressions[1], extra), + _ensure_str(index_expressions[2], extra), + _ensure_str(index_expressions[3], extra)) + else: + status = U1DB_NOT_IMPLEMENTED + handle_status("create_index", status) + + def sync(self, url, creds=None): + cdef const_char_ptr c_url + cdef int local_gen = 0 + cdef u1db_oauth_creds _oauth_creds + cdef u1db_creds *_creds = NULL + c_url = url + if creds is not None: + _oauth_creds.auth_kind = U1DB_OAUTH_AUTH + _oauth_creds.consumer_key = creds['oauth']['consumer_key'] + _oauth_creds.consumer_secret = creds['oauth']['consumer_secret'] + _oauth_creds.token_key = creds['oauth']['token_key'] + _oauth_creds.token_secret = creds['oauth']['token_secret'] + _creds = &_oauth_creds + with nogil: + status = u1db_sync(self._db, c_url, _creds, &local_gen) + handle_status("sync", status) + return local_gen + + def list_indexes(self): + a_list = [] + handle_status("list_indexes", + u1db_list_indexes(self._db, a_list, + _append_index_definition_to_list)) + return a_list + + def delete_index(self, index_name): + handle_status("delete_index", + u1db_delete_index(self._db, index_name)) + + def get_from_index_list(self, index_name, key_values): + cdef const_char_ptr *values + cdef int n_values + cdef CQuery query + + query = self._query_init(index_name) + res = [] + # keep a reference to new_objs so that the pointers in expressions + # remain valid. + new_objs = _list_to_str_array(key_values, &values, &n_values) + try: + handle_status( + "get_from_index", u1db_get_from_index_list( + self._db, query._query, res, _append_doc_to_list, + n_values, values)) + finally: + free(values) + return res + + def get_from_index(self, index_name, *key_values): + cdef CQuery query + cdef int status + + extra = [] + query = self._query_init(index_name) + res = [] + status = U1DB_OK + if len(key_values) == 0: + status = u1db_get_from_index(self._db, query._query, + res, _append_doc_to_list, 0, NULL) + elif len(key_values) == 1: + status = u1db_get_from_index(self._db, query._query, + res, _append_doc_to_list, 1, + _ensure_str(key_values[0], extra)) + elif len(key_values) == 2: + status = u1db_get_from_index(self._db, query._query, + res, _append_doc_to_list, 2, + _ensure_str(key_values[0], extra), + _ensure_str(key_values[1], extra)) + elif len(key_values) == 3: + status = u1db_get_from_index(self._db, query._query, + res, _append_doc_to_list, 3, + _ensure_str(key_values[0], extra), + _ensure_str(key_values[1], extra), + _ensure_str(key_values[2], extra)) + elif len(key_values) == 4: + status = u1db_get_from_index(self._db, query._query, + res, _append_doc_to_list, 4, + _ensure_str(key_values[0], extra), + _ensure_str(key_values[1], extra), + _ensure_str(key_values[2], extra), + _ensure_str(key_values[3], extra)) + else: + status = U1DB_NOT_IMPLEMENTED + handle_status("get_from_index", status) + return res + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + cdef CQuery query + cdef const_char_ptr *start_values + cdef int n_values + cdef const_char_ptr *end_values + + if start_value is not None: + if isinstance(start_value, basestring): + start_value = (start_value,) + new_objs_1 = _list_to_str_array( + start_value, &start_values, &n_values) + else: + n_values = 0 + start_values = NULL + if end_value is not None: + if isinstance(end_value, basestring): + end_value = (end_value,) + new_objs_2 = _list_to_str_array( + end_value, &end_values, &n_values) + else: + end_values = NULL + query = self._query_init(index_name) + res = [] + try: + handle_status("get_range_from_index", + u1db_get_range_from_index( + self._db, query._query, res, _append_doc_to_list, + n_values, start_values, end_values)) + finally: + if start_values != NULL: + free(start_values) + if end_values != NULL: + free(end_values) + return res + + def get_index_keys(self, index_name): + cdef int status + keys = [] + status = U1DB_OK + status = u1db_get_index_keys( + self._db, index_name, keys, _append_key_to_list) + handle_status("get_index_keys", status) + return keys + + def _query_init(self, index_name): + cdef CQuery query + query = CQuery() + handle_status("query_init", + u1db_query_init(self._db, index_name, &query._query)) + return query + + def get_sync_target(self): + cdef CSyncTarget target + target = CSyncTarget() + target._db = self + handle_status("get_sync_target", + u1db__get_sync_target(target._db._db, &target._st)) + return target + + +cdef class VectorClockRev: + + cdef u1db_vectorclock *_clock + + def __init__(self, s): + if s is None: + self._clock = u1db__vectorclock_from_str(NULL) + else: + self._clock = u1db__vectorclock_from_str(s) + + def __dealloc__(self): + u1db__free_vectorclock(&self._clock) + + def __repr__(self): + cdef int status + cdef char *res + if self._clock == NULL: + return '%s(None)' % (self.__class__.__name__,) + status = u1db__vectorclock_as_str(self._clock, &res) + if status != U1DB_OK: + return '%s()' % (status,) + if res == NULL: + val = '%s(NULL)' % (self.__class__.__name__,) + else: + val = '%s(%s)' % (self.__class__.__name__, res) + free(res) + return val + + def as_dict(self): + cdef u1db_vectorclock *cur + cdef int i + cdef int gen + if self._clock == NULL: + return None + res = {} + for i from 0 <= i < self._clock.num_items: + gen = self._clock.items[i].generation + res[self._clock.items[i].replica_uid] = gen + return res + + def as_str(self): + cdef int status + cdef char *res + + status = u1db__vectorclock_as_str(self._clock, &res) + if status != U1DB_OK: + raise RuntimeError("Failed to VectorClockRev.as_str(): %d" % (status,)) + if res == NULL: + s = None + else: + s = res + free(res) + return s + + def increment(self, replica_uid): + cdef int status + + status = u1db__vectorclock_increment(self._clock, replica_uid) + if status != U1DB_OK: + raise RuntimeError("Failed to increment: %d" % (status,)) + + def maximize(self, vcr): + cdef int status + cdef VectorClockRev other + + other = vcr + status = u1db__vectorclock_maximize(self._clock, other._clock) + if status != U1DB_OK: + raise RuntimeError("Failed to maximize: %d" % (status,)) + + def is_newer(self, vcr): + cdef int is_newer + cdef VectorClockRev other + + other = vcr + is_newer = u1db__vectorclock_is_newer(self._clock, other._clock) + if is_newer == 0: + return False + elif is_newer == 1: + return True + else: + raise RuntimeError("Failed to is_newer: %d" % (is_newer,)) + + +def sync_db_to_target(db, target): + """Sync the data between a CDatabase and a CSyncTarget""" + cdef CDatabase cdb + cdef CSyncTarget ctarget + cdef int local_gen = 0, status + + cdb = db + ctarget = target + with nogil: + status = u1db__sync_db_to_target(cdb._db, ctarget._st, &local_gen) + handle_status("sync_db_to_target", status) + return local_gen + + +def create_http_sync_target(url): + cdef CSyncTarget target + + target = CSyncTarget() + handle_status("create_http_sync_target", + u1db__create_http_sync_target(url, &target._st)) + return target + + +def create_oauth_http_sync_target(url, consumer_key, consumer_secret, + token_key, token_secret): + cdef CSyncTarget target + + target = CSyncTarget() + handle_status("create_http_sync_target", + u1db__create_oauth_http_sync_target(url, consumer_key, consumer_secret, + token_key, token_secret, + &target._st)) + return target + + +def _format_sync_url(target, source_replica_uid): + cdef CSyncTarget st + cdef char *sync_url = NULL + cdef object res + st = target + handle_status("format_sync_url", + u1db__format_sync_url(st._st, source_replica_uid, &sync_url)) + if sync_url == NULL: + res = None + else: + res = sync_url + free(sync_url) + return res + + +def _get_oauth_authorization(target, method, url): + cdef CSyncTarget st + cdef char *auth = NULL + + st = target + handle_status("get_oauth_authorization", + u1db__get_oauth_authorization(st._st, method, url, &auth)) + res = None + if auth != NULL: + res = auth + free(auth) + return res diff --git a/src/leap/soledad/u1db/tests/commandline/__init__.py b/src/leap/soledad/u1db/tests/commandline/__init__.py new file mode 100644 index 00000000..007cecd3 --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +import errno +import time + + +def safe_close(process, timeout=0.1): + """Shutdown the process in the nicest fashion you can manage. + + :param process: A subprocess.Popen object. + :param timeout: We'll try to send 'SIGTERM' but if the process is alive + longer that 'timeout', we'll send SIGKILL. + """ + if process.poll() is not None: + return + try: + process.terminate() + except OSError, e: + if e.errno in (errno.ESRCH,): + # Process has exited + return + tend = time.time() + timeout + while time.time() < tend: + if process.poll() is not None: + return + time.sleep(0.01) + try: + process.kill() + except OSError, e: + if e.errno in (errno.ESRCH,): + # Process has exited + return + process.wait() diff --git a/src/leap/soledad/u1db/tests/commandline/test_client.py b/src/leap/soledad/u1db/tests/commandline/test_client.py new file mode 100644 index 00000000..78ca21eb --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/test_client.py @@ -0,0 +1,916 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +import cStringIO +import os +import sys +try: + import simplejson as json +except ImportError: + import json # noqa +import subprocess + +from u1db import ( + errors, + open as u1db_open, + tests, + vectorclock, + ) +from u1db.commandline import ( + client, + serve, + ) +from u1db.tests.commandline import safe_close +from u1db.tests import test_remote_sync_target + + +class TestArgs(tests.TestCase): + """These tests are meant to test just the argument parsing. + + Each Command should have at least one test, possibly more if it allows + optional arguments, etc. + """ + + def setUp(self): + super(TestArgs, self).setUp() + self.parser = client.client_commands.make_argparser() + + def parse_args(self, args): + # ArgumentParser.parse_args doesn't play very nicely with a test suite, + # so we trap SystemExit in case something is wrong with the args we're + # parsing. + try: + return self.parser.parse_args(args) + except SystemExit: + raise AssertionError('got SystemExit') + + def test_create(self): + args = self.parse_args(['create', 'test.db']) + self.assertEqual(client.CmdCreate, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual(None, args.doc_id) + self.assertEqual(None, args.infile) + + def test_create_custom_doc_id(self): + args = self.parse_args(['create', '--id', 'xyz', 'test.db']) + self.assertEqual(client.CmdCreate, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('xyz', args.doc_id) + self.assertEqual(None, args.infile) + + def test_delete(self): + args = self.parse_args(['delete', 'test.db', 'doc-id', 'doc-rev']) + self.assertEqual(client.CmdDelete, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual('doc-rev', args.doc_rev) + + def test_get(self): + args = self.parse_args(['get', 'test.db', 'doc-id']) + self.assertEqual(client.CmdGet, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual(None, args.outfile) + + def test_get_dash(self): + args = self.parse_args(['get', 'test.db', 'doc-id', '-']) + self.assertEqual(client.CmdGet, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual(sys.stdout, args.outfile) + + def test_init_db(self): + args = self.parse_args( + ['init-db', 'test.db', '--replica-uid=replica-uid']) + self.assertEqual(client.CmdInitDB, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('replica-uid', args.replica_uid) + + def test_init_db_no_replica(self): + args = self.parse_args(['init-db', 'test.db']) + self.assertEqual(client.CmdInitDB, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertIs(None, args.replica_uid) + + def test_put(self): + args = self.parse_args(['put', 'test.db', 'doc-id', 'old-doc-rev']) + self.assertEqual(client.CmdPut, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual('old-doc-rev', args.doc_rev) + self.assertEqual(None, args.infile) + + def test_sync(self): + args = self.parse_args(['sync', 'source', 'target']) + self.assertEqual(client.CmdSync, args.subcommand) + self.assertEqual('source', args.source) + self.assertEqual('target', args.target) + + def test_create_index(self): + args = self.parse_args(['create-index', 'db', 'index', 'expression']) + self.assertEqual(client.CmdCreateIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + self.assertEqual(['expression'], args.expression) + + def test_create_index_multi_expression(self): + args = self.parse_args(['create-index', 'db', 'index', 'e1', 'e2']) + self.assertEqual(client.CmdCreateIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + self.assertEqual(['e1', 'e2'], args.expression) + + def test_list_indexes(self): + args = self.parse_args(['list-indexes', 'db']) + self.assertEqual(client.CmdListIndexes, args.subcommand) + self.assertEqual('db', args.database) + + def test_delete_index(self): + args = self.parse_args(['delete-index', 'db', 'index']) + self.assertEqual(client.CmdDeleteIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + + def test_get_index_keys(self): + args = self.parse_args(['get-index-keys', 'db', 'index']) + self.assertEqual(client.CmdGetIndexKeys, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + + def test_get_from_index(self): + args = self.parse_args(['get-from-index', 'db', 'index', 'foo']) + self.assertEqual(client.CmdGetFromIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + self.assertEqual(['foo'], args.values) + + def test_get_doc_conflicts(self): + args = self.parse_args(['get-doc-conflicts', 'db', 'doc-id']) + self.assertEqual(client.CmdGetDocConflicts, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('doc-id', args.doc_id) + + def test_resolve(self): + args = self.parse_args( + ['resolve-doc', 'db', 'doc-id', 'rev:1', 'other:1']) + self.assertEqual(client.CmdResolve, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual(['rev:1', 'other:1'], args.doc_revs) + self.assertEqual(None, args.infile) + + +class TestCaseWithDB(tests.TestCase): + """These next tests are meant to have one class per Command. + + It is meant to test the inner workings of each command. The detailed + testing should happen in these classes. Stuff like how it handles errors, + etc. should be done here. + """ + + def setUp(self): + super(TestCaseWithDB, self).setUp() + self.working_dir = self.createTempDir() + self.db_path = self.working_dir + '/test.db' + self.db = u1db_open(self.db_path, create=True) + self.db._set_replica_uid('test') + self.addCleanup(self.db.close) + + def make_command(self, cls, stdin_content=''): + inf = cStringIO.StringIO(stdin_content) + out = cStringIO.StringIO() + err = cStringIO.StringIO() + return cls(inf, out, err) + + +class TestCmdCreate(TestCaseWithDB): + + def test_create(self): + cmd = self.make_command(client.CmdCreate) + inf = cStringIO.StringIO(tests.simple_doc) + cmd.run(self.db_path, inf, 'test-id') + doc = self.db.get_doc('test-id') + self.assertEqual(tests.simple_doc, doc.get_json()) + self.assertFalse(doc.has_conflicts) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('id: test-id\nrev: %s\n' % (doc.rev,), + cmd.stderr.getvalue()) + + +class TestCmdDelete(TestCaseWithDB): + + def test_delete(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + cmd = self.make_command(client.CmdDelete) + cmd.run(self.db_path, doc.doc_id, doc.rev) + doc2 = self.db.get_doc(doc.doc_id, include_deleted=True) + self.assertEqual(doc.doc_id, doc2.doc_id) + self.assertNotEqual(doc.rev, doc2.rev) + self.assertIs(None, doc2.get_json()) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (doc2.rev,), cmd.stderr.getvalue()) + + def test_delete_fails_if_nonexistent(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + db2_path = self.db_path + '.typo' + cmd = self.make_command(client.CmdDelete) + # TODO: We should really not be showing a traceback here. But we need + # to teach the commandline infrastructure how to handle + # exceptions. + # However, we *do* want to test that the db doesn't get created + # by accident. + self.assertRaises(errors.DatabaseDoesNotExist, + cmd.run, db2_path, doc.doc_id, doc.rev) + self.assertFalse(os.path.exists(db2_path)) + + def test_delete_no_such_doc(self): + cmd = self.make_command(client.CmdDelete) + # TODO: We should really not be showing a traceback here. But we need + # to teach the commandline infrastructure how to handle + # exceptions. + self.assertRaises(errors.DocumentDoesNotExist, + cmd.run, self.db_path, 'no-doc-id', 'no-rev') + + def test_delete_bad_rev(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + cmd = self.make_command(client.CmdDelete) + self.assertRaises(errors.RevisionConflict, + cmd.run, self.db_path, doc.doc_id, 'not-the-actual-doc-rev:1') + # TODO: Test that we get a pretty output. + + +class TestCmdGet(TestCaseWithDB): + + def setUp(self): + super(TestCmdGet, self).setUp() + self.doc = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-test-doc') + + def test_get_simple(self): + cmd = self.make_command(client.CmdGet) + cmd.run(self.db_path, 'my-test-doc', None) + self.assertEqual(tests.simple_doc + "\n", cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (self.doc.rev,), + cmd.stderr.getvalue()) + + def test_get_conflict(self): + doc = self.make_document('my-test-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + cmd = self.make_command(client.CmdGet) + cmd.run(self.db_path, 'my-test-doc', None) + self.assertEqual('{}\n', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\nDocument has conflicts.\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_get_fail(self): + cmd = self.make_command(client.CmdGet) + result = cmd.run(self.db_path, 'doc-not-there', None) + self.assertEqual(1, result) + self.assertEqual("", cmd.stdout.getvalue()) + self.assertTrue("not found" in cmd.stderr.getvalue()) + + def test_get_no_database(self): + cmd = self.make_command(client.CmdGet) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", None) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + +class TestCmdGetDocConflicts(TestCaseWithDB): + + def setUp(self): + super(TestCmdGetDocConflicts, self).setUp() + self.doc1 = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-doc') + self.doc2 = self.make_document('my-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + self.doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + + def test_get_doc_conflicts_none(self): + self.db.create_doc_from_json(tests.simple_doc, doc_id='a-doc') + cmd = self.make_command(client.CmdGetDocConflicts) + cmd.run(self.db_path, 'a-doc') + self.assertEqual([], json.loads(cmd.stdout.getvalue())) + self.assertEqual('', cmd.stderr.getvalue()) + + def test_get_doc_conflicts_simple(self): + cmd = self.make_command(client.CmdGetDocConflicts) + cmd.run(self.db_path, 'my-doc') + self.assertEqual( + [dict(rev=self.doc2.rev, content=self.doc2.content), + dict(rev=self.doc1.rev, content=self.doc1.content)], + json.loads(cmd.stdout.getvalue())) + self.assertEqual('', cmd.stderr.getvalue()) + + def test_get_doc_conflicts_no_db(self): + cmd = self.make_command(client.CmdGetDocConflicts) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_get_doc_conflicts_no_doc(self): + cmd = self.make_command(client.CmdGetDocConflicts) + retval = cmd.run(self.db_path, "some-doc") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n') + + +class TestCmdInit(TestCaseWithDB): + + def test_init_new(self): + path = self.working_dir + '/test2.db' + self.assertFalse(os.path.exists(path)) + cmd = self.make_command(client.CmdInitDB) + cmd.run(path, 'test-uid') + self.assertTrue(os.path.exists(path)) + db = u1db_open(path, create=False) + self.assertEqual('test-uid', db._replica_uid) + + def test_init_no_uid(self): + path = self.working_dir + '/test2.db' + cmd = self.make_command(client.CmdInitDB) + cmd.run(path, None) + self.assertTrue(os.path.exists(path)) + db = u1db_open(path, create=False) + self.assertIsNot(None, db._replica_uid) + + +class TestCmdPut(TestCaseWithDB): + + def setUp(self): + super(TestCmdPut, self).setUp() + self.doc = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-test-doc') + + def test_put_simple(self): + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + cmd.run(self.db_path, 'my-test-doc', self.doc.rev, inf) + doc = self.db.get_doc('my-test-doc') + self.assertNotEqual(self.doc.rev, doc.rev) + self.assertGetDoc(self.db, 'my-test-doc', doc.rev, + tests.nested_doc, False) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_put_no_db(self): + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", + 'my-test-doc', self.doc.rev, inf) + self.assertEqual(retval, 1) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Database does not exist.\n', cmd.stderr.getvalue()) + + def test_put_no_doc(self): + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path, 'no-such-doc', 'wut:1', inf) + self.assertEqual(1, retval) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Document does not exist.\n', cmd.stderr.getvalue()) + + def test_put_doc_old_rev(self): + rev = self.doc.rev + doc = self.make_document('my-test-doc', rev, '{}', False) + self.db.put_doc(doc) + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path, 'my-test-doc', rev, inf) + self.assertEqual(1, retval) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Given revision is not current.\n', + cmd.stderr.getvalue()) + + def test_put_doc_w_conflicts(self): + doc = self.make_document('my-test-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path, 'my-test-doc', 'other:1', inf) + self.assertEqual(1, retval) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Document has conflicts.\n' + 'Inspect with get-doc-conflicts, then resolve.\n', + cmd.stderr.getvalue()) + + +class TestCmdResolve(TestCaseWithDB): + + def setUp(self): + super(TestCmdResolve, self).setUp() + self.doc1 = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-doc') + self.doc2 = self.make_document('my-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + self.doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + + def test_resolve_simple(self): + self.assertTrue(self.db.get_doc('my-doc').has_conflicts) + cmd = self.make_command(client.CmdResolve) + inf = cStringIO.StringIO(tests.nested_doc) + cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf) + doc = self.db.get_doc('my-doc') + vec = vectorclock.VectorClockRev(doc.rev) + self.assertTrue( + vec.is_newer(vectorclock.VectorClockRev(self.doc1.rev))) + self.assertTrue( + vec.is_newer(vectorclock.VectorClockRev(self.doc2.rev))) + self.assertGetDoc(self.db, 'my-doc', doc.rev, tests.nested_doc, False) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_resolve_double(self): + moar = '{"x": 42}' + doc3 = self.make_document('my-doc', 'third:1', moar, False) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + cmd = self.make_command(client.CmdResolve) + inf = cStringIO.StringIO(tests.nested_doc) + cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf) + doc = self.db.get_doc('my-doc') + self.assertGetDoc(self.db, 'my-doc', doc.rev, moar, True) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual( + 'rev: %s\nDocument still has conflicts.\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_resolve_no_db(self): + cmd = self.make_command(client.CmdResolve) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", [], None) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_resolve_no_doc(self): + cmd = self.make_command(client.CmdResolve) + retval = cmd.run(self.db_path, "foo", [], None) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n') + + +class TestCmdSync(TestCaseWithDB): + + def setUp(self): + super(TestCmdSync, self).setUp() + self.db2_path = self.working_dir + '/test2.db' + self.db2 = u1db_open(self.db2_path, create=True) + self.addCleanup(self.db2.close) + self.db2._set_replica_uid('test2') + self.doc = self.db.create_doc_from_json( + tests.simple_doc, doc_id='test-id') + self.doc2 = self.db2.create_doc_from_json( + tests.nested_doc, doc_id='my-test-id') + + def test_sync(self): + cmd = self.make_command(client.CmdSync) + cmd.run(self.db_path, self.db2_path) + self.assertGetDoc(self.db2, 'test-id', self.doc.rev, tests.simple_doc, + False) + self.assertGetDoc(self.db, 'my-test-id', self.doc2.rev, + tests.nested_doc, False) + + +class TestCmdSyncRemote(tests.TestCaseWithServer, TestCaseWithDB): + + make_app_with_state = \ + staticmethod(test_remote_sync_target.make_http_app) + + def setUp(self): + super(TestCmdSyncRemote, self).setUp() + self.startServer() + self.db2 = self.request_state._create_database('test2.db') + + def test_sync_remote(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db2.create_doc_from_json(tests.nested_doc) + db2_url = self.getURL('test2.db') + self.assertTrue(db2_url.startswith('http://')) + self.assertTrue(db2_url.endswith('/test2.db')) + cmd = self.make_command(client.CmdSync) + cmd.run(self.db_path, db2_url) + self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc, + False) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, + False) + + +class TestCmdCreateIndex(TestCaseWithDB): + + def test_create_index(self): + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["bar", "baz"]) + self.assertEqual(self.db.list_indexes(), [('foo', ['bar', "baz"])]) + self.assertEqual(retval, None) # conveniently mapped to 0 + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_create_index_no_db(self): + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", ["bar"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_create_dupe_index(self): + self.db.create_index("foo", "bar") + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["bar"]) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_create_dupe_index_different_expression(self): + self.db.create_index("foo", "bar") + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["baz"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), + "There is already a different index named 'foo'.\n") + + def test_create_index_bad_expression(self): + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["WAT()"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), + 'Bad index expression.\n') + + +class TestCmdListIndexes(TestCaseWithDB): + + def test_list_no_indexes(self): + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_list_indexes(self): + self.db.create_index("foo", "bar", "baz") + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), 'foo: bar, baz\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_list_several_indexes(self): + self.db.create_index("foo", "bar", "baz") + self.db.create_index("bar", "baz", "foo") + self.db.create_index("baz", "foo", "bar") + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), + 'bar: baz, foo\n' + 'baz: foo, bar\n' + 'foo: bar, baz\n' + ) + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_list_indexes_no_db(self): + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + +class TestCmdDeleteIndex(TestCaseWithDB): + + def test_delete_index(self): + self.db.create_index("foo", "bar", "baz") + cmd = self.make_command(client.CmdDeleteIndex) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + self.assertEqual([], self.db.list_indexes()) + + def test_delete_index_no_db(self): + cmd = self.make_command(client.CmdDeleteIndex) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_delete_index_no_index(self): + cmd = self.make_command(client.CmdDeleteIndex) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + +class TestCmdGetIndexKeys(TestCaseWithDB): + + def test_get_index_keys(self): + self.db.create_index("foo", "bar") + self.db.create_doc_from_json('{"bar": 42}') + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '42\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_index_keys_nonascii(self): + self.db.create_index("foo", "bar") + self.db.create_doc_from_json('{"bar": "\u00a4"}') + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '\xc2\xa4\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_index_keys_empty(self): + self.db.create_index("foo", "bar") + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_index_keys_no_db(self): + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_get_index_keys_no_index(self): + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n') + + +class TestCmdGetFromIndex(TestCaseWithDB): + + def test_get_from_index(self): + self.db.create_index("index", "key") + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db.create_doc_from_json(tests.nested_doc) + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path, "index", ["value"]) + self.assertEqual(retval, None) + self.assertEqual(sorted(json.loads(cmd.stdout.getvalue())), + sorted([dict(id=doc1.doc_id, + rev=doc1.rev, + content=doc1.content), + dict(id=doc2.doc_id, + rev=doc2.rev, + content=doc2.content), + ])) + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_from_index_empty(self): + self.db.create_index("index", "key") + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path, "index", ["value"]) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '[]\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_from_index_no_db(self): + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", []) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_get_from_index_no_index(self): + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path, "foo", []) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n') + + def test_get_from_index_two_expr_instead_of_one(self): + self.db.create_index("index", "key1") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1", "value2"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query: index 'index' requires" + " 1 query expression, not 2.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1'\n" + % self.db_path, cmd.stderr.getvalue()) + + def test_get_from_index_three_expr_instead_of_two(self): + self.db.create_index("index", "key1", "key2") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1", "value2", "value3"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query: index 'index' requires" + " 2 query expressions, not 3.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1' 'value2'\n" + % self.db_path, cmd.stderr.getvalue()) + + def test_get_from_index_one_expr_instead_of_two(self): + self.db.create_index("index", "key1", "key2") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query: index 'index' requires" + " 2 query expressions, not 1.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1' '*'\n" + % self.db_path, cmd.stderr.getvalue()) + + def test_get_from_index_cant_bad_glob(self): + self.db.create_index("index", "key1", "key2") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1*", "value2"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query:" + " a star can only be followed by stars.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1*' '*'\n" + % self.db_path, cmd.stderr.getvalue()) + + +class RunMainHelper(object): + + def run_main(self, args, stdin=None): + if stdin is not None: + self.patch(sys, 'stdin', cStringIO.StringIO(stdin)) + stdout = cStringIO.StringIO() + stderr = cStringIO.StringIO() + self.patch(sys, 'stdout', stdout) + self.patch(sys, 'stderr', stderr) + try: + ret = client.main(args) + except SystemExit, e: + self.fail("Intercepted SystemExit: %s" % (e,)) + if ret is None: + ret = 0 + return ret, stdout.getvalue(), stderr.getvalue() + + +class TestCommandLine(TestCaseWithDB, RunMainHelper): + """These are meant to test that the infrastructure is fully connected. + + Each command is likely to only have one test here. Something that ensures + 'main()' knows about and can run the command correctly. Most logic-level + testing of the Command should go into its own test class above. + """ + + def _get_u1db_client_path(self): + from u1db import __path__ as u1db_path + u1db_parent_dir = os.path.dirname(u1db_path[0]) + return os.path.join(u1db_parent_dir, 'u1db-client') + + def runU1DBClient(self, args): + command = [sys.executable, self._get_u1db_client_path()] + command.extend(args) + p = subprocess.Popen(command, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.addCleanup(safe_close, p) + return p + + def test_create_subprocess(self): + p = self.runU1DBClient(['create', '--id', 'test-id', self.db_path]) + stdout, stderr = p.communicate(tests.simple_doc) + self.assertEqual(0, p.returncode) + self.assertEqual('', stdout) + doc = self.db.get_doc('test-id') + self.assertEqual(tests.simple_doc, doc.get_json()) + self.assertFalse(doc.has_conflicts) + expected = 'id: test-id\nrev: %s\n' % (doc.rev,) + stripped = stderr.replace('\r\n', '\n') + if expected != stripped: + # When run under python-dbg, it prints out the refs after the + # actual content, so match it if we need to. + expected_re = expected + '\[\d+ refs\]\n' + self.assertRegexpMatches(stripped, expected_re) + + def test_get(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + ret, stdout, stderr = self.run_main(['get', self.db_path, 'test-id']) + self.assertEqual(0, ret) + self.assertEqual(tests.simple_doc + "\n", stdout) + self.assertEqual('rev: %s\n' % (doc.rev,), stderr) + ret, stdout, stderr = self.run_main(['get', self.db_path, 'not-there']) + self.assertEqual(1, ret) + + def test_delete(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + ret, stdout, stderr = self.run_main( + ['delete', self.db_path, 'test-id', doc.rev]) + doc = self.db.get_doc('test-id', include_deleted=True) + self.assertEqual(0, ret) + self.assertEqual('', stdout) + self.assertEqual('rev: %s\n' % (doc.rev,), stderr) + + def test_init_db(self): + path = self.working_dir + '/test2.db' + ret, stdout, stderr = self.run_main(['init-db', path]) + u1db_open(path, create=False) + + def test_put(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + ret, stdout, stderr = self.run_main( + ['put', self.db_path, 'test-id', doc.rev], + stdin=tests.nested_doc) + doc = self.db.get_doc('test-id') + self.assertFalse(doc.has_conflicts) + self.assertEqual(tests.nested_doc, doc.get_json()) + self.assertEqual(0, ret) + self.assertEqual('', stdout) + self.assertEqual('rev: %s\n' % (doc.rev,), stderr) + + def test_sync(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + self.db2_path = self.working_dir + '/test2.db' + self.db2 = u1db_open(self.db2_path, create=True) + self.addCleanup(self.db2.close) + ret, stdout, stderr = self.run_main( + ['sync', self.db_path, self.db2_path]) + self.assertEqual(0, ret) + self.assertEqual('', stdout) + self.assertEqual('', stderr) + self.assertGetDoc( + self.db2, 'test-id', doc.rev, tests.simple_doc, False) + + +class TestHTTPIntegration(tests.TestCaseWithServer, RunMainHelper): + """Meant to test the cases where commands operate over http.""" + + def server_def(self): + def make_server(host_port, _application): + return serve.make_server(host_port[0], host_port[1], + self.working_dir) + return make_server, "shutdown", "http" + + def setUp(self): + super(TestHTTPIntegration, self).setUp() + self.working_dir = self.createTempDir(prefix='u1db-http-server-') + self.startServer() + + def getPath(self, dbname): + return os.path.join(self.working_dir, dbname) + + def test_init_db(self): + url = self.getURL('new.db') + ret, stdout, stderr = self.run_main(['init-db', url]) + u1db_open(self.getPath('new.db'), create=False) + + def test_create_get_put_delete(self): + db = u1db_open(self.getPath('test.db'), create=True) + url = self.getURL('test.db') + doc_id = '%abcd' + ret, stdout, stderr = self.run_main(['create', url, '--id', doc_id], + stdin=tests.simple_doc) + self.assertEqual(0, ret) + ret, stdout, stderr = self.run_main(['get', url, doc_id]) + self.assertEqual(0, ret) + self.assertTrue(stderr.startswith('rev: ')) + doc_rev = stderr[len('rev: '):].rstrip() + ret, stdout, stderr = self.run_main(['put', url, doc_id, doc_rev], + stdin=tests.nested_doc) + self.assertEqual(0, ret) + self.assertTrue(stderr.startswith('rev: ')) + doc_rev1 = stderr[len('rev: '):].rstrip() + self.assertGetDoc(db, doc_id, doc_rev1, tests.nested_doc, False) + ret, stdout, stderr = self.run_main(['delete', url, doc_id, doc_rev1]) + self.assertEqual(0, ret) + self.assertTrue(stderr.startswith('rev: ')) + doc_rev2 = stderr[len('rev: '):].rstrip() + self.assertGetDocIncludeDeleted(db, doc_id, doc_rev2, None, False) diff --git a/src/leap/soledad/u1db/tests/commandline/test_command.py b/src/leap/soledad/u1db/tests/commandline/test_command.py new file mode 100644 index 00000000..43580f23 --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/test_command.py @@ -0,0 +1,105 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +import cStringIO +import argparse + +from u1db import ( + tests, + ) +from u1db.commandline import ( + command, + ) + + +class MyTestCommand(command.Command): + """Help String""" + + name = 'mycmd' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('foo') + parser.add_argument('--bar', dest='nbar', type=int) + + def run(self, foo, nbar): + self.stdout.write('foo: %s nbar: %d' % (foo, nbar)) + return 0 + + +def make_stdin_out_err(): + return cStringIO.StringIO(), cStringIO.StringIO(), cStringIO.StringIO() + + +class TestCommandGroup(tests.TestCase): + + def trap_system_exit(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except SystemExit, e: + self.fail('Got SystemExit trying to run: %s' % (func,)) + + def parse_args(self, parser, args): + return self.trap_system_exit(parser.parse_args, args) + + def test_register(self): + group = command.CommandGroup() + self.assertEqual({}, group.commands) + group.register(MyTestCommand) + self.assertEqual({'mycmd': MyTestCommand}, + group.commands) + + def test_make_argparser(self): + group = command.CommandGroup(description='test-foo') + parser = group.make_argparser() + self.assertIsInstance(parser, argparse.ArgumentParser) + + def test_make_argparser_with_command(self): + group = command.CommandGroup(description='test-foo') + group.register(MyTestCommand) + parser = group.make_argparser() + args = self.parse_args(parser, ['mycmd', 'foozizle', '--bar=10']) + self.assertEqual('foozizle', args.foo) + self.assertEqual(10, args.nbar) + self.assertEqual(MyTestCommand, args.subcommand) + + def test_run_argv(self): + group = command.CommandGroup() + group.register(MyTestCommand) + stdin, stdout, stderr = make_stdin_out_err() + ret = self.trap_system_exit(group.run_argv, + ['mycmd', 'foozizle', '--bar=10'], + stdin, stdout, stderr) + self.assertEqual(0, ret) + + +class TestCommand(tests.TestCase): + + def make_command(self): + stdin, stdout, stderr = make_stdin_out_err() + return command.Command(stdin, stdout, stderr) + + def test__init__(self): + cmd = self.make_command() + self.assertIsNot(None, cmd.stdin) + self.assertIsNot(None, cmd.stdout) + self.assertIsNot(None, cmd.stderr) + + def test_run_args(self): + stdin, stdout, stderr = make_stdin_out_err() + cmd = MyTestCommand(stdin, stdout, stderr) + res = cmd.run(foo='foozizle', nbar=10) + self.assertEqual('foo: foozizle nbar: 10', stdout.getvalue()) diff --git a/src/leap/soledad/u1db/tests/commandline/test_serve.py b/src/leap/soledad/u1db/tests/commandline/test_serve.py new file mode 100644 index 00000000..6397eabe --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/test_serve.py @@ -0,0 +1,101 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +import os +import socket +import subprocess +import sys + +from u1db import ( + __version__ as _u1db_version, + open as u1db_open, + tests, + ) +from u1db.remote import http_client +from u1db.tests.commandline import safe_close + + +class TestU1DBServe(tests.TestCase): + + def _get_u1db_serve_path(self): + from u1db import __path__ as u1db_path + u1db_parent_dir = os.path.dirname(u1db_path[0]) + return os.path.join(u1db_parent_dir, 'u1db-serve') + + def startU1DBServe(self, args): + command = [sys.executable, self._get_u1db_serve_path()] + command.extend(args) + p = subprocess.Popen(command, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.addCleanup(safe_close, p) + return p + + def test_help(self): + p = self.startU1DBServe(['--help']) + stdout, stderr = p.communicate() + if stderr != '': + # stderr should normally be empty, but if we are running under + # python-dbg, it contains the following string + self.assertRegexpMatches(stderr, r'\[\d+ refs\]') + self.assertEqual(0, p.returncode) + self.assertIn('Run the U1DB server', stdout) + + def test_bind_to_port(self): + p = self.startU1DBServe([]) + starts = 'listening on:' + x = p.stdout.readline() + self.assertTrue(x.startswith(starts)) + port = int(x[len(starts):].split(":")[1]) + url = "http://127.0.0.1:%s/" % port + c = http_client.HTTPClientBase(url) + self.addCleanup(c.close) + res, _ = c._request_json('GET', []) + self.assertEqual({'version': _u1db_version}, res) + + def test_supply_port(self): + s = socket.socket() + s.bind(('127.0.0.1', 0)) + host, port = s.getsockname() + s.close() + p = self.startU1DBServe(['--port', str(port)]) + x = p.stdout.readline().strip() + self.assertEqual('listening on: 127.0.0.1:%s' % (port,), x) + url = "http://127.0.0.1:%s/" % port + c = http_client.HTTPClientBase(url) + self.addCleanup(c.close) + res, _ = c._request_json('GET', []) + self.assertEqual({'version': _u1db_version}, res) + + def test_bind_to_host(self): + p = self.startU1DBServe(["--host", "localhost"]) + starts = 'listening on: 127.0.0.1:' + x = p.stdout.readline() + self.assertTrue(x.startswith(starts)) + + def test_supply_working_dir(self): + tmp_dir = self.createTempDir('u1db-serve-test') + db = u1db_open(os.path.join(tmp_dir, 'landmark.db'), create=True) + db.close() + p = self.startU1DBServe(['--working-dir', tmp_dir]) + starts = 'listening on:' + x = p.stdout.readline() + self.assertTrue(x.startswith(starts)) + port = int(x[len(starts):].split(":")[1]) + url = "http://127.0.0.1:%s/landmark.db" % port + c = http_client.HTTPClientBase(url) + self.addCleanup(c.close) + res, _ = c._request_json('GET', []) + self.assertEqual({}, res) diff --git a/src/leap/soledad/u1db/tests/test_auth_middleware.py b/src/leap/soledad/u1db/tests/test_auth_middleware.py new file mode 100644 index 00000000..e765f8a7 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_auth_middleware.py @@ -0,0 +1,309 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test OAuth wsgi middleware""" +import paste.fixture +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa +import time + +from u1db import tests + +from u1db.remote.oauth_middleware import OAuthMiddleware +from u1db.remote.basic_auth_middleware import BasicAuthMiddleware, Unauthorized + + +BASE_URL = 'https://example.net' + + +class TestBasicAuthMiddleware(tests.TestCase): + + def setUp(self): + super(TestBasicAuthMiddleware, self).setUp() + self.got = [] + + def witness_app(environ, start_response): + start_response("200 OK", [("content-type", "text/plain")]) + self.got.append(( + environ['user_id'], environ['PATH_INFO'], + environ['QUERY_STRING'])) + return ["ok"] + + class MyAuthMiddleware(BasicAuthMiddleware): + + def verify_user(self, environ, user, password): + if user != "correct_user": + raise Unauthorized + if password != "correct_password": + raise Unauthorized + environ['user_id'] = user + + self.auth_midw = MyAuthMiddleware(witness_app, prefix="/pfx/") + self.app = paste.fixture.TestApp(self.auth_midw) + + def test_expect_prefix(self): + url = BASE_URL + '/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"error": "bad request"}', resp.body) + + def test_missing_auth(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "unauthorized", + "message": "Missing Basic Authentication."}, + json.loads(resp.body)) + + def test_correct_auth(self): + user = "correct_user" + password = "correct_password" + params = {'old_rev': 'old-rev'} + url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + auth = '%s:%s' % (user, password) + headers = { + 'Authorization': 'Basic %s' % (auth.encode('base64'),)} + resp = self.app.delete(url, headers=headers) + self.assertEqual(200, resp.status) + self.assertEqual( + [('correct_user', '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_incorrect_auth(self): + user = "correct_user" + password = "incorrect_password" + params = {'old_rev': 'old-rev'} + url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + auth = '%s:%s' % (user, password) + headers = { + 'Authorization': 'Basic %s' % (auth.encode('base64'),)} + resp = self.app.delete(url, headers=headers, expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "unauthorized", + "message": "Incorrect password or login."}, + json.loads(resp.body)) + + +class TestOAuthMiddlewareDefaultPrefix(tests.TestCase): + def setUp(self): + + super(TestOAuthMiddlewareDefaultPrefix, self).setUp() + self.got = [] + + def witness_app(environ, start_response): + start_response("200 OK", [("content-type", "text/plain")]) + self.got.append((environ['token_key'], environ['PATH_INFO'], + environ['QUERY_STRING'])) + return ["ok"] + + class MyOAuthMiddleware(OAuthMiddleware): + get_oauth_data_store = lambda self: tests.testingOAuthStore + + def verify(self, environ, oauth_req): + consumer, token = super(MyOAuthMiddleware, self).verify( + environ, oauth_req) + environ['token_key'] = token.key + + self.oauth_midw = MyOAuthMiddleware(witness_app, BASE_URL) + self.app = paste.fixture.TestApp(self.oauth_midw) + + def test_expect_tilde(self): + url = BASE_URL + '/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"error": "bad request"}', resp.body) + + def test_oauth_in_header(self): + url = BASE_URL + '/~/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer2, + tests.token2, + parameters=params, + http_url=url, + http_method='DELETE' + ) + url = oauth_req.get_normalized_http_url() + '?' + ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer2, tests.token2) + resp = self.app.delete(url, headers=oauth_req.to_header()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token2.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_in_query_string(self): + url = BASE_URL + '/~/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer1, tests.token1) + resp = self.app.delete(oauth_req.to_url()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token1.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + +class TestOAuthMiddleware(tests.TestCase): + + def setUp(self): + super(TestOAuthMiddleware, self).setUp() + self.got = [] + + def witness_app(environ, start_response): + start_response("200 OK", [("content-type", "text/plain")]) + self.got.append((environ['token_key'], environ['PATH_INFO'], + environ['QUERY_STRING'])) + return ["ok"] + + class MyOAuthMiddleware(OAuthMiddleware): + get_oauth_data_store = lambda self: tests.testingOAuthStore + + def verify(self, environ, oauth_req): + consumer, token = super(MyOAuthMiddleware, self).verify( + environ, oauth_req) + environ['token_key'] = token.key + + self.oauth_midw = MyOAuthMiddleware( + witness_app, BASE_URL, prefix='/pfx/') + self.app = paste.fixture.TestApp(self.oauth_midw) + + def test_expect_prefix(self): + url = BASE_URL + '/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"error": "bad request"}', resp.body) + + def test_missing_oauth(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "unauthorized", "message": "Missing OAuth."}, + json.loads(resp.body)) + + def test_oauth_in_query_string(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer1, tests.token1) + resp = self.app.delete(oauth_req.to_url()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token1.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_invalid(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token3, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer1, tests.token3) + resp = self.app.delete(oauth_req.to_url(), + expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + err = json.loads(resp.body) + self.assertEqual({"error": "unauthorized", + "message": err['message']}, + err) + + def test_oauth_in_header(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer2, + tests.token2, + parameters=params, + http_url=url, + http_method='DELETE' + ) + url = oauth_req.get_normalized_http_url() + '?' + ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer2, tests.token2) + resp = self.app.delete(url, headers=oauth_req.to_header()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token2.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_plain_text(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_PLAINTEXT, + tests.consumer1, tests.token1) + resp = self.app.delete(oauth_req.to_url()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token1.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_timestamp_threshold(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.set_parameter('oauth_timestamp', int(time.time()) - 5) + oauth_req.sign_request(tests.sign_meth_PLAINTEXT, + tests.consumer1, tests.token1) + # tweak threshold + self.oauth_midw.timestamp_threshold = 1 + resp = self.app.delete(oauth_req.to_url(), expect_errors=True) + self.assertEqual(401, resp.status) + err = json.loads(resp.body) + self.assertIn('Expired timestamp', err['message']) + self.assertIn('threshold 1', err['message']) diff --git a/src/leap/soledad/u1db/tests/test_backends.py b/src/leap/soledad/u1db/tests/test_backends.py new file mode 100644 index 00000000..7a3c9e5c --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_backends.py @@ -0,0 +1,1895 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""The backend class for U1DB. This deals with hiding storage details.""" + +try: + import simplejson as json +except ImportError: + import json # noqa +from u1db import ( + DocumentBase, + errors, + tests, + vectorclock, + ) + +simple_doc = tests.simple_doc +nested_doc = tests.nested_doc + +from u1db.tests.test_remote_sync_target import ( + make_http_app, + make_oauth_http_app, +) + +from u1db.remote import ( + http_database, + ) + +try: + from u1db.tests import c_backend_wrapper +except ImportError: + c_backend_wrapper = None # noqa + + +def make_http_database_for_test(test, replica_uid, path='test'): + test.startServer() + test.request_state._create_database(replica_uid) + return http_database.HTTPDatabase(test.getURL(path)) + + +def copy_http_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + return test.request_state._copy_database(db) + + +def make_oauth_http_database_for_test(test, replica_uid): + http_db = make_http_database_for_test(test, replica_uid, '~/test') + http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return http_db + + +def copy_oauth_http_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + http_db = test.request_state._copy_database(db) + http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return http_db + + +class TestAlternativeDocument(DocumentBase): + """A (not very) alternative implementation of Document.""" + + +class AllDatabaseTests(tests.DatabaseBaseTests, tests.TestCaseWithServer): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + [ + ('http', {'make_database_for_test': make_http_database_for_test, + 'copy_database_for_test': copy_http_database_for_test, + 'make_document_for_test': tests.make_document_for_test, + 'make_app_with_state': make_http_app}), + ('oauth_http', {'make_database_for_test': + make_oauth_http_database_for_test, + 'copy_database_for_test': + copy_oauth_http_database_for_test, + 'make_document_for_test': tests.make_document_for_test, + 'make_app_with_state': make_oauth_http_app}) + ] + tests.C_DATABASE_SCENARIOS + + def test_close(self): + self.db.close() + + def test_create_doc_allocating_doc_id(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertNotEqual(None, doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_different_ids_same_db(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertNotEqual(doc1.doc_id, doc2.doc_id) + + def test_create_doc_with_id(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id') + self.assertEqual('my-id', doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_existing_id(self): + doc = self.db.create_doc_from_json(simple_doc) + new_content = '{"something": "else"}' + self.assertRaises( + errors.RevisionConflict, self.db.create_doc_from_json, + new_content, doc.doc_id) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_put_doc_creating_initial(self): + doc = self.make_document('my_doc_id', None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertIsNot(None, new_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False) + + def test_put_doc_space_in_id(self): + doc = self.make_document('my doc id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_update(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + orig_rev = doc.rev + doc.set_json('{"updated": "stuff"}') + new_rev = self.db.put_doc(doc) + self.assertNotEqual(new_rev, orig_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, + '{"updated": "stuff"}', False) + self.assertEqual(doc.rev, new_rev) + + def test_put_non_ascii_key(self): + content = json.dumps({u'key\xe5': u'val'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_non_ascii_value(self): + content = json.dumps({'key': u'\xe5'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_doc_refuses_no_id(self): + doc = self.make_document(None, None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document("", None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_refuses_slashes(self): + doc = self.make_document('a/b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document(r'\b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_url_quoting_is_fine(self): + doc_id = "%2F%2Ffoo%2Fbar" + doc = self.make_document(doc_id, None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False) + + def test_put_doc_refuses_non_existing_old_rev(self): + doc = self.make_document('doc-id', 'test:4', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc) + + def test_put_doc_refuses_non_ascii_doc_id(self): + doc = self.make_document('d\xc3\xa5c-id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_fails_with_bad_old_rev(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + old_rev = doc.rev + bad_doc = self.make_document(doc.doc_id, 'other:1', + '{"something": "else"}') + self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc) + self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False) + + def test_create_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(new_doc.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev)) + + def test_put_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + doc2 = self.make_document('my_doc_id', None, simple_doc) + self.db.put_doc(doc2) + self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(doc2.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (doc2.rev, deleted_doc.rev)) + + def test_get_doc_after_put(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False) + + def test_get_doc_nonexisting(self): + self.assertIs(None, self.db.get_doc('non-existing')) + + def test_get_doc_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertIs(None, self.db.get_doc('my_doc_id')) + + def test_get_doc_include_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_get_docs(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual([doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_include_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual( + [doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id], + include_deleted=True))) + + def test_get_docs_request_ordered(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + self.assertEqual([doc2, doc1], + list(self.db.get_docs([doc2.doc_id, doc1.doc_id]))) + + def test_get_docs_empty_list(self): + self.assertEqual([], list(self.db.get_docs([]))) + + def test_handles_nested_content(self): + doc = self.db.create_doc_from_json(nested_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + def test_handles_doc_with_null(self): + doc = self.db.create_doc_from_json('{"key": null}') + self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False) + + def test_delete_doc(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + orig_rev = doc.rev + self.db.delete_doc(doc) + self.assertNotEqual(orig_rev, doc.rev) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + self.assertIs(None, self.db.get_doc(doc.doc_id)) + + def test_delete_doc_non_existent(self): + doc = self.make_document('non-existing', 'other:1', simple_doc) + self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc) + + def test_delete_doc_already_deleted(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertRaises(errors.DocumentAlreadyDeleted, + self.db.delete_doc, doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_delete_doc_bad_rev(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + + def test_delete_doc_sets_content_to_None(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertIs(None, doc.get_json()) + + def test_delete_doc_rev_supersedes(self): + doc = self.db.create_doc_from_json(simple_doc) + doc.set_json(nested_doc) + self.db.put_doc(doc) + doc.set_json('{"fishy": "content"}') + self.db.put_doc(doc) + old_rev = doc.rev + self.db.delete_doc(doc) + cur_vc = vectorclock.VectorClockRev(old_rev) + deleted_vc = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(deleted_vc.is_newer(cur_vc), + "%s does not supersede %s" % (doc.rev, old_rev)) + + def test_delete_then_put(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + doc.set_json(nested_doc) + self.db.put_doc(doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + +class DocumentSizeTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_put_doc_refuses_oversized_documents(self): + self.db.set_document_size_limit(1) + doc = self.make_document('doc-id', None, simple_doc) + self.assertRaises(errors.DocumentTooBig, self.db.put_doc, doc) + + def test_create_doc_refuses_oversized_documents(self): + self.db.set_document_size_limit(1) + self.assertRaises( + errors.DocumentTooBig, self.db.create_doc_from_json, simple_doc, + doc_id='my_doc_id') + + def test_set_document_size_limit_zero(self): + self.db.set_document_size_limit(0) + self.assertEqual(0, self.db.document_size_limit) + + def test_set_document_size_limit(self): + self.db.set_document_size_limit(1000000) + self.assertEqual(1000000, self.db.document_size_limit) + + +class LocalDatabaseTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_create_doc_different_ids_diff_db(self): + doc1 = self.db.create_doc_from_json(simple_doc) + db2 = self.create_database('other-uid') + doc2 = db2.create_doc_from_json(simple_doc) + self.assertNotEqual(doc1.doc_id, doc2.doc_id) + + def test_put_doc_refuses_slashes_picky(self): + doc = self.make_document('/a', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_get_all_docs_empty(self): + self.assertEqual([], list(self.db.get_all_docs()[1])) + + def test_get_all_docs(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual( + sorted([doc1, doc2]), sorted(list(self.db.get_all_docs()[1]))) + + def test_get_all_docs_exclude_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc2) + self.assertEqual([doc1], list(self.db.get_all_docs()[1])) + + def test_get_all_docs_include_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc2) + self.assertEqual( + sorted([doc1, doc2]), + sorted(list(self.db.get_all_docs(include_deleted=True)[1]))) + + def test_get_all_docs_generation(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_doc_from_json(nested_doc) + self.assertEqual(2, self.db.get_all_docs()[0]) + + def test_simple_put_doc_if_newer(self): + doc = self.make_document('my-doc-id', 'test:1', simple_doc) + state_at_gen = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(('inserted', 1), state_at_gen) + self.assertGetDoc(self.db, 'my-doc-id', 'test:1', simple_doc, False) + + def test_simple_put_doc_if_newer_deleted(self): + self.db.create_doc_from_json('{}', doc_id='my-doc-id') + doc = self.make_document('my-doc-id', 'test:2', None) + state_at_gen = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(('inserted', 2), state_at_gen) + self.assertGetDocIncludeDeleted( + self.db, 'my-doc-id', 'test:2', None, False) + + def test_put_doc_if_newer_already_superseded(self): + orig_doc = '{"new": "doc"}' + doc1 = self.db.create_doc_from_json(orig_doc) + doc1_rev1 = doc1.rev + doc1.set_json(simple_doc) + self.db.put_doc(doc1) + doc1_rev2 = doc1.rev + # Nothing is inserted, because the document is already superseded + doc = self.make_document(doc1.doc_id, doc1_rev1, orig_doc) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('superseded', state) + self.assertGetDoc(self.db, doc1.doc_id, doc1_rev2, simple_doc, False) + + def test_put_doc_if_newer_autoresolve(self): + doc1 = self.db.create_doc_from_json(simple_doc) + rev = doc1.rev + doc = self.make_document(doc1.doc_id, "whatever:1", doc1.get_json()) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('superseded', state) + doc2 = self.db.get_doc(doc1.doc_id) + v2 = vectorclock.VectorClockRev(doc2.rev) + self.assertTrue(v2.is_newer(vectorclock.VectorClockRev("whatever:1"))) + self.assertTrue(v2.is_newer(vectorclock.VectorClockRev(rev))) + # strictly newer locally + self.assertTrue(rev not in doc2.rev) + + def test_put_doc_if_newer_already_converged(self): + orig_doc = '{"new": "doc"}' + doc1 = self.db.create_doc_from_json(orig_doc) + state_at_gen = self.db._put_doc_if_newer( + doc1, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(('converged', 1), state_at_gen) + + def test_put_doc_if_newer_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + # Nothing is inserted, the document id is returned as would-conflict + alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + state, _ = self.db._put_doc_if_newer( + alt_doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('conflicted', state) + # The database wasn't altered + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + + def test_put_doc_if_newer_newer_generation(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc = self.make_document('doc_id', 'other:2', simple_doc) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='other', replica_gen=2, + replica_trans_id='T-irrelevant') + self.assertEqual('inserted', state) + + def test_put_doc_if_newer_same_generation_same_txid(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc = self.db.create_doc_from_json(simple_doc) + self.make_document(doc.doc_id, 'other:1', simple_doc) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='other', replica_gen=1, + replica_trans_id='T-sid') + self.assertEqual('converged', state) + + def test_put_doc_if_newer_wrong_transaction_id(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc = self.make_document('doc_id', 'other:1', simple_doc) + self.assertRaises( + errors.InvalidTransactionId, + self.db._put_doc_if_newer, doc, save_conflict=False, + replica_uid='other', replica_gen=1, replica_trans_id='T-sad') + + def test_put_doc_if_newer_old_generation_older_doc(self): + orig_doc = '{"new": "doc"}' + doc = self.db.create_doc_from_json(orig_doc) + doc_rev1 = doc.rev + doc.set_json(simple_doc) + self.db.put_doc(doc) + self.db._set_replica_gen_and_trans_id('other', 3, 'T-sid') + older_doc = self.make_document(doc.doc_id, doc_rev1, simple_doc) + state, _ = self.db._put_doc_if_newer( + older_doc, save_conflict=False, replica_uid='other', replica_gen=8, + replica_trans_id='T-irrelevant') + self.assertEqual('superseded', state) + + def test_put_doc_if_newer_old_generation_newer_doc(self): + self.db._set_replica_gen_and_trans_id('other', 5, 'T-sid') + doc = self.make_document('doc_id', 'other:1', simple_doc) + self.assertRaises( + errors.InvalidGeneration, + self.db._put_doc_if_newer, doc, save_conflict=False, + replica_uid='other', replica_gen=1, replica_trans_id='T-sad') + + def test_put_doc_if_newer_replica_uid(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1', + nested_doc) + self.assertEqual('inserted', + self.db._put_doc_if_newer(doc2, save_conflict=False, + replica_uid='other', replica_gen=2, + replica_trans_id='T-id2')[0]) + self.assertEqual((2, 'T-id2'), self.db._get_replica_gen_and_trans_id( + 'other')) + # Compare to the old rev, should be superseded + doc2 = self.make_document(doc1.doc_id, doc1.rev, nested_doc) + self.assertEqual('superseded', + self.db._put_doc_if_newer(doc2, save_conflict=False, + replica_uid='other', replica_gen=3, + replica_trans_id='T-id3')[0]) + self.assertEqual( + (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other')) + # A conflict that isn't saved still records the sync gen, because we + # don't need to see it again + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|fourth:1', + '{}') + self.assertEqual('conflicted', + self.db._put_doc_if_newer(doc2, save_conflict=False, + replica_uid='other', replica_gen=4, + replica_trans_id='T-id4')[0]) + self.assertEqual( + (4, 'T-id4'), self.db._get_replica_gen_and_trans_id('other')) + + def test__get_replica_gen_and_trans_id(self): + self.assertEqual( + (0, ''), self.db._get_replica_gen_and_trans_id('other-db')) + self.db._set_replica_gen_and_trans_id('other-db', 2, 'T-transaction') + self.assertEqual( + (2, 'T-transaction'), + self.db._get_replica_gen_and_trans_id('other-db')) + + def test_put_updates_transaction_log(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + doc.set_json('{"something": "else"}') + self.db.put_doc(doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), + self.db.whats_changed()) + + def test_delete_updates_transaction_log(self): + doc = self.db.create_doc_from_json(simple_doc) + db_gen, _, _ = self.db.whats_changed() + self.db.delete_doc(doc) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), + self.db.whats_changed(db_gen)) + + def test_whats_changed_initial_database(self): + self.assertEqual((0, '', []), self.db.whats_changed()) + + def test_whats_changed_returns_one_id_for_multiple_changes(self): + doc = self.db.create_doc_from_json(simple_doc) + doc.set_json('{"new": "contents"}') + self.db.put_doc(doc) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), + self.db.whats_changed()) + self.assertEqual((2, last_trans_id, []), self.db.whats_changed(2)) + + def test_whats_changed_returns_last_edits_ascending(self): + doc = self.db.create_doc_from_json(simple_doc) + doc1 = self.db.create_doc_from_json(simple_doc) + doc.set_json('{"new": "contents"}') + self.db.delete_doc(doc1) + delete_trans_id = self.getLastTransId(self.db) + self.db.put_doc(doc) + put_trans_id = self.getLastTransId(self.db) + self.assertEqual((4, put_trans_id, + [(doc1.doc_id, 3, delete_trans_id), + (doc.doc_id, 4, put_trans_id)]), + self.db.whats_changed()) + + def test_whats_changed_doesnt_include_old_gen(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((3, last_trans_id, [(doc2.doc_id, 3, last_trans_id)]), + self.db.whats_changed(2)) + + +class LocalDatabaseValidateGenNTransIdTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_validate_gen_and_trans_id(self): + self.db.create_doc_from_json(simple_doc) + gen, trans_id = self.db._get_generation_info() + self.db.validate_gen_and_trans_id(gen, trans_id) + + def test_validate_gen_and_trans_id_invalid_txid(self): + self.db.create_doc_from_json(simple_doc) + gen, _ = self.db._get_generation_info() + self.assertRaises( + errors.InvalidTransactionId, + self.db.validate_gen_and_trans_id, gen, 'wrong') + + def test_validate_gen_and_trans_id_invalid_gen(self): + self.db.create_doc_from_json(simple_doc) + gen, trans_id = self.db._get_generation_info() + self.assertRaises( + errors.InvalidGeneration, + self.db.validate_gen_and_trans_id, gen + 1, trans_id) + + +class LocalDatabaseValidateSourceGenTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_validate_source_gen_and_trans_id_same(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + self.db._validate_source('other', 1, 'T-sid') + + def test_validate_source_gen_newer(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + self.db._validate_source('other', 2, 'T-whatevs') + + def test_validate_source_wrong_txid(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + self.assertRaises( + errors.InvalidTransactionId, + self.db._validate_source, 'other', 1, 'T-sad') + + +class LocalDatabaseWithConflictsTests(tests.DatabaseBaseTests): + # test supporting/functionality around storing conflicts + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_get_docs_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual([doc2], list(self.db.get_docs([doc1.doc_id]))) + + def test_get_docs_conflicts_ignored(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + no_conflict_doc = self.make_document(doc1.doc_id, 'alternate:1', + nested_doc) + self.assertEqual([no_conflict_doc, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id], + check_for_conflicts=False))) + + def test_get_doc_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual([alt_doc, doc], + self.db.get_doc_conflicts(doc.doc_id)) + + def test_get_all_docs_sees_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + _, docs = self.db.get_all_docs() + self.assertTrue(list(docs)[0].has_conflicts) + + def test_get_doc_conflicts_unconflicted(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertEqual([], self.db.get_doc_conflicts(doc.doc_id)) + + def test_get_doc_conflicts_no_such_id(self): + self.assertEqual([], self.db.get_doc_conflicts('doc-id')) + + def test_resolve_doc(self): + doc = self.db.create_doc_from_json(simple_doc) + alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc.doc_id, + [('alternate:1', nested_doc), (doc.rev, simple_doc)]) + orig_rev = doc.rev + self.db.resolve_doc(doc, [alt_doc.rev, doc.rev]) + self.assertNotEqual(orig_rev, doc.rev) + self.assertFalse(doc.has_conflicts) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + self.assertGetDocConflicts(self.db, doc.doc_id, []) + + def test_resolve_doc_picks_biggest_vcr(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, simple_doc)]) + orig_doc1_rev = doc1.rev + self.db.resolve_doc(doc1, [doc2.rev, doc1.rev]) + self.assertFalse(doc1.has_conflicts) + self.assertNotEqual(orig_doc1_rev, doc1.rev) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + vcr_1 = vectorclock.VectorClockRev(orig_doc1_rev) + vcr_2 = vectorclock.VectorClockRev(doc2.rev) + vcr_new = vectorclock.VectorClockRev(doc1.rev) + self.assertTrue(vcr_new.is_newer(vcr_1)) + self.assertTrue(vcr_new.is_newer(vcr_2)) + + def test_resolve_doc_partial_not_winning(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, simple_doc)]) + content3 = '{"key": "valin3"}' + doc3 = self.make_document(doc1.doc_id, 'third:1', content3) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='bar') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc3.rev, content3), + (doc1.rev, simple_doc), + (doc2.rev, nested_doc)]) + self.db.resolve_doc(doc1, [doc2.rev, doc1.rev]) + self.assertTrue(doc1.has_conflicts) + self.assertGetDoc(self.db, doc1.doc_id, doc3.rev, content3, True) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc3.rev, content3), + (doc1.rev, simple_doc)]) + + def test_resolve_doc_partial_winning(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + content3 = '{"key": "valin3"}' + doc3 = self.make_document(doc1.doc_id, 'third:1', content3) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='bar') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc3.rev, content3), + (doc1.rev, simple_doc), + (doc2.rev, nested_doc)]) + self.db.resolve_doc(doc1, [doc3.rev, doc1.rev]) + self.assertTrue(doc1.has_conflicts) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc1.rev, simple_doc), + (doc2.rev, nested_doc)]) + + def test_resolve_doc_with_delete_conflict(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc1) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, None)]) + self.db.resolve_doc(doc2, [doc1.rev, doc2.rev]) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, False) + + def test_resolve_doc_with_delete_to_delete(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc1) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, None)]) + self.db.resolve_doc(doc1, [doc1.rev, doc2.rev]) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + self.assertGetDocIncludeDeleted( + self.db, doc1.doc_id, doc1.rev, None, False) + + def test_put_doc_if_newer_save_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + # Document is inserted as a conflict + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + state, _ = self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('conflicted', state) + # The database was updated + self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, nested_doc, True) + + def test_force_doc_conflict_supersedes_properly(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', '{"b": 1}') + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + doc3 = self.make_document(doc1.doc_id, 'altalt:1', '{"c": 1}') + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='bar') + doc22 = self.make_document(doc1.doc_id, 'alternate:2', '{"b": 2}') + self.db._put_doc_if_newer( + doc22, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='zed') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:2', doc22.get_json()), + ('altalt:1', doc3.get_json()), + (doc1.rev, simple_doc)]) + + def test_put_doc_if_newer_save_conflict_was_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc1) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertTrue(doc2.has_conflicts) + self.assertGetDoc( + self.db, doc1.doc_id, 'alternate:1', nested_doc, True) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:1', nested_doc), (doc1.rev, None)]) + + def test_put_doc_if_newer_propagates_full_resolution(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + resolved_vcr = vectorclock.VectorClockRev(doc1.rev) + vcr_2 = vectorclock.VectorClockRev(doc2.rev) + resolved_vcr.maximize(vcr_2) + resolved_vcr.increment('alternate') + doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(), + '{"good": 1}') + state, _ = self.db._put_doc_if_newer( + doc_resolved, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual('inserted', state) + self.assertFalse(doc_resolved.has_conflicts) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + doc3 = self.db.get_doc(doc1.doc_id) + self.assertFalse(doc3.has_conflicts) + + def test_put_doc_if_newer_propagates_partial_resolution(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'altalt:1', '{}') + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + doc3 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:1', nested_doc), ('test:1', simple_doc), + ('altalt:1', '{}')]) + resolved_vcr = vectorclock.VectorClockRev(doc1.rev) + vcr_3 = vectorclock.VectorClockRev(doc3.rev) + resolved_vcr.maximize(vcr_3) + resolved_vcr.increment('alternate') + doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(), + '{"good": 1}') + state, _ = self.db._put_doc_if_newer( + doc_resolved, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual('inserted', state) + self.assertTrue(doc_resolved.has_conflicts) + doc4 = self.db.get_doc(doc1.doc_id) + self.assertTrue(doc4.has_conflicts) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:2|test:1', '{"good": 1}'), ('altalt:1', '{}')]) + + def test_put_doc_if_newer_replica_uid(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db._set_replica_gen_and_trans_id('other', 1, 'T-id') + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1', + nested_doc) + self.db._put_doc_if_newer(doc2, save_conflict=True, + replica_uid='other', replica_gen=2, + replica_trans_id='T-id2') + # Conflict vs the current update + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|third:3', + '{}') + self.assertEqual('conflicted', + self.db._put_doc_if_newer(doc2, save_conflict=True, + replica_uid='other', replica_gen=3, + replica_trans_id='T-id3')[0]) + self.assertEqual( + (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other')) + + def test_put_doc_if_newer_autoresolve_2(self): + # this is an ordering variant of _3, but that already works + # adding the test explicitly to catch the regression easily + doc_a1 = self.db.create_doc_from_json(simple_doc) + doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', "{}") + doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', + '{"a":"42"}') + doc_a3 = self.make_document(doc_a1.doc_id, 'test:2|other:1', "{}") + state, _ = self.db._put_doc_if_newer( + doc_a2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(state, 'inserted') + state, _ = self.db._put_doc_if_newer( + doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual(state, 'conflicted') + state, _ = self.db._put_doc_if_newer( + doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual(state, 'inserted') + self.assertFalse(self.db.get_doc(doc_a1.doc_id).has_conflicts) + + def test_put_doc_if_newer_autoresolve_3(self): + doc_a1 = self.db.create_doc_from_json(simple_doc) + doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', "{}") + doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}') + doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', "{}") + state, _ = self.db._put_doc_if_newer( + doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(state, 'inserted') + state, _ = self.db._put_doc_if_newer( + doc_a2, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual(state, 'conflicted') + state, _ = self.db._put_doc_if_newer( + doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual(state, 'superseded') + doc = self.db.get_doc(doc_a1.doc_id, True) + self.assertFalse(doc.has_conflicts) + rev = vectorclock.VectorClockRev(doc.rev) + rev_a3 = vectorclock.VectorClockRev('test:3') + rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1') + self.assertTrue(rev.is_newer(rev_a3)) + self.assertTrue('test:4' in doc.rev) # locally increased + self.assertTrue(rev.is_newer(rev_a1b1)) + + def test_put_doc_if_newer_autoresolve_4(self): + doc_a1 = self.db.create_doc_from_json(simple_doc) + doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', None) + doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}') + doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', None) + state, _ = self.db._put_doc_if_newer( + doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(state, 'inserted') + state, _ = self.db._put_doc_if_newer( + doc_a2, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual(state, 'conflicted') + state, _ = self.db._put_doc_if_newer( + doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual(state, 'superseded') + doc = self.db.get_doc(doc_a1.doc_id, True) + self.assertFalse(doc.has_conflicts) + rev = vectorclock.VectorClockRev(doc.rev) + rev_a3 = vectorclock.VectorClockRev('test:3') + rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1') + self.assertTrue(rev.is_newer(rev_a3)) + self.assertTrue('test:4' in doc.rev) # locally increased + self.assertTrue(rev.is_newer(rev_a1b1)) + + def test_put_refuses_to_update_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + content2 = '{"key": "altval"}' + doc2 = self.make_document(doc1.doc_id, 'altrev:1', content2) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, content2, True) + content3 = '{"key": "local"}' + doc2.set_json(content3) + self.assertRaises(errors.ConflictedDoc, self.db.put_doc, doc2) + + def test_delete_refuses_for_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'altrev:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, True) + self.assertRaises(errors.ConflictedDoc, self.db.delete_doc, doc2) + + +class DatabaseIndexTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def assertParseError(self, definition): + self.db.create_doc_from_json(nested_doc) + self.assertRaises( + errors.IndexDefinitionParseError, self.db.create_index, 'idx', + definition) + + def assertIndexCreatable(self, definition): + name = "idx" + self.db.create_doc_from_json(nested_doc) + self.db.create_index(name, definition) + self.assertEqual( + [(name, [definition])], self.db.list_indexes()) + + def test_create_index(self): + self.db.create_index('test-idx', 'name') + self.assertEqual([('test-idx', ['name'])], + self.db.list_indexes()) + + def test_create_index_on_non_ascii_field_name(self): + doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'})) + self.db.create_index('test-idx', u'\xe5') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_list_indexes_with_non_ascii_field_names(self): + self.db.create_index('test-idx', u'\xe5') + self.assertEqual( + [('test-idx', [u'\xe5'])], self.db.list_indexes()) + + def test_create_index_evaluates_it(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_wildcard_matches_unicode_value(self): + doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) + + def test_retrieve_unicode_value_from_index(self): + doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', u"valu\xe5")) + + def test_create_index_fails_if_name_taken(self): + self.db.create_index('test-idx', 'key') + self.assertRaises(errors.IndexNameTakenError, + self.db.create_index, + 'test-idx', 'stuff') + + def test_create_index_does_not_fail_if_name_taken_with_same_index(self): + self.db.create_index('test-idx', 'key') + self.db.create_index('test-idx', 'key') + self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) + + def test_create_index_does_not_duplicate_indexed_fields(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.db.delete_index('test-idx') + self.db.create_index('test-idx', 'key') + self.assertEqual(1, len(self.db.get_from_index('test-idx', 'value'))) + + def test_delete_index_does_not_remove_fields_from_other_indexes(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.db.create_index('test-idx2', 'key') + self.db.delete_index('test-idx') + self.assertEqual(1, len(self.db.get_from_index('test-idx2', 'value'))) + + def test_create_index_after_deleting_document(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc2) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_delete_index(self): + self.db.create_index('test-idx', 'key') + self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) + self.db.delete_index('test-idx') + self.assertEqual([], self.db.list_indexes()) + + def test_create_adds_to_index(self): + self.db.create_index('test-idx', 'key') + doc = self.db.create_doc_from_json(simple_doc) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index_unmatched(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual([], self.db.get_from_index('test-idx', 'novalue')) + + def test_create_index_multiple_exact_matches(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual( + sorted([doc, doc2]), + sorted(self.db.get_from_index('test-idx', 'value'))) + + def test_get_from_index(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index_multi(self): + content = '{"key": "value", "key2": "value2"}' + doc = self.db.create_doc_from_json(content) + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2')) + + def test_get_from_index_multi_list(self): + doc = self.db.create_doc_from_json( + '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2-1')) + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2-2')) + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2-3')) + self.assertEqual( + [('value', 'value2-1'), ('value', 'value2-2'), + ('value', 'value2-3')], + sorted(self.db.get_index_keys('test-idx'))) + + def test_get_from_index_sees_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key', 'key2') + alt_doc = self.make_document( + doc.doc_id, 'alternate:1', + '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}') + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + docs = self.db.get_from_index('test-idx', 'value', 'value2-1') + self.assertTrue(docs[0].has_conflicts) + + def test_get_index_keys_multi_list_list(self): + self.db.create_doc_from_json( + '{"key": "value1-1 value1-2 value1-3", ' + '"key2": ["value2-1", "value2-2", "value2-3"]}') + self.db.create_index('test-idx', 'split_words(key)', 'key2') + self.assertEqual( + [(u'value1-1', u'value2-1'), (u'value1-1', u'value2-2'), + (u'value1-1', u'value2-3'), (u'value1-2', u'value2-1'), + (u'value1-2', u'value2-2'), (u'value1-2', u'value2-3'), + (u'value1-3', u'value2-1'), (u'value1-3', u'value2-2'), + (u'value1-3', u'value2-3')], + sorted(self.db.get_index_keys('test-idx'))) + + def test_get_from_index_multi_ordered(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2, doc1], + self.db.get_from_index('test-idx', 'v*', '*')) + + def test_get_range_from_index_start_end(self): + doc1 = self.db.create_doc_from_json('{"key": "value3"}') + doc2 = self.db.create_doc_from_json('{"key": "value2"}') + self.db.create_doc_from_json('{"key": "value4"}') + self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc2, doc1], + self.db.get_range_from_index('test-idx', 'value2', 'value3')) + + def test_get_range_from_index_start(self): + doc1 = self.db.create_doc_from_json('{"key": "value3"}') + doc2 = self.db.create_doc_from_json('{"key": "value2"}') + doc3 = self.db.create_doc_from_json('{"key": "value4"}') + self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc2, doc1, doc3], + self.db.get_range_from_index('test-idx', 'value2')) + + def test_get_range_from_index_sees_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + alt_doc = self.make_document( + doc.doc_id, 'alternate:1', '{"key": "valuedepalue"}') + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + docs = self.db.get_range_from_index('test-idx', 'a') + self.assertTrue(docs[0].has_conflicts) + + def test_get_range_from_index_end(self): + self.db.create_doc_from_json('{"key": "value3"}') + doc2 = self.db.create_doc_from_json('{"key": "value2"}') + self.db.create_doc_from_json('{"key": "value4"}') + doc4 = self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc4, doc2], + self.db.get_range_from_index('test-idx', None, 'value2')) + + def test_get_wildcard_range_from_index_start(self): + doc1 = self.db.create_doc_from_json('{"key": "value4"}') + doc2 = self.db.create_doc_from_json('{"key": "value23"}') + doc3 = self.db.create_doc_from_json('{"key": "value2"}') + doc4 = self.db.create_doc_from_json('{"key": "value22"}') + self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc3, doc4, doc2, doc1], + self.db.get_range_from_index('test-idx', 'value2*')) + + def test_get_wildcard_range_from_index_end(self): + self.db.create_doc_from_json('{"key": "value4"}') + doc2 = self.db.create_doc_from_json('{"key": "value23"}') + doc3 = self.db.create_doc_from_json('{"key": "value2"}') + doc4 = self.db.create_doc_from_json('{"key": "value22"}') + doc5 = self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc5, doc3, doc4, doc2], + self.db.get_range_from_index('test-idx', None, 'value2*')) + + def test_get_wildcard_range_from_index_start_end(self): + self.db.create_doc_from_json('{"key": "a"}') + self.db.create_doc_from_json('{"key": "boo3"}') + doc3 = self.db.create_doc_from_json('{"key": "catalyst"}') + doc4 = self.db.create_doc_from_json('{"key": "whaever"}') + self.db.create_doc_from_json('{"key": "zerg"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc3, doc4], + self.db.get_range_from_index('test-idx', 'cat*', 'zap*')) + + def test_get_range_from_index_multi_column_start_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc3, doc2], + self.db.get_range_from_index( + 'test-idx', ('value2', 'value2'), ('value2', 'value3'))) + + def test_get_range_from_index_multi_column_start(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + self.db.create_doc_from_json('{"key": "value2", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc2, doc1], + self.db.get_range_from_index('test-idx', ('value2', 'value3'))) + + def test_get_range_from_index_multi_column_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2], + self.db.get_range_from_index( + 'test-idx', None, ('value2', 'value3'))) + + def test_get_wildcard_range_from_index_multi_column_start(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc3, doc2, doc1], + self.db.get_range_from_index('test-idx', ('value2', 'value2*'))) + + def test_get_wildcard_range_from_index_multi_column_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2], + self.db.get_range_from_index( + 'test-idx', None, ('value2', 'value2*'))) + + def test_get_glob_range_from_index_multi_column_start(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc2, doc1], + self.db.get_range_from_index('test-idx', ('value2', '*'))) + + def test_get_glob_range_from_index_multi_column_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2], + self.db.get_range_from_index('test-idx', None, ('value2', '*'))) + + def test_get_range_from_index_illegal_wildcard_order(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', ('*', 'v2')) + + def test_get_range_from_index_illegal_glob_after_wildcard(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', ('*', 'v*')) + + def test_get_range_from_index_illegal_wildcard_order_end(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', None, ('*', 'v2')) + + def test_get_range_from_index_illegal_glob_after_wildcard_end(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', None, ('*', 'v*')) + + def test_get_from_index_fails_if_no_index(self): + self.assertRaises( + errors.IndexDoesNotExist, self.db.get_from_index, 'foo') + + def test_get_index_keys_fails_if_no_index(self): + self.assertRaises(errors.IndexDoesNotExist, + self.db.get_index_keys, + 'foo') + + def test_get_index_keys_works_if_no_docs(self): + self.db.create_index('test-idx', 'key') + self.assertEqual([], self.db.get_index_keys('test-idx')) + + def test_put_updates_index(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + new_content = '{"key": "altval"}' + doc.set_json(new_content) + self.db.put_doc(doc) + self.assertEqual([], self.db.get_from_index('test-idx', 'value')) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'altval')) + + def test_delete_updates_index(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual( + sorted([doc, doc2]), + sorted(self.db.get_from_index('test-idx', 'value'))) + self.db.delete_doc(doc) + self.assertEqual([doc2], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index_illegal_number_of_entries(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidValueForIndex, self.db.get_from_index, 'test-idx') + self.assertRaises( + errors.InvalidValueForIndex, + self.db.get_from_index, 'test-idx', 'v1') + self.assertRaises( + errors.InvalidValueForIndex, + self.db.get_from_index, 'test-idx', 'v1', 'v2', 'v3') + + def test_get_from_index_illegal_wildcard_order(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', '*', 'v2') + + def test_get_from_index_illegal_glob_after_wildcard(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', '*', 'v*') + + def test_get_all_from_index(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + # This one should not be in the index + self.db.create_doc_from_json('{"no": "key"}') + diff_value_doc = '{"key": "diff value"}' + doc4 = self.db.create_doc_from_json(diff_value_doc) + # This is essentially a 'prefix' match, but we match every entry. + self.assertEqual( + sorted([doc1, doc2, doc4]), + sorted(self.db.get_from_index('test-idx', '*'))) + + def test_get_all_from_index_ordered(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json('{"key": "value x"}') + doc2 = self.db.create_doc_from_json('{"key": "value b"}') + doc3 = self.db.create_doc_from_json('{"key": "value a"}') + doc4 = self.db.create_doc_from_json('{"key": "value m"}') + # This is essentially a 'prefix' match, but we match every entry. + self.assertEqual( + [doc3, doc2, doc4, doc1], self.db.get_from_index('test-idx', '*')) + + def test_put_updates_when_adding_key(self): + doc = self.db.create_doc_from_json("{}") + self.db.create_index('test-idx', 'key') + self.assertEqual([], self.db.get_from_index('test-idx', '*')) + doc.set_json(simple_doc) + self.db.put_doc(doc) + self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) + + def test_get_from_index_empty_string(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + content2 = '{"key": ""}' + doc2 = self.db.create_doc_from_json(content2) + self.assertEqual([doc2], self.db.get_from_index('test-idx', '')) + # Empty string matches the wildcard. + self.assertEqual( + sorted([doc1, doc2]), + sorted(self.db.get_from_index('test-idx', '*'))) + + def test_get_from_index_not_null(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.create_doc_from_json('{"key": null}') + self.assertEqual([doc1], self.db.get_from_index('test-idx', '*')) + + def test_get_partial_from_index(self): + content1 = '{"k1": "v1", "k2": "v2"}' + content2 = '{"k1": "v1", "k2": "x2"}' + content3 = '{"k1": "v1", "k2": "y2"}' + # doc4 has a different k1 value, so it doesn't match the prefix. + content4 = '{"k1": "NN", "k2": "v2"}' + doc1 = self.db.create_doc_from_json(content1) + doc2 = self.db.create_doc_from_json(content2) + doc3 = self.db.create_doc_from_json(content3) + self.db.create_doc_from_json(content4) + self.db.create_index('test-idx', 'k1', 'k2') + self.assertEqual( + sorted([doc1, doc2, doc3]), + sorted(self.db.get_from_index('test-idx', "v1", "*"))) + + def test_get_glob_match(self): + # Note: the exact glob syntax is probably subject to change + content1 = '{"k1": "v1", "k2": "v1"}' + content2 = '{"k1": "v1", "k2": "v2"}' + content3 = '{"k1": "v1", "k2": "v3"}' + # doc4 has a different k2 prefix value, so it doesn't match + content4 = '{"k1": "v1", "k2": "ZZ"}' + self.db.create_index('test-idx', 'k1', 'k2') + doc1 = self.db.create_doc_from_json(content1) + doc2 = self.db.create_doc_from_json(content2) + doc3 = self.db.create_doc_from_json(content3) + self.db.create_doc_from_json(content4) + self.assertEqual( + sorted([doc1, doc2, doc3]), + sorted(self.db.get_from_index('test-idx', "v1", "v*"))) + + def test_nested_index(self): + doc = self.db.create_doc_from_json(nested_doc) + self.db.create_index('test-idx', 'sub.doc') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'underneath')) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual( + sorted([doc, doc2]), + sorted(self.db.get_from_index('test-idx', 'underneath'))) + + def test_nested_nonexistent(self): + self.db.create_doc_from_json(nested_doc) + # sub exists, but sub.foo does not: + self.db.create_index('test-idx', 'sub.foo') + self.assertEqual([], self.db.get_from_index('test-idx', '*')) + + def test_nested_nonexistent2(self): + self.db.create_doc_from_json(nested_doc) + self.db.create_index('test-idx', 'sub.foo.bar.baz.qux.fnord') + self.assertEqual([], self.db.get_from_index('test-idx', '*')) + + def test_nested_traverses_lists(self): + # subpath finds dicts in list + doc = self.db.create_doc_from_json( + '{"foo": [{"zap": "bar"}, {"zap": "baz"}]}') + # subpath only finds dicts in list + self.db.create_doc_from_json('{"foo": ["zap", "baz"]}') + self.db.create_index('test-idx', 'foo.zap') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'bar')) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'baz')) + + def test_nested_list_traversal(self): + # subpath finds dicts in list + doc = self.db.create_doc_from_json( + '{"foo": [{"zap": [{"qux": "fnord"}, {"qux": "zombo"}]},' + '{"zap": "baz"}]}') + # subpath only finds dicts in list + self.db.create_index('test-idx', 'foo.zap.qux') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'fnord')) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'zombo')) + + def test_index_list1(self): + self.db.create_index("index", "name") + content = '{"name": ["foo", "bar"]}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_index_list2(self): + self.db.create_index("index", "name") + content = '{"name": ["foo", "bar"]}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_get_from_index_case_sensitive(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + self.assertEqual([], self.db.get_from_index('test-idx', 'V*')) + self.assertEqual([doc1], self.db.get_from_index('test-idx', 'v*')) + + def test_get_from_index_illegal_glob_before_value(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', 'v*', 'v2') + + def test_get_from_index_illegal_glob_after_glob(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', 'v*', 'v*') + + def test_get_from_index_with_sql_wildcards(self): + self.db.create_index('test-idx', 'key') + content1 = '{"key": "va%lue"}' + content2 = '{"key": "value"}' + content3 = '{"key": "va_lue"}' + doc1 = self.db.create_doc_from_json(content1) + self.db.create_doc_from_json(content2) + doc3 = self.db.create_doc_from_json(content3) + # The '%' in the search should be treated literally, not as a sql + # globbing character. + self.assertEqual([doc1], self.db.get_from_index('test-idx', 'va%*')) + # Same for '_' + self.assertEqual([doc3], self.db.get_from_index('test-idx', 'va_*')) + + def test_get_from_index_with_lower(self): + self.db.create_index("index", "lower(name)") + content = '{"name": "Foo"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_get_from_index_with_lower_matches_same_case(self): + self.db.create_index("index", "lower(name)") + content = '{"name": "foo"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_lower_doesnt_match_different_case(self): + self.db.create_index("index", "lower(name)") + content = '{"name": "Foo"}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "Foo") + self.assertEqual([], rows) + + def test_index_lower_doesnt_match_other_index(self): + self.db.create_index("index", "lower(name)") + self.db.create_index("other_index", "name") + content = '{"name": "Foo"}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "Foo") + self.assertEqual(0, len(rows)) + + def test_index_split_words_match_first(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_split_words_match_second(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_index_split_words_match_both(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo foo"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_split_words_double_space(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_index_split_words_leading_space(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": " foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_split_words_trailing_space(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar "}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_get_from_index_with_number(self): + self.db.create_index("index", "number(foo, 5)") + content = '{"foo": 12}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "00012") + self.assertEqual([doc], rows) + + def test_get_from_index_with_number_bigger_than_padding(self): + self.db.create_index("index", "number(foo, 5)") + content = '{"foo": 123456}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "123456") + self.assertEqual([doc], rows) + + def test_number_mapping_ignores_non_numbers(self): + self.db.create_index("index", "number(foo, 5)") + content = '{"foo": 56}' + doc1 = self.db.create_doc_from_json(content) + content = '{"foo": "this is not a maigret painting"}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "*") + self.assertEqual([doc1], rows) + + def test_get_from_index_with_bool(self): + self.db.create_index("index", "bool(foo)") + content = '{"foo": true}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "1") + self.assertEqual([doc], rows) + + def test_get_from_index_with_bool_false(self): + self.db.create_index("index", "bool(foo)") + content = '{"foo": false}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "0") + self.assertEqual([doc], rows) + + def test_get_from_index_with_non_bool(self): + self.db.create_index("index", "bool(foo)") + content = '{"foo": 42}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "*") + self.assertEqual([], rows) + + def test_get_from_index_with_combine(self): + self.db.create_index("index", "combine(foo, bar)") + content = '{"foo": "value1", "bar": "value2"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "value1") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "value2") + self.assertEqual([doc], rows) + + def test_get_complex_combine(self): + self.db.create_index( + "index", "combine(number(foo, 5), lower(bar), split_words(baz))") + content = '{"foo": 12, "bar": "ALLCAPS", "baz": "qux nox"}' + doc = self.db.create_doc_from_json(content) + content = '{"foo": "not a number", "bar": "something"}' + doc2 = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "00012") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "allcaps") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "nox") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "something") + self.assertEqual([doc2], rows) + + def test_get_index_keys_from_index(self): + self.db.create_index('test-idx', 'key') + content1 = '{"key": "value1"}' + content2 = '{"key": "value2"}' + content3 = '{"key": "value2"}' + self.db.create_doc_from_json(content1) + self.db.create_doc_from_json(content2) + self.db.create_doc_from_json(content3) + self.assertEqual( + [('value1',), ('value2',)], + sorted(self.db.get_index_keys('test-idx'))) + + def test_get_index_keys_from_multicolumn_index(self): + self.db.create_index('test-idx', 'key1', 'key2') + content1 = '{"key1": "value1", "key2": "val2-1"}' + content2 = '{"key1": "value2", "key2": "val2-2"}' + content3 = '{"key1": "value2", "key2": "val2-2"}' + content4 = '{"key1": "value2", "key2": "val3"}' + self.db.create_doc_from_json(content1) + self.db.create_doc_from_json(content2) + self.db.create_doc_from_json(content3) + self.db.create_doc_from_json(content4) + self.assertEqual([ + ('value1', 'val2-1'), + ('value2', 'val2-2'), + ('value2', 'val3')], + sorted(self.db.get_index_keys('test-idx'))) + + def test_empty_expr(self): + self.assertParseError('') + + def test_nested_unknown_operation(self): + self.assertParseError('unknown_operation(field1)') + + def test_parse_missing_close_paren(self): + self.assertParseError("lower(a") + + def test_parse_trailing_close_paren(self): + self.assertParseError("lower(ab))") + + def test_parse_trailing_chars(self): + self.assertParseError("lower(ab)adsf") + + def test_parse_empty_op(self): + self.assertParseError("(ab)") + + def test_parse_top_level_commas(self): + self.assertParseError("a, b") + + def test_invalid_field_name(self): + self.assertParseError("a.") + + def test_invalid_inner_field_name(self): + self.assertParseError("lower(a.)") + + def test_gobbledigook(self): + self.assertParseError("(@#@cc @#!*DFJSXV(()jccd") + + def test_leading_space(self): + self.assertIndexCreatable(" lower(a)") + + def test_trailing_space(self): + self.assertIndexCreatable("lower(a) ") + + def test_spaces_before_open_paren(self): + self.assertIndexCreatable("lower (a)") + + def test_spaces_after_open_paren(self): + self.assertIndexCreatable("lower( a)") + + def test_spaces_before_close_paren(self): + self.assertIndexCreatable("lower(a )") + + def test_spaces_before_comma(self): + self.assertIndexCreatable("combine(a , b , c)") + + def test_spaces_after_comma(self): + self.assertIndexCreatable("combine(a, b, c)") + + def test_all_together_now(self): + self.assertParseError(' (a) ') + + def test_all_together_now2(self): + self.assertParseError('combine(lower(x)x,foo)') + + +class PythonBackendTests(tests.DatabaseBaseTests): + + def setUp(self): + super(PythonBackendTests, self).setUp() + self.simple_doc = json.loads(simple_doc) + + def test_create_doc_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id') + self.assertTrue(isinstance(doc, TestAlternativeDocument)) + + def test_get_doc_after_put_with_factory(self): + doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id') + self.db.set_document_factory(TestAlternativeDocument) + result = self.db.get_doc('my_doc_id') + self.assertTrue(isinstance(result, TestAlternativeDocument)) + self.assertEqual(doc.doc_id, result.doc_id) + self.assertEqual(doc.rev, result.rev) + self.assertEqual(doc.get_json(), result.get_json()) + self.assertEqual(False, result.has_conflicts) + + def test_get_doc_nonexisting_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + self.assertIs(None, self.db.get_doc('non-existing')) + + def test_get_all_docs_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + self.db.create_doc(self.simple_doc) + self.assertTrue(isinstance( + list(self.db.get_all_docs()[1])[0], TestAlternativeDocument)) + + def test_get_docs_conflicted_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + doc1 = self.db.create_doc(self.simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertTrue( + isinstance( + list(self.db.get_docs([doc1.doc_id]))[0], + TestAlternativeDocument)) + + def test_get_from_index_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + self.db.create_doc(self.simple_doc) + self.db.create_index('test-idx', 'key') + self.assertTrue( + isinstance( + self.db.get_from_index('test-idx', 'value')[0], + TestAlternativeDocument)) + + def test_sync_exchange_updates_indexes(self): + doc = self.db.create_doc(self.simple_doc) + self.db.create_index('test-idx', 'key') + new_content = '{"key": "altval"}' + other_rev = 'test:1|z:2' + st = self.db.get_sync_target() + + def ignore(doc_id, doc_rev, doc): + pass + + doc_other = self.make_document(doc.doc_id, other_rev, new_content) + docs_by_gen = [(doc_other, 10, 'T-sid')] + st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=ignore) + self.assertGetDoc(self.db, doc.doc_id, other_rev, new_content, False) + self.assertEqual( + [doc_other], self.db.get_from_index('test-idx', 'altval')) + self.assertEqual([], self.db.get_from_index('test-idx', 'value')) + + +# Use a custom loader to apply the scenarios at load time. +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_c_backend.py b/src/leap/soledad/u1db/tests/test_c_backend.py new file mode 100644 index 00000000..bdd2aec7 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_c_backend.py @@ -0,0 +1,634 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +try: + import simplejson as json +except ImportError: + import json # noqa +from u1db import ( + Document, + errors, + tests, + ) +from u1db.tests import c_backend_wrapper, c_backend_error +from u1db.tests.test_remote_sync_target import ( + make_http_app, + make_oauth_http_app + ) + + +class TestCDatabaseExists(tests.TestCase): + + def test_c_backend_compiled(self): + if c_backend_wrapper is None: + self.fail("Could not import the c_backend_wrapper module." + " Was it compiled properly?\n%s" % (c_backend_error,)) + + +# Rather than lots of failing tests, we have the above check to test that the +# module exists, and all these tests just get skipped +class BackendTests(tests.TestCase): + + def setUp(self): + super(BackendTests, self).setUp() + if c_backend_wrapper is None: + self.skipTest("The c_backend_wrapper could not be imported") + + +class TestCDatabase(BackendTests): + + def test_exists(self): + if c_backend_wrapper is None: + self.fail("Could not import the c_backend_wrapper module." + " Was it compiled properly?") + db = c_backend_wrapper.CDatabase(':memory:') + self.assertEqual(':memory:', db._filename) + + def test__is_closed(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertTrue(db._sql_is_open()) + db.close() + self.assertFalse(db._sql_is_open()) + + def test__run_sql(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertTrue(db._sql_is_open()) + self.assertEqual([], db._run_sql('CREATE TABLE test (id INTEGER)')) + self.assertEqual([], db._run_sql('INSERT INTO test VALUES (1)')) + self.assertEqual([('1',)], db._run_sql('SELECT * FROM test')) + + def test__get_generation(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertEqual(0, db._get_generation()) + db.create_doc_from_json(tests.simple_doc) + self.assertEqual(1, db._get_generation()) + + def test__get_generation_info(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertEqual((0, ''), db._get_generation_info()) + db.create_doc_from_json(tests.simple_doc) + info = db._get_generation_info() + self.assertEqual(1, info[0]) + self.assertTrue(info[1].startswith('T-')) + + def test__set_replica_uid(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertIsNot(None, db._replica_uid) + db._set_replica_uid('foo') + self.assertEqual([('foo',)], db._run_sql( + "SELECT value FROM u1db_config WHERE name='replica_uid'")) + + def test_default_replica_uid(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.assertIsNot(None, self.db._replica_uid) + self.assertEqual(32, len(self.db._replica_uid)) + # casting to an int from the uid *is* the check for correct behavior. + int(self.db._replica_uid, 16) + + def test_get_conflicts_with_borked_data(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + # We add an entry to conflicts, but not to documents, which is an + # invalid situation + self.db._run_sql("INSERT INTO conflicts" + " VALUES ('doc-id', 'doc-rev', '{}')") + self.assertRaises(Exception, self.db.get_doc_conflicts, 'doc-id') + + def test_create_index_list(self): + # We manually poke data into the DB, so that we test just the "get_doc" + # code, rather than also testing the index management code. + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index_list("key-idx", ["key"]) + docs = self.db.get_from_index('key-idx', 'value') + self.assertEqual([doc], docs) + + def test_create_index_list_on_non_ascii_field_name(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'})) + self.db.create_index_list('test-idx', [u'\xe5']) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_list_indexes_with_non_ascii_field_names(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index_list('test-idx', [u'\xe5']) + self.assertEqual( + [('test-idx', [u'\xe5'])], self.db.list_indexes()) + + def test_create_index_evaluates_it(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_wildcard_matches_unicode_value(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) + + def test_create_index_fails_if_name_taken(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index_list('test-idx', ['key']) + self.assertRaises(errors.IndexNameTakenError, + self.db.create_index_list, + 'test-idx', ['stuff']) + + def test_create_index_does_not_fail_if_name_taken_with_same_index(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index_list('test-idx', ['key']) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) + + def test_create_index_after_deleting_document(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db.create_doc_from_json(tests.simple_doc) + self.db.delete_doc(doc2) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index(self): + # We manually poke data into the DB, so that we test just the "get_doc" + # code, rather than also testing the index management code. + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index("key-idx", "key") + docs = self.db.get_from_index('key-idx', 'value') + self.assertEqual([doc], docs) + + def test_get_from_index_list(self): + # We manually poke data into the DB, so that we test just the "get_doc" + # code, rather than also testing the index management code. + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index("key-idx", "key") + docs = self.db.get_from_index_list('key-idx', ['value']) + self.assertEqual([doc], docs) + + def test_get_from_index_list_multi(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + content = '{"key": "value", "key2": "value2"}' + doc = self.db.create_doc_from_json(content) + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc], + self.db.get_from_index_list('test-idx', ['value', 'value2'])) + + def test_get_from_index_list_multi_ordered(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2, doc1], + self.db.get_from_index_list('test-idx', ['v*', '*'])) + + def test_get_from_index_2(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.nested_doc) + self.db.create_index("multi-idx", "key", "sub.doc") + docs = self.db.get_from_index('multi-idx', 'value', 'underneath') + self.assertEqual([doc], docs) + + def test_get_index_keys(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index("key-idx", "key") + keys = self.db.get_index_keys('key-idx') + self.assertEqual([("value",)], keys) + + def test__query_init_one_field(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index("key-idx", "key") + query = self.db._query_init("key-idx") + self.assertEqual("key-idx", query.index_name) + self.assertEqual(1, query.num_fields) + self.assertEqual(["key"], query.fields) + + def test__query_init_two_fields(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index("two-idx", "key", "key2") + query = self.db._query_init("two-idx") + self.assertEqual("two-idx", query.index_name) + self.assertEqual(2, query.num_fields) + self.assertEqual(["key", "key2"], query.fields) + + def assertFormatQueryEquals(self, expected, wildcards, fields): + val, w = c_backend_wrapper._format_query(fields) + self.assertEqual(expected, val) + self.assertEqual(wildcards, w) + + def test__format_query(self): + self.assertFormatQueryEquals( + "SELECT d0.doc_id FROM document_fields d0" + " WHERE d0.field_name = ? AND d0.value = ? ORDER BY d0.value", + [0], ["1"]) + self.assertFormatQueryEquals( + "SELECT d0.doc_id" + " FROM document_fields d0, document_fields d1" + " WHERE d0.field_name = ? AND d0.value = ?" + " AND d0.doc_id = d1.doc_id" + " AND d1.field_name = ? AND d1.value = ?" + " ORDER BY d0.value, d1.value", + [0, 0], ["1", "2"]) + self.assertFormatQueryEquals( + "SELECT d0.doc_id" + " FROM document_fields d0, document_fields d1, document_fields d2" + " WHERE d0.field_name = ? AND d0.value = ?" + " AND d0.doc_id = d1.doc_id" + " AND d1.field_name = ? AND d1.value = ?" + " AND d0.doc_id = d2.doc_id" + " AND d2.field_name = ? AND d2.value = ?" + " ORDER BY d0.value, d1.value, d2.value", + [0, 0, 0], ["1", "2", "3"]) + + def test__format_query_wildcard(self): + self.assertFormatQueryEquals( + "SELECT d0.doc_id FROM document_fields d0" + " WHERE d0.field_name = ? AND d0.value NOT NULL ORDER BY d0.value", + [1], ["*"]) + self.assertFormatQueryEquals( + "SELECT d0.doc_id" + " FROM document_fields d0, document_fields d1" + " WHERE d0.field_name = ? AND d0.value = ?" + " AND d0.doc_id = d1.doc_id" + " AND d1.field_name = ? AND d1.value NOT NULL" + " ORDER BY d0.value, d1.value", + [0, 1], ["1", "*"]) + + def test__format_query_glob(self): + self.assertFormatQueryEquals( + "SELECT d0.doc_id FROM document_fields d0" + " WHERE d0.field_name = ? AND d0.value GLOB ? ORDER BY d0.value", + [2], ["1*"]) + + +class TestCSyncTarget(BackendTests): + + def setUp(self): + super(TestCSyncTarget, self).setUp() + self.db = c_backend_wrapper.CDatabase(':memory:') + self.st = self.db.get_sync_target() + + def test_attached_to_db(self): + self.assertEqual( + self.db._replica_uid, self.st.get_sync_info("misc")[0]) + + def test_get_sync_exchange(self): + exc = self.st._get_sync_exchange("source-uid", 10) + self.assertIsNot(None, exc) + + def test_sync_exchange_insert_doc_from_source(self): + exc = self.st._get_sync_exchange("source-uid", 5) + doc = c_backend_wrapper.make_document('doc-id', 'replica:1', + tests.simple_doc) + self.assertEqual([], exc.get_seen_ids()) + exc.insert_doc_from_source(doc, 10, 'T-sid') + self.assertGetDoc(self.db, 'doc-id', 'replica:1', tests.simple_doc, + False) + self.assertEqual( + (10, 'T-sid'), self.db._get_replica_gen_and_trans_id('source-uid')) + self.assertEqual(['doc-id'], exc.get_seen_ids()) + + def test_sync_exchange_conflicted_doc(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + exc = self.st._get_sync_exchange("source-uid", 5) + doc2 = c_backend_wrapper.make_document(doc.doc_id, 'replica:1', + tests.nested_doc) + self.assertEqual([], exc.get_seen_ids()) + # The insert should be rejected and the doc_id not considered 'seen' + exc.insert_doc_from_source(doc2, 10, 'T-sid') + self.assertGetDoc( + self.db, doc.doc_id, doc.rev, tests.simple_doc, False) + self.assertEqual([], exc.get_seen_ids()) + + def test_sync_exchange_find_doc_ids(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + exc = self.st._get_sync_exchange("source-uid", 0) + self.assertEqual(0, exc.target_gen) + exc.find_doc_ids_to_return() + doc_id = exc.get_doc_ids_to_return()[0] + self.assertEqual( + (doc.doc_id, 1), doc_id[:-1]) + self.assertTrue(doc_id[-1].startswith('T-')) + self.assertEqual(1, exc.target_gen) + + def test_sync_exchange_find_doc_ids_not_including_recently_inserted(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db.create_doc_from_json(tests.nested_doc) + exc = self.st._get_sync_exchange("source-uid", 0) + doc3 = c_backend_wrapper.make_document(doc1.doc_id, + doc1.rev + "|zreplica:2", tests.simple_doc) + exc.insert_doc_from_source(doc3, 10, 'T-sid') + exc.find_doc_ids_to_return() + self.assertEqual( + (doc2.doc_id, 2), exc.get_doc_ids_to_return()[0][:-1]) + self.assertEqual(3, exc.target_gen) + + def test_sync_exchange_return_docs(self): + returned = [] + + def return_doc_cb(doc, gen, trans_id): + returned.append((doc, gen, trans_id)) + + doc1 = self.db.create_doc_from_json(tests.simple_doc) + exc = self.st._get_sync_exchange("source-uid", 0) + exc.find_doc_ids_to_return() + exc.return_docs(return_doc_cb) + self.assertEqual((doc1, 1), returned[0][:-1]) + + def test_sync_exchange_doc_ids(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc, doc_id='doc-1') + db2 = c_backend_wrapper.CDatabase(':memory:') + doc2 = db2.create_doc_from_json(tests.nested_doc, doc_id='doc-2') + returned = [] + + def return_doc_cb(doc, gen, trans_id): + returned.append((doc, gen, trans_id)) + + val = self.st.sync_exchange_doc_ids( + db2, [(doc2.doc_id, 1, 'T-sid')], 0, None, return_doc_cb) + last_trans_id = self.db._get_transaction_log()[-1][1] + self.assertEqual(2, self.db._get_generation()) + self.assertEqual((2, last_trans_id), val) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, + False) + self.assertEqual((doc1, 1), returned[0][:-1]) + + +class TestCHTTPSyncTarget(BackendTests): + + def test_format_sync_url(self): + target = c_backend_wrapper.create_http_sync_target("http://base_url") + self.assertEqual("http://base_url/sync-from/replica-uid", + c_backend_wrapper._format_sync_url(target, "replica-uid")) + + def test_format_sync_url_escapes(self): + # The base_url should not get munged (we assume it is already a + # properly formed URL), but the replica-uid should get properly escaped + target = c_backend_wrapper.create_http_sync_target( + "http://host/base%2Ctest/") + self.assertEqual("http://host/base%2Ctest/sync-from/replica%2Cuid", + c_backend_wrapper._format_sync_url(target, "replica,uid")) + + def test_format_refuses_non_http(self): + db = c_backend_wrapper.CDatabase(':memory:') + target = db.get_sync_target() + self.assertRaises(RuntimeError, + c_backend_wrapper._format_sync_url, target, 'replica,uid') + + def test_oauth_credentials(self): + target = c_backend_wrapper.create_oauth_http_sync_target( + "http://host/base%2Ctest/", + 'consumer-key', 'consumer-secret', 'token-key', 'token-secret') + auth = c_backend_wrapper._get_oauth_authorization(target, + "GET", "http://host/base%2Ctest/sync-from/abcd-efg") + self.assertIsNot(None, auth) + self.assertTrue(auth.startswith('Authorization: OAuth realm="", ')) + self.assertNotIn('http://host/base', auth) + self.assertIn('oauth_nonce="', auth) + self.assertIn('oauth_timestamp="', auth) + self.assertIn('oauth_consumer_key="consumer-key"', auth) + self.assertIn('oauth_signature_method="HMAC-SHA1"', auth) + self.assertIn('oauth_version="1.0"', auth) + self.assertIn('oauth_token="token-key"', auth) + self.assertIn('oauth_signature="', auth) + + +class TestSyncCtoHTTPViaC(tests.TestCaseWithServer): + + make_app_with_state = staticmethod(make_http_app) + + def setUp(self): + super(TestSyncCtoHTTPViaC, self).setUp() + if c_backend_wrapper is None: + self.skipTest("The c_backend_wrapper could not be imported") + self.startServer() + + def test_trivial_sync(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + url = self.getURL('test.db') + target = c_backend_wrapper.create_http_sync_target(url) + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + c_backend_wrapper.sync_db_to_target(db, target) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + def test_unavailable(self): + mem_db = self.request_state._create_database('test.db') + mem_db.create_doc_from_json(tests.nested_doc) + tries = [] + + def wrapper(instance, *args, **kwargs): + tries.append(None) + raise errors.Unavailable + + mem_db.whats_changed = wrapper + url = self.getURL('test.db') + target = c_backend_wrapper.create_http_sync_target(url) + db = c_backend_wrapper.CDatabase(':memory:') + db.create_doc_from_json(tests.simple_doc) + self.assertRaises( + errors.Unavailable, c_backend_wrapper.sync_db_to_target, db, + target) + self.assertEqual(5, len(tries)) + + def test_unavailable_then_available(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + orig_whatschanged = mem_db.whats_changed + tries = [] + + def wrapper(instance, *args, **kwargs): + if len(tries) < 1: + tries.append(None) + raise errors.Unavailable + return orig_whatschanged(instance, *args, **kwargs) + + mem_db.whats_changed = wrapper + url = self.getURL('test.db') + target = c_backend_wrapper.create_http_sync_target(url) + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + c_backend_wrapper.sync_db_to_target(db, target) + self.assertEqual(1, len(tries)) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + def test_db_sync(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + url = self.getURL('test.db') + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + local_gen_before_sync = db.sync(url) + gen, _, changes = db.whats_changed(local_gen_before_sync) + self.assertEqual(1, len(changes)) + self.assertEqual(mem_doc.doc_id, changes[0][0]) + self.assertEqual(1, gen - local_gen_before_sync) + self.assertEqual(1, local_gen_before_sync) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + +class TestSyncCtoOAuthHTTPViaC(tests.TestCaseWithServer): + + make_app_with_state = staticmethod(make_oauth_http_app) + + def setUp(self): + super(TestSyncCtoOAuthHTTPViaC, self).setUp() + if c_backend_wrapper is None: + self.skipTest("The c_backend_wrapper could not be imported") + self.startServer() + + def test_trivial_sync(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + url = self.getURL('~/test.db') + target = c_backend_wrapper.create_oauth_http_sync_target(url, + tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + c_backend_wrapper.sync_db_to_target(db, target) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + +class TestVectorClock(BackendTests): + + def create_vcr(self, rev): + return c_backend_wrapper.VectorClockRev(rev) + + def test_parse_empty(self): + self.assertEqual('VectorClockRev()', + repr(self.create_vcr(''))) + + def test_parse_invalid(self): + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('x'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('x:a'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:a'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('x:a|y:1'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2a'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1||'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2|'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2|:'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2|m:'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:|m:3'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|:|m:3'))) + + def test_parse_single(self): + self.assertEqual('VectorClockRev(test:1)', + repr(self.create_vcr('test:1'))) + + def test_parse_multi(self): + self.assertEqual('VectorClockRev(test:1|z:2)', + repr(self.create_vcr('test:1|z:2'))) + self.assertEqual('VectorClockRev(ab:1|bc:2|cd:3|de:4|ef:5)', + repr(self.create_vcr('ab:1|bc:2|cd:3|de:4|ef:5'))) + self.assertEqual('VectorClockRev(a:2|b:1)', + repr(self.create_vcr('b:1|a:2'))) + + +class TestCDocument(BackendTests): + + def make_document(self, *args, **kwargs): + return c_backend_wrapper.make_document(*args, **kwargs) + + def test_create(self): + self.make_document('doc-id', 'uid:1', tests.simple_doc) + + def assertPyDocEqualCDoc(self, *args, **kwargs): + cdoc = self.make_document(*args, **kwargs) + pydoc = Document(*args, **kwargs) + self.assertEqual(pydoc, cdoc) + self.assertEqual(cdoc, pydoc) + + def test_cmp_to_pydoc_equal(self): + self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc) + self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=False) + self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=True) + + def test_cmp_to_pydoc_not_equal_conflicts(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=True) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + def test_cmp_to_pydoc_not_equal_doc_id(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc2-id', 'uid:1', tests.simple_doc) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + def test_cmp_to_pydoc_not_equal_doc_rev(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc-id', 'uid:2', tests.simple_doc) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + def test_cmp_to_pydoc_not_equal_content(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc-id', 'uid:1', tests.nested_doc) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + +class TestUUID(BackendTests): + + def test_uuid4_conformance(self): + uuids = set() + for i in range(20): + uuid = c_backend_wrapper.generate_hex_uuid() + self.assertIsInstance(uuid, str) + self.assertEqual(32, len(uuid)) + # This will raise ValueError if it isn't a valid hex string + long(uuid, 16) + # Version 4 uuids have 2 other requirements, the high 4 bits of the + # seventh byte are always '0x4', and the middle bits of byte 9 are + # always set + self.assertEqual('4', uuid[12]) + self.assertTrue(uuid[16] in '89ab') + self.assertTrue(uuid not in uuids) + uuids.add(uuid) diff --git a/src/leap/soledad/u1db/tests/test_common_backend.py b/src/leap/soledad/u1db/tests/test_common_backend.py new file mode 100644 index 00000000..8c7c7ed9 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_common_backend.py @@ -0,0 +1,33 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test common backend bits.""" + +from u1db import ( + backends, + tests, + ) + + +class TestCommonBackendImpl(tests.TestCase): + + def test__allocate_doc_id(self): + db = backends.CommonBackend() + doc_id1 = db._allocate_doc_id() + self.assertTrue(doc_id1.startswith('D-')) + self.assertEqual(34, len(doc_id1)) + int(doc_id1[len('D-'):], 16) + self.assertNotEqual(doc_id1, db._allocate_doc_id()) diff --git a/src/leap/soledad/u1db/tests/test_document.py b/src/leap/soledad/u1db/tests/test_document.py new file mode 100644 index 00000000..20f254b9 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_document.py @@ -0,0 +1,148 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + + +from u1db import errors, tests + + +class TestDocument(tests.TestCase): + + scenarios = ([( + 'py', {'make_document_for_test': tests.make_document_for_test})] + + tests.C_DATABASE_SCENARIOS) + + def test_create_doc(self): + doc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + self.assertEqual('doc-id', doc.doc_id) + self.assertEqual('uid:1', doc.rev) + self.assertEqual(tests.simple_doc, doc.get_json()) + self.assertFalse(doc.has_conflicts) + + def test__repr__(self): + doc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + self.assertEqual( + '%s(doc-id, uid:1, \'{"key": "value"}\')' + % (doc.__class__.__name__,), + repr(doc)) + + def test__repr__conflicted(self): + doc = self.make_document('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=True) + self.assertEqual( + '%s(doc-id, uid:1, conflicted, \'{"key": "value"}\')' + % (doc.__class__.__name__,), + repr(doc)) + + def test__lt__(self): + doc_a = self.make_document('a', 'b', '{}') + doc_b = self.make_document('b', 'b', '{}') + self.assertTrue(doc_a < doc_b) + self.assertTrue(doc_b > doc_a) + doc_aa = self.make_document('a', 'a', '{}') + self.assertTrue(doc_aa < doc_a) + + def test__eq__(self): + doc_a = self.make_document('a', 'b', '{}') + doc_b = self.make_document('a', 'b', '{}') + self.assertTrue(doc_a == doc_b) + doc_b = self.make_document('a', 'b', '{}', has_conflicts=True) + self.assertFalse(doc_a == doc_b) + + def test_non_json_dict(self): + self.assertRaises( + errors.InvalidJSON, self.make_document, 'id', 'uid:1', + '"not a json dictionary"') + + def test_non_json(self): + self.assertRaises( + errors.InvalidJSON, self.make_document, 'id', 'uid:1', + 'not a json dictionary') + + def test_get_size(self): + doc_a = self.make_document('a', 'b', '{"some": "content"}') + self.assertEqual( + len('a' + 'b' + '{"some": "content"}'), doc_a.get_size()) + + def test_get_size_empty_document(self): + doc_a = self.make_document('a', 'b', None) + self.assertEqual(len('a' + 'b'), doc_a.get_size()) + + +class TestPyDocument(tests.TestCase): + + scenarios = ([( + 'py', {'make_document_for_test': tests.make_document_for_test})]) + + def test_get_content(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertEqual({"content": ""}, doc.content) + doc.set_json('{"content": "new"}') + self.assertEqual({"content": "new"}, doc.content) + + def test_set_content(self): + doc = self.make_document('id', 'rev', '{"content":""}') + doc.content = {"content": "new"} + self.assertEqual('{"content": "new"}', doc.get_json()) + + def test_set_bad_content(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertRaises( + errors.InvalidContent, setattr, doc, 'content', + '{"content": "new"}') + + def test_is_tombstone(self): + doc_a = self.make_document('a', 'b', '{}') + self.assertFalse(doc_a.is_tombstone()) + doc_a.set_json(None) + self.assertTrue(doc_a.is_tombstone()) + + def test_make_tombstone(self): + doc_a = self.make_document('a', 'b', '{}') + self.assertFalse(doc_a.is_tombstone()) + doc_a.make_tombstone() + self.assertTrue(doc_a.is_tombstone()) + + def test_same_content_as(self): + doc_a = self.make_document('a', 'b', '{}') + doc_b = self.make_document('d', 'e', '{}') + self.assertTrue(doc_a.same_content_as(doc_b)) + doc_b = self.make_document('p', 'q', '{}', has_conflicts=True) + self.assertTrue(doc_a.same_content_as(doc_b)) + doc_b.content['key'] = 'value' + self.assertFalse(doc_a.same_content_as(doc_b)) + + def test_same_content_as_json_order(self): + doc_a = self.make_document( + 'a', 'b', '{"key1": "val1", "key2": "val2"}') + doc_b = self.make_document( + 'c', 'd', '{"key2": "val2", "key1": "val1"}') + self.assertTrue(doc_a.same_content_as(doc_b)) + + def test_set_json(self): + doc = self.make_document('id', 'rev', '{"content":""}') + doc.set_json('{"content": "new"}') + self.assertEqual('{"content": "new"}', doc.get_json()) + + def test_set_json_non_dict(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertRaises(errors.InvalidJSON, doc.set_json, '"is not a dict"') + + def test_set_json_error(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertRaises(errors.InvalidJSON, doc.set_json, 'is not json') + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_errors.py b/src/leap/soledad/u1db/tests/test_errors.py new file mode 100644 index 00000000..0e089ede --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_errors.py @@ -0,0 +1,61 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests error infrastructure.""" + +from u1db import ( + errors, + tests, + ) + + +class TestError(tests.TestCase): + + def test_error_base(self): + err = errors.U1DBError() + self.assertEqual("error", err.wire_description) + self.assertIs(None, err.message) + + err = errors.U1DBError("Message.") + self.assertEqual("error", err.wire_description) + self.assertEqual("Message.", err.message) + + def test_HTTPError(self): + err = errors.HTTPError(500) + self.assertEqual(500, err.status) + self.assertIs(None, err.wire_description) + self.assertIs(None, err.message) + + err = errors.HTTPError(500, "Crash.") + self.assertEqual(500, err.status) + self.assertIs(None, err.wire_description) + self.assertEqual("Crash.", err.message) + + def test_HTTPError_str(self): + err = errors.HTTPError(500) + self.assertEqual("HTTPError(500)", str(err)) + + err = errors.HTTPError(500, "ERROR") + self.assertEqual("HTTPError(500, 'ERROR')", str(err)) + + def test_Unvailable(self): + err = errors.Unavailable() + self.assertEqual(503, err.status) + self.assertEqual("Unavailable()", str(err)) + + err = errors.Unavailable("DOWN") + self.assertEqual("DOWN", err.message) + self.assertEqual("Unavailable('DOWN')", str(err)) diff --git a/src/leap/soledad/u1db/tests/test_http_app.py b/src/leap/soledad/u1db/tests/test_http_app.py new file mode 100644 index 00000000..13522693 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_http_app.py @@ -0,0 +1,1133 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test the WSGI app.""" + +import paste.fixture +import sys +try: + import simplejson as json +except ImportError: + import json # noqa +import StringIO + +from u1db import ( + __version__ as _u1db_version, + errors, + sync, + tests, + ) + +from u1db.remote import ( + http_app, + http_errors, + ) + + +class TestFencedReader(tests.TestCase): + + def test_init(self): + reader = http_app._FencedReader(StringIO.StringIO(""), 25, 100) + self.assertEqual(25, reader.remaining) + + def test_read_chunk(self): + inp = StringIO.StringIO("abcdef") + reader = http_app._FencedReader(inp, 5, 10) + data = reader.read_chunk(2) + self.assertEqual("ab", data) + self.assertEqual(2, inp.tell()) + self.assertEqual(3, reader.remaining) + + def test_read_chunk_remaining(self): + inp = StringIO.StringIO("abcdef") + reader = http_app._FencedReader(inp, 4, 10) + data = reader.read_chunk(9999) + self.assertEqual("abcd", data) + self.assertEqual(4, inp.tell()) + self.assertEqual(0, reader.remaining) + + def test_read_chunk_nothing_left(self): + inp = StringIO.StringIO("abc") + reader = http_app._FencedReader(inp, 2, 10) + reader.read_chunk(2) + self.assertEqual(2, inp.tell()) + self.assertEqual(0, reader.remaining) + data = reader.read_chunk(2) + self.assertEqual("", data) + self.assertEqual(2, inp.tell()) + self.assertEqual(0, reader.remaining) + + def test_read_chunk_kept(self): + inp = StringIO.StringIO("abcde") + reader = http_app._FencedReader(inp, 4, 10) + reader._kept = "xyz" + data = reader.read_chunk(2) # atmost ignored + self.assertEqual("xyz", data) + self.assertEqual(0, inp.tell()) + self.assertEqual(4, reader.remaining) + self.assertIsNone(reader._kept) + + def test_getline(self): + inp = StringIO.StringIO("abc\r\nde") + reader = http_app._FencedReader(inp, 6, 10) + reader.MAXCHUNK = 6 + line = reader.getline() + self.assertEqual("abc\r\n", line) + self.assertEqual("d", reader._kept) + + def test_getline_exact(self): + inp = StringIO.StringIO("abcd\r\nef") + reader = http_app._FencedReader(inp, 6, 10) + reader.MAXCHUNK = 6 + line = reader.getline() + self.assertEqual("abcd\r\n", line) + self.assertIs(None, reader._kept) + + def test_getline_no_newline(self): + inp = StringIO.StringIO("abcd") + reader = http_app._FencedReader(inp, 4, 10) + reader.MAXCHUNK = 6 + line = reader.getline() + self.assertEqual("abcd", line) + + def test_getline_many_chunks(self): + inp = StringIO.StringIO("abcde\r\nf") + reader = http_app._FencedReader(inp, 8, 10) + reader.MAXCHUNK = 4 + line = reader.getline() + self.assertEqual("abcde\r\n", line) + self.assertEqual("f", reader._kept) + line = reader.getline() + self.assertEqual("f", line) + + def test_getline_empty(self): + inp = StringIO.StringIO("") + reader = http_app._FencedReader(inp, 0, 10) + reader.MAXCHUNK = 4 + line = reader.getline() + self.assertEqual("", line) + line = reader.getline() + self.assertEqual("", line) + + def test_getline_just_newline(self): + inp = StringIO.StringIO("\r\n") + reader = http_app._FencedReader(inp, 2, 10) + reader.MAXCHUNK = 4 + line = reader.getline() + self.assertEqual("\r\n", line) + line = reader.getline() + self.assertEqual("", line) + + def test_getline_too_large(self): + inp = StringIO.StringIO("x" * 50) + reader = http_app._FencedReader(inp, 50, 25) + reader.MAXCHUNK = 4 + self.assertRaises(http_app.BadRequest, reader.getline) + + def test_getline_too_large_complete(self): + inp = StringIO.StringIO("x" * 25 + "\r\n") + reader = http_app._FencedReader(inp, 50, 25) + reader.MAXCHUNK = 4 + self.assertRaises(http_app.BadRequest, reader.getline) + + +class TestHTTPMethodDecorator(tests.TestCase): + + def test_args(self): + @http_app.http_method() + def f(self, a, b): + return self, a, b + res = f("self", {"a": "x", "b": "y"}, None) + self.assertEqual(("self", "x", "y"), res) + + def test_args_missing(self): + @http_app.http_method() + def f(self, a, b): + return a, b + self.assertRaises(http_app.BadRequest, f, "self", {"a": "x"}, None) + + def test_args_unexpected(self): + @http_app.http_method() + def f(self, a): + return a + self.assertRaises(http_app.BadRequest, f, "self", + {"a": "x", "c": "z"}, None) + + def test_args_default(self): + @http_app.http_method() + def f(self, a, b="z"): + return a, b + res = f("self", {"a": "x"}, None) + self.assertEqual(("x", "z"), res) + + def test_args_conversion(self): + @http_app.http_method(b=int) + def f(self, a, b): + return self, a, b + res = f("self", {"a": "x", "b": "2"}, None) + self.assertEqual(("self", "x", 2), res) + + self.assertRaises(http_app.BadRequest, f, "self", + {"a": "x", "b": "foo"}, None) + + def test_args_conversion_with_default(self): + @http_app.http_method(b=str) + def f(self, a, b=None): + return self, a, b + res = f("self", {"a": "x"}, None) + self.assertEqual(("self", "x", None), res) + + def test_args_content(self): + @http_app.http_method() + def f(self, a, content): + return a, content + res = f(self, {"a": "x"}, "CONTENT") + self.assertEqual(("x", "CONTENT"), res) + + def test_args_content_as_args(self): + @http_app.http_method(b=int, content_as_args=True) + def f(self, a, b): + return self, a, b + res = f("self", {"a": "x"}, '{"b": "2"}') + self.assertEqual(("self", "x", 2), res) + + self.assertRaises(http_app.BadRequest, f, "self", {}, 'not-json') + + def test_args_content_no_query(self): + @http_app.http_method(no_query=True, + content_as_args=True) + def f(self, a='a', b='b'): + return a, b + res = f("self", {}, '{"b": "y"}') + self.assertEqual(('a', 'y'), res) + + self.assertRaises(http_app.BadRequest, f, "self", {'a': 'x'}, + '{"b": "y"}') + + +class TestResource(object): + + @http_app.http_method() + def get(self, a, b): + self.args = dict(a=a, b=b) + return 'Get' + + @http_app.http_method() + def put(self, a, content): + self.args = dict(a=a) + self.content = content + return 'Put' + + @http_app.http_method(content_as_args=True) + def put_args(self, a, b): + self.args = dict(a=a, b=b) + self.order = ['a'] + self.entries = [] + + @http_app.http_method() + def put_stream_entry(self, content): + self.entries.append(content) + self.order.append('s') + + def put_end(self): + self.order.append('e') + return "Put/end" + + +class parameters: + max_request_size = 200000 + max_entry_size = 100000 + + +class TestHTTPInvocationByMethodWithBody(tests.TestCase): + + def test_get(self): + resource = TestResource() + environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'GET'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + res = invoke() + self.assertEqual('Get', res) + self.assertEqual({'a': '1', 'b': '2'}, resource.args) + + def test_put_json(self): + resource = TestResource() + body = '{"body": true}' + environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + res = invoke() + self.assertEqual('Put', res) + self.assertEqual({'a': '1'}, resource.args) + self.assertEqual('{"body": true}', resource.content) + + def test_put_sync_stream(self): + resource = TestResource() + body = ( + '[\r\n' + '{"b": 2},\r\n' # args + '{"entry": "x"},\r\n' # stream entry + '{"entry": "y"}\r\n' # stream entry + ']' + ) + environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/x-u1db-sync-stream'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + res = invoke() + self.assertEqual('Put/end', res) + self.assertEqual({'a': '1', 'b': 2}, resource.args) + self.assertEqual( + ['{"entry": "x"}', '{"entry": "y"}'], resource.entries) + self.assertEqual(['a', 's', 's', 'e'], resource.order) + + def _put_sync_stream(self, body): + resource = TestResource() + environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/x-u1db-sync-stream'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + invoke() + + def test_put_sync_stream_wrong_start(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "{}\r\n]") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "\r\n{}\r\n]") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "") + + def test_put_sync_stream_wrong_end(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{}") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{}\r\n]\r\n...") + + def test_put_sync_stream_missing_comma(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{}\r\n{}\r\n]") + + def test_put_sync_stream_extra_comma(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{},\r\n]") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{},\r\n{},\r\n]") + + def test_bad_request_decode_failure(self): + resource = TestResource() + environ = {'QUERY_STRING': 'a=\xff', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '2', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_content_type(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '2', + 'CONTENT_TYPE': 'text/plain'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_content_length_too_large(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '10000', + 'CONTENT_TYPE': 'text/plain'} + + resource.max_request_size = 5000 + resource.max_entry_size = sys.maxint # we don't get to use this + + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_no_content_length(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('a'), + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_invalid_content_length(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('abc'), + 'CONTENT_LENGTH': '1unk', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_empty_body(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(''), + 'CONTENT_LENGTH': '0', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_method_get_like(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'DELETE'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_method_put_like(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '2', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_method_put_like_multi_json(self): + resource = TestResource() + body = '{}\r\n{}\r\n' + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'POST', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/x-u1db-multi-json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + +class TestHTTPResponder(tests.TestCase): + + def start_response(self, status, headers): + self.status = status + self.headers = dict(headers) + self.response_body = [] + + def write(data): + self.response_body.append(data) + + return write + + def test_send_response_content_w_headers(self): + responder = http_app.HTTPResponder(self.start_response) + responder.send_response_content('foo', headers={'x-a': '1'}) + self.assertEqual('200 OK', self.status) + self.assertEqual({'content-type': 'application/json', + 'cache-control': 'no-cache', + 'x-a': '1', 'content-length': '3'}, self.headers) + self.assertEqual([], self.response_body) + self.assertEqual(['foo'], responder.content) + + def test_send_response_json(self): + responder = http_app.HTTPResponder(self.start_response) + responder.send_response_json(value='success') + self.assertEqual('200 OK', self.status) + expected_body = '{"value": "success"}\r\n' + self.assertEqual({'content-type': 'application/json', + 'content-length': str(len(expected_body)), + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual([], self.response_body) + self.assertEqual([expected_body], responder.content) + + def test_send_response_json_status_fail(self): + responder = http_app.HTTPResponder(self.start_response) + responder.send_response_json(400) + self.assertEqual('400 Bad Request', self.status) + expected_body = '{}\r\n' + self.assertEqual({'content-type': 'application/json', + 'content-length': str(len(expected_body)), + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual([], self.response_body) + self.assertEqual([expected_body], responder.content) + + def test_start_finish_response_status_fail(self): + responder = http_app.HTTPResponder(self.start_response) + responder.start_response(404, {'error': 'not found'}) + responder.finish_response() + self.assertEqual('404 Not Found', self.status) + self.assertEqual({'content-type': 'application/json', + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual(['{"error": "not found"}\r\n'], self.response_body) + self.assertEqual([], responder.content) + + def test_send_stream_entry(self): + responder = http_app.HTTPResponder(self.start_response) + responder.content_type = "application/x-u1db-multi-json" + responder.start_response(200) + responder.start_stream() + responder.stream_entry({'entry': 1}) + responder.stream_entry({'entry': 2}) + responder.end_stream() + responder.finish_response() + self.assertEqual('200 OK', self.status) + self.assertEqual({'content-type': 'application/x-u1db-multi-json', + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual(['[', + '\r\n', '{"entry": 1}', + ',\r\n', '{"entry": 2}', + '\r\n]\r\n'], self.response_body) + self.assertEqual([], responder.content) + + def test_send_stream_w_error(self): + responder = http_app.HTTPResponder(self.start_response) + responder.content_type = "application/x-u1db-multi-json" + responder.start_response(200) + responder.start_stream() + responder.stream_entry({'entry': 1}) + responder.send_response_json(503, error="unavailable") + self.assertEqual('200 OK', self.status) + self.assertEqual({'content-type': 'application/x-u1db-multi-json', + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual(['[', + '\r\n', '{"entry": 1}'], self.response_body) + self.assertEqual([',\r\n', '{"error": "unavailable"}\r\n'], + responder.content) + + +class TestHTTPApp(tests.TestCase): + + def setUp(self): + super(TestHTTPApp, self).setUp() + self.state = tests.ServerStateForTests() + self.http_app = http_app.HTTPApp(self.state) + self.app = paste.fixture.TestApp(self.http_app) + self.db0 = self.state._create_database('db0') + + def test_bad_request_broken(self): + resp = self.app.put('/db0/doc/doc1', params='{"x": 1}', + headers={'content-type': 'application/foo'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_bad_request_dispatch(self): + resp = self.app.put('/db0/foo/doc1', params='{"x": 1}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_version(self): + resp = self.app.get('/') + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"version": _u1db_version}, json.loads(resp.body)) + + def test_create_database(self): + resp = self.app.put('/db1', params='{}', + headers={'content-type': 'application/json'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'ok': True}, json.loads(resp.body)) + + resp = self.app.put('/db1', params='{}', + headers={'content-type': 'application/json'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'ok': True}, json.loads(resp.body)) + + def test_delete_database(self): + resp = self.app.delete('/db0') + self.assertEqual(200, resp.status) + self.assertRaises(errors.DatabaseDoesNotExist, + self.state.check_database, 'db0') + + def test_get_database(self): + resp = self.app.get('/db0') + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({}, json.loads(resp.body)) + + def test_valid_database_names(self): + resp = self.app.get('/a-database', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/db1', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/0', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/0-0', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/org.future', expect_errors=True) + self.assertEqual(404, resp.status) + + def test_invalid_database_names(self): + resp = self.app.get('/.a', expect_errors=True) + self.assertEqual(400, resp.status) + + resp = self.app.get('/-a', expect_errors=True) + self.assertEqual(400, resp.status) + + resp = self.app.get('/_a', expect_errors=True) + self.assertEqual(400, resp.status) + + def test_put_doc_create(self): + resp = self.app.put('/db0/doc/doc1', params='{"x": 1}', + headers={'content-type': 'application/json'}) + doc = self.db0.get_doc('doc1') + self.assertEqual(201, resp.status) # created + self.assertEqual('{"x": 1}', doc.get_json()) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) + + def test_put_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev, + params='{"x": 2}', + headers={'content-type': 'application/json'}) + doc = self.db0.get_doc('doc1') + self.assertEqual(200, resp.status) + self.assertEqual('{"x": 2}', doc.get_json()) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) + + def test_put_doc_too_large(self): + self.http_app.max_request_size = 15000 + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev, + params='{"%s": 2}' % ('z' * 16000), + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_delete_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.delete('/db0/doc/doc1?old_rev=%s' % doc.rev) + doc = self.db0.get_doc('doc1', include_deleted=True) + self.assertEqual(None, doc.content) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) + + def test_get_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.get('/db0/doc/%s' % doc.doc_id) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"x": 1}', resp.body) + self.assertEqual(doc.rev, resp.header('x-u1db-rev')) + self.assertEqual('false', resp.header('x-u1db-has-conflicts')) + + def test_get_doc_non_existing(self): + resp = self.app.get('/db0/doc/not-there', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "document does not exist"}, json.loads(resp.body)) + self.assertEqual('', resp.header('x-u1db-rev')) + self.assertEqual('false', resp.header('x-u1db-has-conflicts')) + + def test_get_doc_deleted(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.db0.delete_doc(doc) + resp = self.app.get('/db0/doc/doc1', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": errors.DocumentDoesNotExist.wire_description}, + json.loads(resp.body)) + + def test_get_doc_deleted_explicit_exclude(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.db0.delete_doc(doc) + resp = self.app.get( + '/db0/doc/doc1?include_deleted=false', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": errors.DocumentDoesNotExist.wire_description}, + json.loads(resp.body)) + + def test_get_deleted_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.db0.delete_doc(doc) + resp = self.app.get( + '/db0/doc/doc1?include_deleted=true', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": errors.DOCUMENT_DELETED}, json.loads(resp.body)) + self.assertEqual(doc.rev, resp.header('x-u1db-rev')) + self.assertEqual('false', resp.header('x-u1db-has-conflicts')) + + def test_get_doc_non_existing_dabase(self): + resp = self.app.get('/not-there/doc/doc1', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "database does not exist"}, json.loads(resp.body)) + + def test_get_docs(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", + "has_conflicts": False}, + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_docs_missing_doc_ids(self): + resp = self.app.get('/db0/docs', expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "missing document ids"}, json.loads(resp.body)) + + def test_get_docs_empty_doc_ids(self): + resp = self.app.get('/db0/docs?doc_ids=', expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "missing document ids"}, json.loads(resp.body)) + + def test_get_docs_percent(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc%1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc%1", + "has_conflicts": False}, + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_docs_deleted(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + self.db0.delete_doc(doc2) + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_docs_include_deleted(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + self.db0.delete_doc(doc2) + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s&include_deleted=true' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", + "has_conflicts": False}, + {"content": None, "doc_rev": "db0:2", "doc_id": "doc2", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_sync_info(self): + self.db0._set_replica_gen_and_trans_id('other-id', 1, 'T-transid') + resp = self.app.get('/db0/sync-from/other-id') + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual(dict(target_replica_uid='db0', + target_replica_generation=0, + target_replica_transaction_id='', + source_replica_uid='other-id', + source_replica_generation=1, + source_transaction_id='T-transid'), + json.loads(resp.body)) + + def test_record_sync_info(self): + resp = self.app.put('/db0/sync-from/other-id', + params='{"generation": 2, "transaction_id": "T-transid"}', + headers={'content-type': 'application/json'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'ok': True}, json.loads(resp.body)) + self.assertEqual( + (2, 'T-transid'), + self.db0._get_replica_gen_and_trans_id('other-id')) + + def test_sync_exchange_send(self): + entries = { + 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': + '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'}, + 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content': + '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'} + } + + gens = [] + _do_set_replica_gen_and_trans_id = \ + self.db0._do_set_replica_gen_and_trans_id + + def set_sync_generation_witness(other_uid, other_gen, other_trans_id): + gens.append((other_uid, other_gen)) + _do_set_replica_gen_and_trans_id( + other_uid, other_gen, other_trans_id) + self.assertGetDoc(self.db0, entries[other_gen]['id'], + entries[other_gen]['rev'], + entries[other_gen]['content'], False) + + self.patch( + self.db0, '_do_set_replica_gen_and_trans_id', + set_sync_generation_witness) + + args = dict(last_known_generation=0) + body = ("[\r\n" + + "%s,\r\n" % json.dumps(args) + + "%s,\r\n" % json.dumps(entries[10]) + + "%s\r\n" % json.dumps(entries[11]) + + "]\r\n") + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + bits = resp.body.split('\r\n') + self.assertEqual('[', bits[0]) + last_trans_id = self.db0._get_transaction_log()[-1][1] + self.assertEqual({'new_generation': 2, + 'new_transaction_id': last_trans_id}, + json.loads(bits[1])) + self.assertEqual(']', bits[2]) + self.assertEqual('', bits[3]) + self.assertEqual([('replica', 10), ('replica', 11)], gens) + + def test_sync_exchange_send_ensure(self): + entries = { + 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': + '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'}, + 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content': + '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'} + } + + args = dict(last_known_generation=0, ensure=True) + body = ("[\r\n" + + "%s,\r\n" % json.dumps(args) + + "%s,\r\n" % json.dumps(entries[10]) + + "%s\r\n" % json.dumps(entries[11]) + + "]\r\n") + resp = self.app.post('/dbnew/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + bits = resp.body.split('\r\n') + self.assertEqual('[', bits[0]) + dbnew = self.state.open_database("dbnew") + last_trans_id = dbnew._get_transaction_log()[-1][1] + self.assertEqual({'new_generation': 2, + 'new_transaction_id': last_trans_id, + 'replica_uid': dbnew._replica_uid}, + json.loads(bits[1])) + self.assertEqual(']', bits[2]) + self.assertEqual('', bits[3]) + + def test_sync_exchange_send_entry_too_large(self): + self.patch(http_app.SyncResource, 'max_request_size', 20000) + self.patch(http_app.SyncResource, 'max_entry_size', 10000) + entries = { + 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': + '{"value": "%s"}' % ('H' * 11000), 'gen': 10}, + } + args = dict(last_known_generation=0) + body = ("[\r\n" + + "%s,\r\n" % json.dumps(args) + + "%s\r\n" % json.dumps(entries[10]) + + "]\r\n") + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_sync_exchange_receive(self): + doc = self.db0.create_doc_from_json('{"value": "there"}') + doc2 = self.db0.create_doc_from_json('{"value": "there2"}') + args = dict(last_known_generation=0) + body = "[\r\n%s\r\n]" % json.dumps(args) + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + parts = resp.body.splitlines() + self.assertEqual(5, len(parts)) + self.assertEqual('[', parts[0]) + last_trans_id = self.db0._get_transaction_log()[-1][1] + self.assertEqual({'new_generation': 2, + 'new_transaction_id': last_trans_id}, + json.loads(parts[1].rstrip(","))) + part2 = json.loads(parts[2].rstrip(",")) + self.assertTrue(part2['trans_id'].startswith('T-')) + self.assertEqual('{"value": "there"}', part2['content']) + self.assertEqual(doc.rev, part2['rev']) + self.assertEqual(doc.doc_id, part2['id']) + self.assertEqual(1, part2['gen']) + part3 = json.loads(parts[3].rstrip(",")) + self.assertTrue(part3['trans_id'].startswith('T-')) + self.assertEqual('{"value": "there2"}', part3['content']) + self.assertEqual(doc2.rev, part3['rev']) + self.assertEqual(doc2.doc_id, part3['id']) + self.assertEqual(2, part3['gen']) + self.assertEqual(']', parts[4]) + + def test_sync_exchange_error_in_stream(self): + args = dict(last_known_generation=0) + body = "[\r\n%s\r\n]" % json.dumps(args) + + def boom(self, return_doc_cb): + raise errors.Unavailable + + self.patch(sync.SyncExchange, 'return_docs', + boom) + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + parts = resp.body.splitlines() + self.assertEqual(3, len(parts)) + self.assertEqual('[', parts[0]) + self.assertEqual({'new_generation': 0, 'new_transaction_id': ''}, + json.loads(parts[1].rstrip(","))) + self.assertEqual({'error': 'unavailable'}, json.loads(parts[2])) + + +class TestRequestHooks(tests.TestCase): + + def setUp(self): + super(TestRequestHooks, self).setUp() + self.state = tests.ServerStateForTests() + self.http_app = http_app.HTTPApp(self.state) + self.app = paste.fixture.TestApp(self.http_app) + self.db0 = self.state._create_database('db0') + + def test_begin_and_done(self): + calls = [] + + def begin(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('begin') + + def done(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('done') + + self.http_app.request_begin = begin + self.http_app.request_done = done + + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.app.get('/db0/doc/%s' % doc.doc_id) + + self.assertEqual(['begin', 'done'], calls) + + def test_bad_request(self): + calls = [] + + def begin(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('begin') + + def bad_request(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('bad-request') + + self.http_app.request_begin = begin + self.http_app.request_bad_request = bad_request + # shouldn't be called + self.http_app.request_done = lambda env: 1 / 0 + + resp = self.app.put('/db0/foo/doc1', params='{"x": 1}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual(['begin', 'bad-request'], calls) + + +class TestHTTPErrors(tests.TestCase): + + def test_wire_description_to_status(self): + self.assertNotIn("error", http_errors.wire_description_to_status) + + +class TestHTTPAppErrorHandling(tests.TestCase): + + def setUp(self): + super(TestHTTPAppErrorHandling, self).setUp() + self.exc = None + self.state = tests.ServerStateForTests() + + class ErroringResource(object): + + def post(_, args, content): + raise self.exc + + def lookup_resource(environ, responder): + return ErroringResource() + + self.http_app = http_app.HTTPApp(self.state) + self.http_app._lookup_resource = lookup_resource + self.app = paste.fixture.TestApp(self.http_app) + + def test_RevisionConflict_etc(self): + self.exc = errors.RevisionConflict() + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(409, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"error": "revision conflict"}, + json.loads(resp.body)) + + def test_Unavailable(self): + self.exc = errors.Unavailable + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(503, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"error": "unavailable"}, + json.loads(resp.body)) + + def test_generic_u1db_errors(self): + self.exc = errors.U1DBError() + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(500, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"error": "error"}, + json.loads(resp.body)) + + def test_generic_u1db_errors_hooks(self): + calls = [] + + def begin(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('begin') + + def u1db_error(environ, exc): + self.assertTrue('PATH_INFO' in environ) + calls.append(('error', exc)) + + self.http_app.request_begin = begin + self.http_app.request_u1db_error = u1db_error + # shouldn't be called + self.http_app.request_done = lambda env: 1 / 0 + + self.exc = errors.U1DBError() + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(500, resp.status) + self.assertEqual(['begin', ('error', self.exc)], calls) + + def test_failure(self): + class Failure(Exception): + pass + self.exc = Failure() + self.assertRaises(Failure, self.app.post, '/req', params='{}', + headers={'content-type': 'application/json'}) + + def test_failure_hooks(self): + class Failure(Exception): + pass + calls = [] + + def begin(environ): + calls.append('begin') + + def failed(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append(('failed', sys.exc_info())) + + self.http_app.request_begin = begin + self.http_app.request_failed = failed + # shouldn't be called + self.http_app.request_done = lambda env: 1 / 0 + + self.exc = Failure() + self.assertRaises(Failure, self.app.post, '/req', params='{}', + headers={'content-type': 'application/json'}) + + self.assertEqual(2, len(calls)) + self.assertEqual('begin', calls[0]) + marker, (exc_type, exc, tb) = calls[1] + self.assertEqual('failed', marker) + self.assertEqual(self.exc, exc) + + +class TestPluggableSyncExchange(tests.TestCase): + + def setUp(self): + super(TestPluggableSyncExchange, self).setUp() + self.state = tests.ServerStateForTests() + self.state.ensure_database('foo') + + def test_plugging(self): + + class MySyncExchange(object): + def __init__(self, db, source_replica_uid, last_known_generation): + pass + + class MySyncResource(http_app.SyncResource): + sync_exchange_class = MySyncExchange + + sync_res = MySyncResource('foo', 'src', self.state, None) + sync_res.post_args( + {'last_known_generation': 0, 'last_known_trans_id': None}, '{}') + self.assertIsInstance(sync_res.sync_exch, MySyncExchange) diff --git a/src/leap/soledad/u1db/tests/test_http_client.py b/src/leap/soledad/u1db/tests/test_http_client.py new file mode 100644 index 00000000..115c8aaa --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_http_client.py @@ -0,0 +1,361 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests for HTTPDatabase""" + +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + errors, + tests, + ) +from u1db.remote import ( + http_client, + ) + + +class TestEncoder(tests.TestCase): + + def test_encode_string(self): + self.assertEqual("foo", http_client._encode_query_parameter("foo")) + + def test_encode_true(self): + self.assertEqual("true", http_client._encode_query_parameter(True)) + + def test_encode_false(self): + self.assertEqual("false", http_client._encode_query_parameter(False)) + + +class TestHTTPClientBase(tests.TestCaseWithServer): + + def setUp(self): + super(TestHTTPClientBase, self).setUp() + self.errors = 0 + + def app(self, environ, start_response): + if environ['PATH_INFO'].endswith('echo'): + start_response("200 OK", [('Content-Type', 'application/json')]) + ret = {} + for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'): + ret[name] = environ[name] + if environ['REQUEST_METHOD'] in ('PUT', 'POST'): + ret['CONTENT_TYPE'] = environ['CONTENT_TYPE'] + content_length = int(environ['CONTENT_LENGTH']) + ret['body'] = environ['wsgi.input'].read(content_length) + return [json.dumps(ret)] + elif environ['PATH_INFO'].endswith('error_then_accept'): + if self.errors >= 3: + start_response( + "200 OK", [('Content-Type', 'application/json')]) + ret = {} + for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'): + ret[name] = environ[name] + if environ['REQUEST_METHOD'] in ('PUT', 'POST'): + ret['CONTENT_TYPE'] = environ['CONTENT_TYPE'] + content_length = int(environ['CONTENT_LENGTH']) + ret['body'] = '{"oki": "doki"}' + return [json.dumps(ret)] + self.errors += 1 + content_length = int(environ['CONTENT_LENGTH']) + error = json.loads( + environ['wsgi.input'].read(content_length)) + response = error['response'] + # In debug mode, wsgiref has an assertion that the status parameter + # is a 'str' object. However error['status'] returns a unicode + # object. + status = str(error['status']) + if isinstance(response, unicode): + response = str(response) + if isinstance(response, str): + start_response(status, [('Content-Type', 'text/plain')]) + return [str(response)] + else: + start_response(status, [('Content-Type', 'application/json')]) + return [json.dumps(response)] + elif environ['PATH_INFO'].endswith('error'): + self.errors += 1 + content_length = int(environ['CONTENT_LENGTH']) + error = json.loads( + environ['wsgi.input'].read(content_length)) + response = error['response'] + # In debug mode, wsgiref has an assertion that the status parameter + # is a 'str' object. However error['status'] returns a unicode + # object. + status = str(error['status']) + if isinstance(response, unicode): + response = str(response) + if isinstance(response, str): + start_response(status, [('Content-Type', 'text/plain')]) + return [str(response)] + else: + start_response(status, [('Content-Type', 'application/json')]) + return [json.dumps(response)] + elif '/oauth' in environ['PATH_INFO']: + base_url = self.getURL('').rstrip('/') + oauth_req = oauth.OAuthRequest.from_request( + http_method=environ['REQUEST_METHOD'], + http_url=base_url + environ['PATH_INFO'], + headers={'Authorization': environ['HTTP_AUTHORIZATION']}, + query_string=environ['QUERY_STRING'] + ) + oauth_server = oauth.OAuthServer(tests.testingOAuthStore) + oauth_server.add_signature_method(tests.sign_meth_HMAC_SHA1) + try: + consumer, token, params = oauth_server.verify_request( + oauth_req) + except oauth.OAuthError, e: + start_response("401 Unauthorized", + [('Content-Type', 'application/json')]) + return [json.dumps({"error": "unauthorized", + "message": e.message})] + start_response("200 OK", [('Content-Type', 'application/json')]) + return [json.dumps([environ['PATH_INFO'], token.key, params])] + + def make_app(self): + return self.app + + def getClient(self, **kwds): + self.startServer() + return http_client.HTTPClientBase(self.getURL('dbase'), **kwds) + + def test_construct(self): + self.startServer() + url = self.getURL() + cli = http_client.HTTPClientBase(url) + self.assertEqual(url, cli._url.geturl()) + self.assertIs(None, cli._conn) + + def test_parse_url(self): + cli = http_client.HTTPClientBase( + '%s://127.0.0.1:12345/' % self.url_scheme) + self.assertEqual(self.url_scheme, cli._url.scheme) + self.assertEqual('127.0.0.1', cli._url.hostname) + self.assertEqual(12345, cli._url.port) + self.assertEqual('/', cli._url.path) + + def test__ensure_connection(self): + cli = self.getClient() + self.assertIs(None, cli._conn) + cli._ensure_connection() + self.assertIsNot(None, cli._conn) + conn = cli._conn + cli._ensure_connection() + self.assertIs(conn, cli._conn) + + def test_close(self): + cli = self.getClient() + cli._ensure_connection() + cli.close() + self.assertIs(None, cli._conn) + + def test__request(self): + cli = self.getClient() + res, headers = cli._request('PUT', ['echo'], {}, {}) + self.assertEqual({'CONTENT_TYPE': 'application/json', + 'PATH_INFO': '/dbase/echo', + 'QUERY_STRING': '', + 'body': '{}', + 'REQUEST_METHOD': 'PUT'}, json.loads(res)) + + res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1}) + self.assertEqual({'PATH_INFO': '/dbase/doc/echo', + 'QUERY_STRING': 'a=1', + 'REQUEST_METHOD': 'GET'}, json.loads(res)) + + res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1}) + self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo', + 'QUERY_STRING': 'a=1', + 'REQUEST_METHOD': 'GET'}, json.loads(res)) + + res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body', + 'application/x-test') + self.assertEqual({'CONTENT_TYPE': 'application/x-test', + 'PATH_INFO': '/dbase/echo', + 'QUERY_STRING': 'b=2', + 'body': 'Body', + 'REQUEST_METHOD': 'POST'}, json.loads(res)) + + def test__request_json(self): + cli = self.getClient() + res, headers = cli._request_json( + 'POST', ['echo'], {'b': 2}, {'a': 'x'}) + self.assertEqual('application/json', headers['content-type']) + self.assertEqual({'CONTENT_TYPE': 'application/json', + 'PATH_INFO': '/dbase/echo', + 'QUERY_STRING': 'b=2', + 'body': '{"a": "x"}', + 'REQUEST_METHOD': 'POST'}, res) + + def test_unspecified_http_error(self): + cli = self.getClient() + self.assertRaises(errors.HTTPError, + cli._request_json, 'POST', ['error'], {}, + {'status': "500 Internal Error", + 'response': "Crash."}) + try: + cli._request_json('POST', ['error'], {}, + {'status': "500 Internal Error", + 'response': "Fail."}) + except errors.HTTPError, e: + pass + + self.assertEqual(500, e.status) + self.assertEqual("Fail.", e.message) + self.assertTrue("content-type" in e.headers) + + def test_revision_conflict(self): + cli = self.getClient() + self.assertRaises(errors.RevisionConflict, + cli._request_json, 'POST', ['error'], {}, + {'status': "409 Conflict", + 'response': {"error": "revision conflict"}}) + + def test_unavailable_proper(self): + cli = self.getClient() + cli._delays = (0, 0, 0, 0, 0) + self.assertRaises(errors.Unavailable, + cli._request_json, 'POST', ['error'], {}, + {'status': "503 Service Unavailable", + 'response': {"error": "unavailable"}}) + self.assertEqual(5, self.errors) + + def test_unavailable_then_available(self): + cli = self.getClient() + cli._delays = (0, 0, 0, 0, 0) + res, headers = cli._request_json( + 'POST', ['error_then_accept'], {'b': 2}, + {'status': "503 Service Unavailable", + 'response': {"error": "unavailable"}}) + self.assertEqual('application/json', headers['content-type']) + self.assertEqual({'CONTENT_TYPE': 'application/json', + 'PATH_INFO': '/dbase/error_then_accept', + 'QUERY_STRING': 'b=2', + 'body': '{"oki": "doki"}', + 'REQUEST_METHOD': 'POST'}, res) + self.assertEqual(3, self.errors) + + def test_unavailable_random_source(self): + cli = self.getClient() + cli._delays = (0, 0, 0, 0, 0) + try: + cli._request_json('POST', ['error'], {}, + {'status': "503 Service Unavailable", + 'response': "random unavailable."}) + except errors.Unavailable, e: + pass + + self.assertEqual(503, e.status) + self.assertEqual("random unavailable.", e.message) + self.assertTrue("content-type" in e.headers) + self.assertEqual(5, self.errors) + + def test_document_too_big(self): + cli = self.getClient() + self.assertRaises(errors.DocumentTooBig, + cli._request_json, 'POST', ['error'], {}, + {'status': "403 Forbidden", + 'response': {"error": "document too big"}}) + + def test_user_quota_exceeded(self): + cli = self.getClient() + self.assertRaises(errors.UserQuotaExceeded, + cli._request_json, 'POST', ['error'], {}, + {'status': "403 Forbidden", + 'response': {"error": "user quota exceeded"}}) + + def test_user_needs_subscription(self): + cli = self.getClient() + self.assertRaises(errors.SubscriptionNeeded, + cli._request_json, 'POST', ['error'], {}, + {'status': "403 Forbidden", + 'response': {"error": "user needs subscription"}}) + + def test_generic_u1db_error(self): + cli = self.getClient() + self.assertRaises(errors.U1DBError, + cli._request_json, 'POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': {"error": "error"}}) + try: + cli._request_json('POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': {"error": "error"}}) + except errors.U1DBError, e: + pass + self.assertIs(e.__class__, errors.U1DBError) + + def test_unspecified_bad_request(self): + cli = self.getClient() + self.assertRaises(errors.HTTPError, + cli._request_json, 'POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': ""}) + try: + cli._request_json('POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': ""}) + except errors.HTTPError, e: + pass + + self.assertEqual(400, e.status) + self.assertEqual("", e.message) + self.assertTrue("content-type" in e.headers) + + def test_oauth(self): + cli = self.getClient() + cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + params = {'x': u'\xf0', 'y': "foo"} + res, headers = cli._request('GET', ['doc', 'oauth'], params) + self.assertEqual( + ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res)) + + # oauth does its own internal quoting + params = {'x': u'\xf0', 'y': "foo"} + res, headers = cli._request('GET', ['doc', 'oauth', 'foo bar'], params) + self.assertEqual( + ['/dbase/doc/oauth/foo bar', tests.token1.key, params], + json.loads(res)) + + def test_oauth_ctr_creds(self): + cli = self.getClient(creds={'oauth': { + 'consumer_key': tests.consumer1.key, + 'consumer_secret': tests.consumer1.secret, + 'token_key': tests.token1.key, + 'token_secret': tests.token1.secret, + }}) + params = {'x': u'\xf0', 'y': "foo"} + res, headers = cli._request('GET', ['doc', 'oauth'], params) + self.assertEqual( + ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res)) + + def test_unknown_creds(self): + self.assertRaises(errors.UnknownAuthMethod, + self.getClient, creds={'foo': {}}) + self.assertRaises(errors.UnknownAuthMethod, + self.getClient, creds={}) + + def test_oauth_Unauthorized(self): + cli = self.getClient() + cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, "WRONG") + params = {'y': 'foo'} + self.assertRaises(errors.Unauthorized, cli._request, 'GET', + ['doc', 'oauth'], params) diff --git a/src/leap/soledad/u1db/tests/test_http_database.py b/src/leap/soledad/u1db/tests/test_http_database.py new file mode 100644 index 00000000..c8e7eb76 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_http_database.py @@ -0,0 +1,256 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests for HTTPDatabase""" + +import inspect +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + errors, + Document, + tests, + ) +from u1db.remote import ( + http_database, + http_target, + ) +from u1db.tests.test_remote_sync_target import ( + make_http_app, +) + + +class TestHTTPDatabaseSimpleOperations(tests.TestCase): + + def setUp(self): + super(TestHTTPDatabaseSimpleOperations, self).setUp() + self.db = http_database.HTTPDatabase('dbase') + self.db._conn = object() # crash if used + self.got = None + self.response_val = None + + def _request(method, url_parts, params=None, body=None, + content_type=None): + self.got = method, url_parts, params, body, content_type + if isinstance(self.response_val, Exception): + raise self.response_val + return self.response_val + + def _request_json(method, url_parts, params=None, body=None, + content_type=None): + self.got = method, url_parts, params, body, content_type + if isinstance(self.response_val, Exception): + raise self.response_val + return self.response_val + + self.db._request = _request + self.db._request_json = _request_json + + def test__sanity_same_signature(self): + my_request_sig = inspect.getargspec(self.db._request) + my_request_sig = (['self'] + my_request_sig[0],) + my_request_sig[1:] + self.assertEqual(my_request_sig, + inspect.getargspec(http_database.HTTPDatabase._request)) + my_request_json_sig = inspect.getargspec(self.db._request_json) + my_request_json_sig = ((['self'] + my_request_json_sig[0],) + + my_request_json_sig[1:]) + self.assertEqual(my_request_json_sig, + inspect.getargspec(http_database.HTTPDatabase._request_json)) + + def test__ensure(self): + self.response_val = {'ok': True}, {} + self.db._ensure() + self.assertEqual(('PUT', [], {}, {}, None), self.got) + + def test__delete(self): + self.response_val = {'ok': True}, {} + self.db._delete() + self.assertEqual(('DELETE', [], {}, {}, None), self.got) + + def test__check(self): + self.response_val = {}, {} + res = self.db._check() + self.assertEqual({}, res) + self.assertEqual(('GET', [], None, None, None), self.got) + + def test_put_doc(self): + self.response_val = {'rev': 'doc-rev'}, {} + doc = Document('doc-id', None, '{"v": 1}') + res = self.db.put_doc(doc) + self.assertEqual('doc-rev', res) + self.assertEqual('doc-rev', doc.rev) + self.assertEqual(('PUT', ['doc', 'doc-id'], {}, + '{"v": 1}', 'application/json'), self.got) + + self.response_val = {'rev': 'doc-rev-2'}, {} + doc.content = {"v": 2} + res = self.db.put_doc(doc) + self.assertEqual('doc-rev-2', res) + self.assertEqual('doc-rev-2', doc.rev) + self.assertEqual(('PUT', ['doc', 'doc-id'], {'old_rev': 'doc-rev'}, + '{"v": 2}', 'application/json'), self.got) + + def test_get_doc(self): + self.response_val = '{"v": 2}', {'x-u1db-rev': 'doc-rev', + 'x-u1db-has-conflicts': 'false'} + self.assertGetDoc(self.db, 'doc-id', 'doc-rev', '{"v": 2}', False) + self.assertEqual( + ('GET', ['doc', 'doc-id'], {'include_deleted': False}, None, None), + self.got) + + def test_get_doc_non_existing(self): + self.response_val = errors.DocumentDoesNotExist() + self.assertIs(None, self.db.get_doc('not-there')) + self.assertEqual( + ('GET', ['doc', 'not-there'], {'include_deleted': False}, None, + None), self.got) + + def test_get_doc_deleted(self): + self.response_val = errors.DocumentDoesNotExist() + self.assertIs(None, self.db.get_doc('deleted')) + self.assertEqual( + ('GET', ['doc', 'deleted'], {'include_deleted': False}, None, + None), self.got) + + def test_get_doc_deleted_include_deleted(self): + self.response_val = errors.HTTPError(404, + json.dumps( + {"error": errors.DOCUMENT_DELETED} + ), + {'x-u1db-rev': 'doc-rev-gone', + 'x-u1db-has-conflicts': 'false'}) + doc = self.db.get_doc('deleted', include_deleted=True) + self.assertEqual('deleted', doc.doc_id) + self.assertEqual('doc-rev-gone', doc.rev) + self.assertIs(None, doc.content) + self.assertEqual( + ('GET', ['doc', 'deleted'], {'include_deleted': True}, None, None), + self.got) + + def test_get_doc_pass_through_errors(self): + self.response_val = errors.HTTPError(500, 'Crash.') + self.assertRaises(errors.HTTPError, + self.db.get_doc, 'something-something') + + def test_create_doc_with_id(self): + self.response_val = {'rev': 'doc-rev'}, {} + new_doc = self.db.create_doc_from_json('{"v": 1}', doc_id='doc-id') + self.assertEqual('doc-rev', new_doc.rev) + self.assertEqual('doc-id', new_doc.doc_id) + self.assertEqual('{"v": 1}', new_doc.get_json()) + self.assertEqual(('PUT', ['doc', 'doc-id'], {}, + '{"v": 1}', 'application/json'), self.got) + + def test_create_doc_without_id(self): + self.response_val = {'rev': 'doc-rev-2'}, {} + new_doc = self.db.create_doc_from_json('{"v": 3}') + self.assertEqual('D-', new_doc.doc_id[:2]) + self.assertEqual('doc-rev-2', new_doc.rev) + self.assertEqual('{"v": 3}', new_doc.get_json()) + self.assertEqual(('PUT', ['doc', new_doc.doc_id], {}, + '{"v": 3}', 'application/json'), self.got) + + def test_delete_doc(self): + self.response_val = {'rev': 'doc-rev-gone'}, {} + doc = Document('doc-id', 'doc-rev', None) + self.db.delete_doc(doc) + self.assertEqual('doc-rev-gone', doc.rev) + self.assertEqual(('DELETE', ['doc', 'doc-id'], {'old_rev': 'doc-rev'}, + None, None), self.got) + + def test_get_sync_target(self): + st = self.db.get_sync_target() + self.assertIsInstance(st, http_target.HTTPSyncTarget) + self.assertEqual(st._url, self.db._url) + + def test_get_sync_target_inherits_oauth_credentials(self): + self.db.set_oauth_credentials(tests.consumer1.key, + tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + st = self.db.get_sync_target() + self.assertEqual(self.db._creds, st._creds) + + +class TestHTTPDatabaseCtrWithCreds(tests.TestCase): + + def test_ctr_with_creds(self): + db1 = http_database.HTTPDatabase('http://dbs/db', creds={'oauth': { + 'consumer_key': tests.consumer1.key, + 'consumer_secret': tests.consumer1.secret, + 'token_key': tests.token1.key, + 'token_secret': tests.token1.secret + }}) + self.assertIn('oauth', db1._creds) + + +class TestHTTPDatabaseIntegration(tests.TestCaseWithServer): + + make_app_with_state = staticmethod(make_http_app) + + def setUp(self): + super(TestHTTPDatabaseIntegration, self).setUp() + self.startServer() + + def test_non_existing_db(self): + db = http_database.HTTPDatabase(self.getURL('not-there')) + self.assertRaises(errors.DatabaseDoesNotExist, db.get_doc, 'doc1') + + def test__ensure(self): + db = http_database.HTTPDatabase(self.getURL('new')) + db._ensure() + self.assertIs(None, db.get_doc('doc1')) + + def test__delete(self): + self.request_state._create_database('db0') + db = http_database.HTTPDatabase(self.getURL('db0')) + db._delete() + self.assertRaises(errors.DatabaseDoesNotExist, + self.request_state.check_database, 'db0') + + def test_open_database_existing(self): + self.request_state._create_database('db0') + db = http_database.HTTPDatabase.open_database(self.getURL('db0'), + create=False) + self.assertIs(None, db.get_doc('doc1')) + + def test_open_database_non_existing(self): + self.assertRaises(errors.DatabaseDoesNotExist, + http_database.HTTPDatabase.open_database, + self.getURL('not-there'), + create=False) + + def test_open_database_create(self): + db = http_database.HTTPDatabase.open_database(self.getURL('new'), + create=True) + self.assertIs(None, db.get_doc('doc1')) + + def test_delete_database_existing(self): + self.request_state._create_database('db0') + http_database.HTTPDatabase.delete_database(self.getURL('db0')) + self.assertRaises(errors.DatabaseDoesNotExist, + self.request_state.check_database, 'db0') + + def test_doc_ids_needing_quoting(self): + db0 = self.request_state._create_database('db0') + db = http_database.HTTPDatabase.open_database(self.getURL('db0'), + create=False) + doc = Document('%fff', None, '{}') + db.put_doc(doc) + self.assertGetDoc(db0, '%fff', doc.rev, '{}', False) + self.assertGetDoc(db, '%fff', doc.rev, '{}', False) diff --git a/src/leap/soledad/u1db/tests/test_https.py b/src/leap/soledad/u1db/tests/test_https.py new file mode 100644 index 00000000..67681c8a --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_https.py @@ -0,0 +1,117 @@ +"""Test support for client-side https support.""" + +import os +import ssl +import sys + +from paste import httpserver + +from u1db import ( + tests, + ) +from u1db.remote import ( + http_client, + http_target, + ) + +from u1db.tests.test_remote_sync_target import ( + make_oauth_http_app, + ) + + +def https_server_def(): + def make_server(host_port, application): + from OpenSSL import SSL + cert_file = os.path.join(os.path.dirname(__file__), 'testing-certs', + 'testing.cert') + key_file = os.path.join(os.path.dirname(__file__), 'testing-certs', + 'testing.key') + ssl_context = SSL.Context(SSL.SSLv23_METHOD) + ssl_context.use_privatekey_file(key_file) + ssl_context.use_certificate_chain_file(cert_file) + srv = httpserver.WSGIServerBase(application, host_port, + httpserver.WSGIHandler, + ssl_context=ssl_context + ) + + def shutdown_request(req): + req.shutdown() + srv.close_request(req) + + srv.shutdown_request = shutdown_request + application.base_url = "https://localhost:%s" % srv.server_address[1] + return srv + return make_server, "shutdown", "https" + + +def oauth_https_sync_target(test, host, path): + _, port = test.server.server_address + st = http_target.HTTPSyncTarget('https://%s:%d/~/%s' % (host, port, path)) + st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return st + + +class TestHttpSyncTargetHttpsSupport(tests.TestCaseWithServer): + + scenarios = [ + ('oauth_https', {'server_def': https_server_def, + 'make_app_with_state': make_oauth_http_app, + 'make_document_for_test': tests.make_document_for_test, + 'sync_target': oauth_https_sync_target + }), + ] + + def setUp(self): + try: + import OpenSSL # noqa + except ImportError: + self.skipTest("Requires pyOpenSSL") + self.cacert_pem = os.path.join(os.path.dirname(__file__), + 'testing-certs', 'cacert.pem') + super(TestHttpSyncTargetHttpsSupport, self).setUp() + + def getSyncTarget(self, host, path=None): + if self.server is None: + self.startServer() + return self.sync_target(self, host, path) + + def test_working(self): + self.startServer() + db = self.request_state._create_database('test') + self.patch(http_client, 'CA_CERTS', self.cacert_pem) + remote_target = self.getSyncTarget('localhost', 'test') + remote_target.record_sync_info('other-id', 2, 'T-id') + self.assertEqual( + (2, 'T-id'), db._get_replica_gen_and_trans_id('other-id')) + + def test_cannot_verify_cert(self): + if not sys.platform.startswith('linux'): + self.skipTest( + "XXX certificate verification happens on linux only for now") + self.startServer() + # don't print expected traceback server-side + self.server.handle_error = lambda req, cli_addr: None + self.request_state._create_database('test') + remote_target = self.getSyncTarget('localhost', 'test') + try: + remote_target.record_sync_info('other-id', 2, 'T-id') + except ssl.SSLError, e: + self.assertIn("certificate verify failed", str(e)) + else: + self.fail("certificate verification should have failed.") + + def test_host_mismatch(self): + if not sys.platform.startswith('linux'): + self.skipTest( + "XXX certificate verification happens on linux only for now") + self.startServer() + self.request_state._create_database('test') + self.patch(http_client, 'CA_CERTS', self.cacert_pem) + remote_target = self.getSyncTarget('127.0.0.1', 'test') + self.assertRaises( + http_client.CertificateError, remote_target.record_sync_info, + 'other-id', 2, 'T-id') + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_inmemory.py b/src/leap/soledad/u1db/tests/test_inmemory.py new file mode 100644 index 00000000..255a1e08 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_inmemory.py @@ -0,0 +1,128 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test in-memory backend internals.""" + +from u1db import ( + errors, + tests, + ) +from u1db.backends import inmemory + + +simple_doc = '{"key": "value"}' + + +class TestInMemoryDatabaseInternals(tests.TestCase): + + def setUp(self): + super(TestInMemoryDatabaseInternals, self).setUp() + self.db = inmemory.InMemoryDatabase('test') + + def test__allocate_doc_rev_from_None(self): + self.assertEqual('test:1', self.db._allocate_doc_rev(None)) + + def test__allocate_doc_rev_incremental(self): + self.assertEqual('test:2', self.db._allocate_doc_rev('test:1')) + + def test__allocate_doc_rev_other(self): + self.assertEqual('replica:1|test:1', + self.db._allocate_doc_rev('replica:1')) + + def test__get_replica_uid(self): + self.assertEqual('test', self.db._replica_uid) + + +class TestInMemoryIndex(tests.TestCase): + + def test_has_name_and_definition(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + self.assertEqual('idx-name', idx._name) + self.assertEqual(['key'], idx._definition) + + def test_evaluate_json(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + self.assertEqual(['value'], idx.evaluate_json(simple_doc)) + + def test_evaluate_json_field_None(self): + idx = inmemory.InMemoryIndex('idx-name', ['missing']) + self.assertEqual([], idx.evaluate_json(simple_doc)) + + def test_evaluate_json_subfield_None(self): + idx = inmemory.InMemoryIndex('idx-name', ['key', 'missing']) + self.assertEqual([], idx.evaluate_json(simple_doc)) + + def test_evaluate_multi_index(self): + doc = '{"key": "value", "key2": "value2"}' + idx = inmemory.InMemoryIndex('idx-name', ['key', 'key2']) + self.assertEqual(['value\x01value2'], + idx.evaluate_json(doc)) + + def test_update_ignores_None(self): + idx = inmemory.InMemoryIndex('idx-name', ['nokey']) + idx.add_json('doc-id', simple_doc) + self.assertEqual({}, idx._values) + + def test_update_adds_entry(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual({'value': ['doc-id']}, idx._values) + + def test_remove_json(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual({'value': ['doc-id']}, idx._values) + idx.remove_json('doc-id', simple_doc) + self.assertEqual({}, idx._values) + + def test_remove_json_multiple(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + idx.add_json('doc2-id', simple_doc) + self.assertEqual({'value': ['doc-id', 'doc2-id']}, idx._values) + idx.remove_json('doc-id', simple_doc) + self.assertEqual({'value': ['doc2-id']}, idx._values) + + def test_keys(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual(['value'], idx.keys()) + + def test_lookup(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual(['doc-id'], idx.lookup(['value'])) + + def test_lookup_multi(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + idx.add_json('doc2-id', simple_doc) + self.assertEqual(['doc-id', 'doc2-id'], idx.lookup(['value'])) + + def test__find_non_wildcards(self): + idx = inmemory.InMemoryIndex('idx-name', ['k1', 'k2', 'k3']) + self.assertEqual(-1, idx._find_non_wildcards(('a', 'b', 'c'))) + self.assertEqual(2, idx._find_non_wildcards(('a', 'b', '*'))) + self.assertEqual(3, idx._find_non_wildcards(('a', 'b', 'c*'))) + self.assertEqual(2, idx._find_non_wildcards(('a', 'b*', '*'))) + self.assertEqual(0, idx._find_non_wildcards(('*', '*', '*'))) + self.assertEqual(1, idx._find_non_wildcards(('a*', '*', '*'))) + self.assertRaises(errors.InvalidValueForIndex, + idx._find_non_wildcards, ('a', 'b')) + self.assertRaises(errors.InvalidValueForIndex, + idx._find_non_wildcards, ('a', 'b', 'c', 'd')) + self.assertRaises(errors.InvalidGlobbing, + idx._find_non_wildcards, ('*', 'b', 'c')) diff --git a/src/leap/soledad/u1db/tests/test_open.py b/src/leap/soledad/u1db/tests/test_open.py new file mode 100644 index 00000000..fbeb0cfd --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_open.py @@ -0,0 +1,69 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test u1db.open""" + +import os + +from u1db import ( + errors, + open as u1db_open, + tests, + ) +from u1db.backends import sqlite_backend +from u1db.tests.test_backends import TestAlternativeDocument + + +class TestU1DBOpen(tests.TestCase): + + def setUp(self): + super(TestU1DBOpen, self).setUp() + tmpdir = self.createTempDir() + self.db_path = tmpdir + '/test.db' + + def test_open_no_create(self): + self.assertRaises(errors.DatabaseDoesNotExist, + u1db_open, self.db_path, create=False) + self.assertFalse(os.path.exists(self.db_path)) + + def test_open_create(self): + db = u1db_open(self.db_path, create=True) + self.addCleanup(db.close) + self.assertTrue(os.path.exists(self.db_path)) + self.assertIsInstance(db, sqlite_backend.SQLiteDatabase) + + def test_open_with_factory(self): + db = u1db_open(self.db_path, create=True, + document_factory=TestAlternativeDocument) + self.addCleanup(db.close) + self.assertEqual(TestAlternativeDocument, db._factory) + + def test_open_existing(self): + db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) + self.addCleanup(db.close) + doc = db.create_doc_from_json(tests.simple_doc) + # Even though create=True, we shouldn't wipe the db + db2 = u1db_open(self.db_path, create=True) + self.addCleanup(db2.close) + doc2 = db2.get_doc(doc.doc_id) + self.assertEqual(doc, doc2) + + def test_open_existing_no_create(self): + db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) + self.addCleanup(db.close) + db2 = u1db_open(self.db_path, create=False) + self.addCleanup(db2.close) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/u1db/tests/test_query_parser.py b/src/leap/soledad/u1db/tests/test_query_parser.py new file mode 100644 index 00000000..ee374267 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_query_parser.py @@ -0,0 +1,443 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +from u1db import ( + errors, + query_parser, + tests, + ) + + +trivial_raw_doc = {} + + +class TestFieldName(tests.TestCase): + + def test_check_fieldname_valid(self): + self.assertIsNone(query_parser.check_fieldname("foo")) + + def test_check_fieldname_invalid(self): + self.assertRaises( + errors.IndexDefinitionParseError, query_parser.check_fieldname, + "foo.") + + +class TestMakeTree(tests.TestCase): + + def setUp(self): + super(TestMakeTree, self).setUp() + self.parser = query_parser.Parser() + + def assertParseError(self, definition): + self.assertRaises( + errors.IndexDefinitionParseError, self.parser.parse, + definition) + + def test_single_field(self): + self.assertIsInstance( + self.parser.parse('f'), query_parser.ExtractField) + + def test_single_mapping(self): + self.assertIsInstance( + self.parser.parse('bool(field1)'), query_parser.Bool) + + def test_nested_mapping(self): + self.assertIsInstance( + self.parser.parse('lower(split_words(field1))'), + query_parser.Lower) + + def test_nested_branching_mapping(self): + self.assertIsInstance( + self.parser.parse( + 'combine(lower(field1), split_words(field2), ' + 'number(field3, 5))'), query_parser.Combine) + + def test_single_mapping_multiple_fields(self): + self.assertIsInstance( + self.parser.parse('number(field1, 5)'), query_parser.Number) + + def test_unknown_mapping(self): + self.assertParseError('mapping(whatever)') + + def test_parse_missing_close_paren(self): + self.assertParseError("lower(a") + + def test_parse_trailing_chars(self): + self.assertParseError("lower(ab))") + + def test_parse_empty_op(self): + self.assertParseError("(ab)") + + def test_parse_top_level_commas(self): + self.assertParseError("a, b") + + def test_invalid_field_name(self): + self.assertParseError("a.") + + def test_invalid_inner_field_name(self): + self.assertParseError("lower(a.)") + + def test_gobbledigook(self): + self.assertParseError("(@#@cc @#!*DFJSXV(()jccd") + + def test_leading_space(self): + self.assertIsInstance( + self.parser.parse(" lower(a)"), query_parser.Lower) + + def test_trailing_space(self): + self.assertIsInstance( + self.parser.parse("lower(a) "), query_parser.Lower) + + def test_spaces_before_open_paren(self): + self.assertIsInstance( + self.parser.parse("lower (a)"), query_parser.Lower) + + def test_spaces_after_open_paren(self): + self.assertIsInstance( + self.parser.parse("lower( a)"), query_parser.Lower) + + def test_spaces_before_close_paren(self): + self.assertIsInstance( + self.parser.parse("lower(a )"), query_parser.Lower) + + def test_spaces_before_comma(self): + self.assertIsInstance( + self.parser.parse("number(a , 5)"), query_parser.Number) + + def test_spaces_after_comma(self): + self.assertIsInstance( + self.parser.parse("number(a, 5)"), query_parser.Number) + + +class TestStaticGetter(tests.TestCase): + + def test_returns_string(self): + getter = query_parser.StaticGetter('foo') + self.assertEqual(['foo'], getter.get(trivial_raw_doc)) + + def test_returns_int(self): + getter = query_parser.StaticGetter(9) + self.assertEqual([9], getter.get(trivial_raw_doc)) + + def test_returns_float(self): + getter = query_parser.StaticGetter(9.2) + self.assertEqual([9.2], getter.get(trivial_raw_doc)) + + def test_returns_None(self): + getter = query_parser.StaticGetter(None) + self.assertEqual([], getter.get(trivial_raw_doc)) + + def test_returns_list(self): + getter = query_parser.StaticGetter(['a', 'b']) + self.assertEqual(['a', 'b'], getter.get(trivial_raw_doc)) + + +class TestExtractField(tests.TestCase): + + def assertExtractField(self, expected, field_name, raw_doc): + getter = query_parser.ExtractField(field_name) + self.assertEqual(expected, getter.get(raw_doc)) + + def test_get_value(self): + self.assertExtractField(['bar'], 'foo', {'foo': 'bar'}) + + def test_get_value_None(self): + self.assertExtractField([], 'foo', {'foo': None}) + + def test_get_value_missing_key(self): + self.assertExtractField([], 'foo', {}) + + def test_get_value_subfield(self): + self.assertExtractField(['bar'], 'foo.baz', {'foo': {'baz': 'bar'}}) + + def test_get_value_subfield_missing(self): + self.assertExtractField([], 'foo.baz', {'foo': 'bar'}) + + def test_get_value_dict(self): + self.assertExtractField([], 'foo', {'foo': {'baz': 'bar'}}) + + def test_get_value_list(self): + self.assertExtractField(['bar', 'zap'], 'foo', {'foo': ['bar', 'zap']}) + + def test_get_value_mixed_list(self): + self.assertExtractField(['bar', 'zap'], 'foo', + {'foo': ['bar', ['baa'], 'zap', {'bing': 9}]}) + + def test_get_value_list_of_dicts(self): + self.assertExtractField([], 'foo', {'foo': [{'zap': 'bar'}]}) + + def test_get_value_list_of_dicts2(self): + self.assertExtractField( + ['bar', 'baz'], 'foo.zap', + {'foo': [{'zap': 'bar'}, {'zap': 'baz'}]}) + + def test_get_value_int(self): + self.assertExtractField([9], 'foo', {'foo': 9}) + + def test_get_value_float(self): + self.assertExtractField([9.2], 'foo', {'foo': 9.2}) + + def test_get_value_bool(self): + self.assertExtractField([True], 'foo', {'foo': True}) + self.assertExtractField([False], 'foo', {'foo': False}) + + +class TestLower(tests.TestCase): + + def assertLowerGets(self, expected, input_val): + getter = query_parser.Lower(query_parser.StaticGetter(input_val)) + out_val = getter.get(trivial_raw_doc) + self.assertEqual(sorted(expected), sorted(out_val)) + + def test_inner_returns_None(self): + self.assertLowerGets([], None) + + def test_inner_returns_string(self): + self.assertLowerGets(['foo'], 'fOo') + + def test_inner_returns_list(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', 'bAr']) + + def test_inner_returns_int(self): + self.assertLowerGets([], 9) + + def test_inner_returns_float(self): + self.assertLowerGets([], 9.0) + + def test_inner_returns_bool(self): + self.assertLowerGets([], True) + + def test_inner_returns_list_containing_int(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', 9, 'bAr']) + + def test_inner_returns_list_containing_float(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', 9.2, 'bAr']) + + def test_inner_returns_list_containing_bool(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', True, 'bAr']) + + def test_inner_returns_list_containing_list(self): + # TODO: Should this be unfolding the inner list? + self.assertLowerGets(['foo', 'bar'], ['fOo', ['bAa'], 'bAr']) + + def test_inner_returns_list_containing_dict(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', {'baa': 'xam'}, 'bAr']) + + +class TestSplitWords(tests.TestCase): + + def assertSplitWords(self, expected, value): + getter = query_parser.SplitWords(query_parser.StaticGetter(value)) + self.assertEqual(sorted(expected), sorted(getter.get(trivial_raw_doc))) + + def test_inner_returns_None(self): + self.assertSplitWords([], None) + + def test_inner_returns_string(self): + self.assertSplitWords(['foo', 'bar'], 'foo bar') + + def test_inner_returns_list(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', 'bar sux']) + + def test_deduplicates(self): + self.assertSplitWords(['bar'], ['bar', 'bar', 'bar']) + + def test_inner_returns_int(self): + self.assertSplitWords([], 9) + + def test_inner_returns_float(self): + self.assertSplitWords([], 9.2) + + def test_inner_returns_bool(self): + self.assertSplitWords([], True) + + def test_inner_returns_list_containing_int(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', 9, 'bar sux']) + + def test_inner_returns_list_containing_float(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', 9.2, 'bar sux']) + + def test_inner_returns_list_containing_bool(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', True, 'bar sux']) + + def test_inner_returns_list_containing_list(self): + # TODO: Expand sub-lists? + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', ['baa'], 'bar sux']) + + def test_inner_returns_list_containing_dict(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', {'baa': 'xam'}, 'bar sux']) + + +class TestNumber(tests.TestCase): + + def assertNumber(self, expected, value, padding=5): + """Assert number transformation produced expected values.""" + getter = query_parser.Number(query_parser.StaticGetter(value), padding) + self.assertEqual(expected, getter.get(trivial_raw_doc)) + + def test_inner_returns_None(self): + """None is thrown away.""" + self.assertNumber([], None) + + def test_inner_returns_int(self): + """A single integer is converted to zero padded strings.""" + self.assertNumber(['00009'], 9) + + def test_inner_returns_list(self): + """Integers are converted to zero padded strings.""" + self.assertNumber(['00009', '00235'], [9, 235]) + + def test_inner_returns_string(self): + """A string is thrown away.""" + self.assertNumber([], 'foo bar') + + def test_inner_returns_float(self): + """A float is thrown away.""" + self.assertNumber([], 9.2) + + def test_inner_returns_bool(self): + """A boolean is thrown away.""" + self.assertNumber([], True) + + def test_inner_returns_list_containing_strings(self): + """Strings in a list are thrown away.""" + self.assertNumber(['00009'], ['foo baz', 9, 'bar sux']) + + def test_inner_returns_list_containing_float(self): + """Floats in a list are thrown away.""" + self.assertNumber( + ['00083', '00073'], [83, 9.2, 73]) + + def test_inner_returns_list_containing_bool(self): + """Booleans in a list are thrown away.""" + self.assertNumber( + ['00083', '00073'], [83, True, 73]) + + def test_inner_returns_list_containing_list(self): + """Lists in a list are thrown away.""" + # TODO: Expand sub-lists? + self.assertNumber( + ['00012', '03333'], [12, [29], 3333]) + + def test_inner_returns_list_containing_dict(self): + """Dicts in a list are thrown away.""" + self.assertNumber( + ['00012', '00001'], [12, {54: 89}, 1]) + + +class TestIsNull(tests.TestCase): + + def assertIsNull(self, value): + getter = query_parser.IsNull(query_parser.StaticGetter(value)) + self.assertEqual([True], getter.get(trivial_raw_doc)) + + def assertIsNotNull(self, value): + getter = query_parser.IsNull(query_parser.StaticGetter(value)) + self.assertEqual([False], getter.get(trivial_raw_doc)) + + def test_inner_returns_None(self): + self.assertIsNull(None) + + def test_inner_returns_string(self): + self.assertIsNotNull('foo') + + def test_inner_returns_list(self): + self.assertIsNotNull(['foo', 'bar']) + + def test_inner_returns_empty_list(self): + # TODO: is this the behavior we want? + self.assertIsNull([]) + + def test_inner_returns_int(self): + self.assertIsNotNull(9) + + def test_inner_returns_float(self): + self.assertIsNotNull(9.2) + + def test_inner_returns_bool(self): + self.assertIsNotNull(True) + + # TODO: What about a dict? Inner is likely to return None, even though the + # attribute does exist... + + +class TestParser(tests.TestCase): + + def parse(self, spec): + parser = query_parser.Parser() + return parser.parse(spec) + + def parse_all(self, specs): + parser = query_parser.Parser() + return parser.parse_all(specs) + + def assertParseError(self, definition): + self.assertRaises(errors.IndexDefinitionParseError, self.parse, + definition) + + def test_parse_empty_string(self): + self.assertRaises(errors.IndexDefinitionParseError, self.parse, "") + + def test_parse_field(self): + getter = self.parse("a") + self.assertIsInstance(getter, query_parser.ExtractField) + self.assertEqual(["a"], getter.field) + + def test_parse_dotted_field(self): + getter = self.parse("a.b") + self.assertIsInstance(getter, query_parser.ExtractField) + self.assertEqual(["a", "b"], getter.field) + + def test_parse_dotted_field_nothing_after_dot(self): + self.assertParseError("a.") + + def test_parse_missing_close_on_transformation(self): + self.assertParseError("lower(a") + + def test_parse_missing_field_in_transformation(self): + self.assertParseError("lower()") + + def test_parse_trailing_chars(self): + self.assertParseError("lower(ab))") + + def test_parse_empty_op(self): + self.assertParseError("(ab)") + + def test_parse_unknown_op(self): + self.assertParseError("no_such_operation(field)") + + def test_parse_wrong_arg_type(self): + self.assertParseError("number(field, fnord)") + + def test_parse_transformation(self): + getter = self.parse("lower(a)") + self.assertIsInstance(getter, query_parser.Lower) + self.assertIsInstance(getter.inner, query_parser.ExtractField) + self.assertEqual(["a"], getter.inner.field) + + def test_parse_all(self): + getters = self.parse_all(["a", "b"]) + self.assertEqual(2, len(getters)) + self.assertIsInstance(getters[0], query_parser.ExtractField) + self.assertEqual(["a"], getters[0].field) + self.assertIsInstance(getters[1], query_parser.ExtractField) + self.assertEqual(["b"], getters[1].field) diff --git a/src/leap/soledad/u1db/tests/test_remote_sync_target.py b/src/leap/soledad/u1db/tests/test_remote_sync_target.py new file mode 100644 index 00000000..3e0d8995 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_remote_sync_target.py @@ -0,0 +1,314 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests for the remote sync targets""" + +import cStringIO + +from u1db import ( + errors, + tests, + ) +from u1db.remote import ( + http_app, + http_target, + oauth_middleware, + ) + + +class TestHTTPSyncTargetBasics(tests.TestCase): + + def test_parse_url(self): + remote_target = http_target.HTTPSyncTarget('http://127.0.0.1:12345/') + self.assertEqual('http', remote_target._url.scheme) + self.assertEqual('127.0.0.1', remote_target._url.hostname) + self.assertEqual(12345, remote_target._url.port) + self.assertEqual('/', remote_target._url.path) + + +class TestParsingSyncStream(tests.TestCase): + + def test_wrong_start(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "{}\r\n]", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "\r\n{}\r\n]", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "", None) + + def test_wrong_end(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n{}", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n", None) + + def test_missing_comma(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, + '[\r\n{}\r\n{"id": "i", "rev": "r", ' + '"content": "c", "gen": 3}\r\n]', None) + + def test_no_entries(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n]", None) + + def test_extra_comma(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n{},\r\n]", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, + '[\r\n{},\r\n{"id": "i", "rev": "r", ' + '"content": "{}", "gen": 3, "trans_id": "T-sid"}' + ',\r\n]', + lambda doc, gen, trans_id: None) + + def test_error_in_stream(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.Unavailable, + tgt._parse_sync_stream, + '[\r\n{"new_generation": 0},' + '\r\n{"error": "unavailable"}\r\n', None) + + self.assertRaises(errors.Unavailable, + tgt._parse_sync_stream, + '[\r\n{"error": "unavailable"}\r\n', None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, + '[\r\n{"error": "?"}\r\n', None) + + +def make_http_app(state): + return http_app.HTTPApp(state) + + +def http_sync_target(test, path): + return http_target.HTTPSyncTarget(test.getURL(path)) + + +def make_oauth_http_app(state): + app = http_app.HTTPApp(state) + application = oauth_middleware.OAuthMiddleware(app, None, prefix='/~/') + application.get_oauth_data_store = lambda: tests.testingOAuthStore + return application + + +def oauth_http_sync_target(test, path): + st = http_sync_target(test, '~/' + path) + st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return st + + +class TestRemoteSyncTargets(tests.TestCaseWithServer): + + scenarios = [ + ('http', {'make_app_with_state': make_http_app, + 'make_document_for_test': tests.make_document_for_test, + 'sync_target': http_sync_target}), + ('oauth_http', {'make_app_with_state': make_oauth_http_app, + 'make_document_for_test': tests.make_document_for_test, + 'sync_target': oauth_http_sync_target}), + ] + + def getSyncTarget(self, path=None): + if self.server is None: + self.startServer() + return self.sync_target(self, path) + + def test_get_sync_info(self): + self.startServer() + db = self.request_state._create_database('test') + db._set_replica_gen_and_trans_id('other-id', 1, 'T-transid') + remote_target = self.getSyncTarget('test') + self.assertEqual(('test', 0, '', 1, 'T-transid'), + remote_target.get_sync_info('other-id')) + + def test_record_sync_info(self): + self.startServer() + db = self.request_state._create_database('test') + remote_target = self.getSyncTarget('test') + remote_target.record_sync_info('other-id', 2, 'T-transid') + self.assertEqual( + (2, 'T-transid'), db._get_replica_gen_and_trans_id('other-id')) + + def test_sync_exchange_send(self): + self.startServer() + db = self.request_state._create_database('test') + remote_target = self.getSyncTarget('test') + other_docs = [] + + def receive_doc(doc): + other_docs.append((doc.doc_id, doc.rev, doc.get_json())) + + doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}') + new_gen, trans_id = remote_target.sync_exchange( + [(doc, 10, 'T-sid')], 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=receive_doc) + self.assertEqual(1, new_gen) + self.assertGetDoc( + db, 'doc-here', 'replica:1', '{"value": "here"}', False) + + def test_sync_exchange_send_failure_and_retry_scenario(self): + self.startServer() + + def blackhole_getstderr(inst): + return cStringIO.StringIO() + + self.patch(self.server.RequestHandlerClass, 'get_stderr', + blackhole_getstderr) + db = self.request_state._create_database('test') + _put_doc_if_newer = db._put_doc_if_newer + trigger_ids = ['doc-here2'] + + def bomb_put_doc_if_newer(doc, save_conflict, + replica_uid=None, replica_gen=None, + replica_trans_id=None): + if doc.doc_id in trigger_ids: + raise Exception + return _put_doc_if_newer(doc, save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + self.patch(db, '_put_doc_if_newer', bomb_put_doc_if_newer) + remote_target = self.getSyncTarget('test') + other_changes = [] + + def receive_doc(doc, gen, trans_id): + other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + doc1 = self.make_document('doc-here', 'replica:1', '{"value": "here"}') + doc2 = self.make_document('doc-here2', 'replica:1', + '{"value": "here2"}') + self.assertRaises( + errors.HTTPError, + remote_target.sync_exchange, + [(doc1, 10, 'T-sid'), (doc2, 11, 'T-sud')], + 'replica', last_known_generation=0, last_known_trans_id=None, + return_doc_cb=receive_doc) + self.assertGetDoc(db, 'doc-here', 'replica:1', '{"value": "here"}', + False) + self.assertEqual( + (10, 'T-sid'), db._get_replica_gen_and_trans_id('replica')) + self.assertEqual([], other_changes) + # retry + trigger_ids = [] + new_gen, trans_id = remote_target.sync_exchange( + [(doc2, 11, 'T-sud')], 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=receive_doc) + self.assertGetDoc(db, 'doc-here2', 'replica:1', '{"value": "here2"}', + False) + self.assertEqual( + (11, 'T-sud'), db._get_replica_gen_and_trans_id('replica')) + self.assertEqual(2, new_gen) + # bounced back to us + self.assertEqual( + ('doc-here', 'replica:1', '{"value": "here"}', 1), + other_changes[0][:-1]) + + def test_sync_exchange_in_stream_error(self): + self.startServer() + + def blackhole_getstderr(inst): + return cStringIO.StringIO() + + self.patch(self.server.RequestHandlerClass, 'get_stderr', + blackhole_getstderr) + db = self.request_state._create_database('test') + doc = db.create_doc_from_json('{"value": "there"}') + + def bomb_get_docs(doc_ids, check_for_conflicts=None, + include_deleted=False): + yield doc + # delayed failure case + raise errors.Unavailable + + self.patch(db, 'get_docs', bomb_get_docs) + remote_target = self.getSyncTarget('test') + other_changes = [] + + def receive_doc(doc, gen, trans_id): + other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + self.assertRaises( + errors.Unavailable, remote_target.sync_exchange, [], 'replica', + last_known_generation=0, last_known_trans_id=None, + return_doc_cb=receive_doc) + self.assertEqual( + (doc.doc_id, doc.rev, '{"value": "there"}', 1), + other_changes[0][:-1]) + + def test_sync_exchange_receive(self): + self.startServer() + db = self.request_state._create_database('test') + doc = db.create_doc_from_json('{"value": "there"}') + remote_target = self.getSyncTarget('test') + other_changes = [] + + def receive_doc(doc, gen, trans_id): + other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + new_gen, trans_id = remote_target.sync_exchange( + [], 'replica', last_known_generation=0, last_known_trans_id=None, + return_doc_cb=receive_doc) + self.assertEqual(1, new_gen) + self.assertEqual( + (doc.doc_id, doc.rev, '{"value": "there"}', 1), + other_changes[0][:-1]) + + def test_sync_exchange_send_ensure_callback(self): + self.startServer() + remote_target = self.getSyncTarget('test') + other_docs = [] + replica_uid_box = [] + + def receive_doc(doc): + other_docs.append((doc.doc_id, doc.rev, doc.get_json())) + + def ensure_cb(replica_uid): + replica_uid_box.append(replica_uid) + + doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}') + new_gen, trans_id = remote_target.sync_exchange( + [(doc, 10, 'T-sid')], 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=receive_doc, + ensure_callback=ensure_cb) + self.assertEqual(1, new_gen) + db = self.request_state.open_database('test') + self.assertEqual(1, len(replica_uid_box)) + self.assertEqual(db._replica_uid, replica_uid_box[0]) + self.assertGetDoc( + db, 'doc-here', 'replica:1', '{"value": "here"}', False) + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_remote_utils.py b/src/leap/soledad/u1db/tests/test_remote_utils.py new file mode 100644 index 00000000..959cd882 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_remote_utils.py @@ -0,0 +1,36 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests for protocol details utils.""" + +from u1db.tests import TestCase +from u1db.remote import utils + + +class TestUtils(TestCase): + + def test_check_and_strip_comma(self): + line, comma = utils.check_and_strip_comma("abc,") + self.assertTrue(comma) + self.assertEqual("abc", line) + + line, comma = utils.check_and_strip_comma("abc") + self.assertFalse(comma) + self.assertEqual("abc", line) + + line, comma = utils.check_and_strip_comma("") + self.assertFalse(comma) + self.assertEqual("", line) diff --git a/src/leap/soledad/u1db/tests/test_server_state.py b/src/leap/soledad/u1db/tests/test_server_state.py new file mode 100644 index 00000000..fc3f1282 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_server_state.py @@ -0,0 +1,93 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests for server state object.""" + +import os + +from u1db import ( + errors, + tests, + ) +from u1db.remote import ( + server_state, + ) +from u1db.backends import sqlite_backend + + +class TestServerState(tests.TestCase): + + def setUp(self): + super(TestServerState, self).setUp() + self.state = server_state.ServerState() + + def test_set_workingdir(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + self.assertTrue(self.state._relpath('path').startswith(tempdir)) + + def test_open_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + self.assertFalse(os.path.exists(path)) + # Create the db, but don't do anything with it + sqlite_backend.SQLitePartialExpandDatabase(path) + db = self.state.open_database('test.db') + self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase) + + def test_check_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + self.assertFalse(os.path.exists(path)) + + # doesn't exist => raises + self.assertRaises(errors.DatabaseDoesNotExist, + self.state.check_database, 'test.db') + + # Create the db, but don't do anything with it + sqlite_backend.SQLitePartialExpandDatabase(path) + # exists => returns + res = self.state.check_database('test.db') + self.assertIsNone(res) + + def test_ensure_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + self.assertFalse(os.path.exists(path)) + db, replica_uid = self.state.ensure_database('test.db') + self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase) + self.assertEqual(db._replica_uid, replica_uid) + self.assertTrue(os.path.exists(path)) + db2 = self.state.open_database('test.db') + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_delete_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + db, _ = self.state.ensure_database('test.db') + db.close() + self.state.delete_database('test.db') + self.assertFalse(os.path.exists(path)) + + def test_delete_database_DoesNotExist(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + self.assertRaises(errors.DatabaseDoesNotExist, + self.state.delete_database, 'test.db') diff --git a/src/leap/soledad/u1db/tests/test_sqlite_backend.py b/src/leap/soledad/u1db/tests/test_sqlite_backend.py new file mode 100644 index 00000000..73330789 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_sqlite_backend.py @@ -0,0 +1,493 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test sqlite backend internals.""" + +import os +import time +import threading + +from sqlite3 import dbapi2 + +from u1db import ( + errors, + tests, + query_parser, + ) +from u1db.backends import sqlite_backend +from u1db.tests.test_backends import TestAlternativeDocument + + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + + +class TestSQLiteDatabase(tests.TestCase): + + def test_atomic_initialize(self): + tmpdir = self.createTempDir() + dbname = os.path.join(tmpdir, 'atomic.db') + + t2 = None # will be a thread + + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + _index_storage_value = "testing" + + def __init__(self, dbname, ntry): + self._try = ntry + self._is_initialized_invocations = 0 + super(SQLiteDatabaseTesting, self).__init__(dbname) + + def _is_initialized(self, c): + res = super(SQLiteDatabaseTesting, self)._is_initialized(c) + if self._try == 1: + self._is_initialized_invocations += 1 + if self._is_initialized_invocations == 2: + t2.start() + # hard to do better and have a generic test + time.sleep(0.05) + return res + + outcome2 = [] + + def second_try(): + try: + db2 = SQLiteDatabaseTesting(dbname, 2) + except Exception, e: + outcome2.append(e) + else: + outcome2.append(db2) + + t2 = threading.Thread(target=second_try) + db1 = SQLiteDatabaseTesting(dbname, 1) + t2.join() + + self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting) + db2 = outcome2[0] + self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor())) + + +class TestSQLitePartialExpandDatabase(tests.TestCase): + + def setUp(self): + super(TestSQLitePartialExpandDatabase, self).setUp() + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db._set_replica_uid('test') + + def test_create_database(self): + raw_db = self.db._get_sqlite_handle() + self.assertNotEqual(None, raw_db) + + def test_default_replica_uid(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.assertIsNot(None, self.db._replica_uid) + self.assertEqual(32, len(self.db._replica_uid)) + int(self.db._replica_uid, 16) + + def test__close_sqlite_handle(self): + raw_db = self.db._get_sqlite_handle() + self.db._close_sqlite_handle() + self.assertRaises(dbapi2.ProgrammingError, + raw_db.cursor) + + def test_create_database_initializes_schema(self): + raw_db = self.db._get_sqlite_handle() + c = raw_db.cursor() + c.execute("SELECT * FROM u1db_config") + config = dict([(r[0], r[1]) for r in c.fetchall()]) + self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', + 'index_storage': 'expand referenced'}, config) + + # These tables must exist, though we don't care what is in them yet + c.execute("SELECT * FROM transaction_log") + c.execute("SELECT * FROM document") + c.execute("SELECT * FROM document_fields") + c.execute("SELECT * FROM sync_log") + c.execute("SELECT * FROM conflicts") + c.execute("SELECT * FROM index_definitions") + + def test__parse_index(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + g = self.db._parse_index_definition('fieldname') + self.assertIsInstance(g, query_parser.ExtractField) + self.assertEqual(['fieldname'], g.field) + + def test__update_indexes(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + g = self.db._parse_index_definition('fieldname') + c = self.db._get_sqlite_handle().cursor() + self.db._update_indexes('doc-id', {'fieldname': 'val'}, + [('fieldname', g)], c) + c.execute('SELECT doc_id, field_name, value FROM document_fields') + self.assertEqual([('doc-id', 'fieldname', 'val')], + c.fetchall()) + + def test__set_replica_uid(self): + # Start from scratch, so that replica_uid isn't set. + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.assertIsNot(None, self.db._real_replica_uid) + self.assertIsNot(None, self.db._replica_uid) + self.db._set_replica_uid('foo') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'") + self.assertEqual(('foo',), c.fetchone()) + self.assertEqual('foo', self.db._real_replica_uid) + self.assertEqual('foo', self.db._replica_uid) + self.db._close_sqlite_handle() + self.assertEqual('foo', self.db._replica_uid) + + def test__get_generation(self): + self.assertEqual(0, self.db._get_generation()) + + def test__get_generation_info(self): + self.assertEqual((0, ''), self.db._get_generation_info()) + + def test_create_index(self): + self.db.create_index('test-idx', "key") + self.assertEqual([('test-idx', ["key"])], self.db.list_indexes()) + + def test_create_index_multiple_fields(self): + self.db.create_index('test-idx', "key", "key2") + self.assertEqual([('test-idx', ["key", "key2"])], + self.db.list_indexes()) + + def test__get_index_definition(self): + self.db.create_index('test-idx', "key", "key2") + # TODO: How would you test that an index is getting used for an SQL + # request? + self.assertEqual(["key", "key2"], + self.db._get_index_definition('test-idx')) + + def test_list_index_mixed(self): + # Make sure that we properly order the output + c = self.db._get_sqlite_handle().cursor() + # We intentionally insert the data in weird ordering, to make sure the + # query still gets it back correctly. + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + [('idx-1', 0, 'key10'), + ('idx-2', 2, 'key22'), + ('idx-1', 1, 'key11'), + ('idx-2', 0, 'key20'), + ('idx-2', 1, 'key21')]) + self.assertEqual([('idx-1', ['key10', 'key11']), + ('idx-2', ['key20', 'key21', 'key22'])], + self.db.list_indexes()) + + def test_no_indexes_no_document_fields(self): + self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + + def test_create_extracts_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + self.db.create_index('test', 'key1', 'key2') + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual(sorted( + [(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "val2"), + (doc2.doc_id, "key1", "valx"), + (doc2.doc_id, "key2", "valy"), + ]), sorted(c.fetchall())) + + def test_put_updates_fields(self): + self.db.create_index('test', 'key1', 'key2') + doc1 = self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + doc1.content = {"key1": "val1", "key2": "valy"} + self.db.put_doc(doc1) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "valy"), + ], c.fetchall()) + + def test_put_updates_nested_fields(self): + self.db.create_index('test', 'key', 'sub.doc') + doc1 = self.db.create_doc_from_json(nested_doc) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key", "value"), + (doc1.doc_id, "sub.doc", "underneath"), + ], c.fetchall()) + + def test__ensure_schema_rollback(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/rollback.db' + + class SQLitePartialExpandDbTesting( + sqlite_backend.SQLitePartialExpandDatabase): + + def _set_replica_uid_in_transaction(self, uid): + super(SQLitePartialExpandDbTesting, + self)._set_replica_uid_in_transaction(uid) + if fail: + raise Exception() + + db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + fail = True + self.assertRaises(Exception, db._ensure_schema) + fail = False + db._initialize(db._db_handle.cursor()) + + def test__open_database(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/test.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase._open_database(path) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test__open_database_with_factory(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/test.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase._open_database( + path, document_factory=TestAlternativeDocument) + self.assertEqual(TestAlternativeDocument, db2._factory) + + def test__open_database_non_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase._open_database, path) + + def test__open_database_during_init(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/initialised.db' + db = sqlite_backend.SQLitePartialExpandDatabase.__new__( + sqlite_backend.SQLitePartialExpandDatabase) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + self.addCleanup(db.close) + observed = [] + + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 + + @classmethod + def _which_index_storage(cls, c): + res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c) + db._ensure_schema() # init db + observed.append(res[0]) + return res + + db2 = SQLiteDatabaseTesting._open_database(path) + self.addCleanup(db2.close) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + self.assertEqual([None, + sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], + observed) + + def test__open_database_invalid(self): + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 + temp_dir = self.createTempDir(prefix='u1db-test-') + path1 = temp_dir + '/invalid1.db' + with open(path1, 'wb') as f: + f.write("") + self.assertRaises(dbapi2.OperationalError, + SQLiteDatabaseTesting._open_database, path1) + with open(path1, 'wb') as f: + f.write("invalid") + self.assertRaises(dbapi2.DatabaseError, + SQLiteDatabaseTesting._open_database, path1) + + def test_open_database_existing(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/existing.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_open_database_with_factory(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/existing.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase.open_database( + path, create=False, document_factory=TestAlternativeDocument) + self.assertEqual(TestAlternativeDocument, db2._factory) + + def test_open_database_create(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/new.sqlite' + sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_open_database_non_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.open_database, path, + create=False) + + def test_delete_database_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/new.sqlite' + db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db.close() + sqlite_backend.SQLiteDatabase.delete_database(path) + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.open_database, path, + create=False) + + def test_delete_database_nonexistent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.delete_database, path) + + def test__get_indexed_fields(self): + self.db.create_index('idx1', 'a', 'b') + self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields()) + self.db.create_index('idx2', 'b', 'c') + self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields()) + + def test_indexed_fields_expanded(self): + self.db.create_index('idx1', 'key1') + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def test_create_index_updates_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.db.create_index('idx1', 'key1') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def assertFormatQueryEquals(self, exp_statement, exp_args, definition, + values): + statement, args = self.db._format_query(definition, values) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_query(self): + self.assertFormatQueryEquals( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON " + "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name " + "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content " + "ORDER BY d0.value;", ["key1", "a"], + ["key1"], ["a"]) + + def test__format_query2(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b", "key3", "c"], + ["key1", "key2", "key3"], ["a", "b", "c"]) + + def test__format_query_wildcard(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? ' + 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content ' + 'ORDER BY d0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"], + ["a", "b*", "*"]) + + def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition, + start_value, end_value): + statement, args = self.db._format_range_query( + definition, start_value, end_value) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_range_query(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q', + 'key3', 'r'], + ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"]) + + def test__format_range_query_no_start(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], None, ["a", "b", "c"]) + + def test__format_range_query_no_end(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], ["a", "b", "c"], None) + + def test__format_range_query_wildcard(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? ' + 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? ' + 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id ' + 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, ' + 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*', + 'key3'], + ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) diff --git a/src/leap/soledad/u1db/tests/test_sync.py b/src/leap/soledad/u1db/tests/test_sync.py new file mode 100644 index 00000000..f2a925f0 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_sync.py @@ -0,0 +1,1285 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""The Synchronization class for U1DB.""" + +import os +from wsgiref import simple_server + +from u1db import ( + errors, + sync, + tests, + vectorclock, + SyncTarget, + ) +from u1db.backends import ( + inmemory, + ) +from u1db.remote import ( + http_target, + ) + +from u1db.tests.test_remote_sync_target import ( + make_http_app, + make_oauth_http_app, + ) + +simple_doc = tests.simple_doc +nested_doc = tests.nested_doc + + +def _make_local_db_and_target(test): + db = test.create_database('test') + st = db.get_sync_target() + return db, st + + +def _make_local_db_and_http_target(test, path='test'): + test.startServer() + db = test.request_state._create_database(os.path.basename(path)) + st = http_target.HTTPSyncTarget.connect(test.getURL(path)) + return db, st + + +def _make_c_db_and_c_http_target(test, path='test'): + test.startServer() + db = test.request_state._create_database(os.path.basename(path)) + url = test.getURL(path) + st = tests.c_backend_wrapper.create_http_sync_target(url) + return db, st + + +def _make_local_db_and_oauth_http_target(test): + db, st = _make_local_db_and_http_target(test, '~/test') + st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return db, st + + +def _make_c_db_and_oauth_http_target(test, path='~/test'): + test.startServer() + db = test.request_state._create_database(os.path.basename(path)) + url = test.getURL(path) + st = tests.c_backend_wrapper.create_oauth_http_sync_target(url, + tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return db, st + + +target_scenarios = [ + ('local', {'create_db_and_target': _make_local_db_and_target}), + ('http', {'create_db_and_target': _make_local_db_and_http_target, + 'make_app_with_state': make_http_app}), + ('oauth_http', {'create_db_and_target': + _make_local_db_and_oauth_http_target, + 'make_app_with_state': make_oauth_http_app}), + ] + +c_db_scenarios = [ + ('local,c', {'create_db_and_target': _make_local_db_and_target, + 'make_database_for_test': tests.make_c_database_for_test, + 'copy_database_for_test': tests.copy_c_database_for_test, + 'make_document_for_test': tests.make_c_document_for_test, + 'whitebox': False}), + ('http,c', {'create_db_and_target': _make_c_db_and_c_http_target, + 'make_database_for_test': tests.make_c_database_for_test, + 'copy_database_for_test': tests.copy_c_database_for_test, + 'make_document_for_test': tests.make_c_document_for_test, + 'make_app_with_state': make_http_app, + 'whitebox': False}), + ('oauth_http,c', {'create_db_and_target': _make_c_db_and_oauth_http_target, + 'make_database_for_test': tests.make_c_database_for_test, + 'copy_database_for_test': tests.copy_c_database_for_test, + 'make_document_for_test': tests.make_c_document_for_test, + 'make_app_with_state': make_oauth_http_app, + 'whitebox': False}), + ] + + +class DatabaseSyncTargetTests(tests.DatabaseBaseTests, + tests.TestCaseWithServer): + + scenarios = (tests.multiply_scenarios(tests.DatabaseBaseTests.scenarios, + target_scenarios) + + c_db_scenarios) + # whitebox true means self.db is the actual local db object + # against which the sync is performed + whitebox = True + + def setUp(self): + super(DatabaseSyncTargetTests, self).setUp() + self.db, self.st = self.create_db_and_target(self) + self.other_changes = [] + + def tearDown(self): + # We delete them explicitly, so that connections are cleanly closed + del self.st + self.db.close() + del self.db + super(DatabaseSyncTargetTests, self).tearDown() + + def receive_doc(self, doc, gen, trans_id): + self.other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + def set_trace_hook(self, callback, shallow=False): + setter = (self.st._set_trace_hook if not shallow else + self.st._set_trace_hook_shallow) + try: + setter(callback) + except NotImplementedError: + self.skipTest("%s does not implement _set_trace_hook" + % (self.st.__class__.__name__,)) + + def test_get_sync_target(self): + self.assertIsNot(None, self.st) + + def test_get_sync_info(self): + self.assertEqual( + ('test', 0, '', 0, ''), self.st.get_sync_info('other')) + + def test_create_doc_updates_sync_info(self): + self.assertEqual( + ('test', 0, '', 0, ''), self.st.get_sync_info('other')) + self.db.create_doc_from_json(simple_doc) + self.assertEqual(1, self.st.get_sync_info('other')[1]) + + def test_record_sync_info(self): + self.st.record_sync_info('replica', 10, 'T-transid') + self.assertEqual( + ('test', 0, '', 10, 'T-transid'), self.st.get_sync_info('replica')) + + def test_sync_exchange(self): + docs_by_gen = [ + (self.make_document('doc-id', 'replica:1', simple_doc), 10, + 'T-sid')] + new_gen, trans_id = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False) + self.assertTransactionLog(['doc-id'], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 1, last_trans_id), + (self.other_changes, new_gen, last_trans_id)) + self.assertEqual(10, self.st.get_sync_info('replica')[3]) + + def test_sync_exchange_deleted(self): + doc = self.db.create_doc_from_json('{}') + edit_rev = 'replica:1|' + doc.rev + docs_by_gen = [ + (self.make_document(doc.doc_id, edit_rev, None), 10, 'T-sid')] + new_gen, trans_id = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, edit_rev, None, False) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 2, last_trans_id), + (self.other_changes, new_gen, trans_id)) + self.assertEqual(10, self.st.get_sync_info('replica')[3]) + + def test_sync_exchange_push_many(self): + docs_by_gen = [ + (self.make_document('doc-id', 'replica:1', simple_doc), 10, 'T-1'), + (self.make_document('doc-id2', 'replica:1', nested_doc), 11, + 'T-2')] + new_gen, trans_id = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False) + self.assertGetDoc(self.db, 'doc-id2', 'replica:1', nested_doc, False) + self.assertTransactionLog(['doc-id', 'doc-id2'], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 2, last_trans_id), + (self.other_changes, new_gen, trans_id)) + self.assertEqual(11, self.st.get_sync_info('replica')[3]) + + def test_sync_exchange_refuses_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'replica:1', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertEqual( + (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1]) + self.assertEqual(1, new_gen) + if self.whitebox: + self.assertEqual(self.db._last_exchange_log['return'], + {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]}) + + def test_sync_exchange_ignores_convergence(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + gen, txid = self.db._get_generation_info() + docs_by_gen = [ + (self.make_document(doc.doc_id, doc.rev, simple_doc), 10, 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=gen, + last_known_trans_id=txid, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertEqual(([], 1), (self.other_changes, new_gen)) + + def test_sync_exchange_returns_new_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_gen, _ = self.st.sync_exchange( + [], 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertEqual( + (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1]) + self.assertEqual(1, new_gen) + if self.whitebox: + self.assertEqual(self.db._last_exchange_log['return'], + {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]}) + + def test_sync_exchange_returns_deleted_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + new_gen, _ = self.st.sync_exchange( + [], 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + self.assertEqual( + (doc.doc_id, doc.rev, None, 2), self.other_changes[0][:-1]) + self.assertEqual(2, new_gen) + if self.whitebox: + self.assertEqual(self.db._last_exchange_log['return'], + {'last_gen': 2, 'docs': [(doc.doc_id, doc.rev)]}) + + def test_sync_exchange_returns_many_new_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db) + new_gen, _ = self.st.sync_exchange( + [], 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db) + self.assertEqual(2, new_gen) + self.assertEqual( + [(doc.doc_id, doc.rev, simple_doc, 1), + (doc2.doc_id, doc2.rev, nested_doc, 2)], + [c[:-1] for c in self.other_changes]) + if self.whitebox: + self.assertEqual( + self.db._last_exchange_log['return'], + {'last_gen': 2, 'docs': + [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]}) + + def test_sync_exchange_getting_newer_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + self.assertEqual(([], 2), (self.other_changes, new_gen)) + + def test_sync_exchange_with_concurrent_updates_of_synced_doc(self): + expected = [] + + def before_whatschanged_cb(state): + if state != 'before whats_changed': + return + cont = '{"key": "cuncurrent"}' + conc_rev = self.db.put_doc( + self.make_document(doc.doc_id, 'test:1|z:2', cont)) + expected.append((doc.doc_id, conc_rev, cont, 3)) + + self.set_trace_hook(before_whatschanged_cb) + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertEqual(expected, [c[:-1] for c in self.other_changes]) + self.assertEqual(3, new_gen) + + def test_sync_exchange_with_concurrent_updates(self): + + def after_whatschanged_cb(state): + if state != 'after whats_changed': + return + self.db.create_doc_from_json('{"new": "doc"}') + + self.set_trace_hook(after_whatschanged_cb) + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertEqual(([], 2), (self.other_changes, new_gen)) + + def test_sync_exchange_converged_handling(self): + doc = self.db.create_doc_from_json(simple_doc) + docs_by_gen = [ + (self.make_document('new', 'other:1', '{}'), 4, 'T-foo'), + (self.make_document(doc.doc_id, doc.rev, doc.get_json()), 5, + 'T-bar')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertEqual(([], 2), (self.other_changes, new_gen)) + + def test_sync_exchange_detect_incomplete_exchange(self): + def before_get_docs_explode(state): + if state != 'before get_docs': + return + raise errors.U1DBError("fail") + self.set_trace_hook(before_get_docs_explode) + # suppress traceback printing in the wsgiref server + self.patch(simple_server.ServerHandler, + 'log_exception', lambda h, exc_info: None) + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertRaises( + (errors.U1DBError, errors.BrokenSyncStream), + self.st.sync_exchange, [], 'other-replica', + last_known_generation=0, last_known_trans_id=None, + return_doc_cb=self.receive_doc) + + def test_sync_exchange_doc_ids(self): + sync_exchange_doc_ids = getattr(self.st, 'sync_exchange_doc_ids', None) + if sync_exchange_doc_ids is None: + self.skipTest("sync_exchange_doc_ids not implemented") + db2 = self.create_database('test2') + doc = db2.create_doc_from_json(simple_doc) + new_gen, trans_id = sync_exchange_doc_ids( + db2, [(doc.doc_id, 10, 'T-sid')], 0, None, + return_doc_cb=self.receive_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + self.assertTransactionLog([doc.doc_id], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 1, last_trans_id), + (self.other_changes, new_gen, trans_id)) + self.assertEqual(10, self.st.get_sync_info(db2._replica_uid)[3]) + + def test__set_trace_hook(self): + called = [] + + def cb(state): + called.append(state) + + self.set_trace_hook(cb) + self.st.sync_exchange([], 'replica', 0, None, self.receive_doc) + self.st.record_sync_info('replica', 0, 'T-sid') + self.assertEqual(['before whats_changed', + 'after whats_changed', + 'before get_docs', + 'record_sync_info', + ], + called) + + def test__set_trace_hook_shallow(self): + if (self.st._set_trace_hook_shallow == self.st._set_trace_hook + or self.st._set_trace_hook_shallow.im_func == + SyncTarget._set_trace_hook_shallow.im_func): + # shallow same as full + expected = ['before whats_changed', + 'after whats_changed', + 'before get_docs', + 'record_sync_info', + ] + else: + expected = ['sync_exchange', 'record_sync_info'] + + called = [] + + def cb(state): + called.append(state) + + self.set_trace_hook(cb, shallow=True) + self.st.sync_exchange([], 'replica', 0, None, self.receive_doc) + self.st.record_sync_info('replica', 0, 'T-sid') + self.assertEqual(expected, called) + + +def sync_via_synchronizer(test, db_source, db_target, trace_hook=None, + trace_hook_shallow=None): + target = db_target.get_sync_target() + trace_hook = trace_hook or trace_hook_shallow + if trace_hook: + target._set_trace_hook(trace_hook) + return sync.Synchronizer(db_source, target).sync() + + +sync_scenarios = [] +for name, scenario in tests.LOCAL_DATABASES_SCENARIOS: + scenario = dict(scenario) + scenario['do_sync'] = sync_via_synchronizer + sync_scenarios.append((name, scenario)) + scenario = dict(scenario) + + +def make_database_for_http_test(test, replica_uid): + if test.server is None: + test.startServer() + db = test.request_state._create_database(replica_uid) + try: + http_at = test._http_at + except AttributeError: + http_at = test._http_at = {} + http_at[db] = replica_uid + return db + + +def copy_database_for_http_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR HOUSE. + if test.server is None: + test.startServer() + new_db = test.request_state._copy_database(db) + try: + http_at = test._http_at + except AttributeError: + http_at = test._http_at = {} + path = db._replica_uid + while path in http_at.values(): + path += 'copy' + http_at[new_db] = path + return new_db + + +def sync_via_synchronizer_and_http(test, db_source, db_target, + trace_hook=None, trace_hook_shallow=None): + if trace_hook: + test.skipTest("full trace hook unsupported over http") + path = test._http_at[db_target] + target = http_target.HTTPSyncTarget.connect(test.getURL(path)) + if trace_hook_shallow: + target._set_trace_hook_shallow(trace_hook_shallow) + return sync.Synchronizer(db_source, target).sync() + + +sync_scenarios.append(('pyhttp', { + 'make_database_for_test': make_database_for_http_test, + 'copy_database_for_test': copy_database_for_http_test, + 'make_document_for_test': tests.make_document_for_test, + 'make_app_with_state': make_http_app, + 'do_sync': sync_via_synchronizer_and_http + })) + + +if tests.c_backend_wrapper is not None: + # TODO: We should hook up sync tests with an HTTP target + def sync_via_c_sync(test, db_source, db_target, trace_hook=None, + trace_hook_shallow=None): + target = db_target.get_sync_target() + trace_hook = trace_hook or trace_hook_shallow + if trace_hook: + target._set_trace_hook(trace_hook) + return tests.c_backend_wrapper.sync_db_to_target(db_source, target) + + for name, scenario in tests.C_DATABASE_SCENARIOS: + scenario = dict(scenario) + scenario['do_sync'] = sync_via_synchronizer + sync_scenarios.append((name + ',pysync', scenario)) + scenario = dict(scenario) + scenario['do_sync'] = sync_via_c_sync + sync_scenarios.append((name + ',csync', scenario)) + + +class DatabaseSyncTests(tests.DatabaseBaseTests, + tests.TestCaseWithServer): + + scenarios = sync_scenarios + do_sync = None # set by scenarios + + def create_database(self, replica_uid, sync_role=None): + if replica_uid == 'test' and sync_role is None: + # created up the chain by base class but unused + return None + db = self.create_database_for_role(replica_uid, sync_role) + if sync_role: + self._use_tracking[db] = (replica_uid, sync_role) + return db + + def create_database_for_role(self, replica_uid, sync_role): + # hook point for reuse + return super(DatabaseSyncTests, self).create_database(replica_uid) + + def copy_database(self, db, sync_role=None): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES + # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST + # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS + # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND + # NINJA TO YOUR HOUSE. + db_copy = super(DatabaseSyncTests, self).copy_database(db) + name, orig_sync_role = self._use_tracking[db] + self._use_tracking[db_copy] = (name + '(copy)', sync_role + or orig_sync_role) + return db_copy + + def sync(self, db_from, db_to, trace_hook=None, + trace_hook_shallow=None): + from_name, from_sync_role = self._use_tracking[db_from] + to_name, to_sync_role = self._use_tracking[db_to] + if from_sync_role not in ('source', 'both'): + raise Exception("%s marked for %s use but used as source" % + (from_name, from_sync_role)) + if to_sync_role not in ('target', 'both'): + raise Exception("%s marked for %s use but used as target" % + (to_name, to_sync_role)) + return self.do_sync(self, db_from, db_to, trace_hook, + trace_hook_shallow) + + def setUp(self): + self._use_tracking = {} + super(DatabaseSyncTests, self).setUp() + + def assertLastExchangeLog(self, db, expected): + log = getattr(db, '_last_exchange_log', None) + if log is None: + return + self.assertEqual(expected, log) + + def test_sync_tracks_db_generation_of_other(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.assertEqual(0, self.sync(self.db1, self.db2)) + self.assertEqual( + (0, ''), self.db1._get_replica_gen_and_trans_id('test2')) + self.assertEqual( + (0, ''), self.db2._get_replica_gen_and_trans_id('test1')) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [], 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 0}}) + + def test_sync_autoresolves(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc1 = self.db1.create_doc_from_json(simple_doc, doc_id='doc') + rev1 = doc1.rev + doc2 = self.db2.create_doc_from_json(simple_doc, doc_id='doc') + rev2 = doc2.rev + self.sync(self.db1, self.db2) + doc = self.db1.get_doc('doc') + self.assertFalse(doc.has_conflicts) + self.assertEqual(doc.rev, self.db2.get_doc('doc').rev) + v = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev1))) + self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev2))) + + def test_sync_autoresolves_moar(self): + # here we test that when a database that has a conflicted document is + # the source of a sync, and the target database has a revision of the + # conflicted document that is newer than the source database's, and + # that target's database's document's content is the same as the + # source's document's conflict's, the source's document's conflict gets + # autoresolved, and the source's document's revision bumped. + # + # idea is as follows: + # A B + # a1 - + # `-------> + # a1 a1 + # v v + # a2 a1b1 + # `-------> + # a1b1+a2 a1b1 + # v + # a1b1+a2 a1b2 (a1b2 has same content as a2) + # `-------> + # a3b2 a1b2 (autoresolved) + # `-------> + # a3b2 a3b2 + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(simple_doc, doc_id='doc') + self.sync(self.db1, self.db2) + for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]: + doc = db.get_doc('doc') + doc.set_json(content) + db.put_doc(doc) + self.sync(self.db1, self.db2) + # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict + doc = self.db1.get_doc('doc') + rev1 = doc.rev + self.assertTrue(doc.has_conflicts) + # set db2 to have a doc of {} (same as db1 before the conflict) + doc = self.db2.get_doc('doc') + doc.set_json('{}') + self.db2.put_doc(doc) + rev2 = doc.rev + # sync it across + self.sync(self.db1, self.db2) + # tadaa! + doc = self.db1.get_doc('doc') + self.assertFalse(doc.has_conflicts) + vec1 = vectorclock.VectorClockRev(rev1) + vec2 = vectorclock.VectorClockRev(rev2) + vec3 = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(vec3.is_newer(vec1)) + self.assertTrue(vec3.is_newer(vec2)) + # because the conflict is on the source, sync it another time + self.sync(self.db1, self.db2) + # make sure db2 now has the exact same thing + self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) + + def test_sync_autoresolves_moar_backwards(self): + # here we test that when a database that has a conflicted document is + # the target of a sync, and the source database has a revision of the + # conflicted document that is newer than the target database's, and + # that source's database's document's content is the same as the + # target's document's conflict's, the target's document's conflict gets + # autoresolved, and the document's revision bumped. + # + # idea is as follows: + # A B + # a1 - + # `-------> + # a1 a1 + # v v + # a2 a1b1 + # `-------> + # a1b1+a2 a1b1 + # v + # a1b1+a2 a1b2 (a1b2 has same content as a2) + # <-------' + # a3b2 a3b2 (autoresolved and propagated) + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'both') + self.db1.create_doc_from_json(simple_doc, doc_id='doc') + self.sync(self.db1, self.db2) + for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]: + doc = db.get_doc('doc') + doc.set_json(content) + db.put_doc(doc) + self.sync(self.db1, self.db2) + # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict + doc = self.db1.get_doc('doc') + rev1 = doc.rev + self.assertTrue(doc.has_conflicts) + revc = self.db1.get_doc_conflicts('doc')[-1].rev + # set db2 to have a doc of {} (same as db1 before the conflict) + doc = self.db2.get_doc('doc') + doc.set_json('{}') + self.db2.put_doc(doc) + rev2 = doc.rev + # sync it across + self.sync(self.db2, self.db1) + # tadaa! + doc = self.db1.get_doc('doc') + self.assertFalse(doc.has_conflicts) + vec1 = vectorclock.VectorClockRev(rev1) + vec2 = vectorclock.VectorClockRev(rev2) + vec3 = vectorclock.VectorClockRev(doc.rev) + vecc = vectorclock.VectorClockRev(revc) + self.assertTrue(vec3.is_newer(vec1)) + self.assertTrue(vec3.is_newer(vec2)) + self.assertTrue(vec3.is_newer(vecc)) + # make sure db2 now has the exact same thing + self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) + + def test_sync_autoresolves_moar_backwards_three(self): + # same as autoresolves_moar_backwards, but with three databases (note + # all the syncs go in the same direction -- this is a more natural + # scenario): + # + # A B C + # a1 - - + # `-------> + # a1 a1 - + # `-------> + # a1 a1 a1 + # v v + # a2 a1b1 a1 + # `-------------------> + # a2 a1b1 a2 + # `-------> + # a2+a1b1 a2 + # v + # a2 a2+a1b1 a2c1 (same as a1b1) + # `-------------------> + # a2c1 a2+a1b1 a2c1 + # `-------> + # a2b2c1 a2b2c1 a2c1 + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'both') + self.db3 = self.create_database('test3', 'target') + self.db1.create_doc_from_json(simple_doc, doc_id='doc') + self.sync(self.db1, self.db2) + self.sync(self.db2, self.db3) + for db, content in [(self.db2, '{"hi": 42}'), + (self.db1, '{}'), + ]: + doc = db.get_doc('doc') + doc.set_json(content) + db.put_doc(doc) + self.sync(self.db1, self.db3) + self.sync(self.db2, self.db3) + # db2 and db3 now both have a doc of {}, but db2 has a + # conflict + doc = self.db2.get_doc('doc') + self.assertTrue(doc.has_conflicts) + revc = self.db2.get_doc_conflicts('doc')[-1].rev + self.assertEqual('{}', doc.get_json()) + self.assertEqual(self.db3.get_doc('doc').get_json(), doc.get_json()) + self.assertEqual(self.db3.get_doc('doc').rev, doc.rev) + # set db3 to have a doc of {hi:42} (same as db2 before the conflict) + doc = self.db3.get_doc('doc') + doc.set_json('{"hi": 42}') + self.db3.put_doc(doc) + rev3 = doc.rev + # sync it across to db1 + self.sync(self.db1, self.db3) + # db1 now has hi:42, with a rev that is newer than db2's doc + doc = self.db1.get_doc('doc') + rev1 = doc.rev + self.assertFalse(doc.has_conflicts) + self.assertEqual('{"hi": 42}', doc.get_json()) + VCR = vectorclock.VectorClockRev + self.assertTrue(VCR(rev1).is_newer(VCR(self.db2.get_doc('doc').rev))) + # so sync it to db2 + self.sync(self.db1, self.db2) + # tadaa! + doc = self.db2.get_doc('doc') + self.assertFalse(doc.has_conflicts) + # db2's revision of the document is strictly newer than db1's before + # the sync, and db3's before that sync way back when + self.assertTrue(VCR(doc.rev).is_newer(VCR(rev1))) + self.assertTrue(VCR(doc.rev).is_newer(VCR(rev3))) + self.assertTrue(VCR(doc.rev).is_newer(VCR(revc))) + # make sure both dbs now have the exact same thing + self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) + + def test_sync_puts_changes(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db1.create_doc_from_json(simple_doc) + self.assertEqual(1, self.sync(self.db1, self.db2)) + self.assertGetDoc(self.db2, doc.doc_id, doc.rev, simple_doc, False) + self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) + self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0]) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc.doc_id, doc.rev)], + 'source_uid': 'test1', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 1}}) + + def test_sync_pulls_changes(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db2.create_doc_from_json(simple_doc) + self.db1.create_index('test-idx', 'key') + self.assertEqual(0, self.sync(self.db1, self.db2)) + self.assertGetDoc(self.db1, doc.doc_id, doc.rev, simple_doc, False) + self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) + self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0]) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [], 'last_known_gen': 0}, + 'return': {'docs': [(doc.doc_id, doc.rev)], + 'last_gen': 1}}) + self.assertEqual([doc], self.db1.get_from_index('test-idx', 'value')) + + def test_sync_pulling_doesnt_update_other_if_changed(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db2.create_doc_from_json(simple_doc) + # After the local side has sent its list of docs, before we start + # receiving the "targets" response, we update the local database with a + # new record. + # When we finish synchronizing, we can notice that something locally + # was updated, and we cannot tell c2 our new updated generation + + def before_get_docs(state): + if state != 'before get_docs': + return + self.db1.create_doc_from_json(simple_doc) + + self.assertEqual(0, self.sync(self.db1, self.db2, + trace_hook=before_get_docs)) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [], 'last_known_gen': 0}, + 'return': {'docs': [(doc.doc_id, doc.rev)], + 'last_gen': 1}}) + self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) + # c2 should not have gotten a '_record_sync_info' call, because the + # local database had been updated more than just by the messages + # returned from c2. + self.assertEqual( + (0, ''), self.db2._get_replica_gen_and_trans_id('test1')) + + def test_sync_doesnt_update_other_if_nothing_pulled(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(simple_doc) + + def no_record_sync_info(state): + if state != 'record_sync_info': + return + self.fail('SyncTarget.record_sync_info was called') + self.assertEqual(1, self.sync(self.db1, self.db2, + trace_hook_shallow=no_record_sync_info)) + self.assertEqual( + 1, + self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)[0]) + + def test_sync_ignores_convergence(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'both') + doc = self.db1.create_doc_from_json(simple_doc) + self.db3 = self.create_database('test3', 'target') + self.assertEqual(1, self.sync(self.db1, self.db3)) + self.assertEqual(0, self.sync(self.db2, self.db3)) + self.assertEqual(1, self.sync(self.db1, self.db2)) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc.doc_id, doc.rev)], + 'source_uid': 'test1', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 1}}) + + def test_sync_ignores_superseded(self): + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'both') + doc = self.db1.create_doc_from_json(simple_doc) + doc_rev1 = doc.rev + self.db3 = self.create_database('test3', 'target') + self.sync(self.db1, self.db3) + self.sync(self.db2, self.db3) + new_content = '{"key": "altval"}' + doc.set_json(new_content) + self.db1.put_doc(doc) + doc_rev2 = doc.rev + self.sync(self.db2, self.db1) + self.assertLastExchangeLog(self.db1, + {'receive': {'docs': [(doc.doc_id, doc_rev1)], + 'source_uid': 'test2', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [(doc.doc_id, doc_rev2)], + 'last_gen': 2}}) + self.assertGetDoc(self.db1, doc.doc_id, doc_rev2, new_content, False) + + def test_sync_sees_remote_conflicted(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc1 = self.db1.create_doc_from_json(simple_doc) + doc_id = doc1.doc_id + doc1_rev = doc1.rev + self.db1.create_index('test-idx', 'key') + new_doc = '{"key": "altval"}' + doc2 = self.db2.create_doc_from_json(new_doc, doc_id=doc_id) + doc2_rev = doc2.rev + self.assertTransactionLog([doc1.doc_id], self.db1) + self.sync(self.db1, self.db2) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc_id, doc1_rev)], + 'source_uid': 'test1', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [(doc_id, doc2_rev)], + 'last_gen': 1}}) + self.assertTransactionLog([doc_id, doc_id], self.db1) + self.assertGetDoc(self.db1, doc_id, doc2_rev, new_doc, True) + self.assertGetDoc(self.db2, doc_id, doc2_rev, new_doc, False) + from_idx = self.db1.get_from_index('test-idx', 'altval')[0] + self.assertEqual(doc2.doc_id, from_idx.doc_id) + self.assertEqual(doc2.rev, from_idx.rev) + self.assertTrue(from_idx.has_conflicts) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + + def test_sync_sees_remote_delete_conflicted(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc1 = self.db1.create_doc_from_json(simple_doc) + doc_id = doc1.doc_id + self.db1.create_index('test-idx', 'key') + self.sync(self.db1, self.db2) + doc2 = self.make_document(doc1.doc_id, doc1.rev, doc1.get_json()) + new_doc = '{"key": "altval"}' + doc1.set_json(new_doc) + self.db1.put_doc(doc1) + self.db2.delete_doc(doc2) + self.assertTransactionLog([doc_id, doc_id], self.db1) + self.sync(self.db1, self.db2) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc_id, doc1.rev)], + 'source_uid': 'test1', + 'source_gen': 2, 'last_known_gen': 1}, + 'return': {'docs': [(doc_id, doc2.rev)], + 'last_gen': 2}}) + self.assertTransactionLog([doc_id, doc_id, doc_id], self.db1) + self.assertGetDocIncludeDeleted(self.db1, doc_id, doc2.rev, None, True) + self.assertGetDocIncludeDeleted( + self.db2, doc_id, doc2.rev, None, False) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + + def test_sync_local_race_conflicted(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db1.create_doc_from_json(simple_doc) + doc_id = doc.doc_id + doc1_rev = doc.rev + self.db1.create_index('test-idx', 'key') + self.sync(self.db1, self.db2) + content1 = '{"key": "localval"}' + content2 = '{"key": "altval"}' + doc.set_json(content2) + self.db2.put_doc(doc) + doc2_rev2 = doc.rev + triggered = [] + + def after_whatschanged(state): + if state != 'after whats_changed': + return + triggered.append(True) + doc = self.make_document(doc_id, doc1_rev, content1) + self.db1.put_doc(doc) + + self.sync(self.db1, self.db2, trace_hook=after_whatschanged) + self.assertEqual([True], triggered) + self.assertGetDoc(self.db1, doc_id, doc2_rev2, content2, True) + from_idx = self.db1.get_from_index('test-idx', 'altval')[0] + self.assertEqual(doc.doc_id, from_idx.doc_id) + self.assertEqual(doc.rev, from_idx.rev) + self.assertTrue(from_idx.has_conflicts) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + self.assertEqual([], self.db1.get_from_index('test-idx', 'localval')) + + def test_sync_propagates_deletes(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'both') + doc1 = self.db1.create_doc_from_json(simple_doc) + doc_id = doc1.doc_id + self.db1.create_index('test-idx', 'key') + self.sync(self.db1, self.db2) + self.db2.create_index('test-idx', 'key') + self.db3 = self.create_database('test3', 'target') + self.sync(self.db1, self.db3) + self.db1.delete_doc(doc1) + deleted_rev = doc1.rev + self.sync(self.db1, self.db2) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc_id, deleted_rev)], + 'source_uid': 'test1', + 'source_gen': 2, 'last_known_gen': 1}, + 'return': {'docs': [], 'last_gen': 2}}) + self.assertGetDocIncludeDeleted( + self.db1, doc_id, deleted_rev, None, False) + self.assertGetDocIncludeDeleted( + self.db2, doc_id, deleted_rev, None, False) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + self.assertEqual([], self.db2.get_from_index('test-idx', 'value')) + self.sync(self.db2, self.db3) + self.assertLastExchangeLog(self.db3, + {'receive': {'docs': [(doc_id, deleted_rev)], + 'source_uid': 'test2', + 'source_gen': 2, 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 2}}) + self.assertGetDocIncludeDeleted( + self.db3, doc_id, deleted_rev, None, False) + + def test_sync_propagates_resolution(self): + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'both') + doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc') + db3 = self.create_database('test3', 'both') + self.sync(self.db2, self.db1) + self.assertEqual( + self.db1._get_generation_info(), + self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)) + self.assertEqual( + self.db2._get_generation_info(), + self.db1._get_replica_gen_and_trans_id(self.db2._replica_uid)) + self.sync(db3, self.db1) + # update on 2 + doc2 = self.make_document('the-doc', doc1.rev, '{"a": 2}') + self.db2.put_doc(doc2) + self.sync(self.db2, db3) + self.assertEqual(db3.get_doc('the-doc').rev, doc2.rev) + # update on 1 + doc1.set_json('{"a": 3}') + self.db1.put_doc(doc1) + # conflicts + self.sync(self.db2, self.db1) + self.sync(db3, self.db1) + self.assertTrue(self.db2.get_doc('the-doc').has_conflicts) + self.assertTrue(db3.get_doc('the-doc').has_conflicts) + # resolve + conflicts = self.db2.get_doc_conflicts('the-doc') + doc4 = self.make_document('the-doc', None, '{"a": 4}') + revs = [doc.rev for doc in conflicts] + self.db2.resolve_doc(doc4, revs) + doc2 = self.db2.get_doc('the-doc') + self.assertEqual(doc4.get_json(), doc2.get_json()) + self.assertFalse(doc2.has_conflicts) + self.sync(self.db2, db3) + doc3 = db3.get_doc('the-doc') + self.assertEqual(doc4.get_json(), doc3.get_json()) + self.assertFalse(doc3.has_conflicts) + + def test_sync_supersedes_conflicts(self): + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'target') + db3 = self.create_database('test3', 'both') + doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc') + self.db2.create_doc_from_json('{"b": 1}', doc_id='the-doc') + db3.create_doc_from_json('{"c": 1}', doc_id='the-doc') + self.sync(db3, self.db1) + self.assertEqual( + self.db1._get_generation_info(), + db3._get_replica_gen_and_trans_id(self.db1._replica_uid)) + self.assertEqual( + db3._get_generation_info(), + self.db1._get_replica_gen_and_trans_id(db3._replica_uid)) + self.sync(db3, self.db2) + self.assertEqual( + self.db2._get_generation_info(), + db3._get_replica_gen_and_trans_id(self.db2._replica_uid)) + self.assertEqual( + db3._get_generation_info(), + self.db2._get_replica_gen_and_trans_id(db3._replica_uid)) + self.assertEqual(3, len(db3.get_doc_conflicts('the-doc'))) + doc1.set_json('{"a": 2}') + self.db1.put_doc(doc1) + self.sync(db3, self.db1) + # original doc1 should have been removed from conflicts + self.assertEqual(3, len(db3.get_doc_conflicts('the-doc'))) + + def test_sync_stops_after_get_sync_info(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc) + self.sync(self.db1, self.db2) + + def put_hook(state): + self.fail("Tracehook triggered for %s" % (state,)) + + self.sync(self.db1, self.db2, trace_hook_shallow=put_hook) + + def test_sync_detects_rollback_in_source(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1') + self.sync(self.db1, self.db2) + db1_copy = self.copy_database(self.db1) + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidGeneration, self.sync, db1_copy, self.db2) + + def test_sync_detects_rollback_in_target(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + db2_copy = self.copy_database(self.db2) + self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidGeneration, self.sync, self.db1, db2_copy) + + def test_sync_detects_diverged_source(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + db3 = self.copy_database(self.db1) + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + db3.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidTransactionId, self.sync, db3, self.db2) + + def test_sync_detects_diverged_target(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + db3 = self.copy_database(self.db2) + db3.create_doc_from_json(tests.nested_doc, doc_id="divergent") + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidTransactionId, self.sync, self.db1, db3) + + def test_sync_detects_rollback_and_divergence_in_source(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1') + self.sync(self.db1, self.db2) + db1_copy = self.copy_database(self.db1) + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.sync(self.db1, self.db2) + db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2') + db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.assertRaises( + errors.InvalidTransactionId, self.sync, db1_copy, self.db2) + + def test_sync_detects_rollback_and_divergence_in_target(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + db2_copy = self.copy_database(self.db2) + self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.sync(self.db1, self.db2) + db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2') + db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.assertRaises( + errors.InvalidTransactionId, self.sync, self.db1, db2_copy) + + +class TestDbSync(tests.TestCaseWithServer): + """Test db.sync remote sync shortcut""" + + scenarios = [ + ('py-http', { + 'make_app_with_state': make_http_app, + 'make_database_for_test': tests.make_memory_database_for_test, + }), + ('c-http', { + 'make_app_with_state': make_http_app, + 'make_database_for_test': tests.make_c_database_for_test + }), + ('py-oauth-http', { + 'make_app_with_state': make_oauth_http_app, + 'make_database_for_test': tests.make_memory_database_for_test, + 'oauth': True + }), + ('c-oauth-http', { + 'make_app_with_state': make_oauth_http_app, + 'make_database_for_test': tests.make_c_database_for_test, + 'oauth': True + }), + ] + + oauth = False + + def do_sync(self, target_name): + if self.oauth: + path = '~/' + target_name + extra = dict(creds={'oauth': { + 'consumer_key': tests.consumer1.key, + 'consumer_secret': tests.consumer1.secret, + 'token_key': tests.token1.key, + 'token_secret': tests.token1.secret + }}) + else: + path = target_name + extra = {} + target_url = self.getURL(path) + return self.db.sync(target_url, **extra) + + def setUp(self): + super(TestDbSync, self).setUp() + self.startServer() + self.db = self.make_database_for_test(self, 'test1') + self.db2 = self.request_state._create_database('test2.db') + + def test_db_sync(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db2.create_doc_from_json(tests.nested_doc) + local_gen_before_sync = self.do_sync('test2.db') + gen, _, changes = self.db.whats_changed(local_gen_before_sync) + self.assertEqual(1, len(changes)) + self.assertEqual(doc2.doc_id, changes[0][0]) + self.assertEqual(1, gen - local_gen_before_sync) + self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc, + False) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, + False) + + def test_db_sync_autocreate(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + local_gen_before_sync = self.do_sync('test3.db') + gen, _, changes = self.db.whats_changed(local_gen_before_sync) + self.assertEqual(0, gen - local_gen_before_sync) + db3 = self.request_state.open_database('test3.db') + gen, _, changes = db3.whats_changed() + self.assertEqual(1, len(changes)) + self.assertEqual(doc1.doc_id, changes[0][0]) + self.assertGetDoc(db3, doc1.doc_id, doc1.rev, tests.simple_doc, + False) + t_gen, _ = self.db._get_replica_gen_and_trans_id('test3.db') + s_gen, _ = db3._get_replica_gen_and_trans_id('test1') + self.assertEqual(1, t_gen) + self.assertEqual(1, s_gen) + + +class TestRemoteSyncIntegration(tests.TestCaseWithServer): + """Integration tests for the most common sync scenario local -> remote""" + + make_app_with_state = staticmethod(make_http_app) + + def setUp(self): + super(TestRemoteSyncIntegration, self).setUp() + self.startServer() + self.db1 = inmemory.InMemoryDatabase('test1') + self.db2 = self.request_state._create_database('test2') + + def test_sync_tracks_generations_incrementally(self): + doc11 = self.db1.create_doc_from_json('{"a": 1}') + doc12 = self.db1.create_doc_from_json('{"a": 2}') + doc21 = self.db2.create_doc_from_json('{"b": 1}') + doc22 = self.db2.create_doc_from_json('{"b": 2}') + #sanity + self.assertEqual(2, len(self.db1._get_transaction_log())) + self.assertEqual(2, len(self.db2._get_transaction_log())) + progress1 = [] + progress2 = [] + _do_set_replica_gen_and_trans_id = \ + self.db1._do_set_replica_gen_and_trans_id + + def set_sync_generation_witness1(other_uid, other_gen, trans_id): + progress1.append((other_uid, other_gen, + [d for d, t in self.db1._get_transaction_log()[2:]])) + _do_set_replica_gen_and_trans_id(other_uid, other_gen, trans_id) + self.patch(self.db1, '_do_set_replica_gen_and_trans_id', + set_sync_generation_witness1) + _do_set_replica_gen_and_trans_id2 = \ + self.db2._do_set_replica_gen_and_trans_id + + def set_sync_generation_witness2(other_uid, other_gen, trans_id): + progress2.append((other_uid, other_gen, + [d for d, t in self.db2._get_transaction_log()[2:]])) + _do_set_replica_gen_and_trans_id2(other_uid, other_gen, trans_id) + self.patch(self.db2, '_do_set_replica_gen_and_trans_id', + set_sync_generation_witness2) + + db2_url = self.getURL('test2') + self.db1.sync(db2_url) + + self.assertEqual([('test2', 1, [doc21.doc_id]), + ('test2', 2, [doc21.doc_id, doc22.doc_id]), + ('test2', 4, [doc21.doc_id, doc22.doc_id])], + progress1) + self.assertEqual([('test1', 1, [doc11.doc_id]), + ('test1', 2, [doc11.doc_id, doc12.doc_id]), + ('test1', 4, [doc11.doc_id, doc12.doc_id])], + progress2) + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_test_infrastructure.py b/src/leap/soledad/u1db/tests/test_test_infrastructure.py new file mode 100644 index 00000000..b79e0516 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_test_infrastructure.py @@ -0,0 +1,41 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Tests for test infrastructure bits""" + +from wsgiref import simple_server + +from u1db import ( + tests, + ) + + +class TestTestCaseWithServer(tests.TestCaseWithServer): + + def make_app(self): + return "app" + + @staticmethod + def server_def(): + def make_server(host_port, application): + assert application == "app" + return simple_server.WSGIServer(host_port, None) + return (make_server, "shutdown", "http") + + def test_getURL(self): + self.startServer() + url = self.getURL() + self.assertTrue(url.startswith('http://127.0.0.1:')) diff --git a/src/leap/soledad/u1db/tests/test_vectorclock.py b/src/leap/soledad/u1db/tests/test_vectorclock.py new file mode 100644 index 00000000..72baf246 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_vectorclock.py @@ -0,0 +1,121 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""VectorClockRev helper class tests.""" + +from u1db import tests, vectorclock + +try: + from u1db.tests import c_backend_wrapper +except ImportError: + c_backend_wrapper = None + + +c_vectorclock_scenarios = [] +if c_backend_wrapper is not None: + c_vectorclock_scenarios.append( + ('c', {'create_vcr': c_backend_wrapper.VectorClockRev})) + + +class TestVectorClockRev(tests.TestCase): + + scenarios = [('py', {'create_vcr': vectorclock.VectorClockRev}) + ] + c_vectorclock_scenarios + + def assertIsNewer(self, newer_rev, older_rev): + new_vcr = self.create_vcr(newer_rev) + old_vcr = self.create_vcr(older_rev) + self.assertTrue(new_vcr.is_newer(old_vcr)) + self.assertFalse(old_vcr.is_newer(new_vcr)) + + def assertIsConflicted(self, rev_a, rev_b): + vcr_a = self.create_vcr(rev_a) + vcr_b = self.create_vcr(rev_b) + self.assertFalse(vcr_a.is_newer(vcr_b)) + self.assertFalse(vcr_b.is_newer(vcr_a)) + + def assertRoundTrips(self, rev): + self.assertEqual(rev, self.create_vcr(rev).as_str()) + + def test__is_newer_doc_rev(self): + self.assertIsNewer('test:1', None) + self.assertIsNewer('test:2', 'test:1') + self.assertIsNewer('other:2|test:1', 'other:1|test:1') + self.assertIsNewer('other:1|test:1', 'other:1') + self.assertIsNewer('a:2|b:1', 'b:1') + self.assertIsNewer('a:1|b:2', 'a:1') + self.assertIsConflicted('other:2|test:1', 'other:1|test:2') + self.assertIsConflicted('other:1|test:1', 'other:2') + self.assertIsConflicted('test:1', 'test:1') + + def test_None(self): + vcr = self.create_vcr(None) + self.assertEqual('', vcr.as_str()) + + def test_round_trips(self): + self.assertRoundTrips('test:1') + self.assertRoundTrips('a:1|b:2') + self.assertRoundTrips('alternate:2|test:1') + + def test_handles_sort_order(self): + self.assertEqual('a:1|b:2', self.create_vcr('b:2|a:1').as_str()) + # Last one out of place + self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6', + self.create_vcr('f:6|a:1|b:2|c:3|d:4|e:5').as_str()) + # Fully reversed + self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6', + self.create_vcr('f:6|e:5|d:4|c:3|b:2|a:1').as_str()) + + def assertIncrement(self, original, replica_uid, after_increment): + vcr = self.create_vcr(original) + vcr.increment(replica_uid) + self.assertEqual(after_increment, vcr.as_str()) + + def test_increment(self): + self.assertIncrement(None, 'test', 'test:1') + self.assertIncrement('test:1', 'test', 'test:2') + + def test_increment_adds_uid(self): + self.assertIncrement('other:1', 'test', 'other:1|test:1') + self.assertIncrement('a:1|ab:2', 'aa', 'a:1|aa:1|ab:2') + + def test_increment_update_partial(self): + self.assertIncrement('a:1|ab:2', 'a', 'a:2|ab:2') + self.assertIncrement('a:2|ab:2', 'ab', 'a:2|ab:3') + + def test_increment_appends_uid(self): + self.assertIncrement('b:2', 'c', 'b:2|c:1') + + def assertMaximize(self, rev1, rev2, maximized): + vcr1 = self.create_vcr(rev1) + vcr2 = self.create_vcr(rev2) + vcr1.maximize(vcr2) + self.assertEqual(maximized, vcr1.as_str()) + # reset vcr1 to maximize the other way + vcr1 = self.create_vcr(rev1) + vcr2.maximize(vcr1) + self.assertEqual(maximized, vcr2.as_str()) + + def test_maximize(self): + self.assertMaximize(None, None, '') + self.assertMaximize(None, 'x:1', 'x:1') + self.assertMaximize('x:1', 'y:1', 'x:1|y:1') + self.assertMaximize('x:2', 'x:1', 'x:2') + self.assertMaximize('x:2', 'x:1|y:2', 'x:2|y:2') + self.assertMaximize('a:1|c:2|e:3', 'b:3|d:4|f:5', + 'a:1|b:3|c:2|d:4|e:3|f:5') + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/testing-certs/Makefile b/src/leap/soledad/u1db/tests/testing-certs/Makefile new file mode 100644 index 00000000..2385e75b --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/Makefile @@ -0,0 +1,35 @@ +CATOP=./demoCA +ORIG_CONF=/usr/lib/ssl/openssl.cnf +ELEVEN_YEARS=-days 4015 + +init: + cp $(ORIG_CONF) ca.conf + install -d $(CATOP) + install -d $(CATOP)/certs + install -d $(CATOP)/crl + install -d $(CATOP)/newcerts + install -d $(CATOP)/private + touch $(CATOP)/index.txt + echo 01>$(CATOP)/crlnumber + @echo '**** Making CA certificate ...' + openssl req -nodes -new \ + -newkey rsa -keyout $(CATOP)/private/cakey.pem \ + -out $(CATOP)/careq.pem \ + -multivalue-rdn \ + -subj "/C=UK/ST=-/O=u1db LOCAL TESTING ONLY, DO NO TRUST/CN=u1db testing CA" + openssl ca -config ./ca.conf -create_serial \ + -out $(CATOP)/cacert.pem $(ELEVEN_YEARS) -batch \ + -keyfile $(CATOP)/private/cakey.pem -selfsign \ + -extensions v3_ca -infiles $(CATOP)/careq.pem + +pems: + cp ./demoCA/cacert.pem . + openssl req -new -config ca.conf \ + -multivalue-rdn \ + -subj "/O=u1db LOCAL TESTING ONLY, DO NOT TRUST/CN=localhost" \ + -nodes -keyout testing.key -out newreq.pem $(ELEVEN_YEARS) + openssl ca -batch -config ./ca.conf $(ELEVEN_YEARS) \ + -policy policy_anything \ + -out testing.cert -infiles newreq.pem + +.PHONY: init pems diff --git a/src/leap/soledad/u1db/tests/testing-certs/cacert.pem b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem new file mode 100644 index 00000000..c019a730 --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem @@ -0,0 +1,58 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + e4:de:01:76:c4:78:78:7e + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA + Validity + Not Before: May 3 11:11:11 2012 GMT + Not After : May 1 11:11:11 2023 GMT + Subject: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:bc:91:a5:7f:7d:37:f7:06:c7:db:5b:83:6a:6b: + 63:c3:8b:5c:f7:84:4d:97:6d:d4:be:bf:e7:79:a8: + c1:03:57:ec:90:d4:20:e7:02:95:d9:a6:49:e3:f9: + 9a:ea:37:b9:b2:02:62:ab:40:d3:42:bb:4a:4e:a2: + 47:71:0f:1d:a2:c5:94:a1:cf:35:d3:23:32:42:c0: + 1e:8d:cb:08:58:fb:8a:5c:3e:ea:eb:d5:2c:ed:d6: + aa:09:b4:b5:7d:e3:45:c9:ae:c2:82:b2:ae:c0:81: + bc:24:06:65:a9:e7:e0:61:ac:25:ee:53:d3:d7:be: + 22:f7:00:a2:ad:c6:0e:3a:39 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D + X509v3 Authority Key Identifier: + keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 72:9b:c1:f7:07:65:83:36:25:4e:01:2f:b7:4a:f2:a4:00:28: + 80:c7:56:2c:32:39:90:13:61:4b:bb:12:c5:44:9d:42:57:85: + 28:19:70:69:e1:43:c8:bd:11:f6:94:df:91:2d:c3:ea:82:8d: + b4:8f:5d:47:a3:00:99:53:29:93:27:6c:c5:da:c1:20:6f:ab: + ec:4a:be:34:f3:8f:02:e5:0c:c0:03:ac:2b:33:41:71:4f:0a: + 72:5a:b4:26:1a:7f:81:bc:c0:95:8a:06:87:a8:11:9f:5c:73: + 38:df:5a:69:40:21:29:ad:46:23:56:75:e1:e9:8b:10:18:4c: + 7b:54 +-----BEGIN CERTIFICATE----- +MIICkjCCAfugAwIBAgIJAOTeAXbEeHh+MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV +BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg +T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x +MjA1MDMxMTExMTFaFw0yMzA1MDExMTExMTFaMGIxCzAJBgNVBAYTAlVLMQowCAYD +VQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcgT05MWSwgRE8gTk8g +VFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAvJGlf3039wbH21uDamtjw4tc94RNl23Uvr/neajBA1fskNQg +5wKV2aZJ4/ma6je5sgJiq0DTQrtKTqJHcQ8dosWUoc810yMyQsAejcsIWPuKXD7q +69Us7daqCbS1feNFya7CgrKuwIG8JAZlqefgYawl7lPT174i9wCircYOOjkCAwEA +AaNQME4wHQYDVR0OBBYEFNs9k1FsMhVUjxBQ/ElPNhUou5VtMB8GA1UdIwQYMBaA +FNs9k1FsMhVUjxBQ/ElPNhUou5VtMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF +BQADgYEAcpvB9wdlgzYlTgEvt0rypAAogMdWLDI5kBNhS7sSxUSdQleFKBlwaeFD +yL0R9pTfkS3D6oKNtI9dR6MAmVMpkydsxdrBIG+r7Eq+NPOPAuUMwAOsKzNBcU8K +clq0Jhp/gbzAlYoGh6gRn1xzON9aaUAhKa1GI1Z14emLEBhMe1Q= +-----END CERTIFICATE----- diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.cert b/src/leap/soledad/u1db/tests/testing-certs/testing.cert new file mode 100644 index 00000000..985684fb --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/testing.cert @@ -0,0 +1,61 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + e4:de:01:76:c4:78:78:7f + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA + Validity + Not Before: May 3 11:11:14 2012 GMT + Not After : May 1 11:11:14 2023 GMT + Subject: O=u1db LOCAL TESTING ONLY, DO NOT TRUST, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c6:1d:72:d3:c5:e4:fc:d1:4c:d9:e4:08:3e:90: + 10:ce:3f:1f:87:4a:1d:4f:7f:2a:5a:52:c9:65:4f: + d9:2c:bf:69:75:18:1a:b5:c9:09:32:00:47:f5:60: + aa:c6:dd:3a:87:37:5f:16:be:de:29:b5:ea:fc:41: + 7e:eb:77:bb:df:63:c3:06:1e:ed:e9:a0:67:1a:f1: + ec:e1:9d:f7:9c:8f:1c:fa:c3:66:7b:39:dc:70:ae: + 09:1b:9c:c0:9a:c4:90:77:45:8e:39:95:a9:2f:92: + 43:bd:27:07:5a:99:51:6e:76:a0:af:dd:b1:2c:8f: + ca:8b:8c:47:0d:f6:6e:fc:69 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + 1C:63:85:E1:1D:F3:89:2E:6C:4E:3F:FB:D0:10:64:5A:C1:22:6A:2A + X509v3 Authority Key Identifier: + keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D + + Signature Algorithm: sha1WithRSAEncryption + 1d:6d:3e:bd:93:fd:bd:3e:17:b8:9f:f0:99:7f:db:50:5c:b2: + 01:42:03:b5:d5:94:05:d3:f6:8e:80:82:55:47:1f:58:f2:18: + 6c:ab:ef:43:2c:2f:10:e1:7c:c4:5c:cc:ac:50:50:22:42:aa: + 35:33:f5:b9:f3:a6:66:55:d9:36:f4:f2:e4:d4:d9:b5:2c:52: + 66:d4:21:17:97:22:b8:9b:d7:0e:7c:3d:ce:85:19:ca:c4:d2: + 58:62:31:c6:18:3e:44:fc:f4:30:b6:95:87:ee:21:4a:08:f0: + af:3c:8f:c4:ba:5e:a1:5c:37:1a:7d:7b:fe:66:ae:62:50:17: + 31:ca +-----BEGIN CERTIFICATE----- +MIICnzCCAgigAwIBAgIJAOTeAXbEeHh/MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV +BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg +T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x +MjA1MDMxMTExMTRaFw0yMzA1MDExMTExMTRaMEQxLjAsBgNVBAoMJXUxZGIgTE9D +QUwgVEVTVElORyBPTkxZLCBETyBOT1QgVFJVU1QxEjAQBgNVBAMMCWxvY2FsaG9z +dDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAxh1y08Xk/NFM2eQIPpAQzj8f +h0odT38qWlLJZU/ZLL9pdRgatckJMgBH9WCqxt06hzdfFr7eKbXq/EF+63e732PD +Bh7t6aBnGvHs4Z33nI8c+sNmeznccK4JG5zAmsSQd0WOOZWpL5JDvScHWplRbnag +r92xLI/Ki4xHDfZu/GkCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0E +HxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFBxjheEd +84kubE4/+9AQZFrBImoqMB8GA1UdIwQYMBaAFNs9k1FsMhVUjxBQ/ElPNhUou5Vt +MA0GCSqGSIb3DQEBBQUAA4GBAB1tPr2T/b0+F7if8Jl/21BcsgFCA7XVlAXT9o6A +glVHH1jyGGyr70MsLxDhfMRczKxQUCJCqjUz9bnzpmZV2Tb08uTU2bUsUmbUIReX +Irib1w58Pc6FGcrE0lhiMcYYPkT89DC2lYfuIUoI8K88j8S6XqFcNxp9e/5mrmJQ +FzHK +-----END CERTIFICATE----- diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.key b/src/leap/soledad/u1db/tests/testing-certs/testing.key new file mode 100644 index 00000000..d83d4920 --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/testing.key @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMYdctPF5PzRTNnk +CD6QEM4/H4dKHU9/KlpSyWVP2Sy/aXUYGrXJCTIAR/VgqsbdOoc3Xxa+3im16vxB +fut3u99jwwYe7emgZxrx7OGd95yPHPrDZns53HCuCRucwJrEkHdFjjmVqS+SQ70n +B1qZUW52oK/dsSyPyouMRw32bvxpAgMBAAECgYBs3lXxhjg1rhabTjIxnx19GTcM +M3Az9V+izweZQu3HJ1CeZiaXauhAr+LbNsniCkRVddotN6oCJdQB10QVxXBZc9Jz +HPJ4zxtZfRZlNMTMmG7eLWrfxpgWnb/BUjDb40yy1nhr9yhDUnI/8RoHDRHnAEHZ +/CnHGUrqcVcrY5zJAQJBAPLhBJg9W88JVmcOKdWxRgs7dLHnZb999Kv1V5mczmAi +jvGvbUmucqOqke6pTUHNYyNHqU6pySzGUi2cH+BAkFECQQDQ0VoAOysg6FVoT15v +tGh57t5sTiCZZ7PS8jwvtThsgA+vcf6c16XWzXgjGXSap4r2QDOY2rI5lsWLaQ8T ++fyZAkAfyFJRmbXp4c7srW3MCOahkaYzoZQu+syJtBFCiMJ40gzik5I5khpuUGPI +V19EvRu8AiSlppIsycb3MPb64XgBAkEAy7DrUf5le5wmc7G4NM6OeyJ+5LbxJbL6 +vnJ8My1a9LuWkVVpQCU7J+UVo2dZTuLPspW9vwTVhUeFOxAoHRxlQQJAFem93f7m +el2BkB2EFqU3onPejkZ5UrDmfmeOQR1axMQNSXqSxcJxqa16Ru1BWV2gcWRbwajQ +oc+kuJThu/r/Ug== +-----END PRIVATE KEY----- diff --git a/src/leap/soledad/u1db/vectorclock.py b/src/leap/soledad/u1db/vectorclock.py new file mode 100644 index 00000000..42bceaa8 --- /dev/null +++ b/src/leap/soledad/u1db/vectorclock.py @@ -0,0 +1,89 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""VectorClockRev helper class.""" + + +class VectorClockRev(object): + """Track vector clocks for multiple replica ids. + + This allows simple comparison to determine if one VectorClockRev is + newer/older/in-conflict-with another VectorClockRev without having to + examine history. Every replica has a strictly increasing revision. When + creating a new revision, they include all revisions for all other replicas + which the new revision dominates, and increment their own revision to + something greater than the current value. + """ + + def __init__(self, value): + self._values = self._expand(value) + + def __repr__(self): + s = self.as_str() + return '%s(%s)' % (self.__class__.__name__, s) + + def as_str(self): + s = '|'.join(['%s:%d' % (m, r) for m, r + in sorted(self._values.items())]) + return s + + def _expand(self, value): + result = {} + if value is None: + return result + for replica_info in value.split('|'): + replica_uid, counter = replica_info.split(':') + counter = int(counter) + result[replica_uid] = counter + return result + + def is_newer(self, other): + """Is this VectorClockRev strictly newer than other. + """ + if not self._values: + return False + if not other._values: + return True + this_is_newer = False + other_expand = dict(other._values) + for key, value in self._values.iteritems(): + if key in other_expand: + other_value = other_expand.pop(key) + if other_value > value: + return False + elif other_value < value: + this_is_newer = True + else: + this_is_newer = True + if other_expand: + return False + return this_is_newer + + def increment(self, replica_uid): + """Increase the 'replica_uid' section of this vector clock. + + :return: A string representing the new vector clock value + """ + self._values[replica_uid] = self._values.get(replica_uid, 0) + 1 + + def maximize(self, other_vcr): + for replica_uid, counter in other_vcr._values.iteritems(): + if replica_uid not in self._values: + self._values[replica_uid] = counter + else: + this_counter = self._values[replica_uid] + if this_counter < counter: + self._values[replica_uid] = counter -- cgit v1.2.3 From eacfa19b2b58f954d4d8b298ef459133f936bd8c Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 29 Nov 2012 10:57:08 -0200 Subject: add swiftclient code (not as submodule) --- src/leap/soledad/swiftclient/__init__.py | 5 + src/leap/soledad/swiftclient/client.py | 1056 ++++++++++++++++++++ src/leap/soledad/swiftclient/openstack/__init__.py | 0 .../swiftclient/openstack/common/__init__.py | 0 .../soledad/swiftclient/openstack/common/setup.py | 342 +++++++ src/leap/soledad/swiftclient/versioninfo | 1 + 6 files changed, 1404 insertions(+) create mode 100644 src/leap/soledad/swiftclient/__init__.py create mode 100644 src/leap/soledad/swiftclient/client.py create mode 100644 src/leap/soledad/swiftclient/openstack/__init__.py create mode 100644 src/leap/soledad/swiftclient/openstack/common/__init__.py create mode 100644 src/leap/soledad/swiftclient/openstack/common/setup.py create mode 100644 src/leap/soledad/swiftclient/versioninfo (limited to 'src/leap') diff --git a/src/leap/soledad/swiftclient/__init__.py b/src/leap/soledad/swiftclient/__init__.py new file mode 100644 index 00000000..ba0b41a3 --- /dev/null +++ b/src/leap/soledad/swiftclient/__init__.py @@ -0,0 +1,5 @@ +# -*- encoding: utf-8 -*- +"""" +OpenStack Swift Python client binding. +""" +from client import * diff --git a/src/leap/soledad/swiftclient/client.py b/src/leap/soledad/swiftclient/client.py new file mode 100644 index 00000000..79e6594f --- /dev/null +++ b/src/leap/soledad/swiftclient/client.py @@ -0,0 +1,1056 @@ +# Copyright (c) 2010-2012 OpenStack, LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cloud Files client library used internally +""" + +import socket +import os +import logging +import httplib + +from urllib import quote as _quote +from urlparse import urlparse, urlunparse, urljoin + +try: + from eventlet.green.httplib import HTTPException, HTTPSConnection +except ImportError: + from httplib import HTTPException, HTTPSConnection + +try: + from eventlet import sleep +except ImportError: + from time import sleep + +try: + from swift.common.bufferedhttp \ + import BufferedHTTPConnection as HTTPConnection +except ImportError: + try: + from eventlet.green.httplib import HTTPConnection + except ImportError: + from httplib import HTTPConnection + +logger = logging.getLogger("swiftclient") + + +def http_log(args, kwargs, resp, body): + if os.environ.get('SWIFTCLIENT_DEBUG', False): + ch = logging.StreamHandler() + logger.setLevel(logging.DEBUG) + logger.addHandler(ch) + elif not logger.isEnabledFor(logging.DEBUG): + return + + string_parts = ['curl -i'] + for element in args: + if element in ('GET', 'POST', 'PUT', 'HEAD'): + string_parts.append(' -X %s' % element) + else: + string_parts.append(' %s' % element) + + if 'headers' in kwargs: + for element in kwargs['headers']: + header = ' -H "%s: %s"' % (element, kwargs['headers'][element]) + string_parts.append(header) + + logger.debug("REQ: %s\n" % "".join(string_parts)) + if 'raw_body' in kwargs: + logger.debug("REQ BODY (RAW): %s\n" % (kwargs['raw_body'])) + if 'body' in kwargs: + logger.debug("REQ BODY: %s\n" % (kwargs['body'])) + + logger.debug("RESP STATUS: %s\n", resp.status) + if body: + logger.debug("RESP BODY: %s\n", body) + + +def quote(value, safe='/'): + """ + Patched version of urllib.quote that encodes utf8 strings before quoting + """ + if isinstance(value, unicode): + value = value.encode('utf8') + return _quote(value, safe) + + +# look for a real json parser first +try: + # simplejson is popular and pretty good + from simplejson import loads as json_loads + from simplejson import dumps as json_dumps +except ImportError: + # 2.6 will have a json module in the stdlib + from json import loads as json_loads + from json import dumps as json_dumps + + +class ClientException(Exception): + + def __init__(self, msg, http_scheme='', http_host='', http_port='', + http_path='', http_query='', http_status=0, http_reason='', + http_device='', http_response_content=''): + Exception.__init__(self, msg) + self.msg = msg + self.http_scheme = http_scheme + self.http_host = http_host + self.http_port = http_port + self.http_path = http_path + self.http_query = http_query + self.http_status = http_status + self.http_reason = http_reason + self.http_device = http_device + self.http_response_content = http_response_content + + def __str__(self): + a = self.msg + b = '' + if self.http_scheme: + b += '%s://' % self.http_scheme + if self.http_host: + b += self.http_host + if self.http_port: + b += ':%s' % self.http_port + if self.http_path: + b += self.http_path + if self.http_query: + b += '?%s' % self.http_query + if self.http_status: + if b: + b = '%s %s' % (b, self.http_status) + else: + b = str(self.http_status) + if self.http_reason: + if b: + b = '%s %s' % (b, self.http_reason) + else: + b = '- %s' % self.http_reason + if self.http_device: + if b: + b = '%s: device %s' % (b, self.http_device) + else: + b = 'device %s' % self.http_device + if self.http_response_content: + if len(self.http_response_content) <= 60: + b += ' %s' % self.http_response_content + else: + b += ' [first 60 chars of response] %s' \ + % self.http_response_content[:60] + return b and '%s: %s' % (a, b) or a + + +def http_connection(url, proxy=None): + """ + Make an HTTPConnection or HTTPSConnection + + :param url: url to connect to + :param proxy: proxy to connect through, if any; None by default; str of the + format 'http://127.0.0.1:8888' to set one + :returns: tuple of (parsed url, connection object) + :raises ClientException: Unable to handle protocol scheme + """ + parsed = urlparse(url) + proxy_parsed = urlparse(proxy) if proxy else None + if parsed.scheme == 'http': + conn = HTTPConnection((proxy_parsed if proxy else parsed).netloc) + elif parsed.scheme == 'https': + conn = HTTPSConnection((proxy_parsed if proxy else parsed).netloc) + else: + raise ClientException('Cannot handle protocol scheme %s for url %s' % + (parsed.scheme, repr(url))) + if proxy: + conn._set_tunnel(parsed.hostname, parsed.port) + return parsed, conn + + +def json_request(method, url, **kwargs): + """Takes a request in json parse it and return in json""" + kwargs.setdefault('headers', {}) + if 'body' in kwargs: + kwargs['headers']['Content-Type'] = 'application/json' + kwargs['body'] = json_dumps(kwargs['body']) + parsed, conn = http_connection(url) + conn.request(method, parsed.path, **kwargs) + resp = conn.getresponse() + body = resp.read() + http_log((url, method,), kwargs, resp, body) + if body: + try: + body = json_loads(body) + except ValueError: + body = None + if not body or resp.status < 200 or resp.status >= 300: + raise ClientException('Auth GET failed', http_scheme=parsed.scheme, + http_host=conn.host, + http_port=conn.port, + http_path=parsed.path, + http_status=resp.status, + http_reason=resp.reason) + return resp, body + + +def _get_auth_v1_0(url, user, key, snet): + parsed, conn = http_connection(url) + method = 'GET' + conn.request(method, parsed.path, '', + {'X-Auth-User': user, 'X-Auth-Key': key}) + resp = conn.getresponse() + body = resp.read() + url = resp.getheader('x-storage-url') + http_log((url, method,), {}, resp, body) + + # There is a side-effect on current Rackspace 1.0 server where a + # bad URL would get you that document page and a 200. We error out + # if we don't have a x-storage-url header and if we get a body. + if resp.status < 200 or resp.status >= 300 or (body and not url): + raise ClientException('Auth GET failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=parsed.path, http_status=resp.status, + http_reason=resp.reason) + if snet: + parsed = list(urlparse(url)) + # Second item in the list is the netloc + netloc = parsed[1] + parsed[1] = 'snet-' + netloc + url = urlunparse(parsed) + return url, resp.getheader('x-storage-token', + resp.getheader('x-auth-token')) + + +def _get_auth_v2_0(url, user, tenant_name, key, snet): + body = {'auth': + {'passwordCredentials': {'password': key, 'username': user}, + 'tenantName': tenant_name}} + token_url = urljoin(url, "tokens") + resp, body = json_request("POST", token_url, body=body) + token_id = None + try: + url = None + catalogs = body['access']['serviceCatalog'] + for service in catalogs: + if service['type'] == 'object-store': + url = service['endpoints'][0]['publicURL'] + token_id = body['access']['token']['id'] + if not url: + raise ClientException("There is no object-store endpoint " + "on this auth server.") + except(KeyError, IndexError): + raise ClientException("Error while getting answers from auth server") + + if snet: + parsed = list(urlparse(url)) + # Second item in the list is the netloc + parsed[1] = 'snet-' + parsed[1] + url = urlunparse(parsed) + + return url, token_id + + +def get_auth(url, user, key, snet=False, tenant_name=None, auth_version="1.0"): + """ + Get authentication/authorization credentials. + + The snet parameter is used for Rackspace's ServiceNet internal network + implementation. In this function, it simply adds *snet-* to the beginning + of the host name for the returned storage URL. With Rackspace Cloud Files, + use of this network path causes no bandwidth charges but requires the + client to be running on Rackspace's ServiceNet network. + + :param url: authentication/authorization URL + :param user: user to authenticate as + :param key: key or password for authorization + :param snet: use SERVICENET internal network (see above), default is False + :param auth_version: OpenStack auth version, default is 1.0 + :param tenant_name: The tenant/account name, required when connecting + to a auth 2.0 system. + :returns: tuple of (storage URL, auth token) + :raises: ClientException: HTTP GET request to auth URL failed + """ + if auth_version in ["1.0", "1"]: + return _get_auth_v1_0(url, user, key, snet) + elif auth_version in ["2.0", "2"]: + if not tenant_name and ':' in user: + (tenant_name, user) = user.split(':') + if not tenant_name: + raise ClientException('No tenant specified') + return _get_auth_v2_0(url, user, tenant_name, key, snet) + else: + raise ClientException('Unknown auth_version %s specified.' + % auth_version) + + +def get_account(url, token, marker=None, limit=None, prefix=None, + http_conn=None, full_listing=False): + """ + Get a listing of containers for the account. + + :param url: storage URL + :param token: auth token + :param marker: marker query + :param limit: limit query + :param prefix: prefix query + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param full_listing: if True, return a full listing, else returns a max + of 10000 listings + :returns: a tuple of (response headers, a list of containers) The response + headers will be a dict and all header names will be lowercase. + :raises ClientException: HTTP GET request failed + """ + if not http_conn: + http_conn = http_connection(url) + if full_listing: + rv = get_account(url, token, marker, limit, prefix, http_conn) + listing = rv[1] + while listing: + marker = listing[-1]['name'] + listing = \ + get_account(url, token, marker, limit, prefix, http_conn)[1] + if listing: + rv[1].extend(listing) + return rv + parsed, conn = http_conn + qs = 'format=json' + if marker: + qs += '&marker=%s' % quote(marker) + if limit: + qs += '&limit=%d' % limit + if prefix: + qs += '&prefix=%s' % quote(prefix) + full_path = '%s?%s' % (parsed.path, qs) + headers = {'X-Auth-Token': token} + conn.request('GET', full_path, '', + headers) + resp = conn.getresponse() + body = resp.read() + http_log(("%s?%s" % (url, qs), 'GET',), {'headers': headers}, resp, body) + + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + if resp.status < 200 or resp.status >= 300: + raise ClientException('Account GET failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=parsed.path, http_query=qs, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + if resp.status == 204: + body + return resp_headers, [] + return resp_headers, json_loads(body) + + +def head_account(url, token, http_conn=None): + """ + Get account stats. + + :param url: storage URL + :param token: auth token + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :returns: a dict containing the response's headers (all header names will + be lowercase) + :raises ClientException: HTTP HEAD request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + method = "HEAD" + headers = {'X-Auth-Token': token} + conn.request(method, parsed.path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log((url, method,), {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Account HEAD failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=parsed.path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + return resp_headers + + +def post_account(url, token, headers, http_conn=None): + """ + Update an account's metadata. + + :param url: storage URL + :param token: auth token + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP POST request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + method = 'POST' + headers['X-Auth-Token'] = token + conn.request(method, parsed.path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log((url, method,), {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Account POST failed', + http_scheme=parsed.scheme, + http_host=conn.host, + http_port=conn.port, + http_path=parsed.path, + http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + + +def get_container(url, token, container, marker=None, limit=None, + prefix=None, delimiter=None, http_conn=None, + full_listing=False): + """ + Get a listing of objects for the container. + + :param url: storage URL + :param token: auth token + :param container: container name to get a listing for + :param marker: marker query + :param limit: limit query + :param prefix: prefix query + :param delimeter: string to delimit the queries on + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param full_listing: if True, return a full listing, else returns a max + of 10000 listings + :returns: a tuple of (response headers, a list of objects) The response + headers will be a dict and all header names will be lowercase. + :raises ClientException: HTTP GET request failed + """ + if not http_conn: + http_conn = http_connection(url) + if full_listing: + rv = get_container(url, token, container, marker, limit, prefix, + delimiter, http_conn) + listing = rv[1] + while listing: + if not delimiter: + marker = listing[-1]['name'] + else: + marker = listing[-1].get('name', listing[-1].get('subdir')) + listing = get_container(url, token, container, marker, limit, + prefix, delimiter, http_conn)[1] + if listing: + rv[1].extend(listing) + return rv + parsed, conn = http_conn + path = '%s/%s' % (parsed.path, quote(container)) + qs = 'format=json' + if marker: + qs += '&marker=%s' % quote(marker) + if limit: + qs += '&limit=%d' % limit + if prefix: + qs += '&prefix=%s' % quote(prefix) + if delimiter: + qs += '&delimiter=%s' % quote(delimiter) + headers = {'X-Auth-Token': token} + method = 'GET' + conn.request(method, '%s?%s' % (path, qs), '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, qs), method,), {'headers': headers}, resp, body) + + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container GET failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_query=qs, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + if resp.status == 204: + return resp_headers, [] + return resp_headers, json_loads(body) + + +def head_container(url, token, container, http_conn=None, headers=None): + """ + Get container stats. + + :param url: storage URL + :param token: auth token + :param container: container name to get stats for + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :returns: a dict containing the response's headers (all header names will + be lowercase) + :raises ClientException: HTTP HEAD request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + method = 'HEAD' + req_headers = {'X-Auth-Token': token} + if headers: + req_headers.update(headers) + conn.request(method, path, '', req_headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': req_headers}, resp, body) + + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container HEAD failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + return resp_headers + + +def put_container(url, token, container, headers=None, http_conn=None): + """ + Create a container + + :param url: storage URL + :param token: auth token + :param container: container name to create + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP PUT request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + method = 'PUT' + if not headers: + headers = {} + headers['X-Auth-Token'] = token + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container PUT failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +def post_container(url, token, container, headers, http_conn=None): + """ + Update a container's metadata. + + :param url: storage URL + :param token: auth token + :param container: container name to update + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP POST request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + method = 'POST' + headers['X-Auth-Token'] = token + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container POST failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +def delete_container(url, token, container, http_conn=None): + """ + Delete a container + + :param url: storage URL + :param token: auth token + :param container: container name to delete + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP DELETE request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + headers = {'X-Auth-Token': token} + method = 'DELETE' + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container DELETE failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +def get_object(url, token, container, name, http_conn=None, + resp_chunk_size=None): + """ + Get an object + + :param url: storage URL + :param token: auth token + :param container: container name that the object is in + :param name: object name to get + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param resp_chunk_size: if defined, chunk size of data to read. NOTE: If + you specify a resp_chunk_size you must fully read + the object's contents before making another + request. + :returns: a tuple of (response headers, the object's contents) The response + headers will be a dict and all header names will be lowercase. + :raises ClientException: HTTP GET request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) + method = 'GET' + headers = {'X-Auth-Token': token} + conn.request(method, path, '', headers) + resp = conn.getresponse() + if resp.status < 200 or resp.status >= 300: + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + raise ClientException('Object GET failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + if resp_chunk_size: + + def _object_body(): + buf = resp.read(resp_chunk_size) + while buf: + yield buf + buf = resp.read(resp_chunk_size) + object_body = _object_body() + else: + object_body = resp.read() + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, object_body) + return resp_headers, object_body + + +def head_object(url, token, container, name, http_conn=None): + """ + Get object info + + :param url: storage URL + :param token: auth token + :param container: container name that the object is in + :param name: object name to get info for + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :returns: a dict containing the response's headers (all header names will + be lowercase) + :raises ClientException: HTTP HEAD request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) + method = 'HEAD' + headers = {'X-Auth-Token': token} + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object HEAD failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + return resp_headers + + +def put_object(url, token=None, container=None, name=None, contents=None, + content_length=None, etag=None, chunk_size=65536, + content_type=None, headers=None, http_conn=None, proxy=None): + """ + Put an object + + :param url: storage URL + :param token: auth token; if None, no token will be sent + :param container: container name that the object is in; if None, the + container name is expected to be part of the url + :param name: object name to put; if None, the object name is expected to be + part of the url + :param contents: a string or a file like object to read object data from; + if None, a zero-byte put will be done + :param content_length: value to send as content-length header; also limits + the amount read from contents; if None, it will be + computed via the contents or chunked transfer + encoding will be used + :param etag: etag of contents; if None, no etag will be sent + :param chunk_size: chunk size of data to write; default 65536 + :param content_type: value to send as content-type header; if None, no + content-type will be set (remote end will likely try + to auto-detect it) + :param headers: additional headers to include in the request, if any + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param proxy: proxy to connect through, if any; None by default; str of the + format 'http://127.0.0.1:8888' to set one + :returns: etag from server response + :raises ClientException: HTTP PUT request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url, proxy=proxy) + path = parsed.path + if container: + path = '%s/%s' % (path.rstrip('/'), quote(container)) + if name: + path = '%s/%s' % (path.rstrip('/'), quote(name)) + if headers: + headers = dict(headers) + else: + headers = {} + if token: + headers['X-Auth-Token'] = token + if etag: + headers['ETag'] = etag.strip('"') + if content_length is not None: + headers['Content-Length'] = str(content_length) + else: + for n, v in headers.iteritems(): + if n.lower() == 'content-length': + content_length = int(v) + if content_type is not None: + headers['Content-Type'] = content_type + if not contents: + headers['Content-Length'] = '0' + if hasattr(contents, 'read'): + conn.putrequest('PUT', path) + for header, value in headers.iteritems(): + conn.putheader(header, value) + if content_length is None: + conn.putheader('Transfer-Encoding', 'chunked') + conn.endheaders() + chunk = contents.read(chunk_size) + while chunk: + conn.send('%x\r\n%s\r\n' % (len(chunk), chunk)) + chunk = contents.read(chunk_size) + conn.send('0\r\n\r\n') + else: + conn.endheaders() + left = content_length + while left > 0: + size = chunk_size + if size > left: + size = left + chunk = contents.read(size) + conn.send(chunk) + left -= len(chunk) + else: + conn.request('PUT', path, contents, headers) + resp = conn.getresponse() + body = resp.read() + headers = {'X-Auth-Token': token} + http_log(('%s?%s' % (url, path), 'PUT',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object PUT failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + return resp.getheader('etag', '').strip('"') + + +def post_object(url, token, container, name, headers, http_conn=None): + """ + Update object metadata + + :param url: storage URL + :param token: auth token + :param container: container name that the object is in + :param name: name of the object to update + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP POST request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) + headers['X-Auth-Token'] = token + conn.request('POST', path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object POST failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + + +def delete_object(url, token=None, container=None, name=None, http_conn=None, + headers=None, proxy=None): + """ + Delete object + + :param url: storage URL + :param token: auth token; if None, no token will be sent + :param container: container name that the object is in; if None, the + container name is expected to be part of the url + :param name: object name to delete; if None, the object name is expected to + be part of the url + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param headers: additional headers to include in the request + :param proxy: proxy to connect through, if any; None by default; str of the + format 'http://127.0.0.1:8888' to set one + :raises ClientException: HTTP DELETE request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url, proxy=proxy) + path = parsed.path + if container: + path = '%s/%s' % (path.rstrip('/'), quote(container)) + if name: + path = '%s/%s' % (path.rstrip('/'), quote(name)) + if headers: + headers = dict(headers) + else: + headers = {} + if token: + headers['X-Auth-Token'] = token + conn.request('DELETE', path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object DELETE failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +class Connection(object): + """Convenience class to make requests that will also retry the request""" + + def __init__(self, authurl, user, key, retries=5, preauthurl=None, + preauthtoken=None, snet=False, starting_backoff=1, + tenant_name=None, + auth_version="1"): + """ + :param authurl: authentication URL + :param user: user name to authenticate as + :param key: key/password to authenticate with + :param retries: Number of times to retry the request before failing + :param preauthurl: storage URL (if you have already authenticated) + :param preauthtoken: authentication token (if you have already + authenticated) + :param snet: use SERVICENET internal network default is False + :param auth_version: OpenStack auth version, default is 1.0 + :param tenant_name: The tenant/account name, required when connecting + to a auth 2.0 system. + """ + self.authurl = authurl + self.user = user + self.key = key + self.retries = retries + self.http_conn = None + self.url = preauthurl + self.token = preauthtoken + self.attempts = 0 + self.snet = snet + self.starting_backoff = starting_backoff + self.auth_version = auth_version + self.tenant_name = tenant_name + + def get_auth(self): + return get_auth(self.authurl, self.user, + self.key, snet=self.snet, + tenant_name=self.tenant_name, + auth_version=self.auth_version) + + def http_connection(self): + return http_connection(self.url) + + def _retry(self, reset_func, func, *args, **kwargs): + self.attempts = 0 + backoff = self.starting_backoff + while self.attempts <= self.retries: + self.attempts += 1 + try: + if not self.url or not self.token: + self.url, self.token = self.get_auth() + self.http_conn = None + if not self.http_conn: + self.http_conn = self.http_connection() + kwargs['http_conn'] = self.http_conn + rv = func(self.url, self.token, *args, **kwargs) + return rv + except (socket.error, HTTPException): + if self.attempts > self.retries: + raise + self.http_conn = None + except ClientException, err: + if self.attempts > self.retries: + raise + if err.http_status == 401: + self.url = self.token = None + if self.attempts > 1: + raise + elif err.http_status == 408: + self.http_conn = None + elif 500 <= err.http_status <= 599: + pass + else: + raise + sleep(backoff) + backoff *= 2 + if reset_func: + reset_func(func, *args, **kwargs) + + def head_account(self): + """Wrapper for :func:`head_account`""" + return self._retry(None, head_account) + + def get_account(self, marker=None, limit=None, prefix=None, + full_listing=False): + """Wrapper for :func:`get_account`""" + # TODO(unknown): With full_listing=True this will restart the entire + # listing with each retry. Need to make a better version that just + # retries where it left off. + return self._retry(None, get_account, marker=marker, limit=limit, + prefix=prefix, full_listing=full_listing) + + def post_account(self, headers): + """Wrapper for :func:`post_account`""" + return self._retry(None, post_account, headers) + + def head_container(self, container): + """Wrapper for :func:`head_container`""" + return self._retry(None, head_container, container) + + def get_container(self, container, marker=None, limit=None, prefix=None, + delimiter=None, full_listing=False): + """Wrapper for :func:`get_container`""" + # TODO(unknown): With full_listing=True this will restart the entire + # listing with each retry. Need to make a better version that just + # retries where it left off. + return self._retry(None, get_container, container, marker=marker, + limit=limit, prefix=prefix, delimiter=delimiter, + full_listing=full_listing) + + def put_container(self, container, headers=None): + """Wrapper for :func:`put_container`""" + return self._retry(None, put_container, container, headers=headers) + + def post_container(self, container, headers): + """Wrapper for :func:`post_container`""" + return self._retry(None, post_container, container, headers) + + def delete_container(self, container): + """Wrapper for :func:`delete_container`""" + return self._retry(None, delete_container, container) + + def head_object(self, container, obj): + """Wrapper for :func:`head_object`""" + return self._retry(None, head_object, container, obj) + + def get_object(self, container, obj, resp_chunk_size=None): + """Wrapper for :func:`get_object`""" + return self._retry(None, get_object, container, obj, + resp_chunk_size=resp_chunk_size) + + def put_object(self, container, obj, contents, content_length=None, + etag=None, chunk_size=65536, content_type=None, + headers=None): + """Wrapper for :func:`put_object`""" + + def _default_reset(*args, **kwargs): + raise ClientException('put_object(%r, %r, ...) failure and no ' + 'ability to reset contents for reupload.' + % (container, obj)) + + reset_func = _default_reset + tell = getattr(contents, 'tell', None) + seek = getattr(contents, 'seek', None) + if tell and seek: + orig_pos = tell() + reset_func = lambda *a, **k: seek(orig_pos) + elif not contents: + reset_func = lambda *a, **k: None + + return self._retry(reset_func, put_object, container, obj, contents, + content_length=content_length, etag=etag, + chunk_size=chunk_size, content_type=content_type, + headers=headers) + + def post_object(self, container, obj, headers): + """Wrapper for :func:`post_object`""" + return self._retry(None, post_object, container, obj, headers) + + def delete_object(self, container, obj): + """Wrapper for :func:`delete_object`""" + return self._retry(None, delete_object, container, obj) diff --git a/src/leap/soledad/swiftclient/openstack/__init__.py b/src/leap/soledad/swiftclient/openstack/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/leap/soledad/swiftclient/openstack/common/__init__.py b/src/leap/soledad/swiftclient/openstack/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/leap/soledad/swiftclient/openstack/common/setup.py b/src/leap/soledad/swiftclient/openstack/common/setup.py new file mode 100644 index 00000000..caf06fa5 --- /dev/null +++ b/src/leap/soledad/swiftclient/openstack/common/setup.py @@ -0,0 +1,342 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Utilities with minimum-depends for use in setup.py +""" + +import datetime +import os +import re +import subprocess +import sys + +from setuptools.command import sdist + + +def parse_mailmap(mailmap='.mailmap'): + mapping = {} + if os.path.exists(mailmap): + fp = open(mailmap, 'r') + for l in fp: + l = l.strip() + if not l.startswith('#') and ' ' in l: + canonical_email, alias = l.split(' ') + mapping[alias] = canonical_email + return mapping + + +def canonicalize_emails(changelog, mapping): + """Takes in a string and an email alias mapping and replaces all + instances of the aliases in the string with their real email. + """ + for alias, email in mapping.iteritems(): + changelog = changelog.replace(alias, email) + return changelog + + +# Get requirements from the first file that exists +def get_reqs_from_files(requirements_files): + reqs_in = [] + for requirements_file in requirements_files: + if os.path.exists(requirements_file): + return open(requirements_file, 'r').read().split('\n') + return [] + + +def parse_requirements(requirements_files=['requirements.txt', + 'tools/pip-requires']): + requirements = [] + for line in get_reqs_from_files(requirements_files): + # For the requirements list, we need to inject only the portion + # after egg= so that distutils knows the package it's looking for + # such as: + # -e git://github.com/openstack/nova/master#egg=nova + if re.match(r'\s*-e\s+', line): + requirements.append(re.sub(r'\s*-e\s+.*#egg=(.*)$', r'\1', + line)) + # such as: + # http://github.com/openstack/nova/zipball/master#egg=nova + elif re.match(r'\s*https?:', line): + requirements.append(re.sub(r'\s*https?:.*#egg=(.*)$', r'\1', + line)) + # -f lines are for index locations, and don't get used here + elif re.match(r'\s*-f\s+', line): + pass + # argparse is part of the standard library starting with 2.7 + # adding it to the requirements list screws distro installs + elif line == 'argparse' and sys.version_info >= (2, 7): + pass + else: + requirements.append(line) + + return requirements + + +def parse_dependency_links(requirements_files=['requirements.txt', + 'tools/pip-requires']): + dependency_links = [] + # dependency_links inject alternate locations to find packages listed + # in requirements + for line in get_reqs_from_files(requirements_files): + # skip comments and blank lines + if re.match(r'(\s*#)|(\s*$)', line): + continue + # lines with -e or -f need the whole line, minus the flag + if re.match(r'\s*-[ef]\s+', line): + dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line)) + # lines that are only urls can go in unmolested + elif re.match(r'\s*https?:', line): + dependency_links.append(line) + return dependency_links + + +def write_requirements(): + venv = os.environ.get('VIRTUAL_ENV', None) + if venv is not None: + with open("requirements.txt", "w") as req_file: + output = subprocess.Popen(["pip", "-E", venv, "freeze", "-l"], + stdout=subprocess.PIPE) + requirements = output.communicate()[0].strip() + req_file.write(requirements) + + +def _run_shell_command(cmd): + output = subprocess.Popen(["/bin/sh", "-c", cmd], + stdout=subprocess.PIPE) + out = output.communicate() + if len(out) == 0: + return None + if len(out[0].strip()) == 0: + return None + return out[0].strip() + + +def _get_git_next_version_suffix(branch_name): + datestamp = datetime.datetime.now().strftime('%Y%m%d') + if branch_name == 'milestone-proposed': + revno_prefix = "r" + else: + revno_prefix = "" + _run_shell_command("git fetch origin +refs/meta/*:refs/remotes/meta/*") + milestone_cmd = "git show meta/openstack/release:%s" % branch_name + milestonever = _run_shell_command(milestone_cmd) + if not milestonever: + milestonever = "" + post_version = _get_git_post_version() + revno = post_version.split(".")[-1] + return "%s~%s.%s%s" % (milestonever, datestamp, revno_prefix, revno) + + +def _get_git_current_tag(): + return _run_shell_command("git tag --contains HEAD") + + +def _get_git_tag_info(): + return _run_shell_command("git describe --tags") + + +def _get_git_post_version(): + current_tag = _get_git_current_tag() + if current_tag is not None: + return current_tag + else: + tag_info = _get_git_tag_info() + if tag_info is None: + base_version = "0.0" + cmd = "git --no-pager log --oneline" + out = _run_shell_command(cmd) + revno = len(out.split("\n")) + else: + tag_infos = tag_info.split("-") + base_version = "-".join(tag_infos[:-2]) + revno = tag_infos[-2] + return "%s.%s" % (base_version, revno) + + +def write_git_changelog(): + """Write a changelog based on the git changelog.""" + if os.path.isdir('.git'): + git_log_cmd = 'git log --stat' + changelog = _run_shell_command(git_log_cmd) + mailmap = parse_mailmap() + with open("ChangeLog", "w") as changelog_file: + changelog_file.write(canonicalize_emails(changelog, mailmap)) + + +def generate_authors(): + """Create AUTHORS file using git commits.""" + jenkins_email = 'jenkins@review.openstack.org' + old_authors = 'AUTHORS.in' + new_authors = 'AUTHORS' + if os.path.isdir('.git'): + # don't include jenkins email address in AUTHORS file + git_log_cmd = ("git log --format='%aN <%aE>' | sort -u | " + "grep -v " + jenkins_email) + changelog = _run_shell_command(git_log_cmd) + mailmap = parse_mailmap() + with open(new_authors, 'w') as new_authors_fh: + new_authors_fh.write(canonicalize_emails(changelog, mailmap)) + if os.path.exists(old_authors): + with open(old_authors, "r") as old_authors_fh: + new_authors_fh.write('\n' + old_authors_fh.read()) + +_rst_template = """%(heading)s +%(underline)s + +.. automodule:: %(module)s + :members: + :undoc-members: + :show-inheritance: +""" + + +def read_versioninfo(project): + """Read the versioninfo file. If it doesn't exist, we're in a github + zipball, and there's really know way to know what version we really + are, but that should be ok, because the utility of that should be + just about nil if this code path is in use in the first place.""" + versioninfo_path = os.path.join(project, 'versioninfo') + if os.path.exists(versioninfo_path): + with open(versioninfo_path, 'r') as vinfo: + version = vinfo.read().strip() + else: + version = "0.0.0" + return version + + +def write_versioninfo(project, version): + """Write a simple file containing the version of the package.""" + open(os.path.join(project, 'versioninfo'), 'w').write("%s\n" % version) + + +def get_cmdclass(): + """Return dict of commands to run from setup.py.""" + + cmdclass = dict() + + def _find_modules(arg, dirname, files): + for filename in files: + if filename.endswith('.py') and filename != '__init__.py': + arg["%s.%s" % (dirname.replace('/', '.'), + filename[:-3])] = True + + class LocalSDist(sdist.sdist): + """Builds the ChangeLog and Authors files from VC first.""" + + def run(self): + write_git_changelog() + generate_authors() + # sdist.sdist is an old style class, can't use super() + sdist.sdist.run(self) + + cmdclass['sdist'] = LocalSDist + + # If Sphinx is installed on the box running setup.py, + # enable setup.py to build the documentation, otherwise, + # just ignore it + try: + from sphinx.setup_command import BuildDoc + + class LocalBuildDoc(BuildDoc): + def generate_autoindex(self): + print "**Autodocumenting from %s" % os.path.abspath(os.curdir) + modules = {} + option_dict = self.distribution.get_option_dict('build_sphinx') + source_dir = os.path.join(option_dict['source_dir'][1], 'api') + if not os.path.exists(source_dir): + os.makedirs(source_dir) + for pkg in self.distribution.packages: + if '.' not in pkg: + os.path.walk(pkg, _find_modules, modules) + module_list = modules.keys() + module_list.sort() + autoindex_filename = os.path.join(source_dir, 'autoindex.rst') + with open(autoindex_filename, 'w') as autoindex: + autoindex.write(""".. toctree:: + :maxdepth: 1 + +""") + for module in module_list: + output_filename = os.path.join(source_dir, + "%s.rst" % module) + heading = "The :mod:`%s` Module" % module + underline = "=" * len(heading) + values = dict(module=module, heading=heading, + underline=underline) + + print "Generating %s" % output_filename + with open(output_filename, 'w') as output_file: + output_file.write(_rst_template % values) + autoindex.write(" %s.rst\n" % module) + + def run(self): + if not os.getenv('SPHINX_DEBUG'): + self.generate_autoindex() + + for builder in ['html', 'man']: + self.builder = builder + self.finalize_options() + self.project = self.distribution.get_name() + self.version = self.distribution.get_version() + self.release = self.distribution.get_version() + BuildDoc.run(self) + cmdclass['build_sphinx'] = LocalBuildDoc + except ImportError: + pass + + return cmdclass + + +def get_git_branchname(): + for branch in _run_shell_command("git branch --color=never").split("\n"): + if branch.startswith('*'): + _branch_name = branch.split()[1].strip() + if _branch_name == "(no": + _branch_name = "no-branch" + return _branch_name + + +def get_pre_version(projectname, base_version): + """Return a version which is based""" + if os.path.isdir('.git'): + current_tag = _get_git_current_tag() + if current_tag is not None: + version = current_tag + else: + branch_name = os.getenv('BRANCHNAME', + os.getenv('GERRIT_REFNAME', + get_git_branchname())) + version_suffix = _get_git_next_version_suffix(branch_name) + version = "%s~%s" % (base_version, version_suffix) + write_versioninfo(projectname, version) + return version.split('~')[0] + else: + version = read_versioninfo(projectname) + return version.split('~')[0] + + +def get_post_version(projectname): + """Return a version which is equal to the tag that's on the current + revision if there is one, or tag plus number of additional revisions + if the current revision has no tag.""" + + if os.path.isdir('.git'): + version = _get_git_post_version() + write_versioninfo(projectname, version) + return version + return read_versioninfo(projectname) diff --git a/src/leap/soledad/swiftclient/versioninfo b/src/leap/soledad/swiftclient/versioninfo new file mode 100644 index 00000000..524cb552 --- /dev/null +++ b/src/leap/soledad/swiftclient/versioninfo @@ -0,0 +1 @@ +1.1.1 -- cgit v1.2.3 From 8febf0c6f71395bbc8a24440beb28dfb719ba01c Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 10:32:54 -0200 Subject: Add LeapDocument methods for encrypting/decrypting --- src/leap/soledad/__init__.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 3d685635..94286370 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -16,6 +16,7 @@ from u1db import ( ) from swiftclient import client +import base64 class OpenStackDatabase(CommonBackend): @@ -148,10 +149,20 @@ class OpenStackDatabase(CommonBackend): class LeapDocument(Document): def get_content_encrypted(self): - raise NotImplementedError(self.get_content_encrypted) + """ + Returns document's json serialization encrypted with user's public key. + """ + # TODO: replace for openpgp encryption with users's pub key. + return base64.b64encode(self.get_json()) def set_content_encrypted(self): - raise NotImplementedError(self.set_content_encrypted) + """ + Set document's content based on encrypted version of json string. + """ + # TODO: + # - replace for openpgp decryption using user's priv key. + # - raise error if unsuccessful. + return self.set_json(base64.b64decode(self.get_json())) class OpenStackSyncTarget(CommonSyncTarget): -- cgit v1.2.3 From dae0dacb59e2b06b681fab88ddefb038b7e16bb6 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 11:06:23 -0200 Subject: LeapSyncTarget encodes/decodes before/after syncing --- src/leap/soledad/__init__.py | 87 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 94286370..5174d818 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -14,6 +14,7 @@ from u1db import ( query_parser, vectorclock, ) +from u1db.remote.http_target import HTTPSyncTarget from swiftclient import client import base64 @@ -148,14 +149,20 @@ class OpenStackDatabase(CommonBackend): class LeapDocument(Document): - def get_content_encrypted(self): + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, + encrypted_json=None): + super(Document, self).__init__(doc_id, rev, json, has_conflicts) + if encrypted_json: + self.set_encrypted_json(encrypted_json) + + def get_encrypted_json(self): """ Returns document's json serialization encrypted with user's public key. """ # TODO: replace for openpgp encryption with users's pub key. return base64.b64encode(self.get_json()) - def set_content_encrypted(self): + def set_encrypted_json(self): """ Set document's content based on encrypted version of json string. """ @@ -165,6 +172,82 @@ class LeapDocument(Document): return self.set_json(base64.b64decode(self.get_json())) +class LeapSyncTarget(HTTPSyncTarget): + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = LeapDocument(entry['id'], entry['rev'], + encrypted_json=entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, + content=doc.get_encrypted_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] + + class OpenStackSyncTarget(CommonSyncTarget): def get_sync_info(self, source_replica_uid): -- cgit v1.2.3 From 7a932811c018bb30b584451d4fe114cf69ab420c Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 11:13:51 -0200 Subject: Split leap infrastructure and openstack backend in different files. --- src/leap/soledad/__init__.py | 256 +----------------------------------------- src/leap/soledad/leap.py | 114 +++++++++++++++++++ src/leap/soledad/openstack.py | 141 +++++++++++++++++++++++ 3 files changed, 257 insertions(+), 254 deletions(-) create mode 100644 src/leap/soledad/leap.py create mode 100644 src/leap/soledad/openstack.py (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 5174d818..6ba64a61 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -2,257 +2,5 @@ """A U1DB implementation that uses OpenStack Swift as its persistence layer.""" -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db.backends import CommonBackend, CommonSyncTarget -from u1db import ( - Document, - errors, - query_parser, - vectorclock, - ) -from u1db.remote.http_target import HTTPSyncTarget - -from swiftclient import client -import base64 - - -class OpenStackDatabase(CommonBackend): - """A U1DB implementation that uses OpenStack as its persistence layer.""" - - def __init__(self, auth_url, user, auth_key): - """Create a new OpenStack data container.""" - self._auth_url = auth_url - self._user = user - self._auth_key = auth_key - self.set_document_factory(LeapDocument) - self._connection = swiftclient.Connection(self._auth_url, self._user, - self._auth_key) - - #------------------------------------------------------------------------- - # implemented methods from Database - #------------------------------------------------------------------------- - - def set_document_factory(self, factory): - self._factory = factory - - def set_document_size_limit(self, limit): - raise NotImplementedError(self.set_document_size_limit) - - def whats_changed(self, old_generation=0): - raise NotImplementedError(self.whats_changed) - - def get_doc(self, doc_id, include_deleted=False): - raise NotImplementedError(self.get_doc) - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - raise NotImplementedError(self.get_all_docs) - - def put_doc(self, doc): - raise NotImplementedError(self.put_doc) - - def delete_doc(self, doc): - raise NotImplementedError(self.delete_doc) - - # start of index-related methods: these are not supported by this backend. - - def create_index(self, index_name, *index_expressions): - return False - - def delete_index(self, index_name): - return False - - def list_indexes(self): - return [] - - def get_from_index(self, index_name, *key_values): - return [] - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - return [] - - def get_index_keys(self, index_name): - return [] - - # end of index-related methods: these are not supported by this backend. - - def get_doc_conflicts(self, doc_id): - return [] - - def resolve_doc(self, doc, conflicted_doc_revs): - raise NotImplementedError(self.resolve_doc) - - def get_sync_target(self): - return OpenStackSyncTarget(self) - - def close(self): - raise NotImplementedError(self.close) - - def sync(self, url, creds=None, autocreate=True): - raise NotImplementedError(self.close) - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - raise NotImplementedError(self._get_replica_gen_and_trans_id) - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - raise NotImplementedError(self._set_replica_gen_and_trans_id) - - #------------------------------------------------------------------------- - # implemented methods from CommonBackend - #------------------------------------------------------------------------- - - def _get_generation(self): - raise NotImplementedError(self._get_generation) - - def _get_generation_info(self): - raise NotImplementedError(self._get_generation_info) - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - raise NotImplementedError(self._get_doc) - - def _has_conflicts(self, doc_id): - raise NotImplementedError(self._has_conflicts) - - def _get_transaction_log(self): - raise NotImplementedError(self._get_transaction_log) - - def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): - raise NotImplementedError(self._put_and_update_indexes) - - - def _get_trans_id_for_gen(self, generation): - raise NotImplementedError(self._get_trans_id_for_gen) - - #------------------------------------------------------------------------- - # OpenStack specific methods - #------------------------------------------------------------------------- - - def _is_initialized(self, c): - raise NotImplementedError(self._is_initialized) - - def _initialize(self, c): - raise NotImplementedError(self._initialize) - - def _get_auth(self): - self._url, self._auth_token = self._connection.get_auth(self._auth_url, - self._user, - self._auth_key) - return self._url, self.auth_token - - -class LeapDocument(Document): - - def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, - encrypted_json=None): - super(Document, self).__init__(doc_id, rev, json, has_conflicts) - if encrypted_json: - self.set_encrypted_json(encrypted_json) - - def get_encrypted_json(self): - """ - Returns document's json serialization encrypted with user's public key. - """ - # TODO: replace for openpgp encryption with users's pub key. - return base64.b64encode(self.get_json()) - - def set_encrypted_json(self): - """ - Set document's content based on encrypted version of json string. - """ - # TODO: - # - replace for openpgp decryption using user's priv key. - # - raise error if unsuccessful. - return self.set_json(base64.b64decode(self.get_json())) - - -class LeapSyncTarget(HTTPSyncTarget): - - def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): - parts = data.splitlines() # one at a time - if not parts or parts[0] != '[': - raise BrokenSyncStream - data = parts[1:-1] - comma = False - if data: - line, comma = utils.check_and_strip_comma(data[0]) - res = json.loads(line) - if ensure_callback and 'replica_uid' in res: - ensure_callback(res['replica_uid']) - for entry in data[1:]: - if not comma: # missing in between comma - raise BrokenSyncStream - line, comma = utils.check_and_strip_comma(entry) - entry = json.loads(line) - doc = LeapDocument(entry['id'], entry['rev'], - encrypted_json=entry['content']) - return_doc_cb(doc, entry['gen'], entry['trans_id']) - if parts[-1] != ']': - try: - partdic = json.loads(parts[-1]) - except ValueError: - pass - else: - if isinstance(partdic, dict): - self._error(partdic) - raise BrokenSyncStream - if not data or comma: # no entries or bad extra comma - raise BrokenSyncStream - return res - - def sync_exchange(self, docs_by_generations, source_replica_uid, - last_known_generation, last_known_trans_id, - return_doc_cb, ensure_callback=None): - self._ensure_connection() - if self._trace_hook: # for tests - self._trace_hook('sync_exchange') - url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) - self._conn.putrequest('POST', url) - self._conn.putheader('content-type', 'application/x-u1db-sync-stream') - for header_name, header_value in self._sign_request('POST', url, {}): - self._conn.putheader(header_name, header_value) - entries = ['['] - size = 1 - - def prepare(**dic): - entry = comma + '\r\n' + json.dumps(dic) - entries.append(entry) - return len(entry) - - comma = '' - size += prepare( - last_known_generation=last_known_generation, - last_known_trans_id=last_known_trans_id, - ensure=ensure_callback is not None) - comma = ',' - for doc, gen, trans_id in docs_by_generations: - size += prepare(id=doc.doc_id, rev=doc.rev, - content=doc.get_encrypted_json(), - gen=gen, trans_id=trans_id) - entries.append('\r\n]') - size += len(entries[-1]) - self._conn.putheader('content-length', str(size)) - self._conn.endheaders() - for entry in entries: - self._conn.send(entry) - entries = None - data, _ = self._response() - res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) - data = None - return res['new_generation'], res['new_transaction_id'] - - -class OpenStackSyncTarget(CommonSyncTarget): - - def get_sync_info(self, source_replica_uid): - raise NotImplementedError(self.get_sync_info) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - raise NotImplementedError(self.record_sync_info) +from leap import * +from openstack import * diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py new file mode 100644 index 00000000..08330618 --- /dev/null +++ b/src/leap/soledad/leap.py @@ -0,0 +1,114 @@ +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import Document +from u1db.remote.http_target import HTTPSyncTarget +import base64 + + +class LeapDocument(Document): + """ + LEAP Documents are standard u1db documents with cabability of returning an + encrypted version of the document json string as well as setting document + content based on an encrypted version of json string. + """ + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, + encrypted_json=None): + super(Document, self).__init__(doc_id, rev, json, has_conflicts) + if encrypted_json: + self.set_encrypted_json(encrypted_json) + + def get_encrypted_json(self): + """ + Returns document's json serialization encrypted with user's public key. + """ + # TODO: replace for openpgp encryption with users's pub key. + return base64.b64encode(self.get_json()) + + def set_encrypted_json(self): + """ + Set document's content based on encrypted version of json string. + """ + # TODO: + # - replace for openpgp decryption using user's priv key. + # - raise error if unsuccessful. + return self.set_json(base64.b64decode(self.get_json())) + + +class LeapSyncTarget(HTTPSyncTarget): + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = LeapDocument(entry['id'], entry['rev'], + encrypted_json=entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, + content=doc.get_encrypted_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py new file mode 100644 index 00000000..514a4c58 --- /dev/null +++ b/src/leap/soledad/openstack.py @@ -0,0 +1,141 @@ +from u1db.backends import CommonBackend +from leap import * +from u1db.remote.http_target import HTTPSyncTarget +from swiftclient import client + + +class OpenStackDatabase(CommonBackend): + """A U1DB implementation that uses OpenStack as its persistence layer.""" + + def __init__(self, auth_url, user, auth_key): + """Create a new OpenStack data container.""" + self._auth_url = auth_url + self._user = user + self._auth_key = auth_key + self.set_document_factory(LeapDocument) + self._connection = swiftclient.Connection(self._auth_url, self._user, + self._auth_key) + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def set_document_factory(self, factory): + self._factory = factory + + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + raise NotImplementedError(self.whats_changed) + + def get_doc(self, doc_id, include_deleted=False): + raise NotImplementedError(self.get_doc) + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + raise NotImplementedError(self.get_all_docs) + + def put_doc(self, doc): + raise NotImplementedError(self.put_doc) + + def delete_doc(self, doc): + raise NotImplementedError(self.delete_doc) + + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + + def get_doc_conflicts(self, doc_id): + return [] + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + return OpenStackSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + raise NotImplementedError(self.close) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- + + def _get_generation(self): + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + raise NotImplementedError(self._get_generation_info) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + raise NotImplementedError(self._has_conflicts) + + def _get_transaction_log(self): + raise NotImplementedError(self._get_transaction_log) + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + raise NotImplementedError(self._get_trans_id_for_gen) + + #------------------------------------------------------------------------- + # OpenStack specific methods + #------------------------------------------------------------------------- + + def _is_initialized(self, c): + raise NotImplementedError(self._is_initialized) + + def _initialize(self, c): + raise NotImplementedError(self._initialize) + + def _get_auth(self): + self._url, self._auth_token = self._connection.get_auth(self._auth_url, + self._user, + self._auth_key) + return self._url, self.auth_token + + +class OpenStackSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + raise NotImplementedError(self.record_sync_info) -- cgit v1.2.3 From 9c63f2becc0caa1f684852224375b54f828cc42e Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 11:28:21 -0200 Subject: LeapDocument can set and get 'valid' encrypted json --- src/leap/soledad/leap.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index 08330618..863e63f8 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -26,21 +26,27 @@ class LeapDocument(Document): Returns document's json serialization encrypted with user's public key. """ # TODO: replace for openpgp encryption with users's pub key. - return base64.b64encode(self.get_json()) + return json.dumps({'cyphertext':base64.b64encode(self.get_json())}) - def set_encrypted_json(self): + def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ # TODO: # - replace for openpgp decryption using user's priv key. # - raise error if unsuccessful. - return self.set_json(base64.b64decode(self.get_json())) + cyphertext = json.loads(encrypted_json)['cyphertext'] + plaintext = base64.b64decode(cyphertext) + return self.set_json(plaintext) class LeapSyncTarget(HTTPSyncTarget): def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + """ + Does the same as parent's method but ensures incoming content will be + decrypted. + """ parts = data.splitlines() # one at a time if not parts or parts[0] != '[': raise BrokenSyncStream @@ -75,6 +81,9 @@ class LeapSyncTarget(HTTPSyncTarget): def sync_exchange(self, docs_by_generations, source_replica_uid, last_known_generation, last_known_trans_id, return_doc_cb, ensure_callback=None): + """ + Does the same as parent's method but encrypts content before syncing. + """ self._ensure_connection() if self._trace_hook: # for tests self._trace_hook('sync_exchange') -- cgit v1.2.3 From 2980e61298dc3a17715ce5693470c3d7f3a86497 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 12:12:01 -0200 Subject: Add python-gnupg simple wrapper --- src/leap/soledad/__init__.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 6ba64a61..7991f898 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -4,3 +4,40 @@ from leap import * from openstack import * + +import gnupg + +class GPGWrapper(): + """ + This is a temporary class for handling GPG requests, and should be + replaced by a more general class used throughout the project. + """ + + GNUPG_HOME = "~/.config/leap/gnupg" + GNUPG_BINARY = "/usr/bin/gpg" # this has to be changed based on OS + + def __init__(self, gpghome=GNUPG_HOME, gpgbinary=GNUPG_BINARY): + self.gpg = gnupg.GPG(gnupghome=gpghome, gpgbinary=gpgbinary) + + def find_key(self, email): + """ + Find user's key based on their email. + """ + for key in self.gpg.list_keys(): + for uid in key['uids']: + if re.search(email, uid): + return key + raise LookupError("GnuPG public key for %s not found!" % email) + + def encrypt(self, data, recipient, sign=None, always_trust=False, + passphrase=None, symmetric=False): + return self.gpg.encrypt(data, recipient, sign=sign, + always_trust=always_trust, + passphrase=passphrase, symmetric=symmetric) + + def decrypt(self, data, always_trust=False, passphrase=None): + return self.gpg.decrypt(data, always_trust=always_trust, + passphrase=passphrase) + + def import_keys(self, data): + return self.gpg.import_keys(data) -- cgit v1.2.3 From cff9a6ed359f3cfc8ec3e7ad94f159acfc5a4fd8 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 14:13:56 -0200 Subject: Add default key to Leap Document --- src/leap/soledad/leap.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index 863e63f8..46f95a1a 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -5,6 +5,7 @@ except ImportError: from u1db import Document from u1db.remote.http_target import HTTPSyncTarget +from u1db.remote.http_database import HTTPDatabase import base64 @@ -16,10 +17,11 @@ class LeapDocument(Document): """ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, - encrypted_json=None): + encrypted_json=None, default_key=None): super(Document, self).__init__(doc_id, rev, json, has_conflicts) if encrypted_json: self.set_encrypted_json(encrypted_json) + self._default_key = default_key def get_encrypted_json(self): """ -- cgit v1.2.3 From af0e22caca57a04b81f2f74eccdc3599178210c0 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 14:15:50 -0200 Subject: Add LeapDatabase that uses LeapSyncTarget. --- src/leap/soledad/leap.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index 46f95a1a..e81c6b0c 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -42,6 +42,15 @@ class LeapDocument(Document): return self.set_json(plaintext) +class LeapDatabase(HTTPDatabase): + """Implement the HTTP remote database API to a Leap server.""" + + def get_sync_target(self): + st = LeapSyncTarget(self._url.geturl()) + st._creds = self._creds + return st + + class LeapSyncTarget(HTTPSyncTarget): def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): -- cgit v1.2.3 From 2812f05c7997766a0527628877a28efd39e0ff1c Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 14:48:44 -0200 Subject: LeapDatabase can statically open an delete dbs. --- src/leap/soledad/leap.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index e81c6b0c..c9243587 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -45,6 +45,18 @@ class LeapDocument(Document): class LeapDatabase(HTTPDatabase): """Implement the HTTP remote database API to a Leap server.""" + @staticmethod + def open_database(url, create): + db = LeapDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = LeapDatabase(url) + db._delete() + db.close() + def get_sync_target(self): st = LeapSyncTarget(self._url.geturl()) st._creds = self._creds -- cgit v1.2.3 From d6196f88f390d1ee8d4a3f26aa4881fe15bcd2e0 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 14:56:36 -0200 Subject: Fix get_auth parameters. --- src/leap/soledad/openstack.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 514a4c58..9a8a6166 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -15,6 +15,7 @@ class OpenStackDatabase(CommonBackend): self.set_document_factory(LeapDocument) self._connection = swiftclient.Connection(self._auth_url, self._user, self._auth_key) + self._get_auth() #------------------------------------------------------------------------- # implemented methods from Database @@ -125,9 +126,7 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self._initialize) def _get_auth(self): - self._url, self._auth_token = self._connection.get_auth(self._auth_url, - self._user, - self._auth_key) + self._url, self._auth_token = self._connection.get_auth() return self._url, self.auth_token -- cgit v1.2.3 From 22d517e97d81c5630b85dbf55c40f2716d608e96 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 15:26:37 -0200 Subject: Add method get_doc for OpenStack backend --- src/leap/soledad/openstack.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 9a8a6166..9bb4fddd 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -7,11 +7,12 @@ from swiftclient import client class OpenStackDatabase(CommonBackend): """A U1DB implementation that uses OpenStack as its persistence layer.""" - def __init__(self, auth_url, user, auth_key): + def __init__(self, auth_url, user, auth_key, container): """Create a new OpenStack data container.""" self._auth_url = auth_url self._user = user self._auth_key = auth_key + self._container = container self.set_document_factory(LeapDocument) self._connection = swiftclient.Connection(self._auth_url, self._user, self._auth_key) @@ -31,7 +32,11 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.whats_changed) def get_doc(self, doc_id, include_deleted=False): - raise NotImplementedError(self.get_doc) + # TODO: support deleted docs? + headers = self._connection.head_object(self._container, doc_id) + rev = headers['x-object-meta-rev'] + response, contents = self._connection.get_object(self._container, doc_id) + return self._factory(doc_id, rev, contents) def get_all_docs(self, include_deleted=False): """Get all documents from the database.""" -- cgit v1.2.3 From 26f2abf21f295700c0f8fdf3bd62667562f01ea3 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 3 Dec 2012 16:08:49 -0200 Subject: Add put_object for u1db OpenStack backend. --- src/leap/soledad/openstack.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 9bb4fddd..25f1a404 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -43,7 +43,16 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.get_all_docs) def put_doc(self, doc): - raise NotImplementedError(self.put_doc) + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + # TODO: check for conflicts? + new_rev = self._allocate_doc_rev(doc.rev) + headers = { 'X-Object-Meta-Rev' : new_rev } + self._connection.put_object(self._container, doc_id, doc.get_json(), + headers=headers) + return new_rev def delete_doc(self, doc): raise NotImplementedError(self.delete_doc) -- cgit v1.2.3 From b4a8d6f10ebcd7d8cf284d7bd18138d074695aff Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 10:42:48 -0200 Subject: Add simple encoding test --- src/leap/soledad/leap.py | 2 +- src/leap/soledad/tests/__init__.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 src/leap/soledad/tests/__init__.py (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index c9243587..41bcf15a 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -18,7 +18,7 @@ class LeapDocument(Document): def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, encrypted_json=None, default_key=None): - super(Document, self).__init__(doc_id, rev, json, has_conflicts) + super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) if encrypted_json: self.set_encrypted_json(encrypted_json) self._default_key = default_key diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py new file mode 100644 index 00000000..88cddef0 --- /dev/null +++ b/src/leap/soledad/tests/__init__.py @@ -0,0 +1,36 @@ +try: + import simplejson as json +except ImportError: + import json # noqa + +import unittest +import os + +import u1db +from soledad import leap + +class EncryptedSyncTestCase(unittest.TestCase): + + PREFIX = '/var/tmp' + db1_path = "%s/db1.u1db" % PREFIX + db2_path = "%s/db2.u1db" % PREFIX + + def setUp(self): + self.db1 = u1db.open(self.db1_path, create=True, + document_factory=leap.LeapDocument) + self.db2 = u1db.open(self.db2_path, create=True, + document_factory=leap.LeapDocument) + + def tearDown(self): + os.unlink(self.db1_path) + os.unlink(self.db2_path) + + def test_encoding(self): + doc1 = self.db1.create_doc({ 'key' : 'val' }) + enc1 = doc1.get_encrypted_json() + doc2 = leap.LeapDocument(doc_id=doc1.doc_id, json=doc1.get_json()) + enc2 = doc2.get_encrypted_json() + self.assertEqual(enc1, enc2, 'incorrect document encoding') + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3 From 717aa819bf23209a676d965774f75a71e729bb01 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 10:49:33 -0200 Subject: Add gnupg to README as dependency --- src/leap/soledad/README | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README index dc448374..2ece8145 100644 --- a/src/leap/soledad/README +++ b/src/leap/soledad/README @@ -4,10 +4,12 @@ Soledad -- Synchronization Of Locally Encrypted Data Among Devices Dependencies ------------ -Soledad uses the following python libraries: +Soledad depends on the following python libraries: * u1db 0.1.4 [1] * python-swiftclient 1.1.1 [2] + * python-gnupg 0.3.1 [3] [1] http://pypi.python.org/pypi/u1db/0.1.4 [2] https://launchpad.net/python-swiftclient +[3] http://packages.python.org/python-gnupg/index.html -- cgit v1.2.3 From dc84b200916a5f6677f6b1735fd58a6383b0734e Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 11:22:51 -0200 Subject: Basic encryption/decryption of Document's json content --- src/leap/soledad/leap.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index 41bcf15a..b9d253d9 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -9,6 +9,10 @@ from u1db.remote.http_database import HTTPDatabase import base64 +class NoDefaultKey(Exception): + pass + + class LeapDocument(Document): """ LEAP Documents are standard u1db documents with cabability of returning an @@ -17,28 +21,31 @@ class LeapDocument(Document): """ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, - encrypted_json=None, default_key=None): + encrypted_json=None, default_key=None, gpg_wrapper=None): super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) if encrypted_json: self.set_encrypted_json(encrypted_json) + if gpg_wrapper: + self._gpg = gpg_wrapper + else: + self._gpg = GPGWrapper() self._default_key = default_key def get_encrypted_json(self): """ Returns document's json serialization encrypted with user's public key. """ - # TODO: replace for openpgp encryption with users's pub key. - return json.dumps({'cyphertext':base64.b64encode(self.get_json())}) + if self._default_key is None: + raise NoDefaultKey() + cyphertext = self._gpg.encrypt(self.get_json(), self._default_key) + return json.dumps({'cyphertext' : cyphetext}) def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ - # TODO: - # - replace for openpgp decryption using user's priv key. - # - raise error if unsuccessful. cyphertext = json.loads(encrypted_json)['cyphertext'] - plaintext = base64.b64decode(cyphertext) + plaintext = self._gpg.decrypt(cyphertext) return self.set_json(plaintext) -- cgit v1.2.3 From a2c076b8d142ea75721dd25a655f72fc9457f222 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 12:06:31 -0200 Subject: Fix json encrypt/decrypt --- src/leap/soledad/leap.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index b9d253d9..853906a3 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -23,12 +23,12 @@ class LeapDocument(Document): def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, encrypted_json=None, default_key=None, gpg_wrapper=None): super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) - if encrypted_json: - self.set_encrypted_json(encrypted_json) if gpg_wrapper: self._gpg = gpg_wrapper else: self._gpg = GPGWrapper() + if encrypted_json: + self.set_encrypted_json(encrypted_json) self._default_key = default_key def get_encrypted_json(self): @@ -37,15 +37,18 @@ class LeapDocument(Document): """ if self._default_key is None: raise NoDefaultKey() - cyphertext = self._gpg.encrypt(self.get_json(), self._default_key) - return json.dumps({'cyphertext' : cyphetext}) + cyphertext = self._gpg.encrypt(self.get_json(), + self._default_key, + always_trust = True) + # TODO: always trust? + return json.dumps({'cyphertext' : str(cyphertext)}) def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ cyphertext = json.loads(encrypted_json)['cyphertext'] - plaintext = self._gpg.decrypt(cyphertext) + plaintext = str(self._gpg.decrypt(cyphertext)) return self.set_json(plaintext) -- cgit v1.2.3 From 346bac9e40c0003090b6d526e68c6c1d1983fbdf Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 12:06:58 -0200 Subject: Add test for setting/getting encrypted json contents. --- src/leap/soledad/tests/__init__.py | 203 ++++++++++++++++++++++++++++++++++--- 1 file changed, 188 insertions(+), 15 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index 88cddef0..61eb3f35 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -7,30 +7,203 @@ import unittest import os import u1db -from soledad import leap +from soledad import leap, GPGWrapper + class EncryptedSyncTestCase(unittest.TestCase): - PREFIX = '/var/tmp' - db1_path = "%s/db1.u1db" % PREFIX - db2_path = "%s/db2.u1db" % PREFIX + PREFIX = "/var/tmp" + GNUPG_HOME = "%s/gnupg" % PREFIX + DB1_FILE = "%s/db1.u1db" % PREFIX + DB2_FILE = "%s/db2.u1db" % PREFIX def setUp(self): - self.db1 = u1db.open(self.db1_path, create=True, + self.db1 = u1db.open(self.DB1_FILE, create=True, document_factory=leap.LeapDocument) - self.db2 = u1db.open(self.db2_path, create=True, + self.db2 = u1db.open(self.DB2_FILE, create=True, document_factory=leap.LeapDocument) + self.gpg = GPGWrapper(gpghome=self.GNUPG_HOME) + self.gpg.import_keys(PUBLIC_KEY) + self.gpg.import_keys(PRIVATE_KEY) def tearDown(self): - os.unlink(self.db1_path) - os.unlink(self.db2_path) - - def test_encoding(self): - doc1 = self.db1.create_doc({ 'key' : 'val' }) - enc1 = doc1.get_encrypted_json() - doc2 = leap.LeapDocument(doc_id=doc1.doc_id, json=doc1.get_json()) - enc2 = doc2.get_encrypted_json() - self.assertEqual(enc1, enc2, 'incorrect document encoding') + os.unlink(self.DB1_FILE) + os.unlink(self.DB2_FILE) + + def test_get_set_encrypted(self): + doc1 = leap.LeapDocument(gpg_wrapper = self.gpg, + default_key = KEY_FINGERPRINT) + doc1.content = { 'key' : 'val' } + doc2 = leap.LeapDocument(doc_id=doc1.doc_id, + encrypted_json=doc1.get_encrypted_json(), + gpg_wrapper=self.gpg, + default_key = KEY_FINGERPRINT) + res1 = doc1.get_json() + res2 = doc2.get_json() + self.assertEqual(res1, res2, 'incorrect document encoding') + +# Key material for testing +KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF" +PUBLIC_KEY = """ +-----BEGIN PGP PUBLIC KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +mQINBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz +iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO +zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx +irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT +huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs +d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g +wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb +hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv +U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H +T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i +Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB +tBxMZWFwIFRlc3QgS2V5IDxsZWFwQGxlYXAuc2U+iQI3BBMBCAAhBQJQvfnZAhsD +BQsJCAcDBRUKCQgLBRYCAwEAAh4BAheAAAoJEC9FXigk0Y3fT7EQAKH3IuRniOpb +T/DDIgwwjz3oxB/W0DDMyPXowlhSOuM0rgGfntBpBb3boezEXwL86NPQxNGGruF5 +hkmecSiuPSvOmQlqlS95NGQp6hNG0YaKColh+Q5NTspFXCAkFch9oqUje0LdxfSP +QfV9UpeEvGyPmk1I9EJV/YDmZ4+Djge1d7qhVZInz4Rx1NrSyF/Tc2EC0VpjQFsU +Y9Kb2YBBR7ivG6DBc8ty0jJXi7B4WjkFcUEJviQpMF2dCLdonCehYs1PqsN1N7j+ +eFjQd+hqVMJgYuSGKjvuAEfClM6MQw7+FmFwMyLgK/Ew/DttHEDCri77SPSkOGSI +txCzhTg6798f6mJr7WcXmHX1w1Vcib5FfZ8vTDFVhz/XgAgArdhPo9V6/1dgSSiB +KPQ/spsco6u5imdOhckERE0lnAYvVT6KE81TKuhF/b23u7x+Wdew6kK0EQhYA7wy +7LmlaNXc7rMBQJ9Z60CJ4JDtatBWZ0kNrt2VfdDHVdqBTOpl0CraNUjWE5YMDasr +K2dF5IX8D3uuYtpZnxqg0KzyLg0tzL0tvOL1C2iudgZUISZNPKbS0z0v+afuAAnx +2pTC3uezbh2Jt8SWTLhll4i0P4Ps5kZ6HQUO56O+/Z1cWovX+mQekYFmERySDR9n +3k1uAwLilJmRmepGmvYbB8HloV8HqwgguQINBFC9+dkBEAC0I/xn1uborMgDvBtf +H0sEhwnXBC849/32zic6udB6/3Efk9nzbSpL3FSOuXITZsZgCHPkKarnoQ2ztMcS +sh1ke1C5gQGms75UVmM/nS+2YI4vY8OX/GC/on2vUyncqdH+bR6xH5hx4NbWpfTs +iQHmz5C6zzS/kuabGdZyKRaZHt23WQ7JX/4zpjqbC99DjHcP9BSk7tJ8wI4bkMYD +uFVQdT9O6HwyKGYwUU4sAQRAj7XCTGvVbT0dpgJwH4RmrEtJoHAx4Whg8mJ710E0 +GCmzf2jqkNuOw76ivgk27Kge+Hw00jmJjQhHY0yVbiaoJwcRrPKzaSjEVNgrpgP3 +lXPRGQArgESsIOTeVVHQ8fhK2YtTeCY9rIiO+L0OX2xo9HK7hfHZZWL6rqymXdyS +fhzh/f6IPyHFWnvj7Brl7DR8heMikygcJqv+ed2yx7iLyCUJ10g12I48+aEj1aLe +dP7lna32iY8/Z0SHQLNH6PXO9SlPcq2aFUgKqE75A/0FMk7CunzU1OWr2ZtTLNO1 +WT/13LfOhhuEq9jTyTosn0WxBjJKq18lnhzCXlaw6EAtbA7CUwsD3CTPR56aAXFK +3I7KXOVAqggrvMe5Tpdg5drfYpI8hZovL5aAgb+7Y5ta10TcJdUhS5K3kFAWe/td +U0cmWUMDP1UMSQ5Jg6JIQVWhSwARAQABiQIfBBgBCAAJBQJQvfnZAhsMAAoJEC9F +Xigk0Y3fRwsP/i0ElYCyxeLpWJTwo1iCLkMKz2yX1lFVa9nT1BVTPOQwr/IAc5OX +NdtbJ14fUsKL5pWgW8OmrXtwZm1y4euI1RPWWubG01ouzwnGzv26UcuHeqC5orZj +cOnKtL40y8VGMm8LoicVkRJH8blPORCnaLjdOtmA3rx/v2EXrJpSa3AhOy0ZSRXk +ZSrK68AVNwamHRoBSYyo0AtaXnkPX4+tmO8X8BPfj125IljubvwZPIW9VWR9UqCE +VPfDR1XKegVb6VStIywF7kmrknM1C5qUY28rdZYWgKorw01hBGV4jTW0cqde3N51 +XT1jnIAa+NoXUM9uQoGYMiwrL7vNsLlyyiW5ayDyV92H/rIuiqhFgbJsHTlsm7I8 +oGheR784BagAA1NIKD1qEO9T6Kz9lzlDaeWS5AUKeXrb7ZJLI1TTCIZx5/DxjLqM +Tt/RFBpVo9geZQrvLUqLAMwdaUvDXC2c6DaCPXTh65oCZj/hqzlJHH+RoTWWzKI+ +BjXxgUWF9EmZUBrg68DSmI+9wuDFsjZ51BcqvJwxyfxtTaWhdoYqH/UQS+D1FP3/ +diZHHlzwVwPICzM9ooNTgbrcDzyxRkIVqsVwBq7EtzcvgYUyX53yG25Giy6YQaQ2 +ZtQ/VymwFL3XdUWV6B/hU4PVAFvO3qlOtdJ6TpE+nEWgcWjCv5g7RjXX +=MuOY +-----END PGP PUBLIC KEY BLOCK----- +""" +PRIVATE_KEY = """ +-----BEGIN PGP PRIVATE KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +lQcYBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz +iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO +zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx +irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT +huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs +d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g +wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb +hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv +U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H +T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i +Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB +AA/+JHtlL39G1wsH9R6UEfUQJGXR9MiIiwZoKcnRB2o8+DS+OLjg0JOh8XehtuCs +E/8oGQKtQqa5bEIstX7IZoYmYFiUQi9LOzIblmp2vxOm+HKkxa4JszWci2/ZmC3t +KtaA4adl9XVnshoQ7pijuCMUKB3naBEOAxd8s9d/JeReGIYkJErdrnVfNk5N71Ds +FmH5Ll3XtEDvgBUQP3nkA6QFjpsaB94FHjL3gDwum/cxzj6pCglcvHOzEhfY0Ddb +J967FozQTaf2JW3O+w3LOqtcKWpq87B7+O61tVidQPSSuzPjCtFF0D2LC9R/Hpky +KTMQ6CaKja4MPhjwywd4QPcHGYSqjMpflvJqi+kYIt8psUK/YswWjnr3r4fbuqVY +VhtiHvnBHQjz135lUqWvEz4hM3Xpnxydx7aRlv5NlevK8+YIO5oFbWbGNTWsPZI5 +jpoFBpSsnR1Q5tnvtNHauvoWV+XN2qAOBTG+/nEbDYH6Ak3aaE9jrpTdYh0CotYF +q7csANsDy3JvkAzeU6WnYpsHHaAjqOGyiZGsLej1UcXPFMosE/aUo4WQhiS8Zx2c +zOVKOi/X5vQ2GdNT9Qolz8AriwzsvFR+bxPzyd8V6ALwDsoXvwEYinYBKK8j0OPv +OOihSR6HVsuP9NUZNU9ewiGzte/+/r6pNXHvR7wTQ8EWLcEIAN6Zyrb0bHZTIlxt +VWur/Ht2mIZrBaO50qmM5RD3T5oXzWXi/pjLrIpBMfeZR9DWfwQwjYzwqi7pxtYx +nJvbMuY505rfnMoYxb4J+cpRXV8MS7Dr1vjjLVUC9KiwSbM3gg6emfd2yuA93ihv +Pe3mffzLIiQa4mRE3wtGcioC43nWuV2K2e1KjxeFg07JhrezA/1Cak505ab/tmvP +4YmjR5c44+yL/YcQ3HdFgs4mV+nVbptRXvRcPpolJsgxPccGNdvHhsoR4gwXMS3F +RRPD2z6x8xeN73Q4KH3bm01swQdwFBZbWVfmUGLxvN7leCdfs9+iFJyqHiCIB6Iv +mQfp8F0IAOwSo8JhWN+V1dwML4EkIrM8wUb4yecNLkyR6TpPH/qXx4PxVMC+vy6x +sCtjeHIwKE+9vqnlhd5zOYh7qYXEJtYwdeDDmDbL8oks1LFfd+FyAuZXY33DLwn0 +cRYsr2OEZmaajqUB3NVmj3H4uJBN9+paFHyFSXrH68K1Fk2o3n+RSf2EiX+eICwI +L6rqoF5sSVUghBWdNegV7qfy4anwTQwrIMGjgU5S6PKW0Dr/3iO5z3qQpGPAj5OW +ATqPWkDICLbObPxD5cJlyyNE2wCA9VVc6/1d6w4EVwSq9h3/WTpATEreXXxTGptd +LNiTA1nmakBYNO2Iyo3djhaqBdWjk+EIAKtVEnJH9FAVwWOvaj1RoZMA5DnDMo7e +SnhrCXl8AL7Z1WInEaybasTJXn1uQ8xY52Ua4b8cbuEKRKzw/70NesFRoMLYoHTO +dyeszvhoDHberpGRTciVmpMu7Hyi33rM31K9epA4ib6QbbCHnxkWOZB+Bhgj1hJ8 +xb4RBYWiWpAYcg0+DAC3w9gfxQhtUlZPIbmbrBmrVkO2GVGUj8kH6k4UV6kUHEGY +HQWQR0HcbKcXW81ZXCCD0l7ROuEWQtTe5Jw7dJ4/QFuqZnPutXVRNOZqpl6eRShw +7X2/a29VXBpmHA95a88rSQsL+qm7Fb3prqRmuMCtrUZgFz7HLSTuUMR867QcTGVh +cCBUZXN0IEtleSA8bGVhcEBsZWFwLnNlPokCNwQTAQgAIQUCUL352QIbAwULCQgH +AwUVCgkICwUWAgMBAAIeAQIXgAAKCRAvRV4oJNGN30+xEACh9yLkZ4jqW0/wwyIM +MI896MQf1tAwzMj16MJYUjrjNK4Bn57QaQW926HsxF8C/OjT0MTRhq7heYZJnnEo +rj0rzpkJapUveTRkKeoTRtGGigqJYfkOTU7KRVwgJBXIfaKlI3tC3cX0j0H1fVKX +hLxsj5pNSPRCVf2A5mePg44HtXe6oVWSJ8+EcdTa0shf03NhAtFaY0BbFGPSm9mA +QUe4rxugwXPLctIyV4uweFo5BXFBCb4kKTBdnQi3aJwnoWLNT6rDdTe4/nhY0Hfo +alTCYGLkhio77gBHwpTOjEMO/hZhcDMi4CvxMPw7bRxAwq4u+0j0pDhkiLcQs4U4 +Ou/fH+pia+1nF5h19cNVXIm+RX2fL0wxVYc/14AIAK3YT6PVev9XYEkogSj0P7Kb +HKOruYpnToXJBERNJZwGL1U+ihPNUyroRf29t7u8flnXsOpCtBEIWAO8Muy5pWjV +3O6zAUCfWetAieCQ7WrQVmdJDa7dlX3Qx1XagUzqZdAq2jVI1hOWDA2rKytnReSF +/A97rmLaWZ8aoNCs8i4NLcy9Lbzi9QtornYGVCEmTTym0tM9L/mn7gAJ8dqUwt7n +s24dibfElky4ZZeItD+D7OZGeh0FDuejvv2dXFqL1/pkHpGBZhEckg0fZ95NbgMC +4pSZkZnqRpr2GwfB5aFfB6sIIJ0HGARQvfnZARAAtCP8Z9bm6KzIA7wbXx9LBIcJ +1wQvOPf99s4nOrnQev9xH5PZ820qS9xUjrlyE2bGYAhz5Cmq56ENs7THErIdZHtQ +uYEBprO+VFZjP50vtmCOL2PDl/xgv6J9r1Mp3KnR/m0esR+YceDW1qX07IkB5s+Q +us80v5LmmxnWcikWmR7dt1kOyV/+M6Y6mwvfQ4x3D/QUpO7SfMCOG5DGA7hVUHU/ +Tuh8MihmMFFOLAEEQI+1wkxr1W09HaYCcB+EZqxLSaBwMeFoYPJie9dBNBgps39o +6pDbjsO+or4JNuyoHvh8NNI5iY0IR2NMlW4mqCcHEazys2koxFTYK6YD95Vz0RkA +K4BErCDk3lVR0PH4StmLU3gmPayIjvi9Dl9saPRyu4Xx2WVi+q6spl3ckn4c4f3+ +iD8hxVp74+wa5ew0fIXjIpMoHCar/nndsse4i8glCddINdiOPPmhI9Wi3nT+5Z2t +9omPP2dEh0CzR+j1zvUpT3KtmhVICqhO+QP9BTJOwrp81NTlq9mbUyzTtVk/9dy3 +zoYbhKvY08k6LJ9FsQYySqtfJZ4cwl5WsOhALWwOwlMLA9wkz0eemgFxStyOylzl +QKoIK7zHuU6XYOXa32KSPIWaLy+WgIG/u2ObWtdE3CXVIUuSt5BQFnv7XVNHJllD +Az9VDEkOSYOiSEFVoUsAEQEAAQAP/1AagnZQZyzHDEgw4QELAspYHCWLXE5aZInX +wTUJhK31IgIXNn9bJ0hFiSpQR2xeMs9oYtRuPOu0P8oOFMn4/z374fkjZy8QVY3e +PlL+3EUeqYtkMwlGNmVw5a/NbNuNfm5Darb7pEfbYd1gPcni4MAYw7R2SG/57GbC +9gucvspHIfOSfBNLBthDzmK8xEKe1yD2eimfc2T7IRYb6hmkYfeds5GsqvGI6mwI +85h4uUHWRc5JOlhVM6yX8hSWx0L60Z3DZLChmc8maWnFXd7C8eQ6P1azJJbW71Ih +7CoK0XW4LE82vlQurSRFgTwfl7wFYszW2bOzCuhHDDtYnwH86Nsu0DC78ZVRnvxn +E8Ke/AJgrdhIOo4UAyR+aZD2+2mKd7/waOUTUrUtTzc7i8N3YXGi/EIaNReBXaq+ +ZNOp24BlFzRp+FCF/pptDW9HjPdiV09x0DgICmeZS4Gq/4vFFIahWctg52NGebT0 +Idxngjj+xDtLaZlLQoOz0n5ByjO/Wi0ANmMv1sMKCHhGvdaSws2/PbMR2r4caj8m +KXpIgdinM/wUzHJ5pZyF2U/qejsRj8Kw8KH/tfX4JCLhiaP/mgeTuWGDHeZQERAT +xPmRFHaLP9/ZhvGNh6okIYtrKjWTLGoXvKLHcrKNisBLSq+P2WeFrlme1vjvJMo/ +jPwLT5o9CADQmcbKZ+QQ1ZM9v99iDZol7SAMZX43JC019sx6GK0u6xouJBcLfeB4 +OXacTgmSYdTa9RM9fbfVpti01tJ84LV2SyL/VJq/enJF4XQPSynT/tFTn1PAor6o +tEAAd8fjKdJ6LnD5wb92SPHfQfXqI84rFEO8rUNIE/1ErT6DYifDzVCbfD2KZdoF +cOSp7TpD77sY1bs74ocBX5ejKtd+aH99D78bJSMM4pSDZsIEwnomkBHTziubPwJb +OwnATy0LmSMAWOw5rKbsh5nfwCiUTM20xp0t5JeXd+wPVWbpWqI2EnkCEN+RJr9i +7dp/ymDQ+Yt5wrsN3NwoyiexPOG91WQVCADdErHsnglVZZq9Z8Wx7KwecGCUurJ2 +H6lKudv5YOxPnAzqZS5HbpZd/nRTMZh2rdXCr5m2YOuewyYjvM757AkmUpM09zJX +MQ1S67/UX2y8/74TcRF97Ncx9HeELs92innBRXoFitnNguvcO6Esx4BTe1OdU6qR +ER3zAmVf22Le9ciXbu24DN4mleOH+OmBx7X2PqJSYW9GAMTsRB081R6EWKH7romQ +waxFrZ4DJzZ9ltyosEJn5F32StyLrFxpcrdLUoEaclZCv2qka7sZvi0EvovDVEBU +e10jOx9AOwf8Gj2ufhquQ6qgVYCzbP+YrodtkFrXRS3IsljIchj1M2ffB/0bfoUs +rtER9pLvYzCjBPg8IfGLw0o754Qbhh/ReplCRTusP/fQMybvCvfxreS3oyEriu/G +GufRomjewZ8EMHDIgUsLcYo2UHZsfF7tcazgxMGmMvazp4r8vpgrvW/8fIN/6Adu +tF+WjWDTvJLFJCe6O+BFJOWrssNrrra1zGtLC1s8s+Wfpe+bGPL5zpHeebGTwH1U +22eqgJArlEKxrfarz7W5+uHZJHSjF/K9ZvunLGD0n9GOPMpji3UO3zeM8IYoWn7E +/EWK1XbjnssNemeeTZ+sDh+qrD7BOi+vCX1IyBxbfqnQfJZvmcPWpruy1UsO+aIC +0GY8Jr3OL69dDQ21jueJAh8EGAEIAAkFAlC9+dkCGwwACgkQL0VeKCTRjd9HCw/+ +LQSVgLLF4ulYlPCjWIIuQwrPbJfWUVVr2dPUFVM85DCv8gBzk5c121snXh9Swovm +laBbw6ate3BmbXLh64jVE9Za5sbTWi7PCcbO/bpRy4d6oLmitmNw6cq0vjTLxUYy +bwuiJxWREkfxuU85EKdouN062YDevH+/YResmlJrcCE7LRlJFeRlKsrrwBU3BqYd +GgFJjKjQC1peeQ9fj62Y7xfwE9+PXbkiWO5u/Bk8hb1VZH1SoIRU98NHVcp6BVvp +VK0jLAXuSauSczULmpRjbyt1lhaAqivDTWEEZXiNNbRyp17c3nVdPWOcgBr42hdQ +z25CgZgyLCsvu82wuXLKJblrIPJX3Yf+si6KqEWBsmwdOWybsjygaF5HvzgFqAAD +U0goPWoQ71PorP2XOUNp5ZLkBQp5etvtkksjVNMIhnHn8PGMuoxO39EUGlWj2B5l +Cu8tSosAzB1pS8NcLZzoNoI9dOHrmgJmP+GrOUkcf5GhNZbMoj4GNfGBRYX0SZlQ +GuDrwNKYj73C4MWyNnnUFyq8nDHJ/G1NpaF2hiof9RBL4PUU/f92JkceXPBXA8gL +Mz2ig1OButwPPLFGQhWqxXAGrsS3Ny+BhTJfnfIbbkaLLphBpDZm1D9XKbAUvdd1 +RZXoH+FTg9UAW87eqU610npOkT6cRaBxaMK/mDtGNdc= +=JTFu +-----END PGP PRIVATE KEY BLOCK----- +""" if __name__ == '__main__': unittest.main() -- cgit v1.2.3 From 1b409bb1b6f5d0ae6630875f114f202823be420c Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 16:26:11 -0200 Subject: Correct test error message --- src/leap/soledad/tests/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index 61eb3f35..0d7ae2b4 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -40,7 +40,8 @@ class EncryptedSyncTestCase(unittest.TestCase): default_key = KEY_FINGERPRINT) res1 = doc1.get_json() res2 = doc2.get_json() - self.assertEqual(res1, res2, 'incorrect document encoding') + self.assertEqual(res1, res2, 'incorrect document encryption') + # Key material for testing KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF" -- cgit v1.2.3 From a0410a70d1ad2a3965ed1d8de7929ce70d6ea5fc Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 16:26:29 -0200 Subject: Insightful comment on gpg wrappers. --- src/leap/soledad/leap.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py index 853906a3..2c815632 100644 --- a/src/leap/soledad/leap.py +++ b/src/leap/soledad/leap.py @@ -23,10 +23,11 @@ class LeapDocument(Document): def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, encrypted_json=None, default_key=None, gpg_wrapper=None): super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) - if gpg_wrapper: - self._gpg = gpg_wrapper - else: + # we might want to get already initialized wrappers for testing. + if gpg_wrapper is None: self._gpg = GPGWrapper() + else: + self._gpg = gpg_wrapper if encrypted_json: self.set_encrypted_json(encrypted_json) self._default_key = default_key -- cgit v1.2.3 From 1c825cf72575b3e4be81d038e546bbe5fda7ed53 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 4 Dec 2012 20:39:04 -0200 Subject: Add transaction and sync logs as openstack documents. --- src/leap/soledad/openstack.py | 114 ++++++++++++++++++++++++++++++++++--- src/leap/soledad/tests/__init__.py | 46 +++++++++++++++ 2 files changed, 151 insertions(+), 9 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 25f1a404..22a2d067 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -1,5 +1,6 @@ -from u1db.backends import CommonBackend from leap import * +from u1db import errors +from u1db.backends import CommonBackend from u1db.remote.http_target import HTTPSyncTarget from swiftclient import client @@ -96,21 +97,26 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.close) def _get_replica_gen_and_trans_id(self, other_replica_uid): - raise NotImplementedError(self._get_replica_gen_and_trans_id) + self._update_u1db_data() + return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) def _set_replica_gen_and_trans_id(self, other_replica_uid, other_generation, other_transaction_id): - raise NotImplementedError(self._set_replica_gen_and_trans_id) + self._update_u1db_data() + return self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, + other_generation, other_transaction_id) #------------------------------------------------------------------------- # implemented methods from CommonBackend #------------------------------------------------------------------------- def _get_generation(self): - raise NotImplementedError(self._get_generation) + self._update_u1db_data() + return self._transaction_log.get_generation() def _get_generation_info(self): - raise NotImplementedError(self._get_generation_info) + self._update_u1db_data() + return self._transaction_log.get_generation_info() def _get_doc(self, doc_id, check_for_conflicts=False): """Get just the document content, without fancy handling.""" @@ -119,15 +125,16 @@ class OpenStackDatabase(CommonBackend): def _has_conflicts(self, doc_id): raise NotImplementedError(self._has_conflicts) - def _get_transaction_log(self): - raise NotImplementedError(self._get_transaction_log) - def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): raise NotImplementedError(self._put_and_update_indexes) def _get_trans_id_for_gen(self, generation): - raise NotImplementedError(self._get_trans_id_for_gen) + self._update_u1db_data() + trans_id = self._transaction_log.get_trans_id_for_gen(generation) + if trans_id is None: + raise errors.InvalidGeneration + return trans_id #------------------------------------------------------------------------- # OpenStack specific methods @@ -143,6 +150,11 @@ class OpenStackDatabase(CommonBackend): self._url, self._auth_token = self._connection.get_auth() return self._url, self.auth_token + def _update_u1db_data(self): + data = self.get_doc('u1db_data').content + self._transaction_log = data['transaction_log'] + self._sync_log = data['sync_log'] + class OpenStackSyncTarget(HTTPSyncTarget): @@ -152,3 +164,87 @@ class OpenStackSyncTarget(HTTPSyncTarget): def record_sync_info(self, source_replica_uid, source_replica_generation, source_replica_transaction_id): raise NotImplementedError(self.record_sync_info) + + +class SimpleLog(object): + def __init__(self, log=None): + self._log = [] + if log: + self._log = log + + def append(self, msg): + self._log.append(msg) + + def reduce(self, func, initializer=None): + return reduce(func, self._log, initializer) + + def map(self, func): + return map(func, self._log) + + +class TransactionLog(SimpleLog): + """ + A list of (generation, doc_id, transaction_id) tuples. + """ + + def get_generation(self): + """ + Return the current generation. + """ + gens = self.map(lambda x: x[0]) + if not gens: + return 0 + return max(gens) + + def get_generation_info(self): + """ + Return the current generation and transaction id. + """ + if not self._log: + return(0, '') + info = self.map(lambda x: (x[0], x[2])) + return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) + + def get_trans_id_for_gen(self, gen): + """ + Get the transaction id corresponding to a particular generation. + """ + log = self.reduce(lambda x, y: y if y[0] == gen else x) + if log is None: + return None + return log[2] + +class SyncLog(SimpleLog): + """ + A list of (replica_id, generation, transaction_id) tuples. + """ + + def find_by_replica_uid(self, replica_uid): + if not self._log: + return () + return self.reduce(lambda x, y: y if y[0] == replica_uid else x) + + def get_replica_gen_and_trans_id(self, other_replica_uid): + """ + Return the last known generation and transaction id for the other db + replica. + """ + info = self.find_by_replica_uid(other_replica_uid) + if not info: + return (0, '') + return (info[1], info[2]) + + def set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """ + Set the last-known generation and transaction id for the other + database replica. + """ + old_log = self._log + self._log = [] + for log in old_log: + if log[0] != other_replica_uid: + self.append(log) + self.append((other_replica_uid, other_generation, + other_transaction_id)) + diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index 0d7ae2b4..50c99dd4 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -8,6 +8,7 @@ import os import u1db from soledad import leap, GPGWrapper +from soledad.openstack import SimpleLog, TransactionLog, SyncLog class EncryptedSyncTestCase(unittest.TestCase): @@ -43,6 +44,51 @@ class EncryptedSyncTestCase(unittest.TestCase): self.assertEqual(res1, res2, 'incorrect document encryption') +class LogTestCase(unittest.TestCase): + + + def test_transaction_log(self): + data = [ + (2, "doc_3", "tran_3"), + (3, "doc_2", "tran_2"), + (1, "doc_1", "tran_1") + ] + log = TransactionLog(data) + self.assertEqual(log.get_generation(), 3, 'error getting generation') + self.assertEqual(log.get_generation_info(), (3, 'tran_2'), + 'error getting generation info') + self.assertEqual(log.get_trans_id_for_gen(1), 'tran_1', + 'error getting trans_id for gen') + self.assertEqual(log.get_trans_id_for_gen(2), 'tran_3', + 'error getting trans_id for gen') + self.assertEqual(log.get_trans_id_for_gen(3), 'tran_2', + 'error getting trans_id for gen') + + def test_sync_log(self): + data = [ + ("replica_3", 3, "tran_3"), + ("replica_2", 2, "tran_2"), + ("replica_1", 1, "tran_1") + ] + log = SyncLog(data) + # test getting + self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), + (3, 'tran_3'), 'error getting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'), + (2, 'tran_2'), 'error getting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'), + (1, 'tran_1'), 'error getting replica gen and trans id') + # test setting + log.set_replica_gen_and_trans_id('replica_1', 2, 'tran_12') + self.assertEqual(len(log._log), 3, 'error in log size after setting') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'), + (2, 'tran_12'), 'error setting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'), + (2, 'tran_2'), 'error setting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), + (3, 'tran_3'), 'error setting replica gen and trans id') + + # Key material for testing KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF" PUBLIC_KEY = """ -- cgit v1.2.3 From adc66753c6a98a1dbe6a41c496e71602cadfd765 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 10:09:51 -0200 Subject: Transaction and sync logs are updated locally and remotelly. --- src/leap/soledad/openstack.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 22a2d067..8bbae8d8 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -53,6 +53,10 @@ class OpenStackDatabase(CommonBackend): headers = { 'X-Object-Meta-Rev' : new_rev } self._connection.put_object(self._container, doc_id, doc.get_json(), headers=headers) + new_gen = self._get_generation() + 1 + trans_id = self._allocate_transaction_id() + self._transaction_log.append((new_gen, doc.doc_id, trans_id)) + self._set_u1db_data() return new_rev def delete_doc(self, doc): @@ -97,25 +101,27 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.close) def _get_replica_gen_and_trans_id(self, other_replica_uid): - self._update_u1db_data() + self._get_u1db_data() return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) def _set_replica_gen_and_trans_id(self, other_replica_uid, other_generation, other_transaction_id): - self._update_u1db_data() - return self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, - other_generation, other_transaction_id) + self._get_u1db_data() + self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, + other_generation, + other_transaction_id) + self._set_u1db_data() #------------------------------------------------------------------------- # implemented methods from CommonBackend #------------------------------------------------------------------------- def _get_generation(self): - self._update_u1db_data() + self._get_u1db_data() return self._transaction_log.get_generation() def _get_generation_info(self): - self._update_u1db_data() + self._get_u1db_data() return self._transaction_log.get_generation_info() def _get_doc(self, doc_id, check_for_conflicts=False): @@ -130,7 +136,7 @@ class OpenStackDatabase(CommonBackend): def _get_trans_id_for_gen(self, generation): - self._update_u1db_data() + self._get_u1db_data() trans_id = self._transaction_log.get_trans_id_for_gen(generation) if trans_id is None: raise errors.InvalidGeneration @@ -150,11 +156,17 @@ class OpenStackDatabase(CommonBackend): self._url, self._auth_token = self._connection.get_auth() return self._url, self.auth_token - def _update_u1db_data(self): + def _get_u1db_data(self): data = self.get_doc('u1db_data').content self._transaction_log = data['transaction_log'] self._sync_log = data['sync_log'] + def _set_u1db_data(self): + doc = self._factory('u1db_data') + doc.content = { 'transaction_log' : self._transaction_log, + 'sync_log' : self._sync_log } + self.put_doc(doc) + class OpenStackSyncTarget(HTTPSyncTarget): @@ -181,6 +193,9 @@ class SimpleLog(object): def map(self, func): return map(func, self._log) + def filter(self, func): + return filter(func, self._log) + class TransactionLog(SimpleLog): """ @@ -214,6 +229,7 @@ class TransactionLog(SimpleLog): return None return log[2] + class SyncLog(SimpleLog): """ A list of (replica_id, generation, transaction_id) tuples. @@ -240,11 +256,7 @@ class SyncLog(SimpleLog): Set the last-known generation and transaction id for the other database replica. """ - old_log = self._log - self._log = [] - for log in old_log: - if log[0] != other_replica_uid: - self.append(log) + self._log = self.filter(lambda x: x[0] != other_replica_uid) self.append((other_replica_uid, other_generation, other_transaction_id)) - + -- cgit v1.2.3 From e95726b8a7803dbb23bfca470cf4b665cf8559a4 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 15:36:07 -0200 Subject: OpenStack backend can find what's changed. --- src/leap/soledad/openstack.py | 60 ++++++++++++++++++++++++++++++++------ src/leap/soledad/tests/__init__.py | 33 +++++++++++++++++++-- 2 files changed, 81 insertions(+), 12 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 8bbae8d8..7b7e656f 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -30,7 +30,9 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.set_document_size_limit) def whats_changed(self, old_generation=0): - raise NotImplementedError(self.whats_changed) + # This method is implemented in TransactionLog because testing is + # easier like this for now, but it can be moved to here afterwards. + return self._transaction_log.whats_changed(old_generation) def get_doc(self, doc_id, include_deleted=False): # TODO: support deleted docs? @@ -179,22 +181,29 @@ class OpenStackSyncTarget(HTTPSyncTarget): class SimpleLog(object): - def __init__(self, log=None): + def __init__(self): self._log = [] - if log: - self._log = log + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return self._log + + log = property( + _get_log, _set_log, doc="Log contents.") def append(self, msg): self._log.append(msg) def reduce(self, func, initializer=None): - return reduce(func, self._log, initializer) + return reduce(func, self.log, initializer) def map(self, func): - return map(func, self._log) + return map(func, self.log) def filter(self, func): - return filter(func, self._log) + return filter(func, self.log) class TransactionLog(SimpleLog): @@ -202,6 +211,15 @@ class TransactionLog(SimpleLog): A list of (generation, doc_id, transaction_id) tuples. """ + def _set_log(self, log): + self._log = log + + def _get_log(self): + return sorted(self._log, reverse=True) + + log = property( + _get_log, _set_log, doc="Log contents.") + def get_generation(self): """ Return the current generation. @@ -229,6 +247,30 @@ class TransactionLog(SimpleLog): return None return log[2] + def whats_changed(self, old_generation): + results = self.filter(lambda x: x[0] > old_generation) + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + results = self.log + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, _, newest_trans_id = results[0] + + return cur_gen, newest_trans_id, changes + + class SyncLog(SimpleLog): """ @@ -236,7 +278,7 @@ class SyncLog(SimpleLog): """ def find_by_replica_uid(self, replica_uid): - if not self._log: + if not self.log: return () return self.reduce(lambda x, y: y if y[0] == replica_uid else x) @@ -256,7 +298,7 @@ class SyncLog(SimpleLog): Set the last-known generation and transaction id for the other database replica. """ - self._log = self.filter(lambda x: x[0] != other_replica_uid) + self.log = self.filter(lambda x: x[0] != other_replica_uid) self.append((other_replica_uid, other_generation, other_transaction_id)) diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index 50c99dd4..4f63648e 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -8,7 +8,11 @@ import os import u1db from soledad import leap, GPGWrapper -from soledad.openstack import SimpleLog, TransactionLog, SyncLog +from soledad.openstack import ( + SimpleLog, + TransactionLog, + SyncLog, + ) class EncryptedSyncTestCase(unittest.TestCase): @@ -53,7 +57,8 @@ class LogTestCase(unittest.TestCase): (3, "doc_2", "tran_2"), (1, "doc_1", "tran_1") ] - log = TransactionLog(data) + log = TransactionLog() + log.log = data self.assertEqual(log.get_generation(), 3, 'error getting generation') self.assertEqual(log.get_generation_info(), (3, 'tran_2'), 'error getting generation info') @@ -70,7 +75,8 @@ class LogTestCase(unittest.TestCase): ("replica_2", 2, "tran_2"), ("replica_1", 1, "tran_1") ] - log = SyncLog(data) + log = SyncLog() + log.log = data # test getting self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), (3, 'tran_3'), 'error getting replica gen and trans id') @@ -88,6 +94,27 @@ class LogTestCase(unittest.TestCase): self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), (3, 'tran_3'), 'error setting replica gen and trans id') + def test_whats_changed(self): + data = [ + (2, "doc_3", "tran_3"), + (3, "doc_2", "tran_2"), + (1, "doc_1", "tran_1") + ] + log = TransactionLog() + log.log = data + self.assertEqual( + log.whats_changed(3), + (3, "tran_2", []), + 'error getting whats changed.') + self.assertEqual( + log.whats_changed(2), + (3, "tran_2", [("doc_2",3,"tran_2")]), + 'error getting whats changed.') + self.assertEqual( + log.whats_changed(1), + (3, "tran_2", [("doc_3",2,"tran_3"),("doc_2",3,"tran_2")]), + 'error getting whats changed.') + # Key material for testing KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF" -- cgit v1.2.3 From d1bd08fd5952b8782e6fd59129fc4e2b15777617 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 15:44:01 -0200 Subject: Get doc split in two methods. --- src/leap/soledad/openstack.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 7b7e656f..a7220fa8 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -34,13 +34,24 @@ class OpenStackDatabase(CommonBackend): # easier like this for now, but it can be moved to here afterwards. return self._transaction_log.whats_changed(old_generation) - def get_doc(self, doc_id, include_deleted=False): - # TODO: support deleted docs? - headers = self._connection.head_object(self._container, doc_id) - rev = headers['x-object-meta-rev'] + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ response, contents = self._connection.get_object(self._container, doc_id) + rev = response['x-object-meta-rev'] return self._factory(doc_id, rev, contents) + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + def get_all_docs(self, include_deleted=False): """Get all documents from the database.""" raise NotImplementedError(self.get_all_docs) @@ -126,10 +137,6 @@ class OpenStackDatabase(CommonBackend): self._get_u1db_data() return self._transaction_log.get_generation_info() - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - raise NotImplementedError(self._get_doc) - def _has_conflicts(self, doc_id): raise NotImplementedError(self._has_conflicts) -- cgit v1.2.3 From 7d12e18de3224ba6ab21713a45b3620537f0d0cc Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 15:46:09 -0200 Subject: What's changed updates u1db data before querying log. --- src/leap/soledad/openstack.py | 1 + 1 file changed, 1 insertion(+) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index a7220fa8..31f59e10 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -30,6 +30,7 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.set_document_size_limit) def whats_changed(self, old_generation=0): + self._get_u1db_data() # This method is implemented in TransactionLog because testing is # easier like this for now, but it can be moved to here afterwards. return self._transaction_log.whats_changed(old_generation) -- cgit v1.2.3 From e60d2f46a5372a0a6d0d468a919eefde40d4807a Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 15:57:14 -0200 Subject: OpenStack backend can get all docs --- src/leap/soledad/openstack.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 31f59e10..ebb97ac5 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -55,7 +55,16 @@ class OpenStackDatabase(CommonBackend): def get_all_docs(self, include_deleted=False): """Get all documents from the database.""" - raise NotImplementedError(self.get_all_docs) + generation = self._get_generation() + results = [] + _, doc_ids = self._connection.get_container(self._container, + full_listing=True) + for doc_id in doc_ids: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) def put_doc(self, doc): if doc.doc_id is None: -- cgit v1.2.3 From 492c3f711927e09acd044db9aa76ce7a05c946c7 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 16:48:56 -0200 Subject: OpenStack backend can delete docs. --- src/leap/soledad/openstack.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index ebb97ac5..e7d62751 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -41,9 +41,12 @@ class OpenStackDatabase(CommonBackend): Conflicts do not happen on server side, so there's no need to check for them. """ - response, contents = self._connection.get_object(self._container, doc_id) - rev = response['x-object-meta-rev'] - return self._factory(doc_id, rev, contents) + try: + response, contents = self._connection.get_object(self._container, doc_id) + rev = response['x-object-meta-rev'] + return self._factory(doc_id, rev, contents) + except: swiftclient.ClientException + return None def get_doc(self, doc_id, include_deleted=False): doc = self._get_doc(doc_id, check_for_conflicts=True) @@ -83,7 +86,20 @@ class OpenStackDatabase(CommonBackend): return new_rev def delete_doc(self, doc): - raise NotImplementedError(self.delete_doc) + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_doc(olddoc) + return new_rev # start of index-related methods: these are not supported by this backend. -- cgit v1.2.3 From a237e151cc83edc9d3cd2b3ee0df854e7d4b6204 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 16:54:07 -0200 Subject: OpenStack backend can sync. --- src/leap/soledad/openstack.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index e7d62751..af04465d 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -137,7 +137,10 @@ class OpenStackDatabase(CommonBackend): raise NotImplementedError(self.close) def sync(self, url, creds=None, autocreate=True): - raise NotImplementedError(self.close) + from u1db.sync import Synchronizer + from u1db.remote.http_target import OpenStackSyncTarget + return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) def _get_replica_gen_and_trans_id(self, other_replica_uid): self._get_u1db_data() @@ -164,7 +167,8 @@ class OpenStackDatabase(CommonBackend): return self._transaction_log.get_generation_info() def _has_conflicts(self, doc_id): - raise NotImplementedError(self._has_conflicts) + # Documents never have conflicts on server. + return False def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): raise NotImplementedError(self._put_and_update_indexes) -- cgit v1.2.3 From 1815c078a9bb4c016b354429f7618da664344236 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 17:04:46 -0200 Subject: OpenStack backend initialization. --- src/leap/soledad/openstack.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index af04465d..07ed071d 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -18,6 +18,7 @@ class OpenStackDatabase(CommonBackend): self._connection = swiftclient.Connection(self._auth_url, self._user, self._auth_key) self._get_auth() + self._ensure_u1db_data() #------------------------------------------------------------------------- # implemented methods from Database @@ -185,11 +186,29 @@ class OpenStackDatabase(CommonBackend): # OpenStack specific methods #------------------------------------------------------------------------- - def _is_initialized(self, c): - raise NotImplementedError(self._is_initialized) + def _ensure_u1db_data(self): + """ + Guarantee that u1db data exists in store. + """ + if self._is_initialized(): + return + self._initialize() - def _initialize(self, c): - raise NotImplementedError(self._initialize) + def _is_initialized(self): + """ + Verify if u1db data exists in store. + """ + if not self._get_doc('u1db_data'): + return False + return True + + def _initialize(self): + """ + Create u1db data object in store. + """ + content = { 'transaction_log' = [], + 'sync_log' = [] } + doc = self.create_doc('u1db_data', content) def _get_auth(self): self._url, self._auth_token = self._connection.get_auth() -- cgit v1.2.3 From 2cf00360bce0193d8fa73194a148c28426172043 Mon Sep 17 00:00:00 2001 From: drebs Date: Wed, 5 Dec 2012 17:10:18 -0200 Subject: Methods for OpenStack SyncTarget and typ0. --- src/leap/soledad/openstack.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py index 07ed071d..2c27beb3 100644 --- a/src/leap/soledad/openstack.py +++ b/src/leap/soledad/openstack.py @@ -46,7 +46,7 @@ class OpenStackDatabase(CommonBackend): response, contents = self._connection.get_object(self._container, doc_id) rev = response['x-object-meta-rev'] return self._factory(doc_id, rev, contents) - except: swiftclient.ClientException + except swiftclient.ClientException: return None def get_doc(self, doc_id, include_deleted=False): @@ -229,11 +229,20 @@ class OpenStackDatabase(CommonBackend): class OpenStackSyncTarget(HTTPSyncTarget): def get_sync_info(self, source_replica_uid): - raise NotImplementedError(self.get_sync_info) + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) def record_sync_info(self, source_replica_uid, source_replica_generation, source_replica_transaction_id): - raise NotImplementedError(self.record_sync_info) + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) class SimpleLog(object): -- cgit v1.2.3 From 584696e4dbfc13b793208dc4c5c6cdc224db5a12 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:07:53 -0200 Subject: Remove u1db and swiftclient dirs and refactor. --- src/leap/soledad/README | 4 + src/leap/soledad/__init__.py | 4 +- src/leap/soledad/backends/__init__.py | 0 src/leap/soledad/backends/leap.py | 157 ++ src/leap/soledad/backends/openstack.py | 369 ++++ src/leap/soledad/leap.py | 157 -- src/leap/soledad/openstack.py | 369 ---- src/leap/soledad/swiftclient/__init__.py | 5 - src/leap/soledad/swiftclient/client.py | 1056 ----------- src/leap/soledad/swiftclient/openstack/__init__.py | 0 .../swiftclient/openstack/common/__init__.py | 0 .../soledad/swiftclient/openstack/common/setup.py | 342 ---- src/leap/soledad/swiftclient/versioninfo | 1 - src/leap/soledad/tests/__init__.py | 5 +- src/leap/soledad/u1db/__init__.py | 697 ------- src/leap/soledad/u1db/backends/__init__.py | 211 --- src/leap/soledad/u1db/backends/dbschema.sql | 42 - src/leap/soledad/u1db/backends/inmemory.py | 469 ----- src/leap/soledad/u1db/backends/sqlite_backend.py | 926 ---------- src/leap/soledad/u1db/commandline/__init__.py | 15 - src/leap/soledad/u1db/commandline/client.py | 497 ----- src/leap/soledad/u1db/commandline/command.py | 80 - src/leap/soledad/u1db/commandline/serve.py | 34 - src/leap/soledad/u1db/errors.py | 189 -- src/leap/soledad/u1db/query_parser.py | 370 ---- src/leap/soledad/u1db/remote/__init__.py | 15 - .../soledad/u1db/remote/basic_auth_middleware.py | 68 - src/leap/soledad/u1db/remote/http_app.py | 629 ------- src/leap/soledad/u1db/remote/http_client.py | 218 --- src/leap/soledad/u1db/remote/http_database.py | 143 -- src/leap/soledad/u1db/remote/http_errors.py | 46 - src/leap/soledad/u1db/remote/http_target.py | 135 -- src/leap/soledad/u1db/remote/oauth_middleware.py | 89 - src/leap/soledad/u1db/remote/server_state.py | 67 - src/leap/soledad/u1db/remote/ssl_match_hostname.py | 64 - src/leap/soledad/u1db/remote/utils.py | 23 - src/leap/soledad/u1db/sync.py | 304 ---- src/leap/soledad/u1db/tests/__init__.py | 463 ----- src/leap/soledad/u1db/tests/c_backend_wrapper.pyx | 1541 ---------------- .../soledad/u1db/tests/commandline/__init__.py | 47 - .../soledad/u1db/tests/commandline/test_client.py | 916 ---------- .../soledad/u1db/tests/commandline/test_command.py | 105 -- .../soledad/u1db/tests/commandline/test_serve.py | 101 -- .../soledad/u1db/tests/test_auth_middleware.py | 309 ---- src/leap/soledad/u1db/tests/test_backends.py | 1895 -------------------- src/leap/soledad/u1db/tests/test_c_backend.py | 634 ------- src/leap/soledad/u1db/tests/test_common_backend.py | 33 - src/leap/soledad/u1db/tests/test_document.py | 148 -- src/leap/soledad/u1db/tests/test_errors.py | 61 - src/leap/soledad/u1db/tests/test_http_app.py | 1133 ------------ src/leap/soledad/u1db/tests/test_http_client.py | 361 ---- src/leap/soledad/u1db/tests/test_http_database.py | 256 --- src/leap/soledad/u1db/tests/test_https.py | 117 -- src/leap/soledad/u1db/tests/test_inmemory.py | 128 -- src/leap/soledad/u1db/tests/test_open.py | 69 - src/leap/soledad/u1db/tests/test_query_parser.py | 443 ----- .../soledad/u1db/tests/test_remote_sync_target.py | 314 ---- src/leap/soledad/u1db/tests/test_remote_utils.py | 36 - src/leap/soledad/u1db/tests/test_server_state.py | 93 - src/leap/soledad/u1db/tests/test_sqlite_backend.py | 493 ----- src/leap/soledad/u1db/tests/test_sync.py | 1285 ------------- .../soledad/u1db/tests/test_test_infrastructure.py | 41 - src/leap/soledad/u1db/tests/test_vectorclock.py | 121 -- src/leap/soledad/u1db/tests/testing-certs/Makefile | 35 - .../soledad/u1db/tests/testing-certs/cacert.pem | 58 - .../soledad/u1db/tests/testing-certs/testing.cert | 61 - .../soledad/u1db/tests/testing-certs/testing.key | 16 - src/leap/soledad/u1db/vectorclock.py | 89 - 68 files changed, 535 insertions(+), 18667 deletions(-) create mode 100644 src/leap/soledad/backends/__init__.py create mode 100644 src/leap/soledad/backends/leap.py create mode 100644 src/leap/soledad/backends/openstack.py delete mode 100644 src/leap/soledad/leap.py delete mode 100644 src/leap/soledad/openstack.py delete mode 100644 src/leap/soledad/swiftclient/__init__.py delete mode 100644 src/leap/soledad/swiftclient/client.py delete mode 100644 src/leap/soledad/swiftclient/openstack/__init__.py delete mode 100644 src/leap/soledad/swiftclient/openstack/common/__init__.py delete mode 100644 src/leap/soledad/swiftclient/openstack/common/setup.py delete mode 100644 src/leap/soledad/swiftclient/versioninfo delete mode 100644 src/leap/soledad/u1db/__init__.py delete mode 100644 src/leap/soledad/u1db/backends/__init__.py delete mode 100644 src/leap/soledad/u1db/backends/dbschema.sql delete mode 100644 src/leap/soledad/u1db/backends/inmemory.py delete mode 100644 src/leap/soledad/u1db/backends/sqlite_backend.py delete mode 100644 src/leap/soledad/u1db/commandline/__init__.py delete mode 100644 src/leap/soledad/u1db/commandline/client.py delete mode 100644 src/leap/soledad/u1db/commandline/command.py delete mode 100644 src/leap/soledad/u1db/commandline/serve.py delete mode 100644 src/leap/soledad/u1db/errors.py delete mode 100644 src/leap/soledad/u1db/query_parser.py delete mode 100644 src/leap/soledad/u1db/remote/__init__.py delete mode 100644 src/leap/soledad/u1db/remote/basic_auth_middleware.py delete mode 100644 src/leap/soledad/u1db/remote/http_app.py delete mode 100644 src/leap/soledad/u1db/remote/http_client.py delete mode 100644 src/leap/soledad/u1db/remote/http_database.py delete mode 100644 src/leap/soledad/u1db/remote/http_errors.py delete mode 100644 src/leap/soledad/u1db/remote/http_target.py delete mode 100644 src/leap/soledad/u1db/remote/oauth_middleware.py delete mode 100644 src/leap/soledad/u1db/remote/server_state.py delete mode 100644 src/leap/soledad/u1db/remote/ssl_match_hostname.py delete mode 100644 src/leap/soledad/u1db/remote/utils.py delete mode 100644 src/leap/soledad/u1db/sync.py delete mode 100644 src/leap/soledad/u1db/tests/__init__.py delete mode 100644 src/leap/soledad/u1db/tests/c_backend_wrapper.pyx delete mode 100644 src/leap/soledad/u1db/tests/commandline/__init__.py delete mode 100644 src/leap/soledad/u1db/tests/commandline/test_client.py delete mode 100644 src/leap/soledad/u1db/tests/commandline/test_command.py delete mode 100644 src/leap/soledad/u1db/tests/commandline/test_serve.py delete mode 100644 src/leap/soledad/u1db/tests/test_auth_middleware.py delete mode 100644 src/leap/soledad/u1db/tests/test_backends.py delete mode 100644 src/leap/soledad/u1db/tests/test_c_backend.py delete mode 100644 src/leap/soledad/u1db/tests/test_common_backend.py delete mode 100644 src/leap/soledad/u1db/tests/test_document.py delete mode 100644 src/leap/soledad/u1db/tests/test_errors.py delete mode 100644 src/leap/soledad/u1db/tests/test_http_app.py delete mode 100644 src/leap/soledad/u1db/tests/test_http_client.py delete mode 100644 src/leap/soledad/u1db/tests/test_http_database.py delete mode 100644 src/leap/soledad/u1db/tests/test_https.py delete mode 100644 src/leap/soledad/u1db/tests/test_inmemory.py delete mode 100644 src/leap/soledad/u1db/tests/test_open.py delete mode 100644 src/leap/soledad/u1db/tests/test_query_parser.py delete mode 100644 src/leap/soledad/u1db/tests/test_remote_sync_target.py delete mode 100644 src/leap/soledad/u1db/tests/test_remote_utils.py delete mode 100644 src/leap/soledad/u1db/tests/test_server_state.py delete mode 100644 src/leap/soledad/u1db/tests/test_sqlite_backend.py delete mode 100644 src/leap/soledad/u1db/tests/test_sync.py delete mode 100644 src/leap/soledad/u1db/tests/test_test_infrastructure.py delete mode 100644 src/leap/soledad/u1db/tests/test_vectorclock.py delete mode 100644 src/leap/soledad/u1db/tests/testing-certs/Makefile delete mode 100644 src/leap/soledad/u1db/tests/testing-certs/cacert.pem delete mode 100644 src/leap/soledad/u1db/tests/testing-certs/testing.cert delete mode 100644 src/leap/soledad/u1db/tests/testing-certs/testing.key delete mode 100644 src/leap/soledad/u1db/vectorclock.py (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README index 2ece8145..de524672 100644 --- a/src/leap/soledad/README +++ b/src/leap/soledad/README @@ -1,6 +1,8 @@ Soledad -- Synchronization Of Locally Encrypted Data Among Devices ================================================================== +This software is under development, many parts of the code are still untested. + Dependencies ------------ @@ -13,3 +15,5 @@ Soledad depends on the following python libraries: [1] http://pypi.python.org/pypi/u1db/0.1.4 [2] https://launchpad.net/python-swiftclient [3] http://packages.python.org/python-gnupg/index.html + +Right now, all these libs diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 7991f898..b7082e53 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -2,8 +2,8 @@ """A U1DB implementation that uses OpenStack Swift as its persistence layer.""" -from leap import * -from openstack import * +from backends.leap import * +from backends.openstack import * import gnupg diff --git a/src/leap/soledad/backends/__init__.py b/src/leap/soledad/backends/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py new file mode 100644 index 00000000..2c815632 --- /dev/null +++ b/src/leap/soledad/backends/leap.py @@ -0,0 +1,157 @@ +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import Document +from u1db.remote.http_target import HTTPSyncTarget +from u1db.remote.http_database import HTTPDatabase +import base64 + + +class NoDefaultKey(Exception): + pass + + +class LeapDocument(Document): + """ + LEAP Documents are standard u1db documents with cabability of returning an + encrypted version of the document json string as well as setting document + content based on an encrypted version of json string. + """ + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, + encrypted_json=None, default_key=None, gpg_wrapper=None): + super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) + # we might want to get already initialized wrappers for testing. + if gpg_wrapper is None: + self._gpg = GPGWrapper() + else: + self._gpg = gpg_wrapper + if encrypted_json: + self.set_encrypted_json(encrypted_json) + self._default_key = default_key + + def get_encrypted_json(self): + """ + Returns document's json serialization encrypted with user's public key. + """ + if self._default_key is None: + raise NoDefaultKey() + cyphertext = self._gpg.encrypt(self.get_json(), + self._default_key, + always_trust = True) + # TODO: always trust? + return json.dumps({'cyphertext' : str(cyphertext)}) + + def set_encrypted_json(self, encrypted_json): + """ + Set document's content based on encrypted version of json string. + """ + cyphertext = json.loads(encrypted_json)['cyphertext'] + plaintext = str(self._gpg.decrypt(cyphertext)) + return self.set_json(plaintext) + + +class LeapDatabase(HTTPDatabase): + """Implement the HTTP remote database API to a Leap server.""" + + @staticmethod + def open_database(url, create): + db = LeapDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = LeapDatabase(url) + db._delete() + db.close() + + def get_sync_target(self): + st = LeapSyncTarget(self._url.geturl()) + st._creds = self._creds + return st + + +class LeapSyncTarget(HTTPSyncTarget): + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + """ + Does the same as parent's method but ensures incoming content will be + decrypted. + """ + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = LeapDocument(entry['id'], entry['rev'], + encrypted_json=entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + """ + Does the same as parent's method but encrypts content before syncing. + """ + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, + content=doc.get_encrypted_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py new file mode 100644 index 00000000..ec4609b4 --- /dev/null +++ b/src/leap/soledad/backends/openstack.py @@ -0,0 +1,369 @@ +from leap import * +from u1db import errors +from u1db.backends import CommonBackend +from u1db.remote.http_target import HTTPSyncTarget +from swiftclient import client + + +class OpenStackDatabase(CommonBackend): + """A U1DB implementation that uses OpenStack as its persistence layer.""" + + def __init__(self, auth_url, user, auth_key, container): + """Create a new OpenStack data container.""" + self._auth_url = auth_url + self._user = user + self._auth_key = auth_key + self._container = container + self.set_document_factory(LeapDocument) + self._connection = swiftclient.Connection(self._auth_url, self._user, + self._auth_key) + self._get_auth() + self._ensure_u1db_data() + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def set_document_factory(self, factory): + self._factory = factory + + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + self._get_u1db_data() + # This method is implemented in TransactionLog because testing is + # easier like this for now, but it can be moved to here afterwards. + return self._transaction_log.whats_changed(old_generation) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ + try: + response, contents = self._connection.get_object(self._container, doc_id) + rev = response['x-object-meta-rev'] + return self._factory(doc_id, rev, contents) + except swiftclient.ClientException: + return None + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + _, doc_ids = self._connection.get_container(self._container, + full_listing=True) + for doc_id in doc_ids: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + # TODO: check for conflicts? + new_rev = self._allocate_doc_rev(doc.rev) + headers = { 'X-Object-Meta-Rev' : new_rev } + self._connection.put_object(self._container, doc_id, doc.get_json(), + headers=headers) + new_gen = self._get_generation() + 1 + trans_id = self._allocate_transaction_id() + self._transaction_log.append((new_gen, doc.doc_id, trans_id)) + self._set_u1db_data() + return new_rev + + def delete_doc(self, doc): + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_doc(olddoc) + return new_rev + + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + + def get_doc_conflicts(self, doc_id): + return [] + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + return OpenStackSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + from u1db.sync import Synchronizer + from u1db.remote.http_target import OpenStackSyncTarget + return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + self._get_u1db_data() + return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._get_u1db_data() + self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, + other_generation, + other_transaction_id) + self._set_u1db_data() + + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- + + def _get_generation(self): + self._get_u1db_data() + return self._transaction_log.get_generation() + + def _get_generation_info(self): + self._get_u1db_data() + return self._transaction_log.get_generation_info() + + def _has_conflicts(self, doc_id): + # Documents never have conflicts on server. + return False + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + self._get_u1db_data() + trans_id = self._transaction_log.get_trans_id_for_gen(generation) + if trans_id is None: + raise errors.InvalidGeneration + return trans_id + + #------------------------------------------------------------------------- + # OpenStack specific methods + #------------------------------------------------------------------------- + + def _ensure_u1db_data(self): + """ + Guarantee that u1db data exists in store. + """ + if self._is_initialized(): + return + self._initialize() + + def _is_initialized(self): + """ + Verify if u1db data exists in store. + """ + if not self._get_doc('u1db_data'): + return False + return True + + def _initialize(self): + """ + Create u1db data object in store. + """ + content = { 'transaction_log' : [], + 'sync_log' : [] } + doc = self.create_doc('u1db_data', content) + + def _get_auth(self): + self._url, self._auth_token = self._connection.get_auth() + return self._url, self.auth_token + + def _get_u1db_data(self): + data = self.get_doc('u1db_data').content + self._transaction_log = data['transaction_log'] + self._sync_log = data['sync_log'] + + def _set_u1db_data(self): + doc = self._factory('u1db_data') + doc.content = { 'transaction_log' : self._transaction_log, + 'sync_log' : self._sync_log } + self.put_doc(doc) + + +class OpenStackSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SimpleLog(object): + def __init__(self): + self._log = [] + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return self._log + + log = property( + _get_log, _set_log, doc="Log contents.") + + def append(self, msg): + self._log.append(msg) + + def reduce(self, func, initializer=None): + return reduce(func, self.log, initializer) + + def map(self, func): + return map(func, self.log) + + def filter(self, func): + return filter(func, self.log) + + +class TransactionLog(SimpleLog): + """ + A list of (generation, doc_id, transaction_id) tuples. + """ + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return sorted(self._log, reverse=True) + + log = property( + _get_log, _set_log, doc="Log contents.") + + def get_generation(self): + """ + Return the current generation. + """ + gens = self.map(lambda x: x[0]) + if not gens: + return 0 + return max(gens) + + def get_generation_info(self): + """ + Return the current generation and transaction id. + """ + if not self._log: + return(0, '') + info = self.map(lambda x: (x[0], x[2])) + return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) + + def get_trans_id_for_gen(self, gen): + """ + Get the transaction id corresponding to a particular generation. + """ + log = self.reduce(lambda x, y: y if y[0] == gen else x) + if log is None: + return None + return log[2] + + def whats_changed(self, old_generation): + results = self.filter(lambda x: x[0] > old_generation) + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + results = self.log + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, _, newest_trans_id = results[0] + + return cur_gen, newest_trans_id, changes + + + +class SyncLog(SimpleLog): + """ + A list of (replica_id, generation, transaction_id) tuples. + """ + + def find_by_replica_uid(self, replica_uid): + if not self.log: + return () + return self.reduce(lambda x, y: y if y[0] == replica_uid else x) + + def get_replica_gen_and_trans_id(self, other_replica_uid): + """ + Return the last known generation and transaction id for the other db + replica. + """ + info = self.find_by_replica_uid(other_replica_uid) + if not info: + return (0, '') + return (info[1], info[2]) + + def set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """ + Set the last-known generation and transaction id for the other + database replica. + """ + self.log = self.filter(lambda x: x[0] != other_replica_uid) + self.append((other_replica_uid, other_generation, + other_transaction_id)) + diff --git a/src/leap/soledad/leap.py b/src/leap/soledad/leap.py deleted file mode 100644 index 2c815632..00000000 --- a/src/leap/soledad/leap.py +++ /dev/null @@ -1,157 +0,0 @@ -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db import Document -from u1db.remote.http_target import HTTPSyncTarget -from u1db.remote.http_database import HTTPDatabase -import base64 - - -class NoDefaultKey(Exception): - pass - - -class LeapDocument(Document): - """ - LEAP Documents are standard u1db documents with cabability of returning an - encrypted version of the document json string as well as setting document - content based on an encrypted version of json string. - """ - - def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, - encrypted_json=None, default_key=None, gpg_wrapper=None): - super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) - # we might want to get already initialized wrappers for testing. - if gpg_wrapper is None: - self._gpg = GPGWrapper() - else: - self._gpg = gpg_wrapper - if encrypted_json: - self.set_encrypted_json(encrypted_json) - self._default_key = default_key - - def get_encrypted_json(self): - """ - Returns document's json serialization encrypted with user's public key. - """ - if self._default_key is None: - raise NoDefaultKey() - cyphertext = self._gpg.encrypt(self.get_json(), - self._default_key, - always_trust = True) - # TODO: always trust? - return json.dumps({'cyphertext' : str(cyphertext)}) - - def set_encrypted_json(self, encrypted_json): - """ - Set document's content based on encrypted version of json string. - """ - cyphertext = json.loads(encrypted_json)['cyphertext'] - plaintext = str(self._gpg.decrypt(cyphertext)) - return self.set_json(plaintext) - - -class LeapDatabase(HTTPDatabase): - """Implement the HTTP remote database API to a Leap server.""" - - @staticmethod - def open_database(url, create): - db = LeapDatabase(url) - db.open(create) - return db - - @staticmethod - def delete_database(url): - db = LeapDatabase(url) - db._delete() - db.close() - - def get_sync_target(self): - st = LeapSyncTarget(self._url.geturl()) - st._creds = self._creds - return st - - -class LeapSyncTarget(HTTPSyncTarget): - - def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): - """ - Does the same as parent's method but ensures incoming content will be - decrypted. - """ - parts = data.splitlines() # one at a time - if not parts or parts[0] != '[': - raise BrokenSyncStream - data = parts[1:-1] - comma = False - if data: - line, comma = utils.check_and_strip_comma(data[0]) - res = json.loads(line) - if ensure_callback and 'replica_uid' in res: - ensure_callback(res['replica_uid']) - for entry in data[1:]: - if not comma: # missing in between comma - raise BrokenSyncStream - line, comma = utils.check_and_strip_comma(entry) - entry = json.loads(line) - doc = LeapDocument(entry['id'], entry['rev'], - encrypted_json=entry['content']) - return_doc_cb(doc, entry['gen'], entry['trans_id']) - if parts[-1] != ']': - try: - partdic = json.loads(parts[-1]) - except ValueError: - pass - else: - if isinstance(partdic, dict): - self._error(partdic) - raise BrokenSyncStream - if not data or comma: # no entries or bad extra comma - raise BrokenSyncStream - return res - - def sync_exchange(self, docs_by_generations, source_replica_uid, - last_known_generation, last_known_trans_id, - return_doc_cb, ensure_callback=None): - """ - Does the same as parent's method but encrypts content before syncing. - """ - self._ensure_connection() - if self._trace_hook: # for tests - self._trace_hook('sync_exchange') - url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) - self._conn.putrequest('POST', url) - self._conn.putheader('content-type', 'application/x-u1db-sync-stream') - for header_name, header_value in self._sign_request('POST', url, {}): - self._conn.putheader(header_name, header_value) - entries = ['['] - size = 1 - - def prepare(**dic): - entry = comma + '\r\n' + json.dumps(dic) - entries.append(entry) - return len(entry) - - comma = '' - size += prepare( - last_known_generation=last_known_generation, - last_known_trans_id=last_known_trans_id, - ensure=ensure_callback is not None) - comma = ',' - for doc, gen, trans_id in docs_by_generations: - size += prepare(id=doc.doc_id, rev=doc.rev, - content=doc.get_encrypted_json(), - gen=gen, trans_id=trans_id) - entries.append('\r\n]') - size += len(entries[-1]) - self._conn.putheader('content-length', str(size)) - self._conn.endheaders() - for entry in entries: - self._conn.send(entry) - entries = None - data, _ = self._response() - res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) - data = None - return res['new_generation'], res['new_transaction_id'] diff --git a/src/leap/soledad/openstack.py b/src/leap/soledad/openstack.py deleted file mode 100644 index 2c27beb3..00000000 --- a/src/leap/soledad/openstack.py +++ /dev/null @@ -1,369 +0,0 @@ -from leap import * -from u1db import errors -from u1db.backends import CommonBackend -from u1db.remote.http_target import HTTPSyncTarget -from swiftclient import client - - -class OpenStackDatabase(CommonBackend): - """A U1DB implementation that uses OpenStack as its persistence layer.""" - - def __init__(self, auth_url, user, auth_key, container): - """Create a new OpenStack data container.""" - self._auth_url = auth_url - self._user = user - self._auth_key = auth_key - self._container = container - self.set_document_factory(LeapDocument) - self._connection = swiftclient.Connection(self._auth_url, self._user, - self._auth_key) - self._get_auth() - self._ensure_u1db_data() - - #------------------------------------------------------------------------- - # implemented methods from Database - #------------------------------------------------------------------------- - - def set_document_factory(self, factory): - self._factory = factory - - def set_document_size_limit(self, limit): - raise NotImplementedError(self.set_document_size_limit) - - def whats_changed(self, old_generation=0): - self._get_u1db_data() - # This method is implemented in TransactionLog because testing is - # easier like this for now, but it can be moved to here afterwards. - return self._transaction_log.whats_changed(old_generation) - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling. - - Conflicts do not happen on server side, so there's no need to check - for them. - """ - try: - response, contents = self._connection.get_object(self._container, doc_id) - rev = response['x-object-meta-rev'] - return self._factory(doc_id, rev, contents) - except swiftclient.ClientException: - return None - - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - _, doc_ids = self._connection.get_container(self._container, - full_listing=True) - for doc_id in doc_ids: - doc = self._get_doc(doc_id) - if doc.content is None and not include_deleted: - continue - results.append(doc) - return (generation, results) - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - # TODO: check for conflicts? - new_rev = self._allocate_doc_rev(doc.rev) - headers = { 'X-Object-Meta-Rev' : new_rev } - self._connection.put_object(self._container, doc_id, doc.get_json(), - headers=headers) - new_gen = self._get_generation() + 1 - trans_id = self._allocate_transaction_id() - self._transaction_log.append((new_gen, doc.doc_id, trans_id)) - self._set_u1db_data() - return new_rev - - def delete_doc(self, doc): - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_doc(olddoc) - return new_rev - - # start of index-related methods: these are not supported by this backend. - - def create_index(self, index_name, *index_expressions): - return False - - def delete_index(self, index_name): - return False - - def list_indexes(self): - return [] - - def get_from_index(self, index_name, *key_values): - return [] - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - return [] - - def get_index_keys(self, index_name): - return [] - - # end of index-related methods: these are not supported by this backend. - - def get_doc_conflicts(self, doc_id): - return [] - - def resolve_doc(self, doc, conflicted_doc_revs): - raise NotImplementedError(self.resolve_doc) - - def get_sync_target(self): - return OpenStackSyncTarget(self) - - def close(self): - raise NotImplementedError(self.close) - - def sync(self, url, creds=None, autocreate=True): - from u1db.sync import Synchronizer - from u1db.remote.http_target import OpenStackSyncTarget - return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync( - autocreate=autocreate) - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - self._get_u1db_data() - return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - self._get_u1db_data() - self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, - other_generation, - other_transaction_id) - self._set_u1db_data() - - #------------------------------------------------------------------------- - # implemented methods from CommonBackend - #------------------------------------------------------------------------- - - def _get_generation(self): - self._get_u1db_data() - return self._transaction_log.get_generation() - - def _get_generation_info(self): - self._get_u1db_data() - return self._transaction_log.get_generation_info() - - def _has_conflicts(self, doc_id): - # Documents never have conflicts on server. - return False - - def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): - raise NotImplementedError(self._put_and_update_indexes) - - - def _get_trans_id_for_gen(self, generation): - self._get_u1db_data() - trans_id = self._transaction_log.get_trans_id_for_gen(generation) - if trans_id is None: - raise errors.InvalidGeneration - return trans_id - - #------------------------------------------------------------------------- - # OpenStack specific methods - #------------------------------------------------------------------------- - - def _ensure_u1db_data(self): - """ - Guarantee that u1db data exists in store. - """ - if self._is_initialized(): - return - self._initialize() - - def _is_initialized(self): - """ - Verify if u1db data exists in store. - """ - if not self._get_doc('u1db_data'): - return False - return True - - def _initialize(self): - """ - Create u1db data object in store. - """ - content = { 'transaction_log' = [], - 'sync_log' = [] } - doc = self.create_doc('u1db_data', content) - - def _get_auth(self): - self._url, self._auth_token = self._connection.get_auth() - return self._url, self.auth_token - - def _get_u1db_data(self): - data = self.get_doc('u1db_data').content - self._transaction_log = data['transaction_log'] - self._sync_log = data['sync_log'] - - def _set_u1db_data(self): - doc = self._factory('u1db_data') - doc.content = { 'transaction_log' : self._transaction_log, - 'sync_log' : self._sync_log } - self.put_doc(doc) - - -class OpenStackSyncTarget(HTTPSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - -class SimpleLog(object): - def __init__(self): - self._log = [] - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return self._log - - log = property( - _get_log, _set_log, doc="Log contents.") - - def append(self, msg): - self._log.append(msg) - - def reduce(self, func, initializer=None): - return reduce(func, self.log, initializer) - - def map(self, func): - return map(func, self.log) - - def filter(self, func): - return filter(func, self.log) - - -class TransactionLog(SimpleLog): - """ - A list of (generation, doc_id, transaction_id) tuples. - """ - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return sorted(self._log, reverse=True) - - log = property( - _get_log, _set_log, doc="Log contents.") - - def get_generation(self): - """ - Return the current generation. - """ - gens = self.map(lambda x: x[0]) - if not gens: - return 0 - return max(gens) - - def get_generation_info(self): - """ - Return the current generation and transaction id. - """ - if not self._log: - return(0, '') - info = self.map(lambda x: (x[0], x[2])) - return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) - - def get_trans_id_for_gen(self, gen): - """ - Get the transaction id corresponding to a particular generation. - """ - log = self.reduce(lambda x, y: y if y[0] == gen else x) - if log is None: - return None - return log[2] - - def whats_changed(self, old_generation): - results = self.filter(lambda x: x[0] > old_generation) - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - results = self.log - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, _, newest_trans_id = results[0] - - return cur_gen, newest_trans_id, changes - - - -class SyncLog(SimpleLog): - """ - A list of (replica_id, generation, transaction_id) tuples. - """ - - def find_by_replica_uid(self, replica_uid): - if not self.log: - return () - return self.reduce(lambda x, y: y if y[0] == replica_uid else x) - - def get_replica_gen_and_trans_id(self, other_replica_uid): - """ - Return the last known generation and transaction id for the other db - replica. - """ - info = self.find_by_replica_uid(other_replica_uid) - if not info: - return (0, '') - return (info[1], info[2]) - - def set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - """ - Set the last-known generation and transaction id for the other - database replica. - """ - self.log = self.filter(lambda x: x[0] != other_replica_uid) - self.append((other_replica_uid, other_generation, - other_transaction_id)) - diff --git a/src/leap/soledad/swiftclient/__init__.py b/src/leap/soledad/swiftclient/__init__.py deleted file mode 100644 index ba0b41a3..00000000 --- a/src/leap/soledad/swiftclient/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# -*- encoding: utf-8 -*- -"""" -OpenStack Swift Python client binding. -""" -from client import * diff --git a/src/leap/soledad/swiftclient/client.py b/src/leap/soledad/swiftclient/client.py deleted file mode 100644 index 79e6594f..00000000 --- a/src/leap/soledad/swiftclient/client.py +++ /dev/null @@ -1,1056 +0,0 @@ -# Copyright (c) 2010-2012 OpenStack, LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Cloud Files client library used internally -""" - -import socket -import os -import logging -import httplib - -from urllib import quote as _quote -from urlparse import urlparse, urlunparse, urljoin - -try: - from eventlet.green.httplib import HTTPException, HTTPSConnection -except ImportError: - from httplib import HTTPException, HTTPSConnection - -try: - from eventlet import sleep -except ImportError: - from time import sleep - -try: - from swift.common.bufferedhttp \ - import BufferedHTTPConnection as HTTPConnection -except ImportError: - try: - from eventlet.green.httplib import HTTPConnection - except ImportError: - from httplib import HTTPConnection - -logger = logging.getLogger("swiftclient") - - -def http_log(args, kwargs, resp, body): - if os.environ.get('SWIFTCLIENT_DEBUG', False): - ch = logging.StreamHandler() - logger.setLevel(logging.DEBUG) - logger.addHandler(ch) - elif not logger.isEnabledFor(logging.DEBUG): - return - - string_parts = ['curl -i'] - for element in args: - if element in ('GET', 'POST', 'PUT', 'HEAD'): - string_parts.append(' -X %s' % element) - else: - string_parts.append(' %s' % element) - - if 'headers' in kwargs: - for element in kwargs['headers']: - header = ' -H "%s: %s"' % (element, kwargs['headers'][element]) - string_parts.append(header) - - logger.debug("REQ: %s\n" % "".join(string_parts)) - if 'raw_body' in kwargs: - logger.debug("REQ BODY (RAW): %s\n" % (kwargs['raw_body'])) - if 'body' in kwargs: - logger.debug("REQ BODY: %s\n" % (kwargs['body'])) - - logger.debug("RESP STATUS: %s\n", resp.status) - if body: - logger.debug("RESP BODY: %s\n", body) - - -def quote(value, safe='/'): - """ - Patched version of urllib.quote that encodes utf8 strings before quoting - """ - if isinstance(value, unicode): - value = value.encode('utf8') - return _quote(value, safe) - - -# look for a real json parser first -try: - # simplejson is popular and pretty good - from simplejson import loads as json_loads - from simplejson import dumps as json_dumps -except ImportError: - # 2.6 will have a json module in the stdlib - from json import loads as json_loads - from json import dumps as json_dumps - - -class ClientException(Exception): - - def __init__(self, msg, http_scheme='', http_host='', http_port='', - http_path='', http_query='', http_status=0, http_reason='', - http_device='', http_response_content=''): - Exception.__init__(self, msg) - self.msg = msg - self.http_scheme = http_scheme - self.http_host = http_host - self.http_port = http_port - self.http_path = http_path - self.http_query = http_query - self.http_status = http_status - self.http_reason = http_reason - self.http_device = http_device - self.http_response_content = http_response_content - - def __str__(self): - a = self.msg - b = '' - if self.http_scheme: - b += '%s://' % self.http_scheme - if self.http_host: - b += self.http_host - if self.http_port: - b += ':%s' % self.http_port - if self.http_path: - b += self.http_path - if self.http_query: - b += '?%s' % self.http_query - if self.http_status: - if b: - b = '%s %s' % (b, self.http_status) - else: - b = str(self.http_status) - if self.http_reason: - if b: - b = '%s %s' % (b, self.http_reason) - else: - b = '- %s' % self.http_reason - if self.http_device: - if b: - b = '%s: device %s' % (b, self.http_device) - else: - b = 'device %s' % self.http_device - if self.http_response_content: - if len(self.http_response_content) <= 60: - b += ' %s' % self.http_response_content - else: - b += ' [first 60 chars of response] %s' \ - % self.http_response_content[:60] - return b and '%s: %s' % (a, b) or a - - -def http_connection(url, proxy=None): - """ - Make an HTTPConnection or HTTPSConnection - - :param url: url to connect to - :param proxy: proxy to connect through, if any; None by default; str of the - format 'http://127.0.0.1:8888' to set one - :returns: tuple of (parsed url, connection object) - :raises ClientException: Unable to handle protocol scheme - """ - parsed = urlparse(url) - proxy_parsed = urlparse(proxy) if proxy else None - if parsed.scheme == 'http': - conn = HTTPConnection((proxy_parsed if proxy else parsed).netloc) - elif parsed.scheme == 'https': - conn = HTTPSConnection((proxy_parsed if proxy else parsed).netloc) - else: - raise ClientException('Cannot handle protocol scheme %s for url %s' % - (parsed.scheme, repr(url))) - if proxy: - conn._set_tunnel(parsed.hostname, parsed.port) - return parsed, conn - - -def json_request(method, url, **kwargs): - """Takes a request in json parse it and return in json""" - kwargs.setdefault('headers', {}) - if 'body' in kwargs: - kwargs['headers']['Content-Type'] = 'application/json' - kwargs['body'] = json_dumps(kwargs['body']) - parsed, conn = http_connection(url) - conn.request(method, parsed.path, **kwargs) - resp = conn.getresponse() - body = resp.read() - http_log((url, method,), kwargs, resp, body) - if body: - try: - body = json_loads(body) - except ValueError: - body = None - if not body or resp.status < 200 or resp.status >= 300: - raise ClientException('Auth GET failed', http_scheme=parsed.scheme, - http_host=conn.host, - http_port=conn.port, - http_path=parsed.path, - http_status=resp.status, - http_reason=resp.reason) - return resp, body - - -def _get_auth_v1_0(url, user, key, snet): - parsed, conn = http_connection(url) - method = 'GET' - conn.request(method, parsed.path, '', - {'X-Auth-User': user, 'X-Auth-Key': key}) - resp = conn.getresponse() - body = resp.read() - url = resp.getheader('x-storage-url') - http_log((url, method,), {}, resp, body) - - # There is a side-effect on current Rackspace 1.0 server where a - # bad URL would get you that document page and a 200. We error out - # if we don't have a x-storage-url header and if we get a body. - if resp.status < 200 or resp.status >= 300 or (body and not url): - raise ClientException('Auth GET failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=parsed.path, http_status=resp.status, - http_reason=resp.reason) - if snet: - parsed = list(urlparse(url)) - # Second item in the list is the netloc - netloc = parsed[1] - parsed[1] = 'snet-' + netloc - url = urlunparse(parsed) - return url, resp.getheader('x-storage-token', - resp.getheader('x-auth-token')) - - -def _get_auth_v2_0(url, user, tenant_name, key, snet): - body = {'auth': - {'passwordCredentials': {'password': key, 'username': user}, - 'tenantName': tenant_name}} - token_url = urljoin(url, "tokens") - resp, body = json_request("POST", token_url, body=body) - token_id = None - try: - url = None - catalogs = body['access']['serviceCatalog'] - for service in catalogs: - if service['type'] == 'object-store': - url = service['endpoints'][0]['publicURL'] - token_id = body['access']['token']['id'] - if not url: - raise ClientException("There is no object-store endpoint " - "on this auth server.") - except(KeyError, IndexError): - raise ClientException("Error while getting answers from auth server") - - if snet: - parsed = list(urlparse(url)) - # Second item in the list is the netloc - parsed[1] = 'snet-' + parsed[1] - url = urlunparse(parsed) - - return url, token_id - - -def get_auth(url, user, key, snet=False, tenant_name=None, auth_version="1.0"): - """ - Get authentication/authorization credentials. - - The snet parameter is used for Rackspace's ServiceNet internal network - implementation. In this function, it simply adds *snet-* to the beginning - of the host name for the returned storage URL. With Rackspace Cloud Files, - use of this network path causes no bandwidth charges but requires the - client to be running on Rackspace's ServiceNet network. - - :param url: authentication/authorization URL - :param user: user to authenticate as - :param key: key or password for authorization - :param snet: use SERVICENET internal network (see above), default is False - :param auth_version: OpenStack auth version, default is 1.0 - :param tenant_name: The tenant/account name, required when connecting - to a auth 2.0 system. - :returns: tuple of (storage URL, auth token) - :raises: ClientException: HTTP GET request to auth URL failed - """ - if auth_version in ["1.0", "1"]: - return _get_auth_v1_0(url, user, key, snet) - elif auth_version in ["2.0", "2"]: - if not tenant_name and ':' in user: - (tenant_name, user) = user.split(':') - if not tenant_name: - raise ClientException('No tenant specified') - return _get_auth_v2_0(url, user, tenant_name, key, snet) - else: - raise ClientException('Unknown auth_version %s specified.' - % auth_version) - - -def get_account(url, token, marker=None, limit=None, prefix=None, - http_conn=None, full_listing=False): - """ - Get a listing of containers for the account. - - :param url: storage URL - :param token: auth token - :param marker: marker query - :param limit: limit query - :param prefix: prefix query - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :param full_listing: if True, return a full listing, else returns a max - of 10000 listings - :returns: a tuple of (response headers, a list of containers) The response - headers will be a dict and all header names will be lowercase. - :raises ClientException: HTTP GET request failed - """ - if not http_conn: - http_conn = http_connection(url) - if full_listing: - rv = get_account(url, token, marker, limit, prefix, http_conn) - listing = rv[1] - while listing: - marker = listing[-1]['name'] - listing = \ - get_account(url, token, marker, limit, prefix, http_conn)[1] - if listing: - rv[1].extend(listing) - return rv - parsed, conn = http_conn - qs = 'format=json' - if marker: - qs += '&marker=%s' % quote(marker) - if limit: - qs += '&limit=%d' % limit - if prefix: - qs += '&prefix=%s' % quote(prefix) - full_path = '%s?%s' % (parsed.path, qs) - headers = {'X-Auth-Token': token} - conn.request('GET', full_path, '', - headers) - resp = conn.getresponse() - body = resp.read() - http_log(("%s?%s" % (url, qs), 'GET',), {'headers': headers}, resp, body) - - resp_headers = {} - for header, value in resp.getheaders(): - resp_headers[header.lower()] = value - if resp.status < 200 or resp.status >= 300: - raise ClientException('Account GET failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=parsed.path, http_query=qs, - http_status=resp.status, http_reason=resp.reason, - http_response_content=body) - if resp.status == 204: - body - return resp_headers, [] - return resp_headers, json_loads(body) - - -def head_account(url, token, http_conn=None): - """ - Get account stats. - - :param url: storage URL - :param token: auth token - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :returns: a dict containing the response's headers (all header names will - be lowercase) - :raises ClientException: HTTP HEAD request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - method = "HEAD" - headers = {'X-Auth-Token': token} - conn.request(method, parsed.path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log((url, method,), {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Account HEAD failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=parsed.path, http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - resp_headers = {} - for header, value in resp.getheaders(): - resp_headers[header.lower()] = value - return resp_headers - - -def post_account(url, token, headers, http_conn=None): - """ - Update an account's metadata. - - :param url: storage URL - :param token: auth token - :param headers: additional headers to include in the request - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :raises ClientException: HTTP POST request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - method = 'POST' - headers['X-Auth-Token'] = token - conn.request(method, parsed.path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log((url, method,), {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Account POST failed', - http_scheme=parsed.scheme, - http_host=conn.host, - http_port=conn.port, - http_path=parsed.path, - http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - - -def get_container(url, token, container, marker=None, limit=None, - prefix=None, delimiter=None, http_conn=None, - full_listing=False): - """ - Get a listing of objects for the container. - - :param url: storage URL - :param token: auth token - :param container: container name to get a listing for - :param marker: marker query - :param limit: limit query - :param prefix: prefix query - :param delimeter: string to delimit the queries on - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :param full_listing: if True, return a full listing, else returns a max - of 10000 listings - :returns: a tuple of (response headers, a list of objects) The response - headers will be a dict and all header names will be lowercase. - :raises ClientException: HTTP GET request failed - """ - if not http_conn: - http_conn = http_connection(url) - if full_listing: - rv = get_container(url, token, container, marker, limit, prefix, - delimiter, http_conn) - listing = rv[1] - while listing: - if not delimiter: - marker = listing[-1]['name'] - else: - marker = listing[-1].get('name', listing[-1].get('subdir')) - listing = get_container(url, token, container, marker, limit, - prefix, delimiter, http_conn)[1] - if listing: - rv[1].extend(listing) - return rv - parsed, conn = http_conn - path = '%s/%s' % (parsed.path, quote(container)) - qs = 'format=json' - if marker: - qs += '&marker=%s' % quote(marker) - if limit: - qs += '&limit=%d' % limit - if prefix: - qs += '&prefix=%s' % quote(prefix) - if delimiter: - qs += '&delimiter=%s' % quote(delimiter) - headers = {'X-Auth-Token': token} - method = 'GET' - conn.request(method, '%s?%s' % (path, qs), '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, qs), method,), {'headers': headers}, resp, body) - - if resp.status < 200 or resp.status >= 300: - raise ClientException('Container GET failed', - http_scheme=parsed.scheme, http_host=conn.host, - http_port=conn.port, http_path=path, - http_query=qs, http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - resp_headers = {} - for header, value in resp.getheaders(): - resp_headers[header.lower()] = value - if resp.status == 204: - return resp_headers, [] - return resp_headers, json_loads(body) - - -def head_container(url, token, container, http_conn=None, headers=None): - """ - Get container stats. - - :param url: storage URL - :param token: auth token - :param container: container name to get stats for - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :returns: a dict containing the response's headers (all header names will - be lowercase) - :raises ClientException: HTTP HEAD request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s' % (parsed.path, quote(container)) - method = 'HEAD' - req_headers = {'X-Auth-Token': token} - if headers: - req_headers.update(headers) - conn.request(method, path, '', req_headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), method,), - {'headers': req_headers}, resp, body) - - if resp.status < 200 or resp.status >= 300: - raise ClientException('Container HEAD failed', - http_scheme=parsed.scheme, http_host=conn.host, - http_port=conn.port, http_path=path, - http_status=resp.status, http_reason=resp.reason, - http_response_content=body) - resp_headers = {} - for header, value in resp.getheaders(): - resp_headers[header.lower()] = value - return resp_headers - - -def put_container(url, token, container, headers=None, http_conn=None): - """ - Create a container - - :param url: storage URL - :param token: auth token - :param container: container name to create - :param headers: additional headers to include in the request - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :raises ClientException: HTTP PUT request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s' % (parsed.path, quote(container)) - method = 'PUT' - if not headers: - headers = {} - headers['X-Auth-Token'] = token - conn.request(method, path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), method,), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Container PUT failed', - http_scheme=parsed.scheme, http_host=conn.host, - http_port=conn.port, http_path=path, - http_status=resp.status, http_reason=resp.reason, - http_response_content=body) - - -def post_container(url, token, container, headers, http_conn=None): - """ - Update a container's metadata. - - :param url: storage URL - :param token: auth token - :param container: container name to update - :param headers: additional headers to include in the request - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :raises ClientException: HTTP POST request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s' % (parsed.path, quote(container)) - method = 'POST' - headers['X-Auth-Token'] = token - conn.request(method, path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), method,), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Container POST failed', - http_scheme=parsed.scheme, http_host=conn.host, - http_port=conn.port, http_path=path, - http_status=resp.status, http_reason=resp.reason, - http_response_content=body) - - -def delete_container(url, token, container, http_conn=None): - """ - Delete a container - - :param url: storage URL - :param token: auth token - :param container: container name to delete - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :raises ClientException: HTTP DELETE request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s' % (parsed.path, quote(container)) - headers = {'X-Auth-Token': token} - method = 'DELETE' - conn.request(method, path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), method,), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Container DELETE failed', - http_scheme=parsed.scheme, http_host=conn.host, - http_port=conn.port, http_path=path, - http_status=resp.status, http_reason=resp.reason, - http_response_content=body) - - -def get_object(url, token, container, name, http_conn=None, - resp_chunk_size=None): - """ - Get an object - - :param url: storage URL - :param token: auth token - :param container: container name that the object is in - :param name: object name to get - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :param resp_chunk_size: if defined, chunk size of data to read. NOTE: If - you specify a resp_chunk_size you must fully read - the object's contents before making another - request. - :returns: a tuple of (response headers, the object's contents) The response - headers will be a dict and all header names will be lowercase. - :raises ClientException: HTTP GET request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) - method = 'GET' - headers = {'X-Auth-Token': token} - conn.request(method, path, '', headers) - resp = conn.getresponse() - if resp.status < 200 or resp.status >= 300: - body = resp.read() - http_log(('%s?%s' % (url, path), 'POST',), - {'headers': headers}, resp, body) - raise ClientException('Object GET failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=path, http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - if resp_chunk_size: - - def _object_body(): - buf = resp.read(resp_chunk_size) - while buf: - yield buf - buf = resp.read(resp_chunk_size) - object_body = _object_body() - else: - object_body = resp.read() - resp_headers = {} - for header, value in resp.getheaders(): - resp_headers[header.lower()] = value - http_log(('%s?%s' % (url, path), 'POST',), - {'headers': headers}, resp, object_body) - return resp_headers, object_body - - -def head_object(url, token, container, name, http_conn=None): - """ - Get object info - - :param url: storage URL - :param token: auth token - :param container: container name that the object is in - :param name: object name to get info for - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :returns: a dict containing the response's headers (all header names will - be lowercase) - :raises ClientException: HTTP HEAD request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) - method = 'HEAD' - headers = {'X-Auth-Token': token} - conn.request(method, path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), 'POST',), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Object HEAD failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=path, http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - resp_headers = {} - for header, value in resp.getheaders(): - resp_headers[header.lower()] = value - return resp_headers - - -def put_object(url, token=None, container=None, name=None, contents=None, - content_length=None, etag=None, chunk_size=65536, - content_type=None, headers=None, http_conn=None, proxy=None): - """ - Put an object - - :param url: storage URL - :param token: auth token; if None, no token will be sent - :param container: container name that the object is in; if None, the - container name is expected to be part of the url - :param name: object name to put; if None, the object name is expected to be - part of the url - :param contents: a string or a file like object to read object data from; - if None, a zero-byte put will be done - :param content_length: value to send as content-length header; also limits - the amount read from contents; if None, it will be - computed via the contents or chunked transfer - encoding will be used - :param etag: etag of contents; if None, no etag will be sent - :param chunk_size: chunk size of data to write; default 65536 - :param content_type: value to send as content-type header; if None, no - content-type will be set (remote end will likely try - to auto-detect it) - :param headers: additional headers to include in the request, if any - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :param proxy: proxy to connect through, if any; None by default; str of the - format 'http://127.0.0.1:8888' to set one - :returns: etag from server response - :raises ClientException: HTTP PUT request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url, proxy=proxy) - path = parsed.path - if container: - path = '%s/%s' % (path.rstrip('/'), quote(container)) - if name: - path = '%s/%s' % (path.rstrip('/'), quote(name)) - if headers: - headers = dict(headers) - else: - headers = {} - if token: - headers['X-Auth-Token'] = token - if etag: - headers['ETag'] = etag.strip('"') - if content_length is not None: - headers['Content-Length'] = str(content_length) - else: - for n, v in headers.iteritems(): - if n.lower() == 'content-length': - content_length = int(v) - if content_type is not None: - headers['Content-Type'] = content_type - if not contents: - headers['Content-Length'] = '0' - if hasattr(contents, 'read'): - conn.putrequest('PUT', path) - for header, value in headers.iteritems(): - conn.putheader(header, value) - if content_length is None: - conn.putheader('Transfer-Encoding', 'chunked') - conn.endheaders() - chunk = contents.read(chunk_size) - while chunk: - conn.send('%x\r\n%s\r\n' % (len(chunk), chunk)) - chunk = contents.read(chunk_size) - conn.send('0\r\n\r\n') - else: - conn.endheaders() - left = content_length - while left > 0: - size = chunk_size - if size > left: - size = left - chunk = contents.read(size) - conn.send(chunk) - left -= len(chunk) - else: - conn.request('PUT', path, contents, headers) - resp = conn.getresponse() - body = resp.read() - headers = {'X-Auth-Token': token} - http_log(('%s?%s' % (url, path), 'PUT',), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Object PUT failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=path, http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - return resp.getheader('etag', '').strip('"') - - -def post_object(url, token, container, name, headers, http_conn=None): - """ - Update object metadata - - :param url: storage URL - :param token: auth token - :param container: container name that the object is in - :param name: name of the object to update - :param headers: additional headers to include in the request - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :raises ClientException: HTTP POST request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url) - path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) - headers['X-Auth-Token'] = token - conn.request('POST', path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), 'POST',), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Object POST failed', http_scheme=parsed.scheme, - http_host=conn.host, http_port=conn.port, - http_path=path, http_status=resp.status, - http_reason=resp.reason, - http_response_content=body) - - -def delete_object(url, token=None, container=None, name=None, http_conn=None, - headers=None, proxy=None): - """ - Delete object - - :param url: storage URL - :param token: auth token; if None, no token will be sent - :param container: container name that the object is in; if None, the - container name is expected to be part of the url - :param name: object name to delete; if None, the object name is expected to - be part of the url - :param http_conn: HTTP connection object (If None, it will create the - conn object) - :param headers: additional headers to include in the request - :param proxy: proxy to connect through, if any; None by default; str of the - format 'http://127.0.0.1:8888' to set one - :raises ClientException: HTTP DELETE request failed - """ - if http_conn: - parsed, conn = http_conn - else: - parsed, conn = http_connection(url, proxy=proxy) - path = parsed.path - if container: - path = '%s/%s' % (path.rstrip('/'), quote(container)) - if name: - path = '%s/%s' % (path.rstrip('/'), quote(name)) - if headers: - headers = dict(headers) - else: - headers = {} - if token: - headers['X-Auth-Token'] = token - conn.request('DELETE', path, '', headers) - resp = conn.getresponse() - body = resp.read() - http_log(('%s?%s' % (url, path), 'POST',), - {'headers': headers}, resp, body) - if resp.status < 200 or resp.status >= 300: - raise ClientException('Object DELETE failed', - http_scheme=parsed.scheme, http_host=conn.host, - http_port=conn.port, http_path=path, - http_status=resp.status, http_reason=resp.reason, - http_response_content=body) - - -class Connection(object): - """Convenience class to make requests that will also retry the request""" - - def __init__(self, authurl, user, key, retries=5, preauthurl=None, - preauthtoken=None, snet=False, starting_backoff=1, - tenant_name=None, - auth_version="1"): - """ - :param authurl: authentication URL - :param user: user name to authenticate as - :param key: key/password to authenticate with - :param retries: Number of times to retry the request before failing - :param preauthurl: storage URL (if you have already authenticated) - :param preauthtoken: authentication token (if you have already - authenticated) - :param snet: use SERVICENET internal network default is False - :param auth_version: OpenStack auth version, default is 1.0 - :param tenant_name: The tenant/account name, required when connecting - to a auth 2.0 system. - """ - self.authurl = authurl - self.user = user - self.key = key - self.retries = retries - self.http_conn = None - self.url = preauthurl - self.token = preauthtoken - self.attempts = 0 - self.snet = snet - self.starting_backoff = starting_backoff - self.auth_version = auth_version - self.tenant_name = tenant_name - - def get_auth(self): - return get_auth(self.authurl, self.user, - self.key, snet=self.snet, - tenant_name=self.tenant_name, - auth_version=self.auth_version) - - def http_connection(self): - return http_connection(self.url) - - def _retry(self, reset_func, func, *args, **kwargs): - self.attempts = 0 - backoff = self.starting_backoff - while self.attempts <= self.retries: - self.attempts += 1 - try: - if not self.url or not self.token: - self.url, self.token = self.get_auth() - self.http_conn = None - if not self.http_conn: - self.http_conn = self.http_connection() - kwargs['http_conn'] = self.http_conn - rv = func(self.url, self.token, *args, **kwargs) - return rv - except (socket.error, HTTPException): - if self.attempts > self.retries: - raise - self.http_conn = None - except ClientException, err: - if self.attempts > self.retries: - raise - if err.http_status == 401: - self.url = self.token = None - if self.attempts > 1: - raise - elif err.http_status == 408: - self.http_conn = None - elif 500 <= err.http_status <= 599: - pass - else: - raise - sleep(backoff) - backoff *= 2 - if reset_func: - reset_func(func, *args, **kwargs) - - def head_account(self): - """Wrapper for :func:`head_account`""" - return self._retry(None, head_account) - - def get_account(self, marker=None, limit=None, prefix=None, - full_listing=False): - """Wrapper for :func:`get_account`""" - # TODO(unknown): With full_listing=True this will restart the entire - # listing with each retry. Need to make a better version that just - # retries where it left off. - return self._retry(None, get_account, marker=marker, limit=limit, - prefix=prefix, full_listing=full_listing) - - def post_account(self, headers): - """Wrapper for :func:`post_account`""" - return self._retry(None, post_account, headers) - - def head_container(self, container): - """Wrapper for :func:`head_container`""" - return self._retry(None, head_container, container) - - def get_container(self, container, marker=None, limit=None, prefix=None, - delimiter=None, full_listing=False): - """Wrapper for :func:`get_container`""" - # TODO(unknown): With full_listing=True this will restart the entire - # listing with each retry. Need to make a better version that just - # retries where it left off. - return self._retry(None, get_container, container, marker=marker, - limit=limit, prefix=prefix, delimiter=delimiter, - full_listing=full_listing) - - def put_container(self, container, headers=None): - """Wrapper for :func:`put_container`""" - return self._retry(None, put_container, container, headers=headers) - - def post_container(self, container, headers): - """Wrapper for :func:`post_container`""" - return self._retry(None, post_container, container, headers) - - def delete_container(self, container): - """Wrapper for :func:`delete_container`""" - return self._retry(None, delete_container, container) - - def head_object(self, container, obj): - """Wrapper for :func:`head_object`""" - return self._retry(None, head_object, container, obj) - - def get_object(self, container, obj, resp_chunk_size=None): - """Wrapper for :func:`get_object`""" - return self._retry(None, get_object, container, obj, - resp_chunk_size=resp_chunk_size) - - def put_object(self, container, obj, contents, content_length=None, - etag=None, chunk_size=65536, content_type=None, - headers=None): - """Wrapper for :func:`put_object`""" - - def _default_reset(*args, **kwargs): - raise ClientException('put_object(%r, %r, ...) failure and no ' - 'ability to reset contents for reupload.' - % (container, obj)) - - reset_func = _default_reset - tell = getattr(contents, 'tell', None) - seek = getattr(contents, 'seek', None) - if tell and seek: - orig_pos = tell() - reset_func = lambda *a, **k: seek(orig_pos) - elif not contents: - reset_func = lambda *a, **k: None - - return self._retry(reset_func, put_object, container, obj, contents, - content_length=content_length, etag=etag, - chunk_size=chunk_size, content_type=content_type, - headers=headers) - - def post_object(self, container, obj, headers): - """Wrapper for :func:`post_object`""" - return self._retry(None, post_object, container, obj, headers) - - def delete_object(self, container, obj): - """Wrapper for :func:`delete_object`""" - return self._retry(None, delete_object, container, obj) diff --git a/src/leap/soledad/swiftclient/openstack/__init__.py b/src/leap/soledad/swiftclient/openstack/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/leap/soledad/swiftclient/openstack/common/__init__.py b/src/leap/soledad/swiftclient/openstack/common/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/leap/soledad/swiftclient/openstack/common/setup.py b/src/leap/soledad/swiftclient/openstack/common/setup.py deleted file mode 100644 index caf06fa5..00000000 --- a/src/leap/soledad/swiftclient/openstack/common/setup.py +++ /dev/null @@ -1,342 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2011 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Utilities with minimum-depends for use in setup.py -""" - -import datetime -import os -import re -import subprocess -import sys - -from setuptools.command import sdist - - -def parse_mailmap(mailmap='.mailmap'): - mapping = {} - if os.path.exists(mailmap): - fp = open(mailmap, 'r') - for l in fp: - l = l.strip() - if not l.startswith('#') and ' ' in l: - canonical_email, alias = l.split(' ') - mapping[alias] = canonical_email - return mapping - - -def canonicalize_emails(changelog, mapping): - """Takes in a string and an email alias mapping and replaces all - instances of the aliases in the string with their real email. - """ - for alias, email in mapping.iteritems(): - changelog = changelog.replace(alias, email) - return changelog - - -# Get requirements from the first file that exists -def get_reqs_from_files(requirements_files): - reqs_in = [] - for requirements_file in requirements_files: - if os.path.exists(requirements_file): - return open(requirements_file, 'r').read().split('\n') - return [] - - -def parse_requirements(requirements_files=['requirements.txt', - 'tools/pip-requires']): - requirements = [] - for line in get_reqs_from_files(requirements_files): - # For the requirements list, we need to inject only the portion - # after egg= so that distutils knows the package it's looking for - # such as: - # -e git://github.com/openstack/nova/master#egg=nova - if re.match(r'\s*-e\s+', line): - requirements.append(re.sub(r'\s*-e\s+.*#egg=(.*)$', r'\1', - line)) - # such as: - # http://github.com/openstack/nova/zipball/master#egg=nova - elif re.match(r'\s*https?:', line): - requirements.append(re.sub(r'\s*https?:.*#egg=(.*)$', r'\1', - line)) - # -f lines are for index locations, and don't get used here - elif re.match(r'\s*-f\s+', line): - pass - # argparse is part of the standard library starting with 2.7 - # adding it to the requirements list screws distro installs - elif line == 'argparse' and sys.version_info >= (2, 7): - pass - else: - requirements.append(line) - - return requirements - - -def parse_dependency_links(requirements_files=['requirements.txt', - 'tools/pip-requires']): - dependency_links = [] - # dependency_links inject alternate locations to find packages listed - # in requirements - for line in get_reqs_from_files(requirements_files): - # skip comments and blank lines - if re.match(r'(\s*#)|(\s*$)', line): - continue - # lines with -e or -f need the whole line, minus the flag - if re.match(r'\s*-[ef]\s+', line): - dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line)) - # lines that are only urls can go in unmolested - elif re.match(r'\s*https?:', line): - dependency_links.append(line) - return dependency_links - - -def write_requirements(): - venv = os.environ.get('VIRTUAL_ENV', None) - if venv is not None: - with open("requirements.txt", "w") as req_file: - output = subprocess.Popen(["pip", "-E", venv, "freeze", "-l"], - stdout=subprocess.PIPE) - requirements = output.communicate()[0].strip() - req_file.write(requirements) - - -def _run_shell_command(cmd): - output = subprocess.Popen(["/bin/sh", "-c", cmd], - stdout=subprocess.PIPE) - out = output.communicate() - if len(out) == 0: - return None - if len(out[0].strip()) == 0: - return None - return out[0].strip() - - -def _get_git_next_version_suffix(branch_name): - datestamp = datetime.datetime.now().strftime('%Y%m%d') - if branch_name == 'milestone-proposed': - revno_prefix = "r" - else: - revno_prefix = "" - _run_shell_command("git fetch origin +refs/meta/*:refs/remotes/meta/*") - milestone_cmd = "git show meta/openstack/release:%s" % branch_name - milestonever = _run_shell_command(milestone_cmd) - if not milestonever: - milestonever = "" - post_version = _get_git_post_version() - revno = post_version.split(".")[-1] - return "%s~%s.%s%s" % (milestonever, datestamp, revno_prefix, revno) - - -def _get_git_current_tag(): - return _run_shell_command("git tag --contains HEAD") - - -def _get_git_tag_info(): - return _run_shell_command("git describe --tags") - - -def _get_git_post_version(): - current_tag = _get_git_current_tag() - if current_tag is not None: - return current_tag - else: - tag_info = _get_git_tag_info() - if tag_info is None: - base_version = "0.0" - cmd = "git --no-pager log --oneline" - out = _run_shell_command(cmd) - revno = len(out.split("\n")) - else: - tag_infos = tag_info.split("-") - base_version = "-".join(tag_infos[:-2]) - revno = tag_infos[-2] - return "%s.%s" % (base_version, revno) - - -def write_git_changelog(): - """Write a changelog based on the git changelog.""" - if os.path.isdir('.git'): - git_log_cmd = 'git log --stat' - changelog = _run_shell_command(git_log_cmd) - mailmap = parse_mailmap() - with open("ChangeLog", "w") as changelog_file: - changelog_file.write(canonicalize_emails(changelog, mailmap)) - - -def generate_authors(): - """Create AUTHORS file using git commits.""" - jenkins_email = 'jenkins@review.openstack.org' - old_authors = 'AUTHORS.in' - new_authors = 'AUTHORS' - if os.path.isdir('.git'): - # don't include jenkins email address in AUTHORS file - git_log_cmd = ("git log --format='%aN <%aE>' | sort -u | " - "grep -v " + jenkins_email) - changelog = _run_shell_command(git_log_cmd) - mailmap = parse_mailmap() - with open(new_authors, 'w') as new_authors_fh: - new_authors_fh.write(canonicalize_emails(changelog, mailmap)) - if os.path.exists(old_authors): - with open(old_authors, "r") as old_authors_fh: - new_authors_fh.write('\n' + old_authors_fh.read()) - -_rst_template = """%(heading)s -%(underline)s - -.. automodule:: %(module)s - :members: - :undoc-members: - :show-inheritance: -""" - - -def read_versioninfo(project): - """Read the versioninfo file. If it doesn't exist, we're in a github - zipball, and there's really know way to know what version we really - are, but that should be ok, because the utility of that should be - just about nil if this code path is in use in the first place.""" - versioninfo_path = os.path.join(project, 'versioninfo') - if os.path.exists(versioninfo_path): - with open(versioninfo_path, 'r') as vinfo: - version = vinfo.read().strip() - else: - version = "0.0.0" - return version - - -def write_versioninfo(project, version): - """Write a simple file containing the version of the package.""" - open(os.path.join(project, 'versioninfo'), 'w').write("%s\n" % version) - - -def get_cmdclass(): - """Return dict of commands to run from setup.py.""" - - cmdclass = dict() - - def _find_modules(arg, dirname, files): - for filename in files: - if filename.endswith('.py') and filename != '__init__.py': - arg["%s.%s" % (dirname.replace('/', '.'), - filename[:-3])] = True - - class LocalSDist(sdist.sdist): - """Builds the ChangeLog and Authors files from VC first.""" - - def run(self): - write_git_changelog() - generate_authors() - # sdist.sdist is an old style class, can't use super() - sdist.sdist.run(self) - - cmdclass['sdist'] = LocalSDist - - # If Sphinx is installed on the box running setup.py, - # enable setup.py to build the documentation, otherwise, - # just ignore it - try: - from sphinx.setup_command import BuildDoc - - class LocalBuildDoc(BuildDoc): - def generate_autoindex(self): - print "**Autodocumenting from %s" % os.path.abspath(os.curdir) - modules = {} - option_dict = self.distribution.get_option_dict('build_sphinx') - source_dir = os.path.join(option_dict['source_dir'][1], 'api') - if not os.path.exists(source_dir): - os.makedirs(source_dir) - for pkg in self.distribution.packages: - if '.' not in pkg: - os.path.walk(pkg, _find_modules, modules) - module_list = modules.keys() - module_list.sort() - autoindex_filename = os.path.join(source_dir, 'autoindex.rst') - with open(autoindex_filename, 'w') as autoindex: - autoindex.write(""".. toctree:: - :maxdepth: 1 - -""") - for module in module_list: - output_filename = os.path.join(source_dir, - "%s.rst" % module) - heading = "The :mod:`%s` Module" % module - underline = "=" * len(heading) - values = dict(module=module, heading=heading, - underline=underline) - - print "Generating %s" % output_filename - with open(output_filename, 'w') as output_file: - output_file.write(_rst_template % values) - autoindex.write(" %s.rst\n" % module) - - def run(self): - if not os.getenv('SPHINX_DEBUG'): - self.generate_autoindex() - - for builder in ['html', 'man']: - self.builder = builder - self.finalize_options() - self.project = self.distribution.get_name() - self.version = self.distribution.get_version() - self.release = self.distribution.get_version() - BuildDoc.run(self) - cmdclass['build_sphinx'] = LocalBuildDoc - except ImportError: - pass - - return cmdclass - - -def get_git_branchname(): - for branch in _run_shell_command("git branch --color=never").split("\n"): - if branch.startswith('*'): - _branch_name = branch.split()[1].strip() - if _branch_name == "(no": - _branch_name = "no-branch" - return _branch_name - - -def get_pre_version(projectname, base_version): - """Return a version which is based""" - if os.path.isdir('.git'): - current_tag = _get_git_current_tag() - if current_tag is not None: - version = current_tag - else: - branch_name = os.getenv('BRANCHNAME', - os.getenv('GERRIT_REFNAME', - get_git_branchname())) - version_suffix = _get_git_next_version_suffix(branch_name) - version = "%s~%s" % (base_version, version_suffix) - write_versioninfo(projectname, version) - return version.split('~')[0] - else: - version = read_versioninfo(projectname) - return version.split('~')[0] - - -def get_post_version(projectname): - """Return a version which is equal to the tag that's on the current - revision if there is one, or tag plus number of additional revisions - if the current revision has no tag.""" - - if os.path.isdir('.git'): - version = _get_git_post_version() - write_versioninfo(projectname, version) - return version - return read_versioninfo(projectname) diff --git a/src/leap/soledad/swiftclient/versioninfo b/src/leap/soledad/swiftclient/versioninfo deleted file mode 100644 index 524cb552..00000000 --- a/src/leap/soledad/swiftclient/versioninfo +++ /dev/null @@ -1 +0,0 @@ -1.1.1 diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index 4f63648e..8e0a5c52 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -7,8 +7,9 @@ import unittest import os import u1db -from soledad import leap, GPGWrapper -from soledad.openstack import ( +from soledad import GPGWrapper +from soledad.backends import leap +from soledad.backends.openstack import ( SimpleLog, TransactionLog, SyncLog, diff --git a/src/leap/soledad/u1db/__init__.py b/src/leap/soledad/u1db/__init__.py deleted file mode 100644 index ed41bb03..00000000 --- a/src/leap/soledad/u1db/__init__.py +++ /dev/null @@ -1,697 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""U1DB""" - -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db.errors import InvalidJSON, InvalidContent - -__version_info__ = (0, 1, 4) -__version__ = '.'.join(map(str, __version_info__)) - - -def open(path, create, document_factory=None): - """Open a database at the given location. - - Will raise u1db.errors.DatabaseDoesNotExist if create=False and the - database does not already exist. - - :param path: The filesystem path for the database to open. - :param create: True/False, should the database be created if it doesn't - already exist? - :param document_factory: A function that will be called with the same - parameters as Document.__init__. - :return: An instance of Database. - """ - from u1db.backends import sqlite_backend - return sqlite_backend.SQLiteDatabase.open_database( - path, create=create, document_factory=document_factory) - - -# constraints on database names (relevant for remote access, as regex) -DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*" - -# constraints on doc ids (as regex) -# (no slashes, and no characters outside the ascii range) -DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+" - - -class Database(object): - """A JSON Document data store. - - This data store can be synchronized with other u1db.Database instances. - """ - - def set_document_factory(self, factory): - """Set the document factory that will be used to create objects to be - returned as documents by the database. - - :param factory: A function that returns an object which at minimum must - satisfy the same interface as does the class DocumentBase. - Subclassing that class is the easiest way to create such - a function. - """ - raise NotImplementedError(self.set_document_factory) - - def set_document_size_limit(self, limit): - """Set the maximum allowed document size for this database. - - :param limit: Maximum allowed document size in bytes. - """ - raise NotImplementedError(self.set_document_size_limit) - - def whats_changed(self, old_generation=0): - """Return a list of documents that have changed since old_generation. - This allows APPS to only store a db generation before going - 'offline', and then when coming back online they can use this - data to update whatever extra data they are storing. - - :param old_generation: The generation of the database in the old - state. - :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) - The current generation of the database, its associated transaction - id, and a list of of changed documents since old_generation, - represented by tuples with for each document its doc_id and the - generation and transaction id corresponding to the last intervening - change and sorted by generation (old changes first) - """ - raise NotImplementedError(self.whats_changed) - - def get_doc(self, doc_id, include_deleted=False): - """Get the JSON string for the given document. - - :param doc_id: The unique document identifier - :param include_deleted: If set to True, deleted documents will be - returned with empty content. Otherwise asking for a deleted - document will return None. - :return: a Document object. - """ - raise NotImplementedError(self.get_doc) - - def get_docs(self, doc_ids, check_for_conflicts=True, - include_deleted=False): - """Get the JSON content for many documents. - - :param doc_ids: A list of document identifiers. - :param check_for_conflicts: If set to False, then the conflict check - will be skipped, and 'None' will be returned instead of True/False. - :param include_deleted: If set to True, deleted documents will be - returned with empty content. Otherwise deleted documents will not - be included in the results. - :return: iterable giving the Document object for each document id - in matching doc_ids order. - """ - raise NotImplementedError(self.get_docs) - - def get_all_docs(self, include_deleted=False): - """Get the JSON content for all documents in the database. - - :param include_deleted: If set to True, deleted documents will be - returned with empty content. Otherwise deleted documents will not - be included in the results. - :return: (generation, [Document]) - The current generation of the database, followed by a list of all - the documents in the database. - """ - raise NotImplementedError(self.get_all_docs) - - def create_doc(self, content, doc_id=None): - """Create a new document. - - You can optionally specify the document identifier, but the document - must not already exist. See 'put_doc' if you want to override an - existing document. - If the database specifies a maximum document size and the document - exceeds it, create will fail and raise a DocumentTooBig exception. - - :param content: A Python dictionary. - :param doc_id: An optional identifier specifying the document id. - :return: Document - """ - raise NotImplementedError(self.create_doc) - - def create_doc_from_json(self, json, doc_id=None): - """Create a new document. - - You can optionally specify the document identifier, but the document - must not already exist. See 'put_doc' if you want to override an - existing document. - If the database specifies a maximum document size and the document - exceeds it, create will fail and raise a DocumentTooBig exception. - - :param json: The JSON document string - :param doc_id: An optional identifier specifying the document id. - :return: Document - """ - raise NotImplementedError(self.create_doc_from_json) - - def put_doc(self, doc): - """Update a document. - If the document currently has conflicts, put will fail. - If the database specifies a maximum document size and the document - exceeds it, put will fail and raise a DocumentTooBig exception. - - :param doc: A Document with new content. - :return: new_doc_rev - The new revision identifier for the document. - The Document object will also be updated. - """ - raise NotImplementedError(self.put_doc) - - def delete_doc(self, doc): - """Mark a document as deleted. - Will abort if the current revision doesn't match doc.rev. - This will also set doc.content to None. - """ - raise NotImplementedError(self.delete_doc) - - def create_index(self, index_name, *index_expressions): - """Create an named index, which can then be queried for future lookups. - Creating an index which already exists is not an error, and is cheap. - Creating an index which does not match the index_expressions of the - existing index is an error. - Creating an index will block until the expressions have been evaluated - and the index generated. - - :param index_name: A unique name which can be used as a key prefix - :param index_expressions: index expressions defining the index - information. - - Examples: - - "fieldname", or "fieldname.subfieldname" to index alphabetically - sorted on the contents of a field. - - "number(fieldname, width)", "lower(fieldname)" - """ - raise NotImplementedError(self.create_index) - - def delete_index(self, index_name): - """Remove a named index. - - :param index_name: The name of the index we are removing - """ - raise NotImplementedError(self.delete_index) - - def list_indexes(self): - """List the definitions of all known indexes. - - :return: A list of [('index-name', ['field', 'field2'])] definitions. - """ - raise NotImplementedError(self.list_indexes) - - def get_from_index(self, index_name, *key_values): - """Return documents that match the keys supplied. - - You must supply exactly the same number of values as have been defined - in the index. It is possible to do a prefix match by using '*' to - indicate a wildcard match. You can only supply '*' to trailing entries, - (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) - It is also possible to append a '*' to the last supplied value (eg - 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') - - :param index_name: The index to query - :param key_values: values to match. eg, if you have - an index with 3 fields then you would have: - get_from_index(index_name, val1, val2, val3) - :return: List of [Document] - """ - raise NotImplementedError(self.get_from_index) - - def get_range_from_index(self, index_name, start_value, end_value): - """Return documents that fall within the specified range. - - Both ends of the range are inclusive. For both start_value and - end_value, one must supply exactly the same number of values as have - been defined in the index, or pass None. In case of a single column - index, a string is accepted as an alternative for a tuple with a single - value. It is possible to do a prefix match by using '*' to indicate - a wildcard match. You can only supply '*' to trailing entries, (eg - 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also - possible to append a '*' to the last supplied value (eg 'val*', '*', - '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') - - :param index_name: The index to query - :param start_values: tuples of values that define the lower bound of - the range. eg, if you have an index with 3 fields then you would - have: (val1, val2, val3) - :param end_values: tuples of values that define the upper bound of the - range. eg, if you have an index with 3 fields then you would have: - (val1, val2, val3) - :return: List of [Document] - """ - raise NotImplementedError(self.get_range_from_index) - - def get_index_keys(self, index_name): - """Return all keys under which documents are indexed in this index. - - :param index_name: The index to query - :return: [] A list of tuples of indexed keys. - """ - raise NotImplementedError(self.get_index_keys) - - def get_doc_conflicts(self, doc_id): - """Get the list of conflicts for the given document. - - The order of the conflicts is such that the first entry is the value - that would be returned by "get_doc". - - :return: [doc] A list of the Document entries that are conflicted. - """ - raise NotImplementedError(self.get_doc_conflicts) - - def resolve_doc(self, doc, conflicted_doc_revs): - """Mark a document as no longer conflicted. - - We take the list of revisions that the client knows about that it is - superseding. This may be a different list from the actual current - conflicts, in which case only those are removed as conflicted. This - may fail if the conflict list is significantly different from the - supplied information. (sync could have happened in the background from - the time you GET_DOC_CONFLICTS until the point where you RESOLVE) - - :param doc: A Document with the new content to be inserted. - :param conflicted_doc_revs: A list of revisions that the new content - supersedes. - """ - raise NotImplementedError(self.resolve_doc) - - def get_sync_target(self): - """Return a SyncTarget object, for another u1db to synchronize with. - - :return: An instance of SyncTarget. - """ - raise NotImplementedError(self.get_sync_target) - - def close(self): - """Release any resources associated with this database.""" - raise NotImplementedError(self.close) - - def sync(self, url, creds=None, autocreate=True): - """Synchronize documents with remote replica exposed at url. - - :param url: the url of the target replica to sync with. - :param creds: optional dictionary giving credentials - to authorize the operation with the server. For using OAuth - the form of creds is: - {'oauth': { - 'consumer_key': ..., - 'consumer_secret': ..., - 'token_key': ..., - 'token_secret': ... - }} - :param autocreate: ask the target to create the db if non-existent. - :return: local_gen_before_sync The local generation before the - synchronisation was performed. This is useful to pass into - whatschanged, if an application wants to know which documents were - affected by a synchronisation. - """ - from u1db.sync import Synchronizer - from u1db.remote.http_target import HTTPSyncTarget - return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync( - autocreate=autocreate) - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - """Return the last known generation and transaction id for the other db - replica. - - When you do a synchronization with another replica, the Database keeps - track of what generation the other database replica was at, and what - the associated transaction id was. This is used to determine what data - needs to be sent, and if two databases are claiming to be the same - replica. - - :param other_replica_uid: The identifier for the other replica. - :return: (gen, trans_id) The generation and transaction id we - encountered during synchronization. If we've never synchronized - with the replica, this is (0, ''). - """ - raise NotImplementedError(self._get_replica_gen_and_trans_id) - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - """Set the last-known generation and transaction id for the other - database replica. - - We have just performed some synchronization, and we want to track what - generation the other replica was at. See also - _get_replica_gen_and_trans_id. - :param other_replica_uid: The U1DB identifier for the other replica. - :param other_generation: The generation number for the other replica. - :param other_transaction_id: The transaction id associated with the - generation. - """ - raise NotImplementedError(self._set_replica_gen_and_trans_id) - - def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, - replica_trans_id=''): - """Insert/update document into the database with a given revision. - - This api is used during synchronization operations. - - If a document would conflict and save_conflict is set to True, the - content will be selected as the 'current' content for doc.doc_id, - even though doc.rev doesn't supersede the currently stored revision. - The currently stored document will be added to the list of conflict - alternatives for the given doc_id. - - This forces the new content to be 'current' so that we get convergence - after synchronizing, even if people don't resolve conflicts. Users can - then notice that their content is out of date, update it, and - synchronize again. (The alternative is that users could synchronize and - think the data has propagated, but their local copy looks fine, and the - remote copy is never updated again.) - - :param doc: A Document object - :param save_conflict: If this document is a conflict, do you want to - save it as a conflict, or just ignore it. - :param replica_uid: A unique replica identifier. - :param replica_gen: The generation of the replica corresponding to the - this document. The replica arguments are optional, but are used - during synchronization. - :param replica_trans_id: The transaction_id associated with the - generation. - :return: (state, at_gen) - If we don't have doc_id already, - or if doc_rev supersedes the existing document revision, - then the content will be inserted, and state is 'inserted'. - If doc_rev is less than or equal to the existing revision, - then the put is ignored and state is respecitvely 'superseded' - or 'converged'. - If doc_rev is not strictly superseded or supersedes, then - state is 'conflicted'. The document will not be inserted if - save_conflict is False. - For 'inserted' or 'converged', at_gen is the insertion/current - generation. - """ - raise NotImplementedError(self._put_doc_if_newer) - - -class DocumentBase(object): - """Container for handling a single document. - - :ivar doc_id: Unique identifier for this document. - :ivar rev: The revision identifier of the document. - :ivar json_string: The JSON string for this document. - :ivar has_conflicts: Boolean indicating if this document has conflicts - """ - - def __init__(self, doc_id, rev, json_string, has_conflicts=False): - self.doc_id = doc_id - self.rev = rev - if json_string is not None: - try: - value = json.loads(json_string) - except json.JSONDecodeError: - raise InvalidJSON - if not isinstance(value, dict): - raise InvalidJSON - self._json = json_string - self.has_conflicts = has_conflicts - - def same_content_as(self, other): - """Compare the content of two documents.""" - if self._json: - c1 = json.loads(self._json) - else: - c1 = None - if other._json: - c2 = json.loads(other._json) - else: - c2 = None - return c1 == c2 - - def __repr__(self): - if self.has_conflicts: - extra = ', conflicted' - else: - extra = '' - return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, - self.rev, extra, self.get_json()) - - def __hash__(self): - raise NotImplementedError(self.__hash__) - - def __eq__(self, other): - if not isinstance(other, Document): - return NotImplemented - return ( - self.doc_id == other.doc_id and self.rev == other.rev and - self.same_content_as(other) and self.has_conflicts == - other.has_conflicts) - - def __lt__(self, other): - """This is meant for testing, not part of the official api. - - It is implemented so that sorted([Document, Document]) can be used. - It doesn't imply that users would want their documents to be sorted in - this order. - """ - # Since this is just for testing, we don't worry about comparing - # against things that aren't a Document. - return ((self.doc_id, self.rev, self.get_json()) - < (other.doc_id, other.rev, other.get_json())) - - def get_json(self): - """Get the json serialization of this document.""" - if self._json is not None: - return self._json - return None - - def get_size(self): - """Calculate the total size of the document.""" - size = 0 - json = self.get_json() - if json: - size += len(json) - if self.rev: - size += len(self.rev) - if self.doc_id: - size += len(self.doc_id) - return size - - def set_json(self, json_string): - """Set the json serialization of this document.""" - if json_string is not None: - try: - value = json.loads(json_string) - except json.JSONDecodeError: - raise InvalidJSON - if not isinstance(value, dict): - raise InvalidJSON - self._json = json_string - - def make_tombstone(self): - """Make this document into a tombstone.""" - self._json = None - - def is_tombstone(self): - """Return True if the document is a tombstone, False otherwise.""" - if self._json is not None: - return False - return True - - -class Document(DocumentBase): - """Container for handling a single document. - - :ivar doc_id: Unique identifier for this document. - :ivar rev: The revision identifier of the document. - :ivar json: The JSON string for this document. - :ivar has_conflicts: Boolean indicating if this document has conflicts - """ - - # The following part of the API is optional: no implementation is forced to - # have it but if the language supports dictionaries/hashtables, it makes - # Documents a lot more user friendly. - - def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): - # TODO: We convert the json in the superclass to check its validity so - # we might as well set _content here directly since the price is - # already being paid. - super(Document, self).__init__(doc_id, rev, json, has_conflicts) - self._content = None - - def same_content_as(self, other): - """Compare the content of two documents.""" - if self._json: - c1 = json.loads(self._json) - else: - c1 = self._content - if other._json: - c2 = json.loads(other._json) - else: - c2 = other._content - return c1 == c2 - - def get_json(self): - """Get the json serialization of this document.""" - json_string = super(Document, self).get_json() - if json_string is not None: - return json_string - if self._content is not None: - return json.dumps(self._content) - return None - - def set_json(self, json): - """Set the json serialization of this document.""" - self._content = None - super(Document, self).set_json(json) - - def make_tombstone(self): - """Make this document into a tombstone.""" - self._content = None - super(Document, self).make_tombstone() - - def is_tombstone(self): - """Return True if the document is a tombstone, False otherwise.""" - if self._content is not None: - return False - return super(Document, self).is_tombstone() - - def _get_content(self): - """Get the dictionary representing this document.""" - if self._json is not None: - self._content = json.loads(self._json) - self._json = None - if self._content is not None: - return self._content - return None - - def _set_content(self, content): - """Set the dictionary representing this document.""" - try: - tmp = json.dumps(content) - except TypeError: - raise InvalidContent( - "Can not be converted to JSON: %r" % (content,)) - if not tmp.startswith('{'): - raise InvalidContent( - "Can not be converted to a JSON object: %r." % (content,)) - # We might as well store the JSON at this point since we did the work - # of encoding it, and it doesn't lose any information. - self._json = tmp - self._content = None - - content = property( - _get_content, _set_content, doc="Content of the Document.") - - # End of optional part. - - -class SyncTarget(object): - """Functionality for using a Database as a synchronization target.""" - - def get_sync_info(self, source_replica_uid): - """Return information about known state. - - Return the replica_uid and the current database generation of this - database, and the last-seen database generation for source_replica_uid - - :param source_replica_uid: Another replica which we might have - synchronized with in the past. - :return: (target_replica_uid, target_replica_generation, - target_trans_id, source_replica_last_known_generation, - source_replica_last_known_transaction_id) - """ - raise NotImplementedError(self.get_sync_info) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - """Record tip information for another replica. - - After sync_exchange has been processed, the caller will have - received new content from this replica. This call allows the - source replica instigating the sync to inform us what their - generation became after applying the documents we returned. - - This is used to allow future sync operations to not need to repeat data - that we just talked about. It also means that if this is called at the - wrong time, there can be database records that will never be - synchronized. - - :param source_replica_uid: The identifier for the source replica. - :param source_replica_generation: - The database generation for the source replica. - :param source_replica_transaction_id: The transaction id associated - with the source replica generation. - """ - raise NotImplementedError(self.record_sync_info) - - def sync_exchange(self, docs_by_generation, source_replica_uid, - last_known_generation, last_known_trans_id, - return_doc_cb, ensure_callback=None): - """Incorporate the documents sent from the source replica. - - This is not meant to be called by client code directly, but is used as - part of sync(). - - This adds docs to the local store, and determines documents that need - to be returned to the source replica. - - Documents must be supplied in docs_by_generation paired with - the generation of their latest change in order from the oldest - change to the newest, that means from the oldest generation to - the newest. - - Documents are also returned paired with the generation of - their latest change in order from the oldest change to the - newest. - - :param docs_by_generation: A list of [(Document, generation, - transaction_id)] tuples indicating documents which should be - updated on this replica paired with the generation and transaction - id of their latest change. - :param source_replica_uid: The source replica's identifier - :param last_known_generation: The last generation that the source - replica knows about this target replica - :param last_known_trans_id: The last transaction id that the source - replica knows about this target replica - :param: return_doc_cb(doc, gen): is a callback - used to return documents to the source replica, it will - be invoked in turn with Documents that have changed since - last_known_generation together with the generation of - their last change. - :param: ensure_callback(replica_uid): if set the target may create - the target db if not yet existent, the callback can then - be used to inform of the created db replica uid. - :return: new_generation - After applying docs_by_generation, this is - the current generation for this replica - """ - raise NotImplementedError(self.sync_exchange) - - def _set_trace_hook(self, cb): - """Set a callback that will be invoked to trace database actions. - - The callback will be passed a string indicating the current state, and - the sync target object. Implementations do not have to implement this - api, it is used by the test suite. - - :param cb: A callable that takes cb(state) - """ - raise NotImplementedError(self._set_trace_hook) - - def _set_trace_hook_shallow(self, cb): - """Set a callback that will be invoked to trace database actions. - - Similar to _set_trace_hook, for implementations that don't offer - state changes from the inner working of sync_exchange(). - - :param cb: A callable that takes cb(state) - """ - self._set_trace_hook(cb) diff --git a/src/leap/soledad/u1db/backends/__init__.py b/src/leap/soledad/u1db/backends/__init__.py deleted file mode 100644 index c8e5adc6..00000000 --- a/src/leap/soledad/u1db/backends/__init__.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Abstract classes and common implementations for the backends.""" - -import re -try: - import simplejson as json -except ImportError: - import json # noqa -import uuid - -import u1db -from u1db import ( - errors, -) -import u1db.sync -from u1db.vectorclock import VectorClockRev - - -check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE) - - -class CommonSyncTarget(u1db.sync.LocalSyncTarget): - pass - - -class CommonBackend(u1db.Database): - - document_size_limit = 0 - - def _allocate_doc_id(self): - """Generate a unique identifier for this document.""" - return 'D-' + uuid.uuid4().hex # 'D-' stands for document - - def _allocate_transaction_id(self): - return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction - - def _allocate_doc_rev(self, old_doc_rev): - vcr = VectorClockRev(old_doc_rev) - vcr.increment(self._replica_uid) - return vcr.as_str() - - def _check_doc_id(self, doc_id): - if not check_doc_id_re.match(doc_id): - raise errors.InvalidDocId() - - def _check_doc_size(self, doc): - if not self.document_size_limit: - return - if doc.get_size() > self.document_size_limit: - raise errors.DocumentTooBig - - def _get_generation(self): - """Return the current generation. - - """ - raise NotImplementedError(self._get_generation) - - def _get_generation_info(self): - """Return the current generation and transaction id. - - """ - raise NotImplementedError(self._get_generation_info) - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Extract the document from storage. - - This can return None if the document doesn't exist. - """ - raise NotImplementedError(self._get_doc) - - def _has_conflicts(self, doc_id): - """Return True if the doc has conflicts, False otherwise.""" - raise NotImplementedError(self._has_conflicts) - - def create_doc(self, content, doc_id=None): - json_string = json.dumps(content) - if doc_id is None: - doc_id = self._allocate_doc_id() - doc = self._factory(doc_id, None, json_string) - self.put_doc(doc) - return doc - - def create_doc_from_json(self, json, doc_id=None): - if doc_id is None: - doc_id = self._allocate_doc_id() - doc = self._factory(doc_id, None, json) - self.put_doc(doc) - return doc - - def _get_transaction_log(self): - """This is only for the test suite, it is not part of the api.""" - raise NotImplementedError(self._get_transaction_log) - - def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): - raise NotImplementedError(self._put_and_update_indexes) - - def get_docs(self, doc_ids, check_for_conflicts=True, - include_deleted=False): - for doc_id in doc_ids: - doc = self._get_doc( - doc_id, check_for_conflicts=check_for_conflicts) - if doc.is_tombstone() and not include_deleted: - continue - yield doc - - def _get_trans_id_for_gen(self, generation): - """Get the transaction id corresponding to a particular generation. - - Raises an InvalidGeneration when the generation does not exist. - - """ - raise NotImplementedError(self._get_trans_id_for_gen) - - def validate_gen_and_trans_id(self, generation, trans_id): - """Validate the generation and transaction id. - - Raises an InvalidGeneration when the generation does not exist, and an - InvalidTransactionId when it does but with a different transaction id. - - """ - if generation == 0: - return - known_trans_id = self._get_trans_id_for_gen(generation) - if known_trans_id != trans_id: - raise errors.InvalidTransactionId - - def _validate_source(self, other_replica_uid, other_generation, - other_transaction_id): - """Validate the new generation and transaction id. - - other_generation must be greater than what we have stored for this - replica, *or* it must be the same and the transaction_id must be the - same as well. - """ - (old_generation, - old_transaction_id) = self._get_replica_gen_and_trans_id( - other_replica_uid) - if other_generation < old_generation: - raise errors.InvalidGeneration - if other_generation > old_generation: - return - if other_transaction_id == old_transaction_id: - return - raise errors.InvalidTransactionId - - def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, - replica_trans_id=''): - cur_doc = self._get_doc(doc.doc_id) - doc_vcr = VectorClockRev(doc.rev) - if cur_doc is None: - cur_vcr = VectorClockRev(None) - else: - cur_vcr = VectorClockRev(cur_doc.rev) - self._validate_source(replica_uid, replica_gen, replica_trans_id) - if doc_vcr.is_newer(cur_vcr): - rev = doc.rev - self._prune_conflicts(doc, doc_vcr) - if doc.rev != rev: - # conflicts have been autoresolved - state = 'superseded' - else: - state = 'inserted' - self._put_and_update_indexes(cur_doc, doc) - elif doc.rev == cur_doc.rev: - # magical convergence - state = 'converged' - elif cur_vcr.is_newer(doc_vcr): - # Don't add this to seen_ids, because we have something newer, - # so we should send it back, and we should not generate a - # conflict - state = 'superseded' - elif cur_doc.same_content_as(doc): - # the documents have been edited to the same thing at both ends - doc_vcr.maximize(cur_vcr) - doc_vcr.increment(self._replica_uid) - doc.rev = doc_vcr.as_str() - self._put_and_update_indexes(cur_doc, doc) - state = 'superseded' - else: - state = 'conflicted' - if save_conflict: - self._force_doc_sync_conflict(doc) - if replica_uid is not None and replica_gen is not None: - self._do_set_replica_gen_and_trans_id( - replica_uid, replica_gen, replica_trans_id) - return state, self._get_generation() - - def _ensure_maximal_rev(self, cur_rev, extra_revs): - vcr = VectorClockRev(cur_rev) - for rev in extra_revs: - vcr.maximize(VectorClockRev(rev)) - vcr.increment(self._replica_uid) - return vcr.as_str() - - def set_document_size_limit(self, limit): - self.document_size_limit = limit diff --git a/src/leap/soledad/u1db/backends/dbschema.sql b/src/leap/soledad/u1db/backends/dbschema.sql deleted file mode 100644 index ae027fc5..00000000 --- a/src/leap/soledad/u1db/backends/dbschema.sql +++ /dev/null @@ -1,42 +0,0 @@ --- Database schema -CREATE TABLE transaction_log ( - generation INTEGER PRIMARY KEY AUTOINCREMENT, - doc_id TEXT NOT NULL, - transaction_id TEXT NOT NULL -); -CREATE TABLE document ( - doc_id TEXT PRIMARY KEY, - doc_rev TEXT NOT NULL, - content TEXT -); -CREATE TABLE document_fields ( - doc_id TEXT NOT NULL, - field_name TEXT NOT NULL, - value TEXT -); -CREATE INDEX document_fields_field_value_doc_idx - ON document_fields(field_name, value, doc_id); - -CREATE TABLE sync_log ( - replica_uid TEXT PRIMARY KEY, - known_generation INTEGER, - known_transaction_id TEXT -); -CREATE TABLE conflicts ( - doc_id TEXT, - doc_rev TEXT, - content TEXT, - CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) -); -CREATE TABLE index_definitions ( - name TEXT, - offset INT, - field TEXT, - CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) -); -create index index_definitions_field on index_definitions(field); -CREATE TABLE u1db_config ( - name TEXT PRIMARY KEY, - value TEXT -); -INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py deleted file mode 100644 index a271bb37..00000000 --- a/src/leap/soledad/u1db/backends/inmemory.py +++ /dev/null @@ -1,469 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""The in-memory Database class for U1DB.""" - -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db import ( - Document, - errors, - query_parser, - vectorclock, - ) -from u1db.backends import CommonBackend, CommonSyncTarget - - -def get_prefix(value): - key_prefix = '\x01'.join(value) - return key_prefix.rstrip('*') - - -class InMemoryDatabase(CommonBackend): - """A database that only stores the data internally.""" - - def __init__(self, replica_uid, document_factory=None): - self._transaction_log = [] - self._docs = {} - # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner' - self._conflicts = {} - self._other_generations = {} - self._indexes = {} - self._replica_uid = replica_uid - self._factory = document_factory or Document - - def _set_replica_uid(self, replica_uid): - """Force the replica_uid to be set.""" - self._replica_uid = replica_uid - - def set_document_factory(self, factory): - self._factory = factory - - def close(self): - # This is a no-op, We don't want to free the data because one client - # may be closing it, while another wants to inspect the results. - pass - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - return self._other_generations.get(other_replica_uid, (0, '')) - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - self._do_set_replica_gen_and_trans_id( - other_replica_uid, other_generation, other_transaction_id) - - def _do_set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, - other_transaction_id): - # TODO: to handle race conditions, we may want to check if the current - # value is greater than this new value. - self._other_generations[other_replica_uid] = (other_generation, - other_transaction_id) - - def get_sync_target(self): - return InMemorySyncTarget(self) - - def _get_transaction_log(self): - # snapshot! - return self._transaction_log[:] - - def _get_generation(self): - return len(self._transaction_log) - - def _get_generation_info(self): - if not self._transaction_log: - return 0, '' - return len(self._transaction_log), self._transaction_log[-1][1] - - def _get_trans_id_for_gen(self, generation): - if generation == 0: - return '' - if generation > len(self._transaction_log): - raise errors.InvalidGeneration - return self._transaction_log[generation - 1][1] - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc and old_doc.has_conflicts: - raise errors.ConflictedDoc() - if old_doc and doc.rev is None and old_doc.is_tombstone(): - new_rev = self._allocate_doc_rev(old_doc.rev) - else: - if old_doc is not None: - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - else: - if doc.rev is not None: - raise errors.RevisionConflict() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _put_and_update_indexes(self, old_doc, doc): - for index in self._indexes.itervalues(): - if old_doc is not None and not old_doc.is_tombstone(): - index.remove_json(old_doc.doc_id, old_doc.get_json()) - if not doc.is_tombstone(): - index.add_json(doc.doc_id, doc.get_json()) - trans_id = self._allocate_transaction_id() - self._docs[doc.doc_id] = (doc.rev, doc.get_json()) - self._transaction_log.append((doc.doc_id, trans_id)) - - def _get_doc(self, doc_id, check_for_conflicts=False): - try: - doc_rev, content = self._docs[doc_id] - except KeyError: - return None - doc = self._factory(doc_id, doc_rev, content) - if check_for_conflicts: - doc.has_conflicts = (doc.doc_id in self._conflicts) - return doc - - def _has_conflicts(self, doc_id): - return doc_id in self._conflicts - - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - - def get_all_docs(self, include_deleted=False): - """Return all documents in the database.""" - generation = self._get_generation() - results = [] - for doc_id, (doc_rev, content) in self._docs.items(): - if content is None and not include_deleted: - continue - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = self._has_conflicts(doc_id) - results.append(doc) - return (generation, results) - - def get_doc_conflicts(self, doc_id): - if doc_id not in self._conflicts: - return [] - result = [self._get_doc(doc_id)] - result[0].has_conflicts = True - result.extend([self._factory(doc_id, rev, content) - for rev, content in self._conflicts[doc_id]]) - return result - - def _replace_conflicts(self, doc, conflicts): - if not conflicts: - del self._conflicts[doc.doc_id] - else: - self._conflicts[doc.doc_id] = conflicts - doc.has_conflicts = bool(conflicts) - - def _prune_conflicts(self, doc, doc_vcr): - if self._has_conflicts(doc.doc_id): - autoresolved = False - remaining_conflicts = [] - cur_conflicts = self._conflicts[doc.doc_id] - for c_rev, c_doc in cur_conflicts: - c_vcr = vectorclock.VectorClockRev(c_rev) - if doc_vcr.is_newer(c_vcr): - continue - if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)): - doc_vcr.maximize(c_vcr) - autoresolved = True - continue - remaining_conflicts.append((c_rev, c_doc)) - if autoresolved: - doc_vcr.increment(self._replica_uid) - doc.rev = doc_vcr.as_str() - self._replace_conflicts(doc, remaining_conflicts) - - def resolve_doc(self, doc, conflicted_doc_revs): - cur_doc = self._get_doc(doc.doc_id) - if cur_doc is None: - cur_rev = None - else: - cur_rev = cur_doc.rev - new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs) - superseded_revs = set(conflicted_doc_revs) - remaining_conflicts = [] - cur_conflicts = self._conflicts[doc.doc_id] - for c_rev, c_doc in cur_conflicts: - if c_rev in superseded_revs: - continue - remaining_conflicts.append((c_rev, c_doc)) - doc.rev = new_rev - if cur_rev in superseded_revs: - self._put_and_update_indexes(cur_doc, doc) - else: - remaining_conflicts.append((new_rev, doc.get_json())) - self._replace_conflicts(doc, remaining_conflicts) - - def delete_doc(self, doc): - if doc.doc_id not in self._docs: - raise errors.DocumentDoesNotExist - if self._docs[doc.doc_id][1] in ('null', None): - raise errors.DocumentAlreadyDeleted - doc.make_tombstone() - self.put_doc(doc) - - def create_index(self, index_name, *index_expressions): - if index_name in self._indexes: - if self._indexes[index_name]._definition == list( - index_expressions): - return - raise errors.IndexNameTakenError - index = InMemoryIndex(index_name, list(index_expressions)) - for doc_id, (doc_rev, doc) in self._docs.iteritems(): - if doc is not None: - index.add_json(doc_id, doc) - self._indexes[index_name] = index - - def delete_index(self, index_name): - del self._indexes[index_name] - - def list_indexes(self): - definitions = [] - for idx in self._indexes.itervalues(): - definitions.append((idx._name, idx._definition)) - return definitions - - def get_from_index(self, index_name, *key_values): - try: - index = self._indexes[index_name] - except KeyError: - raise errors.IndexDoesNotExist - doc_ids = index.lookup(key_values) - result = [] - for doc_id in doc_ids: - result.append(self._get_doc(doc_id, check_for_conflicts=True)) - return result - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - """Return all documents with key values in the specified range.""" - try: - index = self._indexes[index_name] - except KeyError: - raise errors.IndexDoesNotExist - if isinstance(start_value, basestring): - start_value = (start_value,) - if isinstance(end_value, basestring): - end_value = (end_value,) - doc_ids = index.lookup_range(start_value, end_value) - result = [] - for doc_id in doc_ids: - result.append(self._get_doc(doc_id, check_for_conflicts=True)) - return result - - def get_index_keys(self, index_name): - try: - index = self._indexes[index_name] - except KeyError: - raise errors.IndexDoesNotExist - keys = index.keys() - # XXX inefficiency warning - return list(set([tuple(key.split('\x01')) for key in keys])) - - def whats_changed(self, old_generation=0): - changes = [] - relevant_tail = self._transaction_log[old_generation:] - # We don't use len(self._transaction_log) because _transaction_log may - # get mutated by a concurrent operation. - cur_generation = old_generation + len(relevant_tail) - last_trans_id = '' - if relevant_tail: - last_trans_id = relevant_tail[-1][1] - elif self._transaction_log: - last_trans_id = self._transaction_log[-1][1] - seen = set() - generation = cur_generation - for doc_id, trans_id in reversed(relevant_tail): - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - generation -= 1 - changes.reverse() - return (cur_generation, last_trans_id, changes) - - def _force_doc_sync_conflict(self, doc): - my_doc = self._get_doc(doc.doc_id) - self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) - self._conflicts.setdefault(doc.doc_id, []).append( - (my_doc.rev, my_doc.get_json())) - doc.has_conflicts = True - self._put_and_update_indexes(my_doc, doc) - - -class InMemoryIndex(object): - """Interface for managing an Index.""" - - def __init__(self, index_name, index_definition): - self._name = index_name - self._definition = index_definition - self._values = {} - parser = query_parser.Parser() - self._getters = parser.parse_all(self._definition) - - def evaluate_json(self, doc): - """Determine the 'key' after applying this index to the doc.""" - raw = json.loads(doc) - return self.evaluate(raw) - - def evaluate(self, obj): - """Evaluate a dict object, applying this definition.""" - all_rows = [[]] - for getter in self._getters: - new_rows = [] - keys = getter.get(obj) - if not keys: - return [] - for key in keys: - new_rows.extend([row + [key] for row in all_rows]) - all_rows = new_rows - all_rows = ['\x01'.join(row) for row in all_rows] - return all_rows - - def add_json(self, doc_id, doc): - """Add this json doc to the index.""" - keys = self.evaluate_json(doc) - if not keys: - return - for key in keys: - self._values.setdefault(key, []).append(doc_id) - - def remove_json(self, doc_id, doc): - """Remove this json doc from the index.""" - keys = self.evaluate_json(doc) - if keys: - for key in keys: - doc_ids = self._values[key] - doc_ids.remove(doc_id) - if not doc_ids: - del self._values[key] - - def _find_non_wildcards(self, values): - """Check if this should be a wildcard match. - - Further, this will raise an exception if the syntax is improperly - defined. - - :return: The offset of the last value we need to match against. - """ - if len(values) != len(self._definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - last = 0 - for idx, val in enumerate(values): - if val.endswith('*'): - if val != '*': - # We have an 'x*' style wildcard - if is_wildcard: - # We were already in wildcard mode, so this is invalid - raise errors.InvalidGlobbing - last = idx + 1 - is_wildcard = True - else: - if is_wildcard: - # We were in wildcard mode, we can't follow that with - # non-wildcard - raise errors.InvalidGlobbing - last = idx + 1 - if not is_wildcard: - return -1 - return last - - def lookup(self, values): - """Find docs that match the values.""" - last = self._find_non_wildcards(values) - if last == -1: - return self._lookup_exact(values) - else: - return self._lookup_prefix(values[:last]) - - def lookup_range(self, start_values, end_values): - """Find docs within the range.""" - # TODO: Wildly inefficient, which is unlikely to be a problem for the - # inmemory implementation. - if start_values: - self._find_non_wildcards(start_values) - start_values = get_prefix(start_values) - if end_values: - if self._find_non_wildcards(end_values) == -1: - exact = True - else: - exact = False - end_values = get_prefix(end_values) - found = [] - for key, doc_ids in sorted(self._values.iteritems()): - if start_values and start_values > key: - continue - if end_values and end_values < key: - if exact: - break - else: - if not key.startswith(end_values): - break - found.extend(doc_ids) - return found - - def keys(self): - """Find the indexed keys.""" - return self._values.keys() - - def _lookup_prefix(self, value): - """Find docs that match the prefix string in values.""" - # TODO: We need a different data structure to make prefix style fast, - # some sort of sorted list would work, but a plain dict doesn't. - key_prefix = get_prefix(value) - all_doc_ids = [] - for key, doc_ids in sorted(self._values.iteritems()): - if key.startswith(key_prefix): - all_doc_ids.extend(doc_ids) - return all_doc_ids - - def _lookup_exact(self, value): - """Find docs that match exactly.""" - key = '\x01'.join(value) - if key in self._values: - return self._values[key] - return () - - -class InMemorySyncTarget(CommonSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_transaction_id) diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py deleted file mode 100644 index 773213b5..00000000 --- a/src/leap/soledad/u1db/backends/sqlite_backend.py +++ /dev/null @@ -1,926 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""A U1DB implementation that uses SQLite as its persistence layer.""" - -import errno -import os -try: - import simplejson as json -except ImportError: - import json # noqa -from sqlite3 import dbapi2 -import sys -import time -import uuid - -import pkg_resources - -from u1db.backends import CommonBackend, CommonSyncTarget -from u1db import ( - Document, - errors, - query_parser, - vectorclock, - ) - - -class SQLiteDatabase(CommonBackend): - """A U1DB implementation that uses SQLite as its persistence layer.""" - - _sqlite_registry = {} - - def __init__(self, sqlite_file, document_factory=None): - """Create a new sqlite file.""" - self._db_handle = dbapi2.connect(sqlite_file) - self._real_replica_uid = None - self._ensure_schema() - self._factory = document_factory or Document - - def set_document_factory(self, factory): - self._factory = factory - - def get_sync_target(self): - return SQLiteSyncTarget(self) - - @classmethod - def _which_index_storage(cls, c): - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'index_storage'") - except dbapi2.OperationalError, e: - # The table does not exist yet - return None, e - else: - return c.fetchone()[0], None - - WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 - - @classmethod - def _open_database(cls, sqlite_file, document_factory=None): - if not os.path.isfile(sqlite_file): - raise errors.DatabaseDoesNotExist() - tries = 2 - while True: - # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) - # where without re-opening the database on Windows, it - # doesn't see the transaction that was just committed - db_handle = dbapi2.connect(sqlite_file) - c = db_handle.cursor() - v, err = cls._which_index_storage(c) - db_handle.close() - if v is not None: - break - # possibly another process is initializing it, wait for it to be - # done - if tries == 0: - raise err # go for the richest error? - tries -= 1 - time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) - return SQLiteDatabase._sqlite_registry[v]( - sqlite_file, document_factory=document_factory) - - @classmethod - def open_database(cls, sqlite_file, create, backend_cls=None, - document_factory=None): - try: - return cls._open_database( - sqlite_file, document_factory=document_factory) - except errors.DatabaseDoesNotExist: - if not create: - raise - if backend_cls is None: - # default is SQLitePartialExpandDatabase - backend_cls = SQLitePartialExpandDatabase - return backend_cls(sqlite_file, document_factory=document_factory) - - @staticmethod - def delete_database(sqlite_file): - try: - os.unlink(sqlite_file) - except OSError as ex: - if ex.errno == errno.ENOENT: - raise errors.DatabaseDoesNotExist() - raise - - @staticmethod - def register_implementation(klass): - """Register that we implement an SQLiteDatabase. - - The attribute _index_storage_value will be used as the lookup key. - """ - SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass - - def _get_sqlite_handle(self): - """Get access to the underlying sqlite database. - - This should only be used by the test suite, etc, for examining the - state of the underlying database. - """ - return self._db_handle - - def _close_sqlite_handle(self): - """Release access to the underlying sqlite database.""" - self._db_handle.close() - - def close(self): - self._close_sqlite_handle() - - def _is_initialized(self, c): - """Check if this database has been initialized.""" - c.execute("PRAGMA case_sensitive_like=ON") - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'sql_schema'") - except dbapi2.OperationalError: - # The table does not exist yet - val = None - else: - val = c.fetchone() - if val is not None: - return True - return False - - def _initialize(self, c): - """Create the schema in the database.""" - #read the script with sql commands - # TODO: Change how we set up the dependency. Most likely use something - # like lp:dirspec to grab the file from a common resource - # directory. Doesn't specifically need to be handled until we get - # to the point of packaging this. - schema_content = pkg_resources.resource_string( - __name__, 'dbschema.sql') - # Note: We'd like to use c.executescript() here, but it seems that - # executescript always commits, even if you set - # isolation_level = None, so if we want to properly handle - # exclusive locking and rollbacks between processes, we need - # to execute it line-by-line - for line in schema_content.split(';'): - if not line: - continue - c.execute(line) - #add extra fields - self._extra_schema_init(c) - # A unique identifier should be set for this replica. Implementations - # don't have to strictly use uuid here, but we do want the uid to be - # unique amongst all databases that will sync with each other. - # We might extend this to using something with hostname for easier - # debugging. - self._set_replica_uid_in_transaction(uuid.uuid4().hex) - c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", - (self._index_storage_value,)) - - def _ensure_schema(self): - """Ensure that the database schema has been created.""" - old_isolation_level = self._db_handle.isolation_level - c = self._db_handle.cursor() - if self._is_initialized(c): - return - try: - # autocommit/own mgmt of transactions - self._db_handle.isolation_level = None - with self._db_handle: - # only one execution path should initialize the db - c.execute("begin exclusive") - if self._is_initialized(c): - return - self._initialize(c) - finally: - self._db_handle.isolation_level = old_isolation_level - - def _extra_schema_init(self, c): - """Add any extra fields, etc to the basic table definitions.""" - - def _parse_index_definition(self, index_field): - """Parse a field definition for an index, returning a Getter.""" - # Note: We may want to keep a Parser object around, and cache the - # Getter objects for a greater length of time. Specifically, if - # you create a bunch of indexes, and then insert 50k docs, you'll - # re-parse the indexes between puts. The time to insert the docs - # is still likely to dominate put_doc time, though. - parser = query_parser.Parser() - getter = parser.parse(index_field) - return getter - - def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): - """Update document_fields for a single document. - - :param doc_id: Identifier for this document - :param raw_doc: The python dict representation of the document. - :param getters: A list of [(field_name, Getter)]. Getter.get will be - called to evaluate the index definition for this document, and the - results will be inserted into the db. - :param db_cursor: An sqlite Cursor. - :return: None - """ - values = [] - for field_name, getter in getters: - for idx_value in getter.get(raw_doc): - values.append((doc_id, field_name, idx_value)) - if values: - db_cursor.executemany( - "INSERT INTO document_fields VALUES (?, ?, ?)", values) - - def _set_replica_uid(self, replica_uid): - """Force the replica_uid to be set.""" - with self._db_handle: - self._set_replica_uid_in_transaction(replica_uid) - - def _set_replica_uid_in_transaction(self, replica_uid): - """Set the replica_uid. A transaction should already be held.""" - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO u1db_config" - " VALUES ('replica_uid', ?)", - (replica_uid,)) - self._real_replica_uid = replica_uid - - def _get_replica_uid(self): - if self._real_replica_uid is not None: - return self._real_replica_uid - c = self._db_handle.cursor() - c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") - val = c.fetchone() - if val is None: - return None - self._real_replica_uid = val[0] - return self._real_replica_uid - - _replica_uid = property(_get_replica_uid) - - def _get_generation(self): - c = self._db_handle.cursor() - c.execute('SELECT max(generation) FROM transaction_log') - val = c.fetchone()[0] - if val is None: - return 0 - return val - - def _get_generation_info(self): - c = self._db_handle.cursor() - c.execute( - 'SELECT max(generation), transaction_id FROM transaction_log ') - val = c.fetchone() - if val[0] is None: - return(0, '') - return val - - def _get_trans_id_for_gen(self, generation): - if generation == 0: - return '' - c = self._db_handle.cursor() - c.execute( - 'SELECT transaction_id FROM transaction_log WHERE generation = ?', - (generation,)) - val = c.fetchone() - if val is None: - raise errors.InvalidGeneration - return val[0] - - def _get_transaction_log(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, transaction_id FROM transaction_log" - " ORDER BY generation") - return c.fetchall() - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - c = self._db_handle.cursor() - if check_for_conflicts: - c.execute( - "SELECT document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " - "conflicts ON conflicts.doc_id = document.doc_id WHERE " - "document.doc_id = ? GROUP BY document.doc_id, " - "document.doc_rev, document.content;", (doc_id,)) - else: - c.execute( - "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", - (doc_id,)) - val = c.fetchone() - if val is None: - return None - doc_rev, content, conflicts = val - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - return doc - - def _has_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", - (doc_id,)) - val = c.fetchone() - if val is None: - return False - else: - return True - - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - c = self._db_handle.cursor() - c.execute( - "SELECT document.doc_id, document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " - "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " - "document.doc_rev, document.content;") - rows = c.fetchall() - for doc_id, doc_rev, content, conflicts in rows: - if content is None and not include_deleted: - continue - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - results.append(doc) - return (generation, results) - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc and old_doc.has_conflicts: - raise errors.ConflictedDoc() - if old_doc and doc.rev is None and old_doc.is_tombstone(): - new_rev = self._allocate_doc_rev(old_doc.rev) - else: - if old_doc is not None: - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - else: - if doc.rev is not None: - raise errors.RevisionConflict() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): - """Convert a dict representation into named fields. - - So something like: {'key1': 'val1', 'key2': 'val2'} - gets converted into: [(doc_id, 'key1', 'val1', 0) - (doc_id, 'key2', 'val2', 0)] - :param doc_id: Just added to every record. - :param base_field: if set, these are nested keys, so each field should - be appropriately prefixed. - :param raw_doc: The python dictionary. - """ - # TODO: Handle lists - values = [] - for field_name, value in raw_doc.iteritems(): - if value is None and not save_none: - continue - if base_field: - full_name = base_field + '.' + field_name - else: - full_name = field_name - if value is None or isinstance(value, (int, float, basestring)): - values.append((doc_id, full_name, value, len(values))) - else: - subvalues = self._expand_to_fields(doc_id, full_name, value, - save_none) - for _, subfield_name, val, _ in subvalues: - values.append((doc_id, subfield_name, val, len(values))) - return values - - def _put_and_update_indexes(self, old_doc, doc): - """Actually insert a document into the database. - - This both updates the existing documents content, and any indexes that - refer to this document. - """ - raise NotImplementedError(self._put_and_update_indexes) - - def whats_changed(self, old_generation=0): - c = self._db_handle.cursor() - c.execute("SELECT generation, doc_id, transaction_id" - " FROM transaction_log" - " WHERE generation > ? ORDER BY generation DESC", - (old_generation,)) - results = c.fetchall() - cur_gen = old_generation - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - c.execute("SELECT generation, transaction_id" - " FROM transaction_log ORDER BY generation DESC LIMIT 1") - results = c.fetchone() - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, newest_trans_id = results - - return cur_gen, newest_trans_id, changes - - def delete_doc(self, doc): - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _get_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", - (doc_id,)) - return [self._factory(doc_id, doc_rev, content) - for doc_rev, content in c.fetchall()] - - def get_doc_conflicts(self, doc_id): - with self._db_handle: - conflict_docs = self._get_conflicts(doc_id) - if not conflict_docs: - return [] - this_doc = self._get_doc(doc_id) - this_doc.has_conflicts = True - return [this_doc] + conflict_docs - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - c = self._db_handle.cursor() - c.execute("SELECT known_generation, known_transaction_id FROM sync_log" - " WHERE replica_uid = ?", - (other_replica_uid,)) - val = c.fetchone() - if val is None: - other_gen = 0 - trans_id = '' - else: - other_gen = val[0] - trans_id = val[1] - return other_gen, trans_id - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - with self._db_handle: - self._do_set_replica_gen_and_trans_id( - other_replica_uid, other_generation, other_transaction_id) - - def _do_set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, - other_transaction_id): - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", - (other_replica_uid, other_generation, - other_transaction_id)) - - def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, - replica_gen=None, replica_trans_id=None): - with self._db_handle: - return super(SQLiteDatabase, self)._put_doc_if_newer(doc, - save_conflict=save_conflict, - replica_uid=replica_uid, replica_gen=replica_gen, - replica_trans_id=replica_trans_id) - - def _add_conflict(self, c, doc_id, my_doc_rev, my_content): - c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", - (doc_id, my_doc_rev, my_content)) - - def _delete_conflicts(self, c, doc, conflict_revs): - deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] - c.executemany("DELETE FROM conflicts" - " WHERE doc_id=? AND doc_rev=?", deleting) - doc.has_conflicts = self._has_conflicts(doc.doc_id) - - def _prune_conflicts(self, doc, doc_vcr): - if self._has_conflicts(doc.doc_id): - autoresolved = False - c_revs_to_prune = [] - for c_doc in self._get_conflicts(doc.doc_id): - c_vcr = vectorclock.VectorClockRev(c_doc.rev) - if doc_vcr.is_newer(c_vcr): - c_revs_to_prune.append(c_doc.rev) - elif doc.same_content_as(c_doc): - c_revs_to_prune.append(c_doc.rev) - doc_vcr.maximize(c_vcr) - autoresolved = True - if autoresolved: - doc_vcr.increment(self._replica_uid) - doc.rev = doc_vcr.as_str() - c = self._db_handle.cursor() - self._delete_conflicts(c, doc, c_revs_to_prune) - - def _force_doc_sync_conflict(self, doc): - my_doc = self._get_doc(doc.doc_id) - c = self._db_handle.cursor() - self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) - self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) - doc.has_conflicts = True - self._put_and_update_indexes(my_doc, doc) - - def resolve_doc(self, doc, conflicted_doc_revs): - with self._db_handle: - cur_doc = self._get_doc(doc.doc_id) - # TODO: https://bugs.launchpad.net/u1db/+bug/928274 - # I think we have a logic bug in resolve_doc - # Specifically, cur_doc.rev is always in the final vector - # clock of revisions that we supersede, even if it wasn't in - # conflicted_doc_revs. We still add it as a conflict, but the - # fact that _put_doc_if_newer propagates resolutions means I - # think that conflict could accidentally be resolved. We need - # to add a test for this case first. (create a rev, create a - # conflict, create another conflict, resolve the first rev - # and first conflict, then make sure that the resolved - # rev doesn't supersede the second conflict rev.) It *might* - # not matter, because the superseding rev is in as a - # conflict, but it does seem incorrect - new_rev = self._ensure_maximal_rev(cur_doc.rev, - conflicted_doc_revs) - superseded_revs = set(conflicted_doc_revs) - c = self._db_handle.cursor() - doc.rev = new_rev - if cur_doc.rev in superseded_revs: - self._put_and_update_indexes(cur_doc, doc) - else: - self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) - # TODO: Is there some way that we could construct a rev that would - # end up in superseded_revs, such that we add a conflict, and - # then immediately delete it? - self._delete_conflicts(c, doc, superseded_revs) - - def list_indexes(self): - """Return the list of indexes and their definitions.""" - c = self._db_handle.cursor() - # TODO: How do we test the ordering? - c.execute("SELECT name, field FROM index_definitions" - " ORDER BY name, offset") - definitions = [] - cur_name = None - for name, field in c.fetchall(): - if cur_name != name: - definitions.append((name, [])) - cur_name = name - definitions[-1][-1].append(field) - return definitions - - def _get_index_definition(self, index_name): - """Return the stored definition for a given index_name.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions" - " WHERE name = ? ORDER BY offset", (index_name,)) - fields = [x[0] for x in c.fetchall()] - if not fields: - raise errors.IndexDoesNotExist - return fields - - @staticmethod - def _strip_glob(value): - """Remove the trailing * from a value.""" - assert value[-1] == '*' - return value[:-1] - - def _format_query(self, definition, key_values): - # First, build the definition. We join the document_fields table - # against itself, as many times as the 'width' of our definition. - # We then do a query for each key_value, one-at-a-time. - # Note: All of these strings are static, we could cache them, etc. - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = ["d.doc_id = d%d.doc_id" - " AND d%d.field_name = ?" - % (i, i) for i in range(len(definition))] - wildcard_where = [novalue_where[i] - + (" AND d%d.value NOT NULL" % (i,)) - for i in range(len(definition))] - exact_where = [novalue_where[i] - + (" AND d%d.value = ?" % (i,)) - for i in range(len(definition))] - like_where = [novalue_where[i] - + (" AND d%d.value GLOB ?" % (i,)) - for i in range(len(definition))] - is_wildcard = False - # Merge the lists together, so that: - # [field1, field2, field3], [val1, val2, val3] - # Becomes: - # (field1, val1, field2, val2, field3, val3) - args = [] - where = [] - for idx, (field, value) in enumerate(zip(definition, key_values)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(exact_where[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_from_index(self, index_name, *key_values): - definition = self._get_index_definition(index_name) - if len(key_values) != len(definition): - raise errors.InvalidValueForIndex() - statement, args = self._format_query(definition, key_values) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def _format_range_query(self, definition, start_value, end_value): - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - wildcard_where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - like_where = [ - novalue_where[i] + ( - " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in - range(len(definition))] - range_where_lower = [ - novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in - range(len(definition))] - range_where_upper = [ - novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in - range(len(definition))] - args = [] - where = [] - if start_value: - if isinstance(start_value, basestring): - start_value = (start_value,) - if len(start_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, start_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(self._strip_glob(value)) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(value) - if end_value: - if isinstance(end_value, basestring): - end_value = (end_value,) - if len(end_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, end_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(self._strip_glob(value)) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_upper[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - """Return all documents with key values in the specified range.""" - definition = self._get_index_definition(index_name) - statement, args = self._format_range_query( - definition, start_value, end_value) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def get_index_keys(self, index_name): - c = self._db_handle.cursor() - definition = self._get_index_definition(index_name) - value_fields = ', '.join([ - 'd%d.value' % i for i in range(len(definition))]) - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - statement = ( - "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( - value_fields, ', '.join(tables), ' AND '.join(where), - value_fields)) - try: - c.execute(statement, tuple(definition)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) - return c.fetchall() - - def delete_index(self, index_name): - with self._db_handle: - c = self._db_handle.cursor() - c.execute("DELETE FROM index_definitions WHERE name = ?", - (index_name,)) - c.execute( - "DELETE FROM document_fields WHERE document_fields.field_name " - " NOT IN (SELECT field from index_definitions)") - - -class SQLiteSyncTarget(CommonSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - -class SQLitePartialExpandDatabase(SQLiteDatabase): - """An SQLite Backend that expands documents into a document_field table. - - It stores the original document text in document.doc. For fields that are - indexed, the data goes into document_fields. - """ - - _index_storage_value = 'expand referenced' - - def _get_indexed_fields(self): - """Determine what fields are indexed.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions") - return set([x[0] for x in c.fetchall()]) - - def _evaluate_index(self, raw_doc, field): - parser = query_parser.Parser() - getter = parser.parse(field) - return getter.get(raw_doc) - - def _put_and_update_indexes(self, old_doc, doc): - c = self._db_handle.cursor() - if doc and not doc.is_tombstone(): - raw_doc = json.loads(doc.get_json()) - else: - raw_doc = {} - if old_doc is not None: - c.execute("UPDATE document SET doc_rev=?, content=?" - " WHERE doc_id = ?", - (doc.rev, doc.get_json(), doc.doc_id)) - c.execute("DELETE FROM document_fields WHERE doc_id = ?", - (doc.doc_id,)) - else: - c.execute("INSERT INTO document (doc_id, doc_rev, content)" - " VALUES (?, ?, ?)", - (doc.doc_id, doc.rev, doc.get_json())) - indexed_fields = self._get_indexed_fields() - if indexed_fields: - # It is expected that len(indexed_fields) is shorter than - # len(raw_doc) - getters = [(field, self._parse_index_definition(field)) - for field in indexed_fields] - self._update_indexes(doc.doc_id, raw_doc, getters, c) - trans_id = self._allocate_transaction_id() - c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" - " VALUES (?, ?)", (doc.doc_id, trans_id)) - - def create_index(self, index_name, *index_expressions): - with self._db_handle: - c = self._db_handle.cursor() - cur_fields = self._get_indexed_fields() - definition = [(index_name, idx, field) - for idx, field in enumerate(index_expressions)] - try: - c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", - definition) - except dbapi2.IntegrityError as e: - stored_def = self._get_index_definition(index_name) - if stored_def == [x[-1] for x in definition]: - return - raise errors.IndexNameTakenError, e, sys.exc_info()[2] - new_fields = set( - [f for f in index_expressions if f not in cur_fields]) - if new_fields: - self._update_all_indexes(new_fields) - - def _iter_all_docs(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, content FROM document") - while True: - next_rows = c.fetchmany() - if not next_rows: - break - for row in next_rows: - yield row - - def _update_all_indexes(self, new_fields): - """Iterate all the documents, and add content to document_fields. - - :param new_fields: The index definitions that need to be added. - """ - getters = [(field, self._parse_index_definition(field)) - for field in new_fields] - c = self._db_handle.cursor() - for doc_id, doc in self._iter_all_docs(): - if doc is None: - continue - raw_doc = json.loads(doc) - self._update_indexes(doc_id, raw_doc, getters, c) - -SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/u1db/commandline/__init__.py b/src/leap/soledad/u1db/commandline/__init__.py deleted file mode 100644 index 3f32e381..00000000 --- a/src/leap/soledad/u1db/commandline/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . diff --git a/src/leap/soledad/u1db/commandline/client.py b/src/leap/soledad/u1db/commandline/client.py deleted file mode 100644 index 15bf8561..00000000 --- a/src/leap/soledad/u1db/commandline/client.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Commandline bindings for the u1db-client program.""" - -import argparse -import os -try: - import simplejson as json -except ImportError: - import json # noqa -import sys - -from u1db import ( - Document, - open as u1db_open, - sync, - errors, - ) -from u1db.commandline import command -from u1db.remote import ( - http_database, - http_target, - ) - - -client_commands = command.CommandGroup() - - -def set_oauth_credentials(client): - keys = os.environ.get('OAUTH_CREDENTIALS', None) - if keys is not None: - consumer_key, consumer_secret, \ - token_key, token_secret = keys.split(":") - client.set_oauth_credentials(consumer_key, consumer_secret, - token_key, token_secret) - - -class OneDbCmd(command.Command): - """Base class for commands operating on one local or remote database.""" - - def _open(self, database, create): - if database.startswith(('http://', 'https://')): - db = http_database.HTTPDatabase(database) - set_oauth_credentials(db) - db.open(create) - return db - else: - return u1db_open(database, create) - - -class CmdCreate(OneDbCmd): - """Create a new document from scratch""" - - name = 'create' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local or remote database to update', - metavar='database-path-or-url') - parser.add_argument('infile', nargs='?', default=None, - help='The file to read content from.') - parser.add_argument('--id', dest='doc_id', default=None, - help='Set the document identifier') - - def run(self, database, infile, doc_id): - if infile is None: - infile = self.stdin - db = self._open(database, create=False) - doc = db.create_doc_from_json(infile.read(), doc_id=doc_id) - self.stderr.write('id: %s\nrev: %s\n' % (doc.doc_id, doc.rev)) - -client_commands.register(CmdCreate) - - -class CmdDelete(OneDbCmd): - """Delete a document from the database""" - - name = 'delete' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local or remote database to update', - metavar='database-path-or-url') - parser.add_argument('doc_id', help='The document id to retrieve') - parser.add_argument('doc_rev', - help='The revision of the document (which is being superseded.)') - - def run(self, database, doc_id, doc_rev): - db = self._open(database, create=False) - doc = Document(doc_id, doc_rev, None) - db.delete_doc(doc) - self.stderr.write('rev: %s\n' % (doc.rev,)) - -client_commands.register(CmdDelete) - - -class CmdGet(OneDbCmd): - """Extract a document from the database""" - - name = 'get' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local or remote database to query', - metavar='database-path-or-url') - parser.add_argument('doc_id', help='The document id to retrieve.') - parser.add_argument('outfile', nargs='?', default=None, - help='The file to write the document to', - type=argparse.FileType('wb')) - - def run(self, database, doc_id, outfile): - if outfile is None: - outfile = self.stdout - try: - db = self._open(database, create=False) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - return 1 - doc = db.get_doc(doc_id) - if doc is None: - self.stderr.write('Document not found (id: %s)\n' % (doc_id,)) - return 1 # failed - if doc.is_tombstone(): - outfile.write('[document deleted]\n') - else: - outfile.write(doc.get_json() + '\n') - self.stderr.write('rev: %s\n' % (doc.rev,)) - if doc.has_conflicts: - self.stderr.write("Document has conflicts.\n") - -client_commands.register(CmdGet) - - -class CmdGetDocConflicts(OneDbCmd): - """Get the conflicts from a document""" - - name = 'get-doc-conflicts' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local database to query', - metavar='database-path') - parser.add_argument('doc_id', help='The document id to retrieve.') - - def run(self, database, doc_id): - try: - db = self._open(database, False) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - return 1 - conflicts = db.get_doc_conflicts(doc_id) - if not conflicts: - if db.get_doc(doc_id) is None: - self.stderr.write("Document does not exist.\n") - return 1 - self.stdout.write("[") - for i, doc in enumerate(conflicts): - if i: - self.stdout.write(",") - self.stdout.write( - json.dumps(dict(rev=doc.rev, content=doc.content), indent=4)) - self.stdout.write("]\n") - -client_commands.register(CmdGetDocConflicts) - - -class CmdInitDB(OneDbCmd): - """Create a new database""" - - name = 'init-db' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local or remote database to create', - metavar='database-path-or-url') - parser.add_argument('--replica-uid', default=None, - help='The unique identifier for this database (not for remote)') - - def run(self, database, replica_uid): - db = self._open(database, create=True) - if replica_uid is not None: - db._set_replica_uid(replica_uid) - -client_commands.register(CmdInitDB) - - -class CmdPut(OneDbCmd): - """Add a document to the database""" - - name = 'put' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local or remote database to update', - metavar='database-path-or-url'), - parser.add_argument('doc_id', help='The document id to retrieve') - parser.add_argument('doc_rev', - help='The revision of the document (which is being superseded.)') - parser.add_argument('infile', nargs='?', default=None, - help='The filename of the document that will be used for content', - type=argparse.FileType('rb')) - - def run(self, database, doc_id, doc_rev, infile): - if infile is None: - infile = self.stdin - try: - db = self._open(database, create=False) - doc = Document(doc_id, doc_rev, infile.read()) - doc_rev = db.put_doc(doc) - self.stderr.write('rev: %s\n' % (doc_rev,)) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - except errors.RevisionConflict: - if db.get_doc(doc_id) is None: - self.stderr.write("Document does not exist.\n") - else: - self.stderr.write("Given revision is not current.\n") - except errors.ConflictedDoc: - self.stderr.write( - "Document has conflicts.\n" - "Inspect with get-doc-conflicts, then resolve.\n") - else: - return - return 1 - -client_commands.register(CmdPut) - - -class CmdResolve(OneDbCmd): - """Resolve a conflicted document""" - - name = 'resolve-doc' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', - help='The local or remote database to update', - metavar='database-path-or-url'), - parser.add_argument('doc_id', help='The conflicted document id') - parser.add_argument('doc_revs', metavar="doc-rev", nargs="+", - help='The revisions that the new content supersedes') - parser.add_argument('--infile', nargs='?', default=None, - help='The filename of the document that will be used for content', - type=argparse.FileType('rb')) - - def run(self, database, doc_id, doc_revs, infile): - if infile is None: - infile = self.stdin - try: - db = self._open(database, create=False) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - return 1 - doc = db.get_doc(doc_id) - if doc is None: - self.stderr.write("Document does not exist.\n") - return 1 - doc.set_json(infile.read()) - db.resolve_doc(doc, doc_revs) - self.stderr.write("rev: %s\n" % db.get_doc(doc_id).rev) - if doc.has_conflicts: - self.stderr.write("Document still has conflicts.\n") - -client_commands.register(CmdResolve) - - -class CmdSync(command.Command): - """Synchronize two databases""" - - name = 'sync' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('source', help='database to sync from') - parser.add_argument('target', help='database to sync to') - - def _open_target(self, target): - if target.startswith(('http://', 'https://')): - st = http_target.HTTPSyncTarget.connect(target) - set_oauth_credentials(st) - else: - db = u1db_open(target, create=True) - st = db.get_sync_target() - return st - - def run(self, source, target): - """Start a Sync request.""" - source_db = u1db_open(source, create=False) - st = self._open_target(target) - syncer = sync.Synchronizer(source_db, st) - syncer.sync() - source_db.close() - -client_commands.register(CmdSync) - - -class CmdCreateIndex(OneDbCmd): - """Create an index""" - - name = "create-index" - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', help='The local database to update', - metavar='database-path') - parser.add_argument('index', help='the name of the index') - parser.add_argument('expression', help='an index expression', - nargs='+') - - def run(self, database, index, expression): - try: - db = self._open(database, create=False) - db.create_index(index, *expression) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - return 1 - except errors.IndexNameTakenError: - self.stderr.write("There is already a different index named %r.\n" - % (index,)) - return 1 - except errors.IndexDefinitionParseError: - self.stderr.write("Bad index expression.\n") - return 1 - -client_commands.register(CmdCreateIndex) - - -class CmdListIndexes(OneDbCmd): - """List existing indexes""" - - name = "list-indexes" - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', help='The local database to query', - metavar='database-path') - - def run(self, database): - try: - db = self._open(database, create=False) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - return 1 - for (index, expression) in db.list_indexes(): - self.stdout.write("%s: %s\n" % (index, ", ".join(expression))) - -client_commands.register(CmdListIndexes) - - -class CmdDeleteIndex(OneDbCmd): - """Delete an index""" - - name = "delete-index" - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', help='The local database to update', - metavar='database-path') - parser.add_argument('index', help='the name of the index') - - def run(self, database, index): - try: - db = self._open(database, create=False) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - return 1 - db.delete_index(index) - -client_commands.register(CmdDeleteIndex) - - -class CmdGetIndexKeys(OneDbCmd): - """Get the index's keys""" - - name = "get-index-keys" - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', help='The local database to query', - metavar='database-path') - parser.add_argument('index', help='the name of the index') - - def run(self, database, index): - try: - db = self._open(database, create=False) - for key in db.get_index_keys(index): - self.stdout.write("%s\n" % (", ".join( - [i.encode('utf-8') for i in key],))) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - except errors.IndexDoesNotExist: - self.stderr.write("Index does not exist.\n") - else: - return - return 1 - -client_commands.register(CmdGetIndexKeys) - - -class CmdGetFromIndex(OneDbCmd): - """Find documents by searching an index""" - - name = "get-from-index" - argv = None - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('database', help='The local database to query', - metavar='database-path') - parser.add_argument('index', help='the name of the index') - parser.add_argument('values', metavar="value", - help='the value to look up (one per index column)', - nargs="+") - - def run(self, database, index, values): - try: - db = self._open(database, create=False) - docs = db.get_from_index(index, *values) - except errors.DatabaseDoesNotExist: - self.stderr.write("Database does not exist.\n") - except errors.IndexDoesNotExist: - self.stderr.write("Index does not exist.\n") - except errors.InvalidValueForIndex: - index_def = db._get_index_definition(index) - len_diff = len(index_def) - len(values) - if len_diff == 0: - # can't happen (HAH) - raise - argv = self.argv if self.argv is not None else sys.argv - self.stderr.write( - "Invalid query: " - "index %r requires %d query expression%s%s.\n" - "For example, the following would be valid:\n" - " %s %s %r %r %s\n" - % (index, - len(index_def), - "s" if len(index_def) > 1 else "", - ", not %d" % len(values) if len(values) else "", - argv[0], argv[1], database, index, - " ".join(map(repr, - values[:len(index_def)] - + ["*" for i in range(len_diff)])), - )) - except errors.InvalidGlobbing: - argv = self.argv if self.argv is not None else sys.argv - fixed = [] - for (i, v) in enumerate(values): - fixed.append(v) - if v.endswith('*'): - break - # values has at least one element, so i is defined - fixed.extend('*' * (len(values) - i - 1)) - self.stderr.write( - "Invalid query: a star can only be followed by stars.\n" - "For example, the following would be valid:\n" - " %s %s %r %r %s\n" - % (argv[0], argv[1], database, index, - " ".join(map(repr, fixed)))) - - else: - self.stdout.write("[") - for i, doc in enumerate(docs): - if i: - self.stdout.write(",") - self.stdout.write( - json.dumps( - dict(id=doc.doc_id, rev=doc.rev, content=doc.content), - indent=4)) - self.stdout.write("]\n") - return - return 1 - -client_commands.register(CmdGetFromIndex) - - -def main(args): - return client_commands.run_argv(args, sys.stdin, sys.stdout, sys.stderr) diff --git a/src/leap/soledad/u1db/commandline/command.py b/src/leap/soledad/u1db/commandline/command.py deleted file mode 100644 index eace0560..00000000 --- a/src/leap/soledad/u1db/commandline/command.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Command infrastructure for u1db""" - -import argparse -import inspect - - -class CommandGroup(object): - """A collection of commands.""" - - def __init__(self, description=None): - self.commands = {} - self.description = description - - def register(self, cmd): - """Register a new command to be incorporated with this group.""" - self.commands[cmd.name] = cmd - - def make_argparser(self): - """Create an argparse.ArgumentParser""" - parser = argparse.ArgumentParser(description=self.description) - subs = parser.add_subparsers(title='commands') - for name, cmd in sorted(self.commands.iteritems()): - sub = subs.add_parser(name, help=cmd.__doc__) - sub.set_defaults(subcommand=cmd) - cmd._populate_subparser(sub) - return parser - - def run_argv(self, argv, stdin, stdout, stderr): - """Run a command, from a sys.argv[1:] style input.""" - parser = self.make_argparser() - args = parser.parse_args(argv) - cmd = args.subcommand(stdin, stdout, stderr) - params, _, _, _ = inspect.getargspec(cmd.run) - vals = [] - for param in params[1:]: - vals.append(getattr(args, param)) - return cmd.run(*vals) - - -class Command(object): - """Definition of a Command that can be run. - - :cvar name: The name of the command, so that you can run - 'u1db-client '. - """ - - name = None - - def __init__(self, stdin, stdout, stderr): - self.stdin = stdin - self.stdout = stdout - self.stderr = stderr - - @classmethod - def _populate_subparser(cls, parser): - """Child classes should override this to provide their arguments.""" - raise NotImplementedError(cls._populate_subparser) - - def run(self, *args): - """This is where the magic happens. - - Subclasses should implement this, requesting their specific arguments. - """ - raise NotImplementedError(self.run) diff --git a/src/leap/soledad/u1db/commandline/serve.py b/src/leap/soledad/u1db/commandline/serve.py deleted file mode 100644 index 0bb0e641..00000000 --- a/src/leap/soledad/u1db/commandline/serve.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Build server for u1db-serve.""" - -from paste import httpserver - -from u1db.remote import ( - http_app, - server_state, - ) - - -def make_server(host, port, working_dir): - """Make a server on host and port exposing dbs living in working_dir.""" - state = server_state.ServerState() - state.set_workingdir(working_dir) - application = http_app.HTTPApp(state) - server = httpserver.WSGIServer(application, (host, port), - httpserver.WSGIHandler) - return server diff --git a/src/leap/soledad/u1db/errors.py b/src/leap/soledad/u1db/errors.py deleted file mode 100644 index 967c7c38..00000000 --- a/src/leap/soledad/u1db/errors.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""A list of errors that u1db can raise.""" - - -class U1DBError(Exception): - """Generic base class for U1DB errors.""" - - # description/tag for identifying the error during transmission (http,...) - wire_description = "error" - - def __init__(self, message=None): - self.message = message - - -class RevisionConflict(U1DBError): - """The document revisions supplied does not match the current version.""" - - wire_description = "revision conflict" - - -class InvalidJSON(U1DBError): - """Content was not valid json.""" - - -class InvalidContent(U1DBError): - """Content was not a python dictionary.""" - - -class InvalidDocId(U1DBError): - """A document was requested with an invalid document identifier.""" - - wire_description = "invalid document id" - - -class MissingDocIds(U1DBError): - """Needs document ids.""" - - wire_description = "missing document ids" - - -class DocumentTooBig(U1DBError): - """Document exceeds the maximum document size for this database.""" - - wire_description = "document too big" - - -class UserQuotaExceeded(U1DBError): - """Document exceeds the maximum document size for this database.""" - - wire_description = "user quota exceeded" - - -class SubscriptionNeeded(U1DBError): - """User needs a subscription to be able to use this replica..""" - - wire_description = "user needs subscription" - - -class InvalidTransactionId(U1DBError): - """Invalid transaction for generation.""" - - wire_description = "invalid transaction id" - - -class InvalidGeneration(U1DBError): - """Generation was previously synced with a different transaction id.""" - - wire_description = "invalid generation" - - -class ConflictedDoc(U1DBError): - """The document is conflicted, you must call resolve before put()""" - - -class InvalidValueForIndex(U1DBError): - """The values supplied does not match the index definition.""" - - -class InvalidGlobbing(U1DBError): - """Raised if wildcard matches are not strictly at the tail of the request. - """ - - -class DocumentDoesNotExist(U1DBError): - """The document does not exist.""" - - wire_description = "document does not exist" - - -class DocumentAlreadyDeleted(U1DBError): - """The document was already deleted.""" - - wire_description = "document already deleted" - - -class DatabaseDoesNotExist(U1DBError): - """The database does not exist.""" - - wire_description = "database does not exist" - - -class IndexNameTakenError(U1DBError): - """The given index name is already taken.""" - - -class IndexDefinitionParseError(U1DBError): - """The index definition cannot be parsed.""" - - -class IndexDoesNotExist(U1DBError): - """No index of that name exists.""" - - -class Unauthorized(U1DBError): - """Request wasn't authorized properly.""" - - wire_description = "unauthorized" - - -class HTTPError(U1DBError): - """Unspecific HTTP errror.""" - - wire_description = None - - def __init__(self, status, message=None, headers={}): - self.status = status - self.message = message - self.headers = headers - - def __str__(self): - if not self.message: - return "HTTPError(%d)" % self.status - else: - return "HTTPError(%d, %r)" % (self.status, self.message) - - -class Unavailable(HTTPError): - """Server not available not serve request.""" - - wire_description = "unavailable" - - def __init__(self, message=None, headers={}): - super(Unavailable, self).__init__(503, message, headers) - - def __str__(self): - if not self.message: - return "Unavailable()" - else: - return "Unavailable(%r)" % self.message - - -class BrokenSyncStream(U1DBError): - """Unterminated or otherwise broken sync exchange stream.""" - - wire_description = None - - -class UnknownAuthMethod(U1DBError): - """Unknown auhorization method.""" - - wire_description = None - - -# mapping wire (transimission) descriptions/tags for errors to the exceptions -wire_description_to_exc = dict( - (x.wire_description, x) for x in globals().values() - if getattr(x, 'wire_description', None) not in (None, "error") -) -wire_description_to_exc["error"] = U1DBError - - -# -# wire error descriptions not corresponding to an exception -DOCUMENT_DELETED = "document deleted" diff --git a/src/leap/soledad/u1db/query_parser.py b/src/leap/soledad/u1db/query_parser.py deleted file mode 100644 index f564821f..00000000 --- a/src/leap/soledad/u1db/query_parser.py +++ /dev/null @@ -1,370 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Code for parsing Index definitions.""" - -import re -from u1db import ( - errors, - ) - - -class Getter(object): - """Get values from a document based on a specification.""" - - def get(self, raw_doc): - """Get a value from the document. - - :param raw_doc: a python dictionary to get the value from. - :return: A list of values that match the description. - """ - raise NotImplementedError(self.get) - - -class StaticGetter(Getter): - """A getter that returns a defined value (independent of the doc).""" - - def __init__(self, value): - """Create a StaticGetter. - - :param value: the value to return when get is called. - """ - if value is None: - self.value = [] - elif isinstance(value, list): - self.value = value - else: - self.value = [value] - - def get(self, raw_doc): - return self.value - - -def extract_field(raw_doc, subfields, index=0): - if not isinstance(raw_doc, dict): - return [] - val = raw_doc.get(subfields[index]) - if val is None: - return [] - if index < len(subfields) - 1: - if isinstance(val, list): - results = [] - for item in val: - results.extend(extract_field(item, subfields, index + 1)) - return results - if isinstance(val, dict): - return extract_field(val, subfields, index + 1) - return [] - if isinstance(val, dict): - return [] - if isinstance(val, list): - # Strip anything in the list that isn't a simple type - return [v for v in val if not isinstance(v, (dict, list))] - return [val] - - -class ExtractField(Getter): - """Extract a field from the document.""" - - def __init__(self, field): - """Create an ExtractField object. - - When a document is passed to get() this will return a value - from the document based on the field specifier passed to - the constructor. - - None will be returned if the field is nonexistant, or refers to an - object, rather than a simple type or list of simple types. - - :param field: a specifier for the field to return. - This is either a field name, or a dotted field name. - """ - self.field = field.split('.') - - def get(self, raw_doc): - return extract_field(raw_doc, self.field) - - -class Transformation(Getter): - """A transformation on a value from another Getter.""" - - name = None - arity = 1 - args = ['expression'] - - def __init__(self, inner): - """Create a transformation. - - :param inner: the argument(s) to the transformation. - """ - self.inner = inner - - def get(self, raw_doc): - inner_values = self.inner.get(raw_doc) - assert isinstance(inner_values, list),\ - 'get() should always return a list' - return self.transform(inner_values) - - def transform(self, values): - """Transform the values. - - This should be implemented by subclasses to transform the - value when get() is called. - - :param values: the values from the other Getter - :return: the transformed values. - """ - raise NotImplementedError(self.transform) - - -class Lower(Transformation): - """Lowercase a string. - - This transformation will return None for non-string inputs. However, - it will lowercase any strings in a list, dropping any elements - that are not strings. - """ - - name = "lower" - - def _can_transform(self, val): - return isinstance(val, basestring) - - def transform(self, values): - if not values: - return [] - return [val.lower() for val in values if self._can_transform(val)] - - -class Number(Transformation): - """Convert an integer to a zero padded string. - - This transformation will return None for non-integer inputs. However, it - will transform any integers in a list, dropping any elements that are not - integers. - """ - - name = 'number' - arity = 2 - args = ['expression', int] - - def __init__(self, inner, number): - super(Number, self).__init__(inner) - self.padding = "%%0%sd" % number - - def _can_transform(self, val): - return isinstance(val, int) and not isinstance(val, bool) - - def transform(self, values): - """Transform any integers in values into zero padded strings.""" - if not values: - return [] - return [self.padding % (v,) for v in values if self._can_transform(v)] - - -class Bool(Transformation): - """Convert bool to string.""" - - name = "bool" - args = ['expression'] - - def _can_transform(self, val): - return isinstance(val, bool) - - def transform(self, values): - """Transform any booleans in values into strings.""" - if not values: - return [] - return [('1' if v else '0') for v in values if self._can_transform(v)] - - -class SplitWords(Transformation): - """Split a string on whitespace. - - This Getter will return [] for non-string inputs. It will however - split any strings in an input list, discarding any elements that - are not strings. - """ - - name = "split_words" - - def _can_transform(self, val): - return isinstance(val, basestring) - - def transform(self, values): - if not values: - return [] - result = set() - for value in values: - if self._can_transform(value): - for word in value.split(): - result.add(word) - return list(result) - - -class Combine(Transformation): - """Combine multiple expressions into a single index.""" - - name = "combine" - # variable number of args - arity = -1 - - def __init__(self, *inner): - super(Combine, self).__init__(inner) - - def get(self, raw_doc): - inner_values = [] - for inner in self.inner: - inner_values.extend(inner.get(raw_doc)) - return self.transform(inner_values) - - def transform(self, values): - return values - - -class IsNull(Transformation): - """Indicate whether the input is None. - - This Getter returns a bool indicating whether the input is nil. - """ - - name = "is_null" - - def transform(self, values): - return [len(values) == 0] - - -def check_fieldname(fieldname): - if fieldname.endswith('.'): - raise errors.IndexDefinitionParseError( - "Fieldname cannot end in '.':%s^" % (fieldname,)) - - -class Parser(object): - """Parse an index expression into a sequence of transformations.""" - - _transformations = {} - _delimiters = re.compile("\(|\)|,") - - def __init__(self): - self._tokens = [] - - def _set_expression(self, expression): - self._open_parens = 0 - self._tokens = [] - expression = expression.strip() - while expression: - delimiter = self._delimiters.search(expression) - if delimiter: - idx = delimiter.start() - if idx == 0: - result, expression = (expression[:1], expression[1:]) - self._tokens.append(result) - else: - result, expression = (expression[:idx], expression[idx:]) - result = result.strip() - if result: - self._tokens.append(result) - else: - expression = expression.strip() - if expression: - self._tokens.append(expression) - expression = None - - def _get_token(self): - if self._tokens: - return self._tokens.pop(0) - - def _peek_token(self): - if self._tokens: - return self._tokens[0] - - @staticmethod - def _to_getter(term): - if isinstance(term, Getter): - return term - check_fieldname(term) - return ExtractField(term) - - def _parse_op(self, op_name): - self._get_token() # '(' - op = self._transformations.get(op_name, None) - if op is None: - raise errors.IndexDefinitionParseError( - "Unknown operation: %s" % op_name) - args = [] - while True: - args.append(self._parse_term()) - sep = self._get_token() - if sep == ')': - break - if sep != ',': - raise errors.IndexDefinitionParseError( - "Unexpected token '%s' in parentheses." % (sep,)) - parsed = [] - for i, arg in enumerate(args): - arg_type = op.args[i % len(op.args)] - if arg_type == 'expression': - inner = self._to_getter(arg) - else: - try: - inner = arg_type(arg) - except ValueError, e: - raise errors.IndexDefinitionParseError( - "Invalid value %r for argument type %r " - "(%r)." % (arg, arg_type, e)) - parsed.append(inner) - return op(*parsed) - - def _parse_term(self): - term = self._get_token() - if term is None: - raise errors.IndexDefinitionParseError( - "Unexpected end of index definition.") - if term in (',', ')', '('): - raise errors.IndexDefinitionParseError( - "Unexpected token '%s' at start of expression." % (term,)) - next_token = self._peek_token() - if next_token == '(': - return self._parse_op(term) - return term - - def parse(self, expression): - self._set_expression(expression) - term = self._to_getter(self._parse_term()) - if self._peek_token(): - raise errors.IndexDefinitionParseError( - "Unexpected token '%s' after end of expression." - % (self._peek_token(),)) - return term - - def parse_all(self, fields): - return [self.parse(field) for field in fields] - - @classmethod - def register_transormation(cls, transform): - assert transform.name not in cls._transformations, ( - "Transform %s already registered for %s" - % (transform.name, cls._transformations[transform.name])) - cls._transformations[transform.name] = transform - - -Parser.register_transormation(SplitWords) -Parser.register_transormation(Lower) -Parser.register_transormation(Number) -Parser.register_transormation(Bool) -Parser.register_transormation(IsNull) -Parser.register_transormation(Combine) diff --git a/src/leap/soledad/u1db/remote/__init__.py b/src/leap/soledad/u1db/remote/__init__.py deleted file mode 100644 index 3f32e381..00000000 --- a/src/leap/soledad/u1db/remote/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . diff --git a/src/leap/soledad/u1db/remote/basic_auth_middleware.py b/src/leap/soledad/u1db/remote/basic_auth_middleware.py deleted file mode 100644 index a2cbff62..00000000 --- a/src/leap/soledad/u1db/remote/basic_auth_middleware.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . -"""U1DB Basic Auth authorisation WSGI middleware.""" -import httplib -try: - import simplejson as json -except ImportError: - import json # noqa -from wsgiref.util import shift_path_info - - -class Unauthorized(Exception): - """User authorization failed.""" - - -class BasicAuthMiddleware(object): - """U1DB Basic Auth Authorisation WSGI middleware.""" - - def __init__(self, app, prefix): - self.app = app - self.prefix = prefix - - def _error(self, start_response, status, description, message=None): - start_response("%d %s" % (status, httplib.responses[status]), - [('content-type', 'application/json')]) - err = {"error": description} - if message: - err['message'] = message - return [json.dumps(err)] - - def __call__(self, environ, start_response): - if self.prefix and not environ['PATH_INFO'].startswith(self.prefix): - return self._error(start_response, 400, "bad request") - auth = environ.get('HTTP_AUTHORIZATION') - if not auth: - return self._error(start_response, 401, "unauthorized", - "Missing Basic Authentication.") - scheme, encoded = auth.split(None, 1) - if scheme.lower() != 'basic': - return self._error( - start_response, 401, "unauthorized", - "Missing Basic Authentication") - user, password = encoded.decode('base64').split(':', 1) - try: - self.verify_user(environ, user, password) - except Unauthorized: - return self._error( - start_response, 401, "unauthorized", - "Incorrect password or login.") - del environ['HTTP_AUTHORIZATION'] - shift_path_info(environ) - return self.app(environ, start_response) - - def verify_user(self, environ, username, password): - raise NotImplementedError(self.verify_user) diff --git a/src/leap/soledad/u1db/remote/http_app.py b/src/leap/soledad/u1db/remote/http_app.py deleted file mode 100644 index 3d7d4248..00000000 --- a/src/leap/soledad/u1db/remote/http_app.py +++ /dev/null @@ -1,629 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""HTTP Application exposing U1DB.""" - -import functools -import httplib -import inspect -try: - import simplejson as json -except ImportError: - import json # noqa -import sys -import urlparse - -import routes.mapper - -from u1db import ( - __version__ as _u1db_version, - DBNAME_CONSTRAINTS, - Document, - errors, - sync, - ) -from u1db.remote import ( - http_errors, - utils, - ) - - -def parse_bool(expression): - """Parse boolean querystring parameter.""" - if expression == 'true': - return True - return False - - -def parse_list(expression): - if expression is None: - return [] - return [t.strip() for t in expression.split(',')] - - -def none_or_str(expression): - if expression is None: - return None - return str(expression) - - -class BadRequest(Exception): - """Bad request.""" - - -class _FencedReader(object): - """Read and get lines from a file but not past a given length.""" - - MAXCHUNK = 8192 - - def __init__(self, rfile, total, max_entry_size): - self.rfile = rfile - self.remaining = total - self.max_entry_size = max_entry_size - self._kept = None - - def read_chunk(self, atmost): - if self._kept is not None: - # ignore atmost, kept data should be a subchunk anyway - kept, self._kept = self._kept, None - return kept - if self.remaining == 0: - return '' - data = self.rfile.read(min(self.remaining, atmost)) - self.remaining -= len(data) - return data - - def getline(self): - line_parts = [] - size = 0 - while True: - chunk = self.read_chunk(self.MAXCHUNK) - if chunk == '': - break - nl = chunk.find("\n") - if nl != -1: - size += nl + 1 - if size > self.max_entry_size: - raise BadRequest - line_parts.append(chunk[:nl + 1]) - rest = chunk[nl + 1:] - self._kept = rest or None - break - else: - size += len(chunk) - if size > self.max_entry_size: - raise BadRequest - line_parts.append(chunk) - return ''.join(line_parts) - - -def http_method(**control): - """Decoration for handling of query arguments and content for a HTTP - method. - - args and content here are the query arguments and body of the incoming - HTTP requests. - - Match query arguments to python method arguments: - w = http_method()(f) - w(self, args, content) => args["content"]=content; - f(self, **args) - - JSON deserialize content to arguments: - w = http_method(content_as_args=True,...)(f) - w(self, args, content) => args.update(json.loads(content)); - f(self, **args) - - Support conversions (e.g int): - w = http_method(Arg=Conv,...)(f) - w(self, args, content) => args["Arg"]=Conv(args["Arg"]); - f(self, **args) - - Enforce no use of query arguments: - w = http_method(no_query=True,...)(f) - w(self, args, content) raises BadRequest if args is not empty - - Argument mismatches, deserialisation failures produce BadRequest. - """ - content_as_args = control.pop('content_as_args', False) - no_query = control.pop('no_query', False) - conversions = control.items() - - def wrap(f): - argspec = inspect.getargspec(f) - assert argspec.args[0] == "self" - nargs = len(argspec.args) - ndefaults = len(argspec.defaults or ()) - required_args = set(argspec.args[1:nargs - ndefaults]) - all_args = set(argspec.args) - - @functools.wraps(f) - def wrapper(self, args, content): - if no_query and args: - raise BadRequest() - if content is not None: - if content_as_args: - try: - args.update(json.loads(content)) - except ValueError: - raise BadRequest() - else: - args["content"] = content - if not (required_args <= set(args) <= all_args): - raise BadRequest("Missing required arguments.") - for name, conv in conversions: - if name not in args: - continue - try: - args[name] = conv(args[name]) - except ValueError: - raise BadRequest() - return f(self, **args) - - return wrapper - - return wrap - - -class URLToResource(object): - """Mappings from URLs to resources.""" - - def __init__(self): - self._map = routes.mapper.Mapper(controller_scan=None) - - def register(self, resource_cls): - # register - self._map.connect(None, resource_cls.url_pattern, - resource_cls=resource_cls, - requirements={"dbname": DBNAME_CONSTRAINTS}) - self._map.create_regs() - return resource_cls - - def match(self, path): - params = self._map.match(path) - if params is None: - return None, None - resource_cls = params.pop('resource_cls') - return resource_cls, params - -url_to_resource = URLToResource() - - -@url_to_resource.register -class GlobalResource(object): - """Global (root) resource.""" - - url_pattern = "/" - - def __init__(self, state, responder): - self.responder = responder - - @http_method() - def get(self): - self.responder.send_response_json(version=_u1db_version) - - -@url_to_resource.register -class DatabaseResource(object): - """Database resource.""" - - url_pattern = "/{dbname}" - - def __init__(self, dbname, state, responder): - self.dbname = dbname - self.state = state - self.responder = responder - - @http_method() - def get(self): - self.state.check_database(self.dbname) - self.responder.send_response_json(200) - - @http_method(content_as_args=True) - def put(self): - self.state.ensure_database(self.dbname) - self.responder.send_response_json(200, ok=True) - - @http_method() - def delete(self): - self.state.delete_database(self.dbname) - self.responder.send_response_json(200, ok=True) - - -@url_to_resource.register -class DocsResource(object): - """Documents resource.""" - - url_pattern = "/{dbname}/docs" - - def __init__(self, dbname, state, responder): - self.responder = responder - self.db = state.open_database(dbname) - - @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool, - include_deleted=parse_bool) - def get(self, doc_ids=None, check_for_conflicts=True, - include_deleted=False): - if doc_ids is None: - raise errors.MissingDocIds - docs = self.db.get_docs(doc_ids, include_deleted=include_deleted) - self.responder.content_type = 'application/json' - self.responder.start_response(200) - self.responder.start_stream(), - for doc in docs: - entry = dict( - doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), - has_conflicts=doc.has_conflicts) - self.responder.stream_entry(entry) - self.responder.end_stream() - self.responder.finish_response() - - -@url_to_resource.register -class DocResource(object): - """Document resource.""" - - url_pattern = "/{dbname}/doc/{id:.*}" - - def __init__(self, dbname, id, state, responder): - self.id = id - self.responder = responder - self.db = state.open_database(dbname) - - @http_method(old_rev=str) - def put(self, content, old_rev=None): - doc = Document(self.id, old_rev, content) - doc_rev = self.db.put_doc(doc) - if old_rev is None: - status = 201 # created - else: - status = 200 - self.responder.send_response_json(status, rev=doc_rev) - - @http_method(old_rev=str) - def delete(self, old_rev=None): - doc = Document(self.id, old_rev, None) - self.db.delete_doc(doc) - self.responder.send_response_json(200, rev=doc.rev) - - @http_method(include_deleted=parse_bool) - def get(self, include_deleted=False): - doc = self.db.get_doc(self.id, include_deleted=include_deleted) - if doc is None: - wire_descr = errors.DocumentDoesNotExist.wire_description - self.responder.send_response_json( - http_errors.wire_description_to_status[wire_descr], - error=wire_descr, - headers={ - 'x-u1db-rev': '', - 'x-u1db-has-conflicts': 'false' - }) - return - headers = { - 'x-u1db-rev': doc.rev, - 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts) - } - if doc.is_tombstone(): - self.responder.send_response_json( - http_errors.wire_description_to_status[ - errors.DOCUMENT_DELETED], - error=errors.DOCUMENT_DELETED, - headers=headers) - else: - self.responder.send_response_content( - doc.get_json(), headers=headers) - - -@url_to_resource.register -class SyncResource(object): - """Sync endpoint resource.""" - - # maximum allowed request body size - max_request_size = 15 * 1024 * 1024 # 15Mb - # maximum allowed entry/line size in request body - max_entry_size = 10 * 1024 * 1024 # 10Mb - - url_pattern = "/{dbname}/sync-from/{source_replica_uid}" - - # pluggable - sync_exchange_class = sync.SyncExchange - - def __init__(self, dbname, source_replica_uid, state, responder): - self.source_replica_uid = source_replica_uid - self.responder = responder - self.state = state - self.dbname = dbname - self.replica_uid = None - - def get_target(self): - return self.state.open_database(self.dbname).get_sync_target() - - @http_method() - def get(self): - result = self.get_target().get_sync_info(self.source_replica_uid) - self.responder.send_response_json( - target_replica_uid=result[0], target_replica_generation=result[1], - target_replica_transaction_id=result[2], - source_replica_uid=self.source_replica_uid, - source_replica_generation=result[3], - source_transaction_id=result[4]) - - @http_method(generation=int, - content_as_args=True, no_query=True) - def put(self, generation, transaction_id): - self.get_target().record_sync_info(self.source_replica_uid, - generation, - transaction_id) - self.responder.send_response_json(ok=True) - - # Implements the same logic as LocalSyncTarget.sync_exchange - - @http_method(last_known_generation=int, last_known_trans_id=none_or_str, - content_as_args=True) - def post_args(self, last_known_generation, last_known_trans_id=None, - ensure=False): - if ensure: - db, self.replica_uid = self.state.ensure_database(self.dbname) - else: - db = self.state.open_database(self.dbname) - db.validate_gen_and_trans_id( - last_known_generation, last_known_trans_id) - self.sync_exch = self.sync_exchange_class( - db, self.source_replica_uid, last_known_generation) - - @http_method(content_as_args=True) - def post_stream_entry(self, id, rev, content, gen, trans_id): - doc = Document(id, rev, content) - self.sync_exch.insert_doc_from_source(doc, gen, trans_id) - - def post_end(self): - - def send_doc(doc, gen, trans_id): - entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), - gen=gen, trans_id=trans_id) - self.responder.stream_entry(entry) - - new_gen = self.sync_exch.find_changes_to_return() - self.responder.content_type = 'application/x-u1db-sync-stream' - self.responder.start_response(200) - self.responder.start_stream(), - header = {"new_generation": new_gen, - "new_transaction_id": self.sync_exch.new_trans_id} - if self.replica_uid is not None: - header['replica_uid'] = self.replica_uid - self.responder.stream_entry(header) - self.sync_exch.return_docs(send_doc) - self.responder.end_stream() - self.responder.finish_response() - - -class HTTPResponder(object): - """Encode responses from the server back to the client.""" - - # a multi document response will put args and documents - # each on one line of the response body - - def __init__(self, start_response): - self._started = False - self._stream_state = -1 - self._no_initial_obj = True - self.sent_response = False - self._start_response = start_response - self._write = None - self.content_type = 'application/json' - self.content = [] - - def start_response(self, status, obj_dic=None, headers={}): - """start sending response with optional first json object.""" - if self._started: - return - self._started = True - status_text = httplib.responses[status] - self._write = self._start_response('%d %s' % (status, status_text), - [('content-type', self.content_type), - ('cache-control', 'no-cache')] + - headers.items()) - # xxx version in headers - if obj_dic is not None: - self._no_initial_obj = False - self._write(json.dumps(obj_dic) + "\r\n") - - def finish_response(self): - """finish sending response.""" - self.sent_response = True - - def send_response_json(self, status=200, headers={}, **kwargs): - """send and finish response with json object body from keyword args.""" - content = json.dumps(kwargs) + "\r\n" - self.send_response_content(content, headers=headers, status=status) - - def send_response_content(self, content, status=200, headers={}): - """send and finish response with content""" - headers['content-length'] = str(len(content)) - self.start_response(status, headers=headers) - if self._stream_state == 1: - self.content = [',\r\n', content] - else: - self.content = [content] - self.finish_response() - - def start_stream(self): - "start stream (array) as part of the response." - assert self._started and self._no_initial_obj - self._stream_state = 0 - self._write("[") - - def stream_entry(self, entry): - "send stream entry as part of the response." - assert self._stream_state != -1 - if self._stream_state == 0: - self._stream_state = 1 - self._write('\r\n') - else: - self._write(',\r\n') - self._write(json.dumps(entry)) - - def end_stream(self): - "end stream (array)." - assert self._stream_state != -1 - self._write("\r\n]\r\n") - - -class HTTPInvocationByMethodWithBody(object): - """Invoke methods on a resource.""" - - def __init__(self, resource, environ, parameters): - self.resource = resource - self.environ = environ - self.max_request_size = getattr( - resource, 'max_request_size', parameters.max_request_size) - self.max_entry_size = getattr( - resource, 'max_entry_size', parameters.max_entry_size) - - def _lookup(self, method): - try: - return getattr(self.resource, method) - except AttributeError: - raise BadRequest() - - def __call__(self): - args = urlparse.parse_qsl(self.environ['QUERY_STRING'], - strict_parsing=False) - try: - args = dict( - (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) - except ValueError: - raise BadRequest() - method = self.environ['REQUEST_METHOD'].lower() - if method in ('get', 'delete'): - meth = self._lookup(method) - return meth(args, None) - else: - # we expect content-length > 0, reconsider if we move - # to support chunked enconding - try: - content_length = int(self.environ['CONTENT_LENGTH']) - except (ValueError, KeyError): - raise BadRequest - if content_length <= 0: - raise BadRequest - if content_length > self.max_request_size: - raise BadRequest - reader = _FencedReader(self.environ['wsgi.input'], content_length, - self.max_entry_size) - content_type = self.environ.get('CONTENT_TYPE') - if content_type == 'application/json': - meth = self._lookup(method) - body = reader.read_chunk(sys.maxint) - return meth(args, body) - elif content_type == 'application/x-u1db-sync-stream': - meth_args = self._lookup('%s_args' % method) - meth_entry = self._lookup('%s_stream_entry' % method) - meth_end = self._lookup('%s_end' % method) - body_getline = reader.getline - if body_getline().strip() != '[': - raise BadRequest() - line = body_getline() - line, comma = utils.check_and_strip_comma(line.strip()) - meth_args(args, line) - while True: - line = body_getline() - entry = line.strip() - if entry == ']': - break - if not entry or not comma: # empty or no prec comma - raise BadRequest - entry, comma = utils.check_and_strip_comma(entry) - meth_entry({}, entry) - if comma or body_getline(): # extra comma or data - raise BadRequest - return meth_end() - else: - raise BadRequest() - - -class HTTPApp(object): - - # maximum allowed request body size - max_request_size = 15 * 1024 * 1024 # 15Mb - # maximum allowed entry/line size in request body - max_entry_size = 10 * 1024 * 1024 # 10Mb - - def __init__(self, state): - self.state = state - - def _lookup_resource(self, environ, responder): - resource_cls, params = url_to_resource.match(environ['PATH_INFO']) - if resource_cls is None: - raise BadRequest # 404 instead? - resource = resource_cls( - state=self.state, responder=responder, **params) - return resource - - def __call__(self, environ, start_response): - responder = HTTPResponder(start_response) - self.request_begin(environ) - try: - resource = self._lookup_resource(environ, responder) - HTTPInvocationByMethodWithBody(resource, environ, self)() - except errors.U1DBError, e: - self.request_u1db_error(environ, e) - status = http_errors.wire_description_to_status.get( - e.wire_description, 500) - responder.send_response_json(status, error=e.wire_description) - except BadRequest: - self.request_bad_request(environ) - responder.send_response_json(400, error="bad request") - except KeyboardInterrupt: - raise - except: - self.request_failed(environ) - raise - else: - self.request_done(environ) - return responder.content - - # hooks for tracing requests - - def request_begin(self, environ): - """Hook called at the beginning of processing a request.""" - pass - - def request_done(self, environ): - """Hook called when done processing a request.""" - pass - - def request_u1db_error(self, environ, exc): - """Hook called when processing a request resulted in a U1DBError. - - U1DBError passed as exc. - """ - pass - - def request_bad_request(self, environ): - """Hook called when processing a bad request. - - No actual processing was done. - """ - pass - - def request_failed(self, environ): - """Hook called when processing a request failed unexpectedly. - - Invoked from an except block, so there's interpreter exception - information available. - """ - pass diff --git a/src/leap/soledad/u1db/remote/http_client.py b/src/leap/soledad/u1db/remote/http_client.py deleted file mode 100644 index decddda3..00000000 --- a/src/leap/soledad/u1db/remote/http_client.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Base class to make requests to a remote HTTP server.""" - -import httplib -from oauth import oauth -try: - import simplejson as json -except ImportError: - import json # noqa -import socket -import ssl -import sys -import urlparse -import urllib - -from time import sleep -from u1db import ( - errors, - ) -from u1db.remote import ( - http_errors, - ) - -from u1db.remote.ssl_match_hostname import ( # noqa - CertificateError, - match_hostname, - ) - -# Ubuntu/debian -# XXX other... -CA_CERTS = "/etc/ssl/certs/ca-certificates.crt" - - -def _encode_query_parameter(value): - """Encode query parameter.""" - if isinstance(value, bool): - if value: - value = 'true' - else: - value = 'false' - return unicode(value).encode('utf-8') - - -class _VerifiedHTTPSConnection(httplib.HTTPSConnection): - """HTTPSConnection verifying server side certificates.""" - # derived from httplib.py - - def connect(self): - "Connect to a host on a given (SSL) port." - - sock = socket.create_connection((self.host, self.port), - self.timeout, self.source_address) - if self._tunnel_host: - self.sock = sock - self._tunnel() - if sys.platform.startswith('linux'): - cert_opts = { - 'cert_reqs': ssl.CERT_REQUIRED, - 'ca_certs': CA_CERTS - } - else: - # XXX no cert verification implemented elsewhere for now - cert_opts = {} - self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, - ssl_version=ssl.PROTOCOL_SSLv3, - **cert_opts - ) - if cert_opts: - match_hostname(self.sock.getpeercert(), self.host) - - -class HTTPClientBase(object): - """Base class to make requests to a remote HTTP server.""" - - # by default use HMAC-SHA1 OAuth signature method to not disclose - # tokens - # NB: given that the content bodies are not covered by the - # signatures though, to achieve security (against man-in-the-middle - # attacks for example) one would need HTTPS - oauth_signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1() - - # Will use these delays to retry on 503 befor finally giving up. The final - # 0 is there to not wait after the final try fails. - _delays = (1, 1, 2, 4, 0) - - def __init__(self, url, creds=None): - self._url = urlparse.urlsplit(url) - self._conn = None - self._creds = {} - if creds is not None: - if len(creds) != 1: - raise errors.UnknownAuthMethod() - auth_meth, credentials = creds.items()[0] - try: - set_creds = getattr(self, 'set_%s_credentials' % auth_meth) - except AttributeError: - raise errors.UnknownAuthMethod(auth_meth) - set_creds(**credentials) - - def set_oauth_credentials(self, consumer_key, consumer_secret, - token_key, token_secret): - self._creds = {'oauth': ( - oauth.OAuthConsumer(consumer_key, consumer_secret), - oauth.OAuthToken(token_key, token_secret))} - - def _ensure_connection(self): - if self._conn is not None: - return - if self._url.scheme == 'https': - connClass = _VerifiedHTTPSConnection - else: - connClass = httplib.HTTPConnection - self._conn = connClass(self._url.hostname, self._url.port) - - def close(self): - if self._conn: - self._conn.close() - self._conn = None - - # xxx retry mechanism? - - def _error(self, respdic): - descr = respdic.get("error") - exc_cls = errors.wire_description_to_exc.get(descr) - if exc_cls is not None: - message = respdic.get("message") - raise exc_cls(message) - - def _response(self): - resp = self._conn.getresponse() - body = resp.read() - headers = dict(resp.getheaders()) - if resp.status in (200, 201): - return body, headers - elif resp.status in http_errors.ERROR_STATUSES: - try: - respdic = json.loads(body) - except ValueError: - pass - else: - self._error(respdic) - # special case - if resp.status == 503: - raise errors.Unavailable(body, headers) - raise errors.HTTPError(resp.status, body, headers) - - def _sign_request(self, method, url_query, params): - if 'oauth' in self._creds: - consumer, token = self._creds['oauth'] - full_url = "%s://%s%s" % (self._url.scheme, self._url.netloc, - url_query) - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - consumer, token, - http_method=method, - parameters=params, - http_url=full_url - ) - oauth_req.sign_request( - self.oauth_signature_method, consumer, token) - # Authorization: OAuth ... - return oauth_req.to_header().items() - else: - return [] - - def _request(self, method, url_parts, params=None, body=None, - content_type=None): - self._ensure_connection() - unquoted_url = url_query = self._url.path - if url_parts: - if not url_query.endswith('/'): - url_query += '/' - unquoted_url = url_query - url_query += '/'.join(urllib.quote(part, safe='') - for part in url_parts) - # oauth performs its own quoting - unquoted_url += '/'.join(url_parts) - encoded_params = {} - if params: - for key, value in params.items(): - key = unicode(key).encode('utf-8') - encoded_params[key] = _encode_query_parameter(value) - url_query += ('?' + urllib.urlencode(encoded_params)) - if body is not None and not isinstance(body, basestring): - body = json.dumps(body) - content_type = 'application/json' - headers = {} - if content_type: - headers['content-type'] = content_type - headers.update( - self._sign_request(method, unquoted_url, encoded_params)) - for delay in self._delays: - try: - self._conn.request(method, url_query, body, headers) - return self._response() - except errors.Unavailable, e: - sleep(delay) - raise e - - def _request_json(self, method, url_parts, params=None, body=None, - content_type=None): - res, headers = self._request(method, url_parts, params, body, - content_type) - return json.loads(res), headers diff --git a/src/leap/soledad/u1db/remote/http_database.py b/src/leap/soledad/u1db/remote/http_database.py deleted file mode 100644 index 6901baad..00000000 --- a/src/leap/soledad/u1db/remote/http_database.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""HTTPDatabase to access a remote db over the HTTP API.""" - -try: - import simplejson as json -except ImportError: - import json # noqa -import uuid - -from u1db import ( - Database, - Document, - errors, - ) -from u1db.remote import ( - http_client, - http_errors, - http_target, - ) - - -DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[ - errors.DOCUMENT_DELETED] - - -class HTTPDatabase(http_client.HTTPClientBase, Database): - """Implement the Database API to a remote HTTP server.""" - - def __init__(self, url, document_factory=None, creds=None): - super(HTTPDatabase, self).__init__(url, creds=creds) - self._factory = document_factory or Document - - def set_document_factory(self, factory): - self._factory = factory - - @staticmethod - def open_database(url, create): - db = HTTPDatabase(url) - db.open(create) - return db - - @staticmethod - def delete_database(url): - db = HTTPDatabase(url) - db._delete() - db.close() - - def open(self, create): - if create: - self._ensure() - else: - self._check() - - def _check(self): - return self._request_json('GET', [])[0] - - def _ensure(self): - self._request_json('PUT', [], {}, {}) - - def _delete(self): - self._request_json('DELETE', [], {}, {}) - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - params = {} - if doc.rev is not None: - params['old_rev'] = doc.rev - res, headers = self._request_json('PUT', ['doc', doc.doc_id], params, - doc.get_json(), 'application/json') - doc.rev = res['rev'] - return res['rev'] - - def get_doc(self, doc_id, include_deleted=False): - try: - res, headers = self._request( - 'GET', ['doc', doc_id], {"include_deleted": include_deleted}) - except errors.DocumentDoesNotExist: - return None - except errors.HTTPError, e: - if (e.status == DOCUMENT_DELETED_STATUS and - 'x-u1db-rev' in e.headers): - res = None - headers = e.headers - else: - raise - doc_rev = headers['x-u1db-rev'] - has_conflicts = json.loads(headers['x-u1db-has-conflicts']) - doc = self._factory(doc_id, doc_rev, res) - doc.has_conflicts = has_conflicts - return doc - - def get_docs(self, doc_ids, check_for_conflicts=True, - include_deleted=False): - if not doc_ids: - return - doc_ids = ','.join(doc_ids) - res, headers = self._request( - 'GET', ['docs'], { - "doc_ids": doc_ids, "include_deleted": include_deleted, - "check_for_conflicts": check_for_conflicts}) - for doc_dict in json.loads(res): - doc = self._factory( - doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content']) - doc.has_conflicts = doc_dict['has_conflicts'] - yield doc - - def create_doc_from_json(self, content, doc_id=None): - if doc_id is None: - doc_id = 'D-%s' % (uuid.uuid4().hex,) - res, headers = self._request_json('PUT', ['doc', doc_id], {}, - content, 'application/json') - new_doc = self._factory(doc_id, res['rev'], content) - return new_doc - - def delete_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - params = {'old_rev': doc.rev} - res, headers = self._request_json('DELETE', - ['doc', doc.doc_id], params) - doc.make_tombstone() - doc.rev = res['rev'] - - def get_sync_target(self): - st = http_target.HTTPSyncTarget(self._url.geturl()) - st._creds = self._creds - return st diff --git a/src/leap/soledad/u1db/remote/http_errors.py b/src/leap/soledad/u1db/remote/http_errors.py deleted file mode 100644 index 2039c5b2..00000000 --- a/src/leap/soledad/u1db/remote/http_errors.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Information about the encoding of errors over HTTP.""" - -from u1db import ( - errors, - ) - - -# error wire descriptions mapping to HTTP status codes -wire_description_to_status = dict([ - (errors.InvalidDocId.wire_description, 400), - (errors.MissingDocIds.wire_description, 400), - (errors.Unauthorized.wire_description, 401), - (errors.DocumentTooBig.wire_description, 403), - (errors.UserQuotaExceeded.wire_description, 403), - (errors.SubscriptionNeeded.wire_description, 403), - (errors.DatabaseDoesNotExist.wire_description, 404), - (errors.DocumentDoesNotExist.wire_description, 404), - (errors.DocumentAlreadyDeleted.wire_description, 404), - (errors.RevisionConflict.wire_description, 409), - (errors.InvalidGeneration.wire_description, 409), - (errors.InvalidTransactionId.wire_description, 409), - (errors.Unavailable.wire_description, 503), -# without matching exception - (errors.DOCUMENT_DELETED, 404) -]) - - -ERROR_STATUSES = set(wire_description_to_status.values()) -# 400 included explicitly for tests -ERROR_STATUSES.add(400) diff --git a/src/leap/soledad/u1db/remote/http_target.py b/src/leap/soledad/u1db/remote/http_target.py deleted file mode 100644 index 1028963e..00000000 --- a/src/leap/soledad/u1db/remote/http_target.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""SyncTarget API implementation to a remote HTTP server.""" - -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db import ( - Document, - SyncTarget, - ) -from u1db.errors import ( - BrokenSyncStream, - ) -from u1db.remote import ( - http_client, - utils, - ) - - -class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget): - """Implement the SyncTarget api to a remote HTTP server.""" - - @staticmethod - def connect(url): - return HTTPSyncTarget(url) - - def get_sync_info(self, source_replica_uid): - self._ensure_connection() - res, _ = self._request_json('GET', ['sync-from', source_replica_uid]) - return (res['target_replica_uid'], res['target_replica_generation'], - res['target_replica_transaction_id'], - res['source_replica_generation'], res['source_transaction_id']) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_transaction_id): - self._ensure_connection() - if self._trace_hook: # for tests - self._trace_hook('record_sync_info') - self._request_json('PUT', ['sync-from', source_replica_uid], {}, - {'generation': source_replica_generation, - 'transaction_id': source_transaction_id}) - - def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): - parts = data.splitlines() # one at a time - if not parts or parts[0] != '[': - raise BrokenSyncStream - data = parts[1:-1] - comma = False - if data: - line, comma = utils.check_and_strip_comma(data[0]) - res = json.loads(line) - if ensure_callback and 'replica_uid' in res: - ensure_callback(res['replica_uid']) - for entry in data[1:]: - if not comma: # missing in between comma - raise BrokenSyncStream - line, comma = utils.check_and_strip_comma(entry) - entry = json.loads(line) - doc = Document(entry['id'], entry['rev'], entry['content']) - return_doc_cb(doc, entry['gen'], entry['trans_id']) - if parts[-1] != ']': - try: - partdic = json.loads(parts[-1]) - except ValueError: - pass - else: - if isinstance(partdic, dict): - self._error(partdic) - raise BrokenSyncStream - if not data or comma: # no entries or bad extra comma - raise BrokenSyncStream - return res - - def sync_exchange(self, docs_by_generations, source_replica_uid, - last_known_generation, last_known_trans_id, - return_doc_cb, ensure_callback=None): - self._ensure_connection() - if self._trace_hook: # for tests - self._trace_hook('sync_exchange') - url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) - self._conn.putrequest('POST', url) - self._conn.putheader('content-type', 'application/x-u1db-sync-stream') - for header_name, header_value in self._sign_request('POST', url, {}): - self._conn.putheader(header_name, header_value) - entries = ['['] - size = 1 - - def prepare(**dic): - entry = comma + '\r\n' + json.dumps(dic) - entries.append(entry) - return len(entry) - - comma = '' - size += prepare( - last_known_generation=last_known_generation, - last_known_trans_id=last_known_trans_id, - ensure=ensure_callback is not None) - comma = ',' - for doc, gen, trans_id in docs_by_generations: - size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), - gen=gen, trans_id=trans_id) - entries.append('\r\n]') - size += len(entries[-1]) - self._conn.putheader('content-length', str(size)) - self._conn.endheaders() - for entry in entries: - self._conn.send(entry) - entries = None - data, _ = self._response() - res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) - data = None - return res['new_generation'], res['new_transaction_id'] - - # for tests - _trace_hook = None - - def _set_trace_hook_shallow(self, cb): - self._trace_hook = cb diff --git a/src/leap/soledad/u1db/remote/oauth_middleware.py b/src/leap/soledad/u1db/remote/oauth_middleware.py deleted file mode 100644 index 5772580a..00000000 --- a/src/leap/soledad/u1db/remote/oauth_middleware.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . -"""U1DB OAuth authorisation WSGI middleware.""" -import httplib -from oauth import oauth -try: - import simplejson as json -except ImportError: - import json # noqa -from urllib import quote -from wsgiref.util import shift_path_info - - -sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() -sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() - - -class OAuthMiddleware(object): - """U1DB OAuth Authorisation WSGI middleware.""" - - # max seconds the request timestamp is allowed to be shifted - # from arrival time - timestamp_threshold = 300 - - def __init__(self, app, base_url, prefix='/~/'): - self.app = app - self.base_url = base_url - self.prefix = prefix - - def get_oauth_data_store(self): - """Provide a oauth.OAuthDataStore.""" - raise NotImplementedError(self.get_oauth_data_store) - - def _error(self, start_response, status, description, message=None): - start_response("%d %s" % (status, httplib.responses[status]), - [('content-type', 'application/json')]) - err = {"error": description} - if message: - err['message'] = message - return [json.dumps(err)] - - def __call__(self, environ, start_response): - if self.prefix and not environ['PATH_INFO'].startswith(self.prefix): - return self._error(start_response, 400, "bad request") - headers = {} - if 'HTTP_AUTHORIZATION' in environ: - headers['Authorization'] = environ['HTTP_AUTHORIZATION'] - oauth_req = oauth.OAuthRequest.from_request( - http_method=environ['REQUEST_METHOD'], - http_url=self.base_url + environ['PATH_INFO'], - headers=headers, - query_string=environ['QUERY_STRING'] - ) - if oauth_req is None: - return self._error(start_response, 401, "unauthorized", - "Missing OAuth.") - try: - self.verify(environ, oauth_req) - except oauth.OAuthError, e: - return self._error(start_response, 401, "unauthorized", - e.message) - shift_path_info(environ) - return self.app(environ, start_response) - - def verify(self, environ, oauth_req): - """Verify OAuth request, put user_id in the environ.""" - oauth_server = oauth.OAuthServer(self.get_oauth_data_store()) - oauth_server.timestamp_threshold = self.timestamp_threshold - oauth_server.add_signature_method(sign_meth_HMAC_SHA1) - oauth_server.add_signature_method(sign_meth_PLAINTEXT) - consumer, token, parameters = oauth_server.verify_request(oauth_req) - # filter out oauth bits - environ['QUERY_STRING'] = '&'.join("%s=%s" % (quote(k, safe=''), - quote(v, safe='')) - for k, v in parameters.iteritems()) - return consumer, token diff --git a/src/leap/soledad/u1db/remote/server_state.py b/src/leap/soledad/u1db/remote/server_state.py deleted file mode 100644 index 96581359..00000000 --- a/src/leap/soledad/u1db/remote/server_state.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""State for servers exposing a set of U1DB databases.""" -import os -import errno - -class ServerState(object): - """Passed to a Request when it is instantiated. - - This is used to track server-side state, such as working-directory, open - databases, etc. - """ - - def __init__(self): - self._workingdir = None - - def set_workingdir(self, path): - self._workingdir = path - - def _relpath(self, relpath): - # Note: We don't want to allow absolute paths here, because we - # don't want to expose the filesystem. We should also check that - # relpath doesn't have '..' in it, etc. - return self._workingdir + '/' + relpath - - def open_database(self, path): - """Open a database at the given location.""" - from u1db.backends import sqlite_backend - full_path = self._relpath(path) - return sqlite_backend.SQLiteDatabase.open_database(full_path, - create=False) - - def check_database(self, path): - """Check if the database at the given location exists. - - Simply returns if it does or raises DatabaseDoesNotExist. - """ - db = self.open_database(path) - db.close() - - def ensure_database(self, path): - """Ensure database at the given location.""" - from u1db.backends import sqlite_backend - full_path = self._relpath(path) - db = sqlite_backend.SQLiteDatabase.open_database(full_path, - create=True) - return db, db._replica_uid - - def delete_database(self, path): - """Delete database at the given location.""" - from u1db.backends import sqlite_backend - full_path = self._relpath(path) - sqlite_backend.SQLiteDatabase.delete_database(full_path) diff --git a/src/leap/soledad/u1db/remote/ssl_match_hostname.py b/src/leap/soledad/u1db/remote/ssl_match_hostname.py deleted file mode 100644 index fbabc177..00000000 --- a/src/leap/soledad/u1db/remote/ssl_match_hostname.py +++ /dev/null @@ -1,64 +0,0 @@ -"""The match_hostname() function from Python 3.2, essential when using SSL.""" -# XXX put it here until it's packaged - -import re - -__version__ = '3.2a3' - - -class CertificateError(ValueError): - pass - - -def _dnsname_to_pat(dn): - pats = [] - for frag in dn.split(r'.'): - if frag == '*': - # When '*' is a fragment by itself, it matches a non-empty dotless - # fragment. - pats.append('[^.]+') - else: - # Otherwise, '*' matches any dotless fragment. - frag = re.escape(frag) - pats.append(frag.replace(r'\*', '[^.]*')) - return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) - - -def match_hostname(cert, hostname): - """Verify that *cert* (in decoded format as returned by - SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules - are mostly followed, but IP addresses are not accepted for *hostname*. - - CertificateError is raised on failure. On success, the function - returns nothing. - """ - if not cert: - raise ValueError("empty or no certificate") - dnsnames = [] - san = cert.get('subjectAltName', ()) - for key, value in san: - if key == 'DNS': - if _dnsname_to_pat(value).match(hostname): - return - dnsnames.append(value) - if not san: - # The subject is only checked when subjectAltName is empty - for sub in cert.get('subject', ()): - for key, value in sub: - # XXX according to RFC 2818, the most specific Common Name - # must be used. - if key == 'commonName': - if _dnsname_to_pat(value).match(hostname): - return - dnsnames.append(value) - if len(dnsnames) > 1: - raise CertificateError("hostname %r " - "doesn't match either of %s" - % (hostname, ', '.join(map(repr, dnsnames)))) - elif len(dnsnames) == 1: - raise CertificateError("hostname %r " - "doesn't match %r" - % (hostname, dnsnames[0])) - else: - raise CertificateError("no appropriate commonName or " - "subjectAltName fields were found") diff --git a/src/leap/soledad/u1db/remote/utils.py b/src/leap/soledad/u1db/remote/utils.py deleted file mode 100644 index 14cedea9..00000000 --- a/src/leap/soledad/u1db/remote/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Utilities for details of the procotol.""" - - -def check_and_strip_comma(line): - if line and line[-1] == ',': - return line[:-1], True - return line, False diff --git a/src/leap/soledad/u1db/sync.py b/src/leap/soledad/u1db/sync.py deleted file mode 100644 index 3375d097..00000000 --- a/src/leap/soledad/u1db/sync.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""The synchronization utilities for U1DB.""" -from itertools import izip - -import u1db -from u1db import errors - - -class Synchronizer(object): - """Collect the state around synchronizing 2 U1DB replicas. - - Synchronization is bi-directional, in that new items in the source are sent - to the target, and new items in the target are returned to the source. - However, it still recognizes that one side is initiating the request. Also, - at the moment, conflicts are only created in the source. - """ - - def __init__(self, source, sync_target): - """Create a new Synchronization object. - - :param source: A Database - :param sync_target: A SyncTarget - """ - self.source = source - self.sync_target = sync_target - self.target_replica_uid = None - self.num_inserted = 0 - - def _insert_doc_from_target(self, doc, replica_gen, trans_id): - """Try to insert synced document from target. - - Implements TAKE OTHER semantics: any document from the target - that is in conflict will be taken as the new official value, - while the current conflicting value will be stored alongside - as a conflict. In the process indexes will be updated etc. - - :return: None - """ - # Increases self.num_inserted depending whether the document - # was effectively inserted. - state, _ = self.source._put_doc_if_newer(doc, save_conflict=True, - replica_uid=self.target_replica_uid, replica_gen=replica_gen, - replica_trans_id=trans_id) - if state == 'inserted': - self.num_inserted += 1 - elif state == 'converged': - # magical convergence - pass - elif state == 'superseded': - # we have something newer, will be taken care of at the next sync - pass - else: - assert state == 'conflicted' - # The doc was saved as a conflict, so the database was updated - self.num_inserted += 1 - - def _record_sync_info_with_the_target(self, start_generation): - """Record our new after sync generation with the target if gapless. - - Any documents received from the target will cause the local - database to increment its generation. We do not want to send - them back to the target in a future sync. However, there could - also be concurrent updates from another process doing eg - 'put_doc' while the sync was running. And we do want to - synchronize those documents. We can tell if there was a - concurrent update by comparing our new generation number - versus the generation we started, and how many documents we - inserted from the target. If it matches exactly, then we can - record with the target that they are fully up to date with our - new generation. - """ - cur_gen, trans_id = self.source._get_generation_info() - if (cur_gen == start_generation + self.num_inserted - and self.num_inserted > 0): - self.sync_target.record_sync_info( - self.source._replica_uid, cur_gen, trans_id) - - def sync(self, callback=None, autocreate=False): - """Synchronize documents between source and target.""" - sync_target = self.sync_target - # get target identifier, its current generation, - # and its last-seen database generation for this source - try: - (self.target_replica_uid, target_gen, target_trans_id, - target_my_gen, target_my_trans_id) = sync_target.get_sync_info( - self.source._replica_uid) - except errors.DatabaseDoesNotExist: - if not autocreate: - raise - # will try to ask sync_exchange() to create the db - self.target_replica_uid = None - target_gen, target_trans_id = 0, '' - target_my_gen, target_my_trans_id = 0, '' - def ensure_callback(replica_uid): - self.target_replica_uid = replica_uid - else: - ensure_callback = None - # validate the generation and transaction id the target knows about us - self.source.validate_gen_and_trans_id( - target_my_gen, target_my_trans_id) - # what's changed since that generation and this current gen - my_gen, _, changes = self.source.whats_changed(target_my_gen) - - # this source last-seen database generation for the target - if self.target_replica_uid is None: - target_last_known_gen, target_last_known_trans_id = 0, '' - else: - target_last_known_gen, target_last_known_trans_id = \ - self.source._get_replica_gen_and_trans_id(self.target_replica_uid) - if not changes and target_last_known_gen == target_gen: - if target_trans_id != target_last_known_trans_id: - raise errors.InvalidTransactionId - return my_gen - changed_doc_ids = [doc_id for doc_id, _, _ in changes] - # prepare to send all the changed docs - docs_to_send = self.source.get_docs(changed_doc_ids, - check_for_conflicts=False, include_deleted=True) - # TODO: there must be a way to not iterate twice - docs_by_generation = zip( - docs_to_send, (gen for _, gen, _ in changes), - (trans for _, _, trans in changes)) - - # exchange documents and try to insert the returned ones with - # the target, return target synced-up-to gen - new_gen, new_trans_id = sync_target.sync_exchange( - docs_by_generation, self.source._replica_uid, - target_last_known_gen, target_last_known_trans_id, - self._insert_doc_from_target, ensure_callback=ensure_callback) - # record target synced-up-to generation including applying what we sent - self.source._set_replica_gen_and_trans_id( - self.target_replica_uid, new_gen, new_trans_id) - - # if gapless record current reached generation with target - self._record_sync_info_with_the_target(my_gen) - - return my_gen - - -class SyncExchange(object): - """Steps and state for carrying through a sync exchange on a target.""" - - def __init__(self, db, source_replica_uid, last_known_generation): - self._db = db - self.source_replica_uid = source_replica_uid - self.source_last_known_generation = last_known_generation - self.seen_ids = {} # incoming ids not superseded - self.changes_to_return = None - self.new_gen = None - self.new_trans_id = None - # for tests - self._incoming_trace = [] - self._trace_hook = None - self._db._last_exchange_log = { - 'receive': {'docs': self._incoming_trace}, - 'return': None - } - - def _set_trace_hook(self, cb): - self._trace_hook = cb - - def _trace(self, state): - if not self._trace_hook: - return - self._trace_hook(state) - - def insert_doc_from_source(self, doc, source_gen, trans_id): - """Try to insert synced document from source. - - Conflicting documents are not inserted but will be sent over - to the sync source. - - It keeps track of progress by storing the document source - generation as well. - - The 1st step of a sync exchange is to call this repeatedly to - try insert all incoming documents from the source. - - :param doc: A Document object. - :param source_gen: The source generation of doc. - :return: None - """ - state, at_gen = self._db._put_doc_if_newer(doc, save_conflict=False, - replica_uid=self.source_replica_uid, replica_gen=source_gen, - replica_trans_id=trans_id) - if state == 'inserted': - self.seen_ids[doc.doc_id] = at_gen - elif state == 'converged': - # magical convergence - self.seen_ids[doc.doc_id] = at_gen - elif state == 'superseded': - # we have something newer that we will return - pass - else: - # conflict that we will returne - assert state == 'conflicted' - # for tests - self._incoming_trace.append((doc.doc_id, doc.rev)) - self._db._last_exchange_log['receive'].update({ - 'source_uid': self.source_replica_uid, - 'source_gen': source_gen - }) - - def find_changes_to_return(self): - """Find changes to return. - - Find changes since last_known_generation in db generation - order using whats_changed. It excludes documents ids that have - already been considered (superseded by the sender, etc). - - :return: new_generation - the generation of this database - which the caller can consider themselves to be synchronized after - processing the returned documents. - """ - self._db._last_exchange_log['receive'].update({ # for tests - 'last_known_gen': self.source_last_known_generation - }) - self._trace('before whats_changed') - gen, trans_id, changes = self._db.whats_changed( - self.source_last_known_generation) - self._trace('after whats_changed') - self.new_gen = gen - self.new_trans_id = trans_id - seen_ids = self.seen_ids - # changed docs that weren't superseded by or converged with - self.changes_to_return = [ - (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes - # there was a subsequent update - if doc_id not in seen_ids or seen_ids.get(doc_id) < gen] - return self.new_gen - - def return_docs(self, return_doc_cb): - """Return the changed documents and their last change generation - repeatedly invoking the callback return_doc_cb. - - The final step of a sync exchange. - - :param: return_doc_cb(doc, gen, trans_id): is a callback - used to return the documents with their last change generation - to the target replica. - :return: None - """ - changes_to_return = self.changes_to_return - # return docs, including conflicts - changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] - self._trace('before get_docs') - docs = self._db.get_docs( - changed_doc_ids, check_for_conflicts=False, include_deleted=True) - - docs_by_gen = izip( - docs, (gen for _, gen, _ in changes_to_return), - (trans_id for _, _, trans_id in changes_to_return)) - _outgoing_trace = [] # for tests - for doc, gen, trans_id in docs_by_gen: - return_doc_cb(doc, gen, trans_id) - _outgoing_trace.append((doc.doc_id, doc.rev)) - # for tests - self._db._last_exchange_log['return'] = { - 'docs': _outgoing_trace, - 'last_gen': self.new_gen - } - - -class LocalSyncTarget(u1db.SyncTarget): - """Common sync target implementation logic for all local sync targets.""" - - def __init__(self, db): - self._db = db - self._trace_hook = None - - def sync_exchange(self, docs_by_generations, source_replica_uid, - last_known_generation, last_known_trans_id, - return_doc_cb, ensure_callback=None): - self._db.validate_gen_and_trans_id( - last_known_generation, last_known_trans_id) - sync_exch = SyncExchange( - self._db, source_replica_uid, last_known_generation) - if self._trace_hook: - sync_exch._set_trace_hook(self._trace_hook) - # 1st step: try to insert incoming docs and record progress - for doc, doc_gen, trans_id in docs_by_generations: - sync_exch.insert_doc_from_source(doc, doc_gen, trans_id) - # 2nd step: find changed documents (including conflicts) to return - new_gen = sync_exch.find_changes_to_return() - # final step: return docs and record source replica sync point - sync_exch.return_docs(return_doc_cb) - return new_gen, sync_exch.new_trans_id - - def _set_trace_hook(self, cb): - self._trace_hook = cb diff --git a/src/leap/soledad/u1db/tests/__init__.py b/src/leap/soledad/u1db/tests/__init__.py deleted file mode 100644 index b8e16b15..00000000 --- a/src/leap/soledad/u1db/tests/__init__.py +++ /dev/null @@ -1,463 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test infrastructure for U1DB""" - -import copy -import shutil -import socket -import tempfile -import threading - -try: - import simplejson as json -except ImportError: - import json # noqa - -from wsgiref import simple_server - -from oauth import oauth -from sqlite3 import dbapi2 -from StringIO import StringIO - -import testscenarios -import testtools - -from u1db import ( - errors, - Document, - ) -from u1db.backends import ( - inmemory, - sqlite_backend, - ) -from u1db.remote import ( - server_state, - ) - -try: - from u1db.tests import c_backend_wrapper - c_backend_error = None -except ImportError, e: - c_backend_wrapper = None # noqa - c_backend_error = e - -# Setting this means that failing assertions will not include this module in -# their traceback. However testtools doesn't seem to set it, and we don't want -# this level to be omitted, but the lower levels to be shown. -# __unittest = 1 - - -class TestCase(testtools.TestCase): - - def createTempDir(self, prefix='u1db-tmp-'): - """Create a temporary directory to do some work in. - - This directory will be scheduled for cleanup when the test ends. - """ - tempdir = tempfile.mkdtemp(prefix=prefix) - self.addCleanup(shutil.rmtree, tempdir) - return tempdir - - def make_document(self, doc_id, doc_rev, content, has_conflicts=False): - return self.make_document_for_test( - self, doc_id, doc_rev, content, has_conflicts) - - def make_document_for_test(self, test, doc_id, doc_rev, content, - has_conflicts): - return make_document_for_test( - test, doc_id, doc_rev, content, has_conflicts) - - def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): - """Assert that the document in the database looks correct.""" - exp_doc = self.make_document(doc_id, doc_rev, content, - has_conflicts=has_conflicts) - self.assertEqual(exp_doc, db.get_doc(doc_id)) - - def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, - has_conflicts): - """Assert that the document in the database looks correct.""" - exp_doc = self.make_document(doc_id, doc_rev, content, - has_conflicts=has_conflicts) - self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) - - def assertGetDocConflicts(self, db, doc_id, conflicts): - """Assert what conflicts are stored for a given doc_id. - - :param conflicts: A list of (doc_rev, content) pairs. - The first item must match the first item returned from the - database, however the rest can be returned in any order. - """ - if conflicts: - conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) - else cont)) for (rev, cont) in conflicts] - conflicts = conflicts[:1] + sorted(conflicts[1:]) - actual = db.get_doc_conflicts(doc_id) - if actual: - actual = [(doc.rev, (json.loads(doc.get_json()) - if doc.get_json() is not None else None)) for doc in actual] - actual = actual[:1] + sorted(actual[1:]) - self.assertEqual(conflicts, actual) - - -def multiply_scenarios(a_scenarios, b_scenarios): - """Create the cross-product of scenarios.""" - - all_scenarios = [] - for a_name, a_attrs in a_scenarios: - for b_name, b_attrs in b_scenarios: - name = '%s,%s' % (a_name, b_name) - attrs = dict(a_attrs) - attrs.update(b_attrs) - all_scenarios.append((name, attrs)) - return all_scenarios - - -simple_doc = '{"key": "value"}' -nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' - - -def make_memory_database_for_test(test, replica_uid): - return inmemory.InMemoryDatabase(replica_uid) - - -def copy_memory_database_for_test(test, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR - # HOUSE. - new_db = inmemory.InMemoryDatabase(db._replica_uid) - new_db._transaction_log = db._transaction_log[:] - new_db._docs = copy.deepcopy(db._docs) - new_db._conflicts = copy.deepcopy(db._conflicts) - new_db._indexes = copy.deepcopy(db._indexes) - new_db._factory = db._factory - return new_db - - -def make_sqlite_partial_expanded_for_test(test, replica_uid): - db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - db._set_replica_uid(replica_uid) - return db - - -def copy_sqlite_partial_expanded_for_test(test, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR - # HOUSE. - new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - tmpfile = StringIO() - for line in db._db_handle.iterdump(): - if not 'sqlite_sequence' in line: # work around bug in iterdump - tmpfile.write('%s\n' % line) - tmpfile.seek(0) - new_db._db_handle = dbapi2.connect(':memory:') - new_db._db_handle.cursor().executescript(tmpfile.read()) - new_db._db_handle.commit() - new_db._set_replica_uid(db._replica_uid) - new_db._factory = db._factory - return new_db - - -def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): - return Document(doc_id, rev, content, has_conflicts=has_conflicts) - - -def make_c_database_for_test(test, replica_uid): - if c_backend_wrapper is None: - test.skipTest('c_backend_wrapper is not available') - db = c_backend_wrapper.CDatabase(':memory:') - db._set_replica_uid(replica_uid) - return db - - -def copy_c_database_for_test(test, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR - # HOUSE. - if c_backend_wrapper is None: - test.skipTest('c_backend_wrapper is not available') - new_db = db._copy(db) - return new_db - - -def make_c_document_for_test(test, doc_id, rev, content, has_conflicts=False): - if c_backend_wrapper is None: - test.skipTest('c_backend_wrapper is not available') - return c_backend_wrapper.make_document( - doc_id, rev, content, has_conflicts=has_conflicts) - - -LOCAL_DATABASES_SCENARIOS = [ - ('mem', {'make_database_for_test': make_memory_database_for_test, - 'copy_database_for_test': copy_memory_database_for_test, - 'make_document_for_test': make_document_for_test}), - ('sql', {'make_database_for_test': - make_sqlite_partial_expanded_for_test, - 'copy_database_for_test': - copy_sqlite_partial_expanded_for_test, - 'make_document_for_test': make_document_for_test}), - ] - - -C_DATABASE_SCENARIOS = [ - ('c', {'make_database_for_test': make_c_database_for_test, - 'copy_database_for_test': copy_c_database_for_test, - 'make_document_for_test': make_c_document_for_test})] - - -class DatabaseBaseTests(TestCase): - - accept_fixed_trans_id = False # set to True assertTransactionLog - # is happy with all trans ids = '' - - scenarios = LOCAL_DATABASES_SCENARIOS - - def create_database(self, replica_uid): - return self.make_database_for_test(self, replica_uid) - - def copy_database(self, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES - # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST - # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS - # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND - # NINJA TO YOUR HOUSE. - return self.copy_database_for_test(self, db) - - def setUp(self): - super(DatabaseBaseTests, self).setUp() - self.db = self.create_database('test') - - def tearDown(self): - # TODO: Add close_database parameterization - # self.close_database(self.db) - super(DatabaseBaseTests, self).tearDown() - - def assertTransactionLog(self, doc_ids, db): - """Assert that the given docs are in the transaction log.""" - log = db._get_transaction_log() - just_ids = [] - seen_transactions = set() - for doc_id, transaction_id in log: - just_ids.append(doc_id) - self.assertIsNot(None, transaction_id, - "Transaction id should not be None") - if transaction_id == '' and self.accept_fixed_trans_id: - continue - self.assertNotEqual('', transaction_id, - "Transaction id should be a unique string") - self.assertTrue(transaction_id.startswith('T-')) - self.assertNotIn(transaction_id, seen_transactions) - seen_transactions.add(transaction_id) - self.assertEqual(doc_ids, just_ids) - - def getLastTransId(self, db): - """Return the transaction id for the last database update.""" - return self.db._get_transaction_log()[-1][-1] - - -class ServerStateForTests(server_state.ServerState): - """Used in the test suite, so we don't have to touch disk, etc.""" - - def __init__(self): - super(ServerStateForTests, self).__init__() - self._dbs = {} - - def open_database(self, path): - try: - return self._dbs[path] - except KeyError: - raise errors.DatabaseDoesNotExist - - def check_database(self, path): - # cares only about the possible exception - self.open_database(path) - - def ensure_database(self, path): - try: - db = self.open_database(path) - except errors.DatabaseDoesNotExist: - db = self._create_database(path) - return db, db._replica_uid - - def _copy_database(self, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES - # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST - # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS - # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND - # NINJA TO YOUR HOUSE. - new_db = copy_memory_database_for_test(None, db) - path = db._replica_uid - while path in self._dbs: - path += 'copy' - self._dbs[path] = new_db - return new_db - - def _create_database(self, path): - db = inmemory.InMemoryDatabase(path) - self._dbs[path] = db - return db - - def delete_database(self, path): - del self._dbs[path] - - -class ResponderForTests(object): - """Responder for tests.""" - _started = False - sent_response = False - status = None - - def start_response(self, status='success', **kwargs): - self._started = True - self.status = status - self.kwargs = kwargs - - def send_response(self, status='success', **kwargs): - self.start_response(status, **kwargs) - self.finish_response() - - def finish_response(self): - self.sent_response = True - - -class TestCaseWithServer(TestCase): - - @staticmethod - def server_def(): - # hook point - # should return (ServerClass, "shutdown method name", "url_scheme") - class _RequestHandler(simple_server.WSGIRequestHandler): - def log_request(*args): - pass # suppress - - def make_server(host_port, application): - assert application, "forgot to override make_app(_with_state)?" - srv = simple_server.WSGIServer(host_port, _RequestHandler) - # patch the value in if it's None - if getattr(application, 'base_url', 1) is None: - application.base_url = "http://%s:%s" % srv.server_address - srv.set_app(application) - return srv - - return make_server, "shutdown", "http" - - @staticmethod - def make_app_with_state(state): - # hook point - return None - - def make_app(self): - # potential hook point - self.request_state = ServerStateForTests() - return self.make_app_with_state(self.request_state) - - def setUp(self): - super(TestCaseWithServer, self).setUp() - self.server = self.server_thread = None - - @property - def url_scheme(self): - return self.server_def()[-1] - - def startServer(self): - server_def = self.server_def() - server_class, shutdown_meth, _ = server_def - application = self.make_app() - self.server = server_class(('127.0.0.1', 0), application) - self.server_thread = threading.Thread(target=self.server.serve_forever, - kwargs=dict(poll_interval=0.01)) - self.server_thread.start() - self.addCleanup(self.server_thread.join) - self.addCleanup(getattr(self.server, shutdown_meth)) - - def getURL(self, path=None): - host, port = self.server.server_address - if path is None: - path = '' - return '%s://%s:%s/%s' % (self.url_scheme, host, port, path) - - -def socket_pair(): - """Return a pair of TCP sockets connected to each other. - - Unlike socket.socketpair, this should work on Windows. - """ - sock_pair = getattr(socket, 'socket_pair', None) - if sock_pair: - return sock_pair(socket.AF_INET, socket.SOCK_STREAM) - listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listen_sock.bind(('127.0.0.1', 0)) - listen_sock.listen(1) - client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client_sock.connect(listen_sock.getsockname()) - server_sock, addr = listen_sock.accept() - listen_sock.close() - return server_sock, client_sock - - -# OAuth related testing - -consumer1 = oauth.OAuthConsumer('K1', 'S1') -token1 = oauth.OAuthToken('kkkk1', 'XYZ') -consumer2 = oauth.OAuthConsumer('K2', 'S2') -token2 = oauth.OAuthToken('kkkk2', 'ZYX') -token3 = oauth.OAuthToken('kkkk3', 'ZYX') - - -class TestingOAuthDataStore(oauth.OAuthDataStore): - """In memory predefined OAuthDataStore for testing.""" - - consumers = { - consumer1.key: consumer1, - consumer2.key: consumer2, - } - - tokens = { - token1.key: token1, - token2.key: token2 - } - - def lookup_consumer(self, key): - return self.consumers.get(key) - - def lookup_token(self, token_type, token_token): - return self.tokens.get(token_token) - - def lookup_nonce(self, oauth_consumer, oauth_token, nonce): - return None - -testingOAuthStore = TestingOAuthDataStore() - -sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() -sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() - - -def load_with_scenarios(loader, standard_tests, pattern): - """Load the tests in a given module. - - This just applies testscenarios.generate_scenarios to all the tests that - are present. We do it at load time rather than at run time, because it - plays nicer with various tools. - """ - suite = loader.suiteClass() - suite.addTests(testscenarios.generate_scenarios(standard_tests)) - return suite diff --git a/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx deleted file mode 100644 index 8a4b600d..00000000 --- a/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx +++ /dev/null @@ -1,1541 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . -# -"""A Cython wrapper around the C implementation of U1DB Database backend.""" - -cdef extern from "Python.h": - object PyString_FromStringAndSize(char *s, Py_ssize_t n) - int PyString_AsStringAndSize(object o, char **buf, Py_ssize_t *length - ) except -1 - char *PyString_AsString(object) except NULL - char *PyString_AS_STRING(object) - char *strdup(char *) - void *calloc(size_t, size_t) - void free(void *) - ctypedef struct FILE: - pass - fprintf(FILE *, char *, ...) - FILE *stderr - size_t strlen(char *) - -cdef extern from "stdarg.h": - ctypedef struct va_list: - pass - void va_start(va_list, void*) - void va_start_int "va_start" (va_list, int) - void va_end(va_list) - -cdef extern from "u1db/u1db.h": - ctypedef struct u1database: - pass - ctypedef struct u1db_document: - char *doc_id - size_t doc_id_len - char *doc_rev - size_t doc_rev_len - char *json - size_t json_len - int has_conflicts - # Note: u1query is actually defined in u1db_internal.h, and in u1db.h it is - # just an opaque pointer. However, older versions of Cython don't let - # you have a forward declaration and a full declaration, so we just - # expose the whole thing here. - ctypedef struct u1query: - char *index_name - int num_fields - char **fields - cdef struct u1db_oauth_creds: - int auth_kind - char *consumer_key - char *consumer_secret - char *token_key - char *token_secret - ctypedef union u1db_creds - ctypedef u1db_creds* const_u1db_creds_ptr "const u1db_creds *" - - ctypedef char* const_char_ptr "const char*" - ctypedef int (*u1db_doc_callback)(void *context, u1db_document *doc) - ctypedef int (*u1db_key_callback)(void *context, int num_fields, - const_char_ptr *key) - ctypedef int (*u1db_doc_gen_callback)(void *context, - u1db_document *doc, int gen, const_char_ptr trans_id) - ctypedef int (*u1db_trans_info_callback)(void *context, - const_char_ptr doc_id, int gen, const_char_ptr trans_id) - - u1database * u1db_open(char *fname) - void u1db_free(u1database **) - int u1db_set_replica_uid(u1database *, char *replica_uid) - int u1db_set_document_size_limit(u1database *, int limit) - int u1db_get_replica_uid(u1database *, const_char_ptr *replica_uid) - int u1db_create_doc_from_json(u1database *db, char *json, char *doc_id, - u1db_document **doc) - int u1db_delete_doc(u1database *db, u1db_document *doc) - int u1db_get_doc(u1database *db, char *doc_id, int include_deleted, - u1db_document **doc) - int u1db_get_docs(u1database *db, int n_doc_ids, const_char_ptr *doc_ids, - int check_for_conflicts, int include_deleted, - void *context, u1db_doc_callback cb) - int u1db_get_all_docs(u1database *db, int include_deleted, int *generation, - void *context, u1db_doc_callback cb) - int u1db_put_doc(u1database *db, u1db_document *doc) - int u1db__validate_source(u1database *db, const_char_ptr replica_uid, - int replica_gen, const_char_ptr replica_trans_id) - int u1db__put_doc_if_newer(u1database *db, u1db_document *doc, - int save_conflict, char *replica_uid, - int replica_gen, char *replica_trans_id, - int *state, int *at_gen) - int u1db_resolve_doc(u1database *db, u1db_document *doc, - int n_revs, const_char_ptr *revs) - int u1db_delete_doc(u1database *db, u1db_document *doc) - int u1db_whats_changed(u1database *db, int *gen, char **trans_id, - void *context, u1db_trans_info_callback cb) - int u1db__get_transaction_log(u1database *db, void *context, - u1db_trans_info_callback cb) - int u1db_get_doc_conflicts(u1database *db, char *doc_id, void *context, - u1db_doc_callback cb) - int u1db_sync(u1database *db, const_char_ptr url, - const_u1db_creds_ptr creds, int *local_gen) nogil - int u1db_create_index_list(u1database *db, char *index_name, - int n_expressions, const_char_ptr *expressions) - int u1db_create_index(u1database *db, char *index_name, int n_expressions, - ...) - int u1db_get_from_index_list(u1database *db, u1query *query, void *context, - u1db_doc_callback cb, int n_values, - const_char_ptr *values) - int u1db_get_from_index(u1database *db, u1query *query, void *context, - u1db_doc_callback cb, int n_values, char *val0, - ...) - int u1db_get_range_from_index(u1database *db, u1query *query, - void *context, u1db_doc_callback cb, - int n_values, const_char_ptr *start_values, - const_char_ptr *end_values) - int u1db_delete_index(u1database *db, char *index_name) - int u1db_list_indexes(u1database *db, void *context, - int (*cb)(void *context, const_char_ptr index_name, - int n_expressions, const_char_ptr *expressions)) - int u1db_get_index_keys(u1database *db, char *index_name, void *context, - u1db_key_callback cb) - int u1db_simple_lookup1(u1database *db, char *index_name, char *val1, - void *context, u1db_doc_callback cb) - int u1db_query_init(u1database *db, char *index_name, u1query **query) - void u1db_free_query(u1query **query) - - int U1DB_OK - int U1DB_INVALID_PARAMETER - int U1DB_REVISION_CONFLICT - int U1DB_INVALID_DOC_ID - int U1DB_DOCUMENT_ALREADY_DELETED - int U1DB_DOCUMENT_DOES_NOT_EXIST - int U1DB_NOT_IMPLEMENTED - int U1DB_INVALID_JSON - int U1DB_DOCUMENT_TOO_BIG - int U1DB_USER_QUOTA_EXCEEDED - int U1DB_INVALID_VALUE_FOR_INDEX - int U1DB_INVALID_FIELD_SPECIFIER - int U1DB_INVALID_GLOBBING - int U1DB_BROKEN_SYNC_STREAM - int U1DB_DUPLICATE_INDEX_NAME - int U1DB_INDEX_DOES_NOT_EXIST - int U1DB_INVALID_GENERATION - int U1DB_INVALID_TRANSACTION_ID - int U1DB_INVALID_TRANSFORMATION_FUNCTION - int U1DB_UNKNOWN_OPERATION - int U1DB_INTERNAL_ERROR - int U1DB_TARGET_UNAVAILABLE - - int U1DB_INSERTED - int U1DB_SUPERSEDED - int U1DB_CONVERGED - int U1DB_CONFLICTED - - int U1DB_OAUTH_AUTH - - void u1db_free_doc(u1db_document **doc) - int u1db_doc_set_json(u1db_document *doc, char *json) - int u1db_doc_get_size(u1db_document *doc) - - -cdef extern from "u1db/u1db_internal.h": - ctypedef struct u1db_row: - u1db_row *next - int num_columns - int *column_sizes - unsigned char **columns - - ctypedef struct u1db_table: - int status - u1db_row *first_row - - ctypedef struct u1db_record: - u1db_record *next - char *doc_id - char *doc_rev - char *doc - - ctypedef struct u1db_sync_exchange: - int target_gen - int num_doc_ids - char **doc_ids_to_return - int *gen_for_doc_ids - const_char_ptr *trans_ids_for_doc_ids - - ctypedef int (*u1db__trace_callback)(void *context, const_char_ptr state) - ctypedef struct u1db_sync_target: - int (*get_sync_info)(u1db_sync_target *st, char *source_replica_uid, - const_char_ptr *st_replica_uid, int *st_gen, - char **st_trans_id, int *source_gen, - char **source_trans_id) nogil - int (*record_sync_info)(u1db_sync_target *st, - char *source_replica_uid, int source_gen, char *trans_id) nogil - int (*sync_exchange)(u1db_sync_target *st, - char *source_replica_uid, int n_docs, - u1db_document **docs, int *generations, - const_char_ptr *trans_ids, - int *target_gen, char **target_trans_id, - void *context, u1db_doc_gen_callback cb, - void *ensure_callback) nogil - int (*sync_exchange_doc_ids)(u1db_sync_target *st, - u1database *source_db, int n_doc_ids, - const_char_ptr *doc_ids, int *generations, - const_char_ptr *trans_ids, - int *target_gen, char **target_trans_id, - void *context, - u1db_doc_gen_callback cb, - void *ensure_callback) nogil - int (*get_sync_exchange)(u1db_sync_target *st, - char *source_replica_uid, - int last_known_source_gen, - u1db_sync_exchange **exchange) nogil - void (*finalize_sync_exchange)(u1db_sync_target *st, - u1db_sync_exchange **exchange) nogil - int (*_set_trace_hook)(u1db_sync_target *st, - void *context, u1db__trace_callback cb) nogil - - - void u1db__set_zero_delays() - int u1db__get_generation(u1database *, int *db_rev) - int u1db__get_document_size_limit(u1database *, int *limit) - int u1db__get_generation_info(u1database *, int *db_rev, char **trans_id) - int u1db__get_trans_id_for_gen(u1database *, int db_rev, char **trans_id) - int u1db_validate_gen_and_trans_id(u1database *, int db_rev, - const_char_ptr trans_id) - char *u1db__allocate_doc_id(u1database *) - int u1db__sql_close(u1database *) - u1database *u1db__copy(u1database *) - int u1db__sql_is_open(u1database *) - u1db_table *u1db__sql_run(u1database *, char *sql, size_t n) - void u1db__free_table(u1db_table **table) - u1db_record *u1db__create_record(char *doc_id, char *doc_rev, char *doc) - void u1db__free_records(u1db_record **) - - int u1db__allocate_document(char *doc_id, char *revision, char *content, - int has_conflicts, u1db_document **result) - int u1db__generate_hex_uuid(char *) - - int u1db__get_replica_gen_and_trans_id(u1database *db, char *replica_uid, - int *generation, char **trans_id) - int u1db__set_replica_gen_and_trans_id(u1database *db, char *replica_uid, - int generation, char *trans_id) - int u1db__sync_get_machine_info(u1database *db, char *other_replica_uid, - int *other_db_rev, char **my_replica_uid, - int *my_db_rev) - int u1db__sync_record_machine_info(u1database *db, char *replica_uid, - int db_rev) - int u1db__sync_exchange_seen_ids(u1db_sync_exchange *se, int *n_ids, - const_char_ptr **doc_ids) - int u1db__format_query(int n_fields, const_char_ptr *values, char **buf, - int *wildcard) - int u1db__get_sync_target(u1database *db, u1db_sync_target **sync_target) - int u1db__free_sync_target(u1db_sync_target **sync_target) - int u1db__sync_db_to_target(u1database *db, u1db_sync_target *target, - int *local_gen_before_sync) nogil - - int u1db__sync_exchange_insert_doc_from_source(u1db_sync_exchange *se, - u1db_document *doc, int source_gen, const_char_ptr trans_id) - int u1db__sync_exchange_find_doc_ids_to_return(u1db_sync_exchange *se) - int u1db__sync_exchange_return_docs(u1db_sync_exchange *se, void *context, - int (*cb)(void *context, - u1db_document *doc, int gen, - const_char_ptr trans_id)) - int u1db__create_http_sync_target(char *url, u1db_sync_target **target) - int u1db__create_oauth_http_sync_target(char *url, - char *consumer_key, char *consumer_secret, - char *token_key, char *token_secret, - u1db_sync_target **target) - -cdef extern from "u1db/u1db_http_internal.h": - int u1db__format_sync_url(u1db_sync_target *st, - const_char_ptr source_replica_uid, char **sync_url) - int u1db__get_oauth_authorization(u1db_sync_target *st, - char *http_method, char *url, - char **oauth_authorization) - - -cdef extern from "u1db/u1db_vectorclock.h": - ctypedef struct u1db_vectorclock_item: - char *replica_uid - int generation - - ctypedef struct u1db_vectorclock: - int num_items - u1db_vectorclock_item *items - - u1db_vectorclock *u1db__vectorclock_from_str(char *s) - void u1db__free_vectorclock(u1db_vectorclock **clock) - int u1db__vectorclock_increment(u1db_vectorclock *clock, char *replica_uid) - int u1db__vectorclock_maximize(u1db_vectorclock *clock, - u1db_vectorclock *other) - int u1db__vectorclock_as_str(u1db_vectorclock *clock, char **result) - int u1db__vectorclock_is_newer(u1db_vectorclock *maybe_newer, - u1db_vectorclock *older) - -from u1db import errors -from sqlite3 import dbapi2 - - -cdef int _append_trans_info_to_list(void *context, const_char_ptr doc_id, - int generation, - const_char_ptr trans_id) with gil: - a_list = (context) - doc = doc_id - a_list.append((doc, generation, trans_id)) - return 0 - - -cdef int _append_doc_to_list(void *context, u1db_document *doc) with gil: - a_list = context - pydoc = CDocument() - pydoc._doc = doc - a_list.append(pydoc) - return 0 - -cdef int _append_key_to_list(void *context, int num_fields, - const_char_ptr *key) with gil: - a_list = (context) - field_list = [] - for i from 0 <= i < num_fields: - field = key[i] - field_list.append(field.decode('utf-8')) - a_list.append(tuple(field_list)) - return 0 - -cdef _list_to_array(lst, const_char_ptr **res, int *count): - cdef const_char_ptr *tmp - count[0] = len(lst) - tmp = calloc(sizeof(char*), count[0]) - for idx, x in enumerate(lst): - tmp[idx] = x - res[0] = tmp - -cdef _list_to_str_array(lst, const_char_ptr **res, int *count): - cdef const_char_ptr *tmp - count[0] = len(lst) - tmp = calloc(sizeof(char*), count[0]) - new_objs = [] - for idx, x in enumerate(lst): - if isinstance(x, unicode): - x = x.encode('utf-8') - new_objs.append(x) - tmp[idx] = x - res[0] = tmp - return new_objs - - -cdef int _append_index_definition_to_list(void *context, - const_char_ptr index_name, int n_expressions, - const_char_ptr *expressions) with gil: - cdef int i - - a_list = (context) - exp_list = [] - for i from 0 <= i < n_expressions: - s = expressions[i] - exp_list.append(s.decode('utf-8')) - a_list.append((index_name, exp_list)) - return 0 - - -cdef int return_doc_cb_wrapper(void *context, u1db_document *doc, - int gen, const_char_ptr trans_id) with gil: - cdef CDocument pydoc - user_cb = context - pydoc = CDocument() - pydoc._doc = doc - try: - user_cb(pydoc, gen, trans_id) - except Exception, e: - # We suppress the exception here, because intermediating through the C - # layer gets a bit crazy - return U1DB_INVALID_PARAMETER - return U1DB_OK - - -cdef int _trace_hook(void *context, const_char_ptr state) with gil: - if context == NULL: - return U1DB_INVALID_PARAMETER - ctx = context - try: - ctx(state) - except: - # Note: It would be nice if we could map the Python exception into - # something in C - return U1DB_INTERNAL_ERROR - return U1DB_OK - - -cdef char *_ensure_str(object obj, object extra_objs) except NULL: - """Ensure that we have the UTF-8 representation of a parameter. - - :param obj: A Unicode or String object. - :param extra_objs: This should be a Python list. If we have to convert obj - from being a Unicode object, this will hold the PyString object so that - we know the char* lifetime will be correct. - :return: A C pointer to the UTF-8 representation. - """ - if isinstance(obj, unicode): - obj = obj.encode('utf-8') - extra_objs.append(obj) - return PyString_AsString(obj) - - -def _format_query(fields): - """Wrapper around u1db__format_query for testing.""" - cdef int status - cdef char *buf - cdef int wildcard[10] - cdef const_char_ptr *values - cdef int n_values - - # keep a reference to new_objs so that the pointers in expressions - # remain valid. - new_objs = _list_to_str_array(fields, &values, &n_values) - try: - status = u1db__format_query(n_values, values, &buf, wildcard) - finally: - free(values) - handle_status("format_query", status) - if buf == NULL: - res = None - else: - res = buf - free(buf) - w = [] - for i in range(len(fields)): - w.append(wildcard[i]) - return res, w - - -def make_document(doc_id, rev, content, has_conflicts=False): - cdef u1db_document *doc - cdef char *c_content = NULL, *c_rev = NULL, *c_doc_id = NULL - cdef int conflict - - if has_conflicts: - conflict = 1 - else: - conflict = 0 - if doc_id is None: - c_doc_id = NULL - else: - c_doc_id = doc_id - if content is None: - c_content = NULL - else: - c_content = content - if rev is None: - c_rev = NULL - else: - c_rev = rev - handle_status( - "make_document", - u1db__allocate_document(c_doc_id, c_rev, c_content, conflict, &doc)) - pydoc = CDocument() - pydoc._doc = doc - return pydoc - - -def generate_hex_uuid(): - uuid = PyString_FromStringAndSize(NULL, 32) - handle_status( - "Failed to generate uuid", - u1db__generate_hex_uuid(PyString_AS_STRING(uuid))) - return uuid - - -cdef class CDocument(object): - """A thin wrapper around the C Document struct.""" - - cdef u1db_document *_doc - - def __init__(self): - self._doc = NULL - - def __dealloc__(self): - u1db_free_doc(&self._doc) - - property doc_id: - def __get__(self): - if self._doc.doc_id == NULL: - return None - return PyString_FromStringAndSize( - self._doc.doc_id, self._doc.doc_id_len) - - property rev: - def __get__(self): - if self._doc.doc_rev == NULL: - return None - return PyString_FromStringAndSize( - self._doc.doc_rev, self._doc.doc_rev_len) - - def get_json(self): - if self._doc.json == NULL: - return None - return PyString_FromStringAndSize( - self._doc.json, self._doc.json_len) - - def set_json(self, val): - u1db_doc_set_json(self._doc, val) - - def get_size(self): - return u1db_doc_get_size(self._doc) - - property has_conflicts: - def __get__(self): - if self._doc.has_conflicts: - return True - return False - - def __repr__(self): - if self._doc.has_conflicts: - extra = ', conflicted' - else: - extra = '' - return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, - self.rev, extra, self.get_json()) - - def __hash__(self): - raise NotImplementedError(self.__hash__) - - def __richcmp__(self, other, int t): - try: - if t == 0: # Py_LT < - return ((self.doc_id, self.rev, self.get_json()) - < (other.doc_id, other.rev, other.get_json())) - elif t == 2: # Py_EQ == - return (self.doc_id == other.doc_id - and self.rev == other.rev - and self.get_json() == other.get_json() - and self.has_conflicts == other.has_conflicts) - except AttributeError: - # Fall through to NotImplemented - pass - - return NotImplemented - - -cdef object safe_str(const_char_ptr s): - if s == NULL: - return None - return s - - -cdef class CQuery: - - cdef u1query *_query - - def __init__(self): - self._query = NULL - - def __dealloc__(self): - u1db_free_query(&self._query) - - def _check(self): - if self._query == NULL: - raise RuntimeError("No valid _query.") - - property index_name: - def __get__(self): - self._check() - return safe_str(self._query.index_name) - - property num_fields: - def __get__(self): - self._check() - return self._query.num_fields - - property fields: - def __get__(self): - cdef int i - self._check() - fields = [] - for i from 0 <= i < self._query.num_fields: - fields.append(safe_str(self._query.fields[i])) - return fields - - -cdef handle_status(context, int status): - if status == U1DB_OK: - return - if status == U1DB_REVISION_CONFLICT: - raise errors.RevisionConflict() - if status == U1DB_INVALID_DOC_ID: - raise errors.InvalidDocId() - if status == U1DB_DOCUMENT_ALREADY_DELETED: - raise errors.DocumentAlreadyDeleted() - if status == U1DB_DOCUMENT_DOES_NOT_EXIST: - raise errors.DocumentDoesNotExist() - if status == U1DB_INVALID_PARAMETER: - raise RuntimeError('Bad parameters supplied') - if status == U1DB_NOT_IMPLEMENTED: - raise NotImplementedError("Functionality not implemented yet: %s" - % (context,)) - if status == U1DB_INVALID_VALUE_FOR_INDEX: - raise errors.InvalidValueForIndex() - if status == U1DB_INVALID_GLOBBING: - raise errors.InvalidGlobbing() - if status == U1DB_INTERNAL_ERROR: - raise errors.U1DBError("internal error") - if status == U1DB_BROKEN_SYNC_STREAM: - raise errors.BrokenSyncStream() - if status == U1DB_CONFLICTED: - raise errors.ConflictedDoc() - if status == U1DB_DUPLICATE_INDEX_NAME: - raise errors.IndexNameTakenError() - if status == U1DB_INDEX_DOES_NOT_EXIST: - raise errors.IndexDoesNotExist - if status == U1DB_INVALID_GENERATION: - raise errors.InvalidGeneration - if status == U1DB_INVALID_TRANSACTION_ID: - raise errors.InvalidTransactionId - if status == U1DB_TARGET_UNAVAILABLE: - raise errors.Unavailable - if status == U1DB_INVALID_JSON: - raise errors.InvalidJSON - if status == U1DB_DOCUMENT_TOO_BIG: - raise errors.DocumentTooBig - if status == U1DB_USER_QUOTA_EXCEEDED: - raise errors.UserQuotaExceeded - if status == U1DB_INVALID_TRANSFORMATION_FUNCTION: - raise errors.IndexDefinitionParseError - if status == U1DB_UNKNOWN_OPERATION: - raise errors.IndexDefinitionParseError - if status == U1DB_INVALID_FIELD_SPECIFIER: - raise errors.IndexDefinitionParseError() - raise RuntimeError('%s (status: %s)' % (context, status)) - - -cdef class CDatabase -cdef class CSyncTarget - -cdef class CSyncExchange(object): - - cdef u1db_sync_exchange *_exchange - cdef CSyncTarget _target - - def __init__(self, CSyncTarget target, source_replica_uid, source_gen): - self._target = target - assert self._target._st.get_sync_exchange != NULL, \ - "get_sync_exchange is NULL?" - handle_status("get_sync_exchange", - self._target._st.get_sync_exchange(self._target._st, - source_replica_uid, source_gen, &self._exchange)) - - def __dealloc__(self): - if self._target is not None and self._target._st != NULL: - self._target._st.finalize_sync_exchange(self._target._st, - &self._exchange) - - def _check(self): - if self._exchange == NULL: - raise RuntimeError("self._exchange is NULL") - - property target_gen: - def __get__(self): - self._check() - return self._exchange.target_gen - - def insert_doc_from_source(self, CDocument doc, source_gen, - source_trans_id): - self._check() - handle_status("insert_doc_from_source", - u1db__sync_exchange_insert_doc_from_source(self._exchange, - doc._doc, source_gen, source_trans_id)) - - def find_doc_ids_to_return(self): - self._check() - handle_status("find_doc_ids_to_return", - u1db__sync_exchange_find_doc_ids_to_return(self._exchange)) - - def return_docs(self, return_doc_cb): - self._check() - handle_status("return_docs", - u1db__sync_exchange_return_docs(self._exchange, - return_doc_cb, &return_doc_cb_wrapper)) - - def get_seen_ids(self): - cdef const_char_ptr *seen_ids - cdef int i, n_ids - self._check() - handle_status("sync_exchange_seen_ids", - u1db__sync_exchange_seen_ids(self._exchange, &n_ids, &seen_ids)) - res = [] - for i from 0 <= i < n_ids: - res.append(seen_ids[i]) - if (seen_ids != NULL): - free(seen_ids) - return res - - def get_doc_ids_to_return(self): - self._check() - res = [] - if (self._exchange.num_doc_ids > 0 - and self._exchange.doc_ids_to_return != NULL): - for i from 0 <= i < self._exchange.num_doc_ids: - res.append( - (self._exchange.doc_ids_to_return[i], - self._exchange.gen_for_doc_ids[i], - self._exchange.trans_ids_for_doc_ids[i])) - return res - - -cdef class CSyncTarget(object): - - cdef u1db_sync_target *_st - cdef CDatabase _db - - def __init__(self): - self._db = None - self._st = NULL - u1db__set_zero_delays() - - def __dealloc__(self): - u1db__free_sync_target(&self._st) - - def _check(self): - if self._st == NULL: - raise RuntimeError("self._st is NULL") - - def get_sync_info(self, source_replica_uid): - cdef const_char_ptr st_replica_uid = NULL - cdef int st_gen = 0, source_gen = 0, status - cdef char *trans_id = NULL - cdef char *st_trans_id = NULL - cdef char *c_source_replica_uid = NULL - - self._check() - assert self._st.get_sync_info != NULL, "get_sync_info is NULL?" - c_source_replica_uid = source_replica_uid - with nogil: - status = self._st.get_sync_info(self._st, c_source_replica_uid, - &st_replica_uid, &st_gen, &st_trans_id, &source_gen, &trans_id) - handle_status("get_sync_info", status) - res_trans_id = None - res_st_trans_id = None - if trans_id != NULL: - res_trans_id = trans_id - free(trans_id) - if st_trans_id != NULL: - res_st_trans_id = st_trans_id - free(st_trans_id) - return ( - safe_str(st_replica_uid), st_gen, res_st_trans_id, source_gen, - res_trans_id) - - def record_sync_info(self, source_replica_uid, source_gen, source_trans_id): - cdef int status - cdef int c_source_gen - cdef char *c_source_replica_uid = NULL - cdef char *c_source_trans_id = NULL - - self._check() - assert self._st.record_sync_info != NULL, "record_sync_info is NULL?" - c_source_replica_uid = source_replica_uid - c_source_gen = source_gen - c_source_trans_id = source_trans_id - with nogil: - status = self._st.record_sync_info( - self._st, c_source_replica_uid, c_source_gen, - c_source_trans_id) - handle_status("record_sync_info", status) - - def _get_sync_exchange(self, source_replica_uid, source_gen): - self._check() - return CSyncExchange(self, source_replica_uid, source_gen) - - def sync_exchange_doc_ids(self, source_db, doc_id_generations, - last_known_generation, last_known_trans_id, - return_doc_cb): - cdef const_char_ptr *doc_ids - cdef int *generations - cdef int num_doc_ids - cdef int target_gen - cdef char *target_trans_id = NULL - cdef int status - cdef CDatabase sdb - - self._check() - assert self._st.sync_exchange_doc_ids != NULL, "sync_exchange_doc_ids is NULL?" - sdb = source_db - num_doc_ids = len(doc_id_generations) - doc_ids = calloc(num_doc_ids, sizeof(char *)) - if doc_ids == NULL: - raise MemoryError - generations = calloc(num_doc_ids, sizeof(int)) - if generations == NULL: - free(doc_ids) - raise MemoryError - trans_ids = calloc(num_doc_ids, sizeof(char *)) - if trans_ids == NULL: - raise MemoryError - res_trans_id = '' - try: - for i, (doc_id, gen, trans_id) in enumerate(doc_id_generations): - doc_ids[i] = PyString_AsString(doc_id) - generations[i] = gen - trans_ids[i] = trans_id - target_gen = last_known_generation - if last_known_trans_id is not None: - target_trans_id = last_known_trans_id - with nogil: - status = self._st.sync_exchange_doc_ids(self._st, sdb._db, - num_doc_ids, doc_ids, generations, trans_ids, - &target_gen, &target_trans_id, - return_doc_cb, return_doc_cb_wrapper, NULL) - handle_status("sync_exchange_doc_ids", status) - if target_trans_id != NULL: - res_trans_id = target_trans_id - finally: - if target_trans_id != NULL: - free(target_trans_id) - if doc_ids != NULL: - free(doc_ids) - if generations != NULL: - free(generations) - if trans_ids != NULL: - free(trans_ids) - return target_gen, res_trans_id - - def sync_exchange(self, docs_by_generations, source_replica_uid, - last_known_generation, last_known_trans_id, - return_doc_cb, ensure_callback=None): - cdef CDocument cur_doc - cdef u1db_document **docs = NULL - cdef int *generations = NULL - cdef const_char_ptr *trans_ids = NULL - cdef char *target_trans_id = NULL - cdef char *c_source_replica_uid = NULL - cdef int i, count, status, target_gen - assert ensure_callback is None # interface difference - - self._check() - assert self._st.sync_exchange != NULL, "sync_exchange is NULL?" - count = len(docs_by_generations) - res_trans_id = '' - try: - docs = calloc(count, sizeof(u1db_document*)) - if docs == NULL: - raise MemoryError - generations = calloc(count, sizeof(int)) - if generations == NULL: - raise MemoryError - trans_ids = calloc(count, sizeof(char*)) - if trans_ids == NULL: - raise MemoryError - for i from 0 <= i < count: - cur_doc = docs_by_generations[i][0] - generations[i] = docs_by_generations[i][1] - trans_ids[i] = docs_by_generations[i][2] - docs[i] = cur_doc._doc - target_gen = last_known_generation - if last_known_trans_id is not None: - target_trans_id = last_known_trans_id - c_source_replica_uid = source_replica_uid - with nogil: - status = self._st.sync_exchange( - self._st, c_source_replica_uid, count, docs, generations, - trans_ids, &target_gen, &target_trans_id, - return_doc_cb, return_doc_cb_wrapper, NULL) - handle_status("sync_exchange", status) - finally: - if docs != NULL: - free(docs) - if generations != NULL: - free(generations) - if trans_ids != NULL: - free(trans_ids) - if target_trans_id != NULL: - res_trans_id = target_trans_id - free(target_trans_id) - return target_gen, res_trans_id - - def _set_trace_hook(self, cb): - self._check() - assert self._st._set_trace_hook != NULL, "_set_trace_hook is NULL?" - handle_status("_set_trace_hook", - self._st._set_trace_hook(self._st, cb, _trace_hook)) - - _set_trace_hook_shallow = _set_trace_hook - - -cdef class CDatabase(object): - """A thin wrapper/shim to interact with the C implementation. - - Functionality should not be written here. It is only provided as a way to - expose the C API to the python test suite. - """ - - cdef public object _filename - cdef u1database *_db - cdef public object _supports_indexes - - def __init__(self, filename): - self._supports_indexes = False - self._filename = filename - self._db = u1db_open(self._filename) - - def __dealloc__(self): - u1db_free(&self._db) - - def close(self): - return u1db__sql_close(self._db) - - def _copy(self, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR - # HOUSE. - new_db = CDatabase(':memory:') - u1db_free(&new_db._db) - new_db._db = u1db__copy(self._db) - return new_db - - def _sql_is_open(self): - if self._db == NULL: - return True - return u1db__sql_is_open(self._db) - - property _replica_uid: - def __get__(self): - cdef const_char_ptr val - cdef int status - status = u1db_get_replica_uid(self._db, &val) - if status != 0: - if val != NULL: - err = str(val) - else: - err = "" - raise RuntimeError("Failed to get_replica_uid: %d %s" - % (status, err)) - if val == NULL: - return None - return str(val) - - def _set_replica_uid(self, replica_uid): - cdef int status - status = u1db_set_replica_uid(self._db, replica_uid) - if status != 0: - raise RuntimeError('replica_uid could not be set to %s, error: %d' - % (replica_uid, status)) - - property document_size_limit: - def __get__(self): - cdef int limit - handle_status("document_size_limit", - u1db__get_document_size_limit(self._db, &limit)) - return limit - - def set_document_size_limit(self, limit): - cdef int status - status = u1db_set_document_size_limit(self._db, limit) - if status != 0: - raise RuntimeError( - "document_size_limit could not be set to %d, error: %d", - (limit, status)) - - def _allocate_doc_id(self): - cdef char *val - val = u1db__allocate_doc_id(self._db) - if val == NULL: - raise RuntimeError("Failed to allocate document id") - s = str(val) - free(val) - return s - - def _run_sql(self, sql): - cdef u1db_table *tbl - cdef u1db_row *cur_row - cdef size_t n - cdef int i - - if self._db == NULL: - raise RuntimeError("called _run_sql with a NULL pointer.") - tbl = u1db__sql_run(self._db, sql, len(sql)) - if tbl == NULL: - raise MemoryError("Failed to allocate table memory.") - try: - if tbl.status != 0: - raise RuntimeError("Status was not 0: %d" % (tbl.status,)) - # Now convert the table into python - res = [] - cur_row = tbl.first_row - while cur_row != NULL: - row = [] - for i from 0 <= i < cur_row.num_columns: - row.append(PyString_FromStringAndSize( - (cur_row.columns[i]), cur_row.column_sizes[i])) - res.append(tuple(row)) - cur_row = cur_row.next - return res - finally: - u1db__free_table(&tbl) - - def create_doc_from_json(self, json, doc_id=None): - cdef u1db_document *doc = NULL - cdef char *c_doc_id - - if doc_id is None: - c_doc_id = NULL - else: - c_doc_id = doc_id - handle_status('Failed to create_doc', - u1db_create_doc_from_json(self._db, json, c_doc_id, &doc)) - pydoc = CDocument() - pydoc._doc = doc - return pydoc - - def put_doc(self, CDocument doc): - handle_status("Failed to put_doc", - u1db_put_doc(self._db, doc._doc)) - return doc.rev - - def _validate_source(self, replica_uid, replica_gen, replica_trans_id): - cdef const_char_ptr c_uid, c_trans_id - cdef int c_gen = 0 - - c_uid = replica_uid - c_trans_id = replica_trans_id - c_gen = replica_gen - handle_status( - "invalid generation or transaction id", - u1db__validate_source(self._db, c_uid, c_gen, c_trans_id)) - - def _put_doc_if_newer(self, CDocument doc, save_conflict, replica_uid=None, - replica_gen=None, replica_trans_id=None): - cdef char *c_uid, *c_trans_id - cdef int gen, state = 0, at_gen = -1 - - if replica_uid is None: - c_uid = NULL - else: - c_uid = replica_uid - if replica_trans_id is None: - c_trans_id = NULL - else: - c_trans_id = replica_trans_id - if replica_gen is None: - gen = 0 - else: - gen = replica_gen - handle_status("Failed to _put_doc_if_newer", - u1db__put_doc_if_newer(self._db, doc._doc, save_conflict, - c_uid, gen, c_trans_id, &state, &at_gen)) - if state == U1DB_INSERTED: - return 'inserted', at_gen - elif state == U1DB_SUPERSEDED: - return 'superseded', at_gen - elif state == U1DB_CONVERGED: - return 'converged', at_gen - elif state == U1DB_CONFLICTED: - return 'conflicted', at_gen - else: - raise RuntimeError("Unknown _put_doc_if_newer state: %d" % (state,)) - - def get_doc(self, doc_id, include_deleted=False): - cdef u1db_document *doc = NULL - deleted = 1 if include_deleted else 0 - handle_status("get_doc failed", - u1db_get_doc(self._db, doc_id, deleted, &doc)) - if doc == NULL: - return None - pydoc = CDocument() - pydoc._doc = doc - return pydoc - - def get_docs(self, doc_ids, check_for_conflicts=True, - include_deleted=False): - cdef int n_doc_ids, conflicts - cdef const_char_ptr *c_doc_ids - - _list_to_array(doc_ids, &c_doc_ids, &n_doc_ids) - deleted = 1 if include_deleted else 0 - conflicts = 1 if check_for_conflicts else 0 - a_list = [] - handle_status("get_docs", - u1db_get_docs(self._db, n_doc_ids, c_doc_ids, - conflicts, deleted, a_list, _append_doc_to_list)) - free(c_doc_ids) - return a_list - - def get_all_docs(self, include_deleted=False): - cdef int c_generation - - a_list = [] - deleted = 1 if include_deleted else 0 - generation = 0 - c_generation = generation - handle_status( - "get_all_docs", u1db_get_all_docs( - self._db, deleted, &c_generation, a_list, - _append_doc_to_list)) - return (c_generation, a_list) - - def resolve_doc(self, CDocument doc, conflicted_doc_revs): - cdef const_char_ptr *revs - cdef int n_revs - - _list_to_array(conflicted_doc_revs, &revs, &n_revs) - handle_status("resolve_doc", - u1db_resolve_doc(self._db, doc._doc, n_revs, revs)) - free(revs) - - def get_doc_conflicts(self, doc_id): - conflict_docs = [] - handle_status("get_doc_conflicts", - u1db_get_doc_conflicts(self._db, doc_id, conflict_docs, - _append_doc_to_list)) - return conflict_docs - - def delete_doc(self, CDocument doc): - handle_status( - "Failed to delete %s" % (doc,), - u1db_delete_doc(self._db, doc._doc)) - - def whats_changed(self, generation=0): - cdef int c_generation - cdef int status - cdef char *trans_id = NULL - - a_list = [] - c_generation = generation - res_trans_id = '' - status = u1db_whats_changed(self._db, &c_generation, &trans_id, - a_list, _append_trans_info_to_list) - try: - handle_status("whats_changed", status) - finally: - if trans_id != NULL: - res_trans_id = trans_id - free(trans_id) - return c_generation, res_trans_id, a_list - - def _get_transaction_log(self): - a_list = [] - handle_status("_get_transaction_log", - u1db__get_transaction_log(self._db, a_list, - _append_trans_info_to_list)) - return [(doc_id, trans_id) for doc_id, gen, trans_id in a_list] - - def _get_generation(self): - cdef int generation - handle_status("get_generation", - u1db__get_generation(self._db, &generation)) - return generation - - def _get_generation_info(self): - cdef int generation - cdef char *trans_id - handle_status("get_generation_info", - u1db__get_generation_info(self._db, &generation, &trans_id)) - raw_trans_id = None - if trans_id != NULL: - raw_trans_id = trans_id - free(trans_id) - return generation, raw_trans_id - - def validate_gen_and_trans_id(self, generation, trans_id): - handle_status( - "validate_gen_and_trans_id", - u1db_validate_gen_and_trans_id(self._db, generation, trans_id)) - - def _get_trans_id_for_gen(self, generation): - cdef char *trans_id = NULL - - handle_status( - "_get_trans_id_for_gen", - u1db__get_trans_id_for_gen(self._db, generation, &trans_id)) - raw_trans_id = None - if trans_id != NULL: - raw_trans_id = trans_id - free(trans_id) - return raw_trans_id - - def _get_replica_gen_and_trans_id(self, replica_uid): - cdef int generation, status - cdef char *trans_id = NULL - - status = u1db__get_replica_gen_and_trans_id( - self._db, replica_uid, &generation, &trans_id) - handle_status("_get_replica_gen_and_trans_id", status) - raw_trans_id = None - if trans_id != NULL: - raw_trans_id = trans_id - free(trans_id) - return generation, raw_trans_id - - def _set_replica_gen_and_trans_id(self, replica_uid, generation, trans_id): - handle_status("_set_replica_gen_and_trans_id", - u1db__set_replica_gen_and_trans_id( - self._db, replica_uid, generation, trans_id)) - - def create_index_list(self, index_name, index_expressions): - cdef const_char_ptr *expressions - cdef int n_expressions - - # keep a reference to new_objs so that the pointers in expressions - # remain valid. - new_objs = _list_to_str_array( - index_expressions, &expressions, &n_expressions) - try: - status = u1db_create_index_list( - self._db, index_name, n_expressions, expressions) - finally: - free(expressions) - handle_status("create_index", status) - - def create_index(self, index_name, *index_expressions): - extra = [] - if len(index_expressions) == 0: - status = u1db_create_index(self._db, index_name, 0, NULL) - elif len(index_expressions) == 1: - status = u1db_create_index( - self._db, index_name, 1, - _ensure_str(index_expressions[0], extra)) - elif len(index_expressions) == 2: - status = u1db_create_index( - self._db, index_name, 2, - _ensure_str(index_expressions[0], extra), - _ensure_str(index_expressions[1], extra)) - elif len(index_expressions) == 3: - status = u1db_create_index( - self._db, index_name, 3, - _ensure_str(index_expressions[0], extra), - _ensure_str(index_expressions[1], extra), - _ensure_str(index_expressions[2], extra)) - elif len(index_expressions) == 4: - status = u1db_create_index( - self._db, index_name, 4, - _ensure_str(index_expressions[0], extra), - _ensure_str(index_expressions[1], extra), - _ensure_str(index_expressions[2], extra), - _ensure_str(index_expressions[3], extra)) - else: - status = U1DB_NOT_IMPLEMENTED - handle_status("create_index", status) - - def sync(self, url, creds=None): - cdef const_char_ptr c_url - cdef int local_gen = 0 - cdef u1db_oauth_creds _oauth_creds - cdef u1db_creds *_creds = NULL - c_url = url - if creds is not None: - _oauth_creds.auth_kind = U1DB_OAUTH_AUTH - _oauth_creds.consumer_key = creds['oauth']['consumer_key'] - _oauth_creds.consumer_secret = creds['oauth']['consumer_secret'] - _oauth_creds.token_key = creds['oauth']['token_key'] - _oauth_creds.token_secret = creds['oauth']['token_secret'] - _creds = &_oauth_creds - with nogil: - status = u1db_sync(self._db, c_url, _creds, &local_gen) - handle_status("sync", status) - return local_gen - - def list_indexes(self): - a_list = [] - handle_status("list_indexes", - u1db_list_indexes(self._db, a_list, - _append_index_definition_to_list)) - return a_list - - def delete_index(self, index_name): - handle_status("delete_index", - u1db_delete_index(self._db, index_name)) - - def get_from_index_list(self, index_name, key_values): - cdef const_char_ptr *values - cdef int n_values - cdef CQuery query - - query = self._query_init(index_name) - res = [] - # keep a reference to new_objs so that the pointers in expressions - # remain valid. - new_objs = _list_to_str_array(key_values, &values, &n_values) - try: - handle_status( - "get_from_index", u1db_get_from_index_list( - self._db, query._query, res, _append_doc_to_list, - n_values, values)) - finally: - free(values) - return res - - def get_from_index(self, index_name, *key_values): - cdef CQuery query - cdef int status - - extra = [] - query = self._query_init(index_name) - res = [] - status = U1DB_OK - if len(key_values) == 0: - status = u1db_get_from_index(self._db, query._query, - res, _append_doc_to_list, 0, NULL) - elif len(key_values) == 1: - status = u1db_get_from_index(self._db, query._query, - res, _append_doc_to_list, 1, - _ensure_str(key_values[0], extra)) - elif len(key_values) == 2: - status = u1db_get_from_index(self._db, query._query, - res, _append_doc_to_list, 2, - _ensure_str(key_values[0], extra), - _ensure_str(key_values[1], extra)) - elif len(key_values) == 3: - status = u1db_get_from_index(self._db, query._query, - res, _append_doc_to_list, 3, - _ensure_str(key_values[0], extra), - _ensure_str(key_values[1], extra), - _ensure_str(key_values[2], extra)) - elif len(key_values) == 4: - status = u1db_get_from_index(self._db, query._query, - res, _append_doc_to_list, 4, - _ensure_str(key_values[0], extra), - _ensure_str(key_values[1], extra), - _ensure_str(key_values[2], extra), - _ensure_str(key_values[3], extra)) - else: - status = U1DB_NOT_IMPLEMENTED - handle_status("get_from_index", status) - return res - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - cdef CQuery query - cdef const_char_ptr *start_values - cdef int n_values - cdef const_char_ptr *end_values - - if start_value is not None: - if isinstance(start_value, basestring): - start_value = (start_value,) - new_objs_1 = _list_to_str_array( - start_value, &start_values, &n_values) - else: - n_values = 0 - start_values = NULL - if end_value is not None: - if isinstance(end_value, basestring): - end_value = (end_value,) - new_objs_2 = _list_to_str_array( - end_value, &end_values, &n_values) - else: - end_values = NULL - query = self._query_init(index_name) - res = [] - try: - handle_status("get_range_from_index", - u1db_get_range_from_index( - self._db, query._query, res, _append_doc_to_list, - n_values, start_values, end_values)) - finally: - if start_values != NULL: - free(start_values) - if end_values != NULL: - free(end_values) - return res - - def get_index_keys(self, index_name): - cdef int status - keys = [] - status = U1DB_OK - status = u1db_get_index_keys( - self._db, index_name, keys, _append_key_to_list) - handle_status("get_index_keys", status) - return keys - - def _query_init(self, index_name): - cdef CQuery query - query = CQuery() - handle_status("query_init", - u1db_query_init(self._db, index_name, &query._query)) - return query - - def get_sync_target(self): - cdef CSyncTarget target - target = CSyncTarget() - target._db = self - handle_status("get_sync_target", - u1db__get_sync_target(target._db._db, &target._st)) - return target - - -cdef class VectorClockRev: - - cdef u1db_vectorclock *_clock - - def __init__(self, s): - if s is None: - self._clock = u1db__vectorclock_from_str(NULL) - else: - self._clock = u1db__vectorclock_from_str(s) - - def __dealloc__(self): - u1db__free_vectorclock(&self._clock) - - def __repr__(self): - cdef int status - cdef char *res - if self._clock == NULL: - return '%s(None)' % (self.__class__.__name__,) - status = u1db__vectorclock_as_str(self._clock, &res) - if status != U1DB_OK: - return '%s()' % (status,) - if res == NULL: - val = '%s(NULL)' % (self.__class__.__name__,) - else: - val = '%s(%s)' % (self.__class__.__name__, res) - free(res) - return val - - def as_dict(self): - cdef u1db_vectorclock *cur - cdef int i - cdef int gen - if self._clock == NULL: - return None - res = {} - for i from 0 <= i < self._clock.num_items: - gen = self._clock.items[i].generation - res[self._clock.items[i].replica_uid] = gen - return res - - def as_str(self): - cdef int status - cdef char *res - - status = u1db__vectorclock_as_str(self._clock, &res) - if status != U1DB_OK: - raise RuntimeError("Failed to VectorClockRev.as_str(): %d" % (status,)) - if res == NULL: - s = None - else: - s = res - free(res) - return s - - def increment(self, replica_uid): - cdef int status - - status = u1db__vectorclock_increment(self._clock, replica_uid) - if status != U1DB_OK: - raise RuntimeError("Failed to increment: %d" % (status,)) - - def maximize(self, vcr): - cdef int status - cdef VectorClockRev other - - other = vcr - status = u1db__vectorclock_maximize(self._clock, other._clock) - if status != U1DB_OK: - raise RuntimeError("Failed to maximize: %d" % (status,)) - - def is_newer(self, vcr): - cdef int is_newer - cdef VectorClockRev other - - other = vcr - is_newer = u1db__vectorclock_is_newer(self._clock, other._clock) - if is_newer == 0: - return False - elif is_newer == 1: - return True - else: - raise RuntimeError("Failed to is_newer: %d" % (is_newer,)) - - -def sync_db_to_target(db, target): - """Sync the data between a CDatabase and a CSyncTarget""" - cdef CDatabase cdb - cdef CSyncTarget ctarget - cdef int local_gen = 0, status - - cdb = db - ctarget = target - with nogil: - status = u1db__sync_db_to_target(cdb._db, ctarget._st, &local_gen) - handle_status("sync_db_to_target", status) - return local_gen - - -def create_http_sync_target(url): - cdef CSyncTarget target - - target = CSyncTarget() - handle_status("create_http_sync_target", - u1db__create_http_sync_target(url, &target._st)) - return target - - -def create_oauth_http_sync_target(url, consumer_key, consumer_secret, - token_key, token_secret): - cdef CSyncTarget target - - target = CSyncTarget() - handle_status("create_http_sync_target", - u1db__create_oauth_http_sync_target(url, consumer_key, consumer_secret, - token_key, token_secret, - &target._st)) - return target - - -def _format_sync_url(target, source_replica_uid): - cdef CSyncTarget st - cdef char *sync_url = NULL - cdef object res - st = target - handle_status("format_sync_url", - u1db__format_sync_url(st._st, source_replica_uid, &sync_url)) - if sync_url == NULL: - res = None - else: - res = sync_url - free(sync_url) - return res - - -def _get_oauth_authorization(target, method, url): - cdef CSyncTarget st - cdef char *auth = NULL - - st = target - handle_status("get_oauth_authorization", - u1db__get_oauth_authorization(st._st, method, url, &auth)) - res = None - if auth != NULL: - res = auth - free(auth) - return res diff --git a/src/leap/soledad/u1db/tests/commandline/__init__.py b/src/leap/soledad/u1db/tests/commandline/__init__.py deleted file mode 100644 index 007cecd3..00000000 --- a/src/leap/soledad/u1db/tests/commandline/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -import errno -import time - - -def safe_close(process, timeout=0.1): - """Shutdown the process in the nicest fashion you can manage. - - :param process: A subprocess.Popen object. - :param timeout: We'll try to send 'SIGTERM' but if the process is alive - longer that 'timeout', we'll send SIGKILL. - """ - if process.poll() is not None: - return - try: - process.terminate() - except OSError, e: - if e.errno in (errno.ESRCH,): - # Process has exited - return - tend = time.time() + timeout - while time.time() < tend: - if process.poll() is not None: - return - time.sleep(0.01) - try: - process.kill() - except OSError, e: - if e.errno in (errno.ESRCH,): - # Process has exited - return - process.wait() diff --git a/src/leap/soledad/u1db/tests/commandline/test_client.py b/src/leap/soledad/u1db/tests/commandline/test_client.py deleted file mode 100644 index 78ca21eb..00000000 --- a/src/leap/soledad/u1db/tests/commandline/test_client.py +++ /dev/null @@ -1,916 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -import cStringIO -import os -import sys -try: - import simplejson as json -except ImportError: - import json # noqa -import subprocess - -from u1db import ( - errors, - open as u1db_open, - tests, - vectorclock, - ) -from u1db.commandline import ( - client, - serve, - ) -from u1db.tests.commandline import safe_close -from u1db.tests import test_remote_sync_target - - -class TestArgs(tests.TestCase): - """These tests are meant to test just the argument parsing. - - Each Command should have at least one test, possibly more if it allows - optional arguments, etc. - """ - - def setUp(self): - super(TestArgs, self).setUp() - self.parser = client.client_commands.make_argparser() - - def parse_args(self, args): - # ArgumentParser.parse_args doesn't play very nicely with a test suite, - # so we trap SystemExit in case something is wrong with the args we're - # parsing. - try: - return self.parser.parse_args(args) - except SystemExit: - raise AssertionError('got SystemExit') - - def test_create(self): - args = self.parse_args(['create', 'test.db']) - self.assertEqual(client.CmdCreate, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual(None, args.doc_id) - self.assertEqual(None, args.infile) - - def test_create_custom_doc_id(self): - args = self.parse_args(['create', '--id', 'xyz', 'test.db']) - self.assertEqual(client.CmdCreate, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual('xyz', args.doc_id) - self.assertEqual(None, args.infile) - - def test_delete(self): - args = self.parse_args(['delete', 'test.db', 'doc-id', 'doc-rev']) - self.assertEqual(client.CmdDelete, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual('doc-id', args.doc_id) - self.assertEqual('doc-rev', args.doc_rev) - - def test_get(self): - args = self.parse_args(['get', 'test.db', 'doc-id']) - self.assertEqual(client.CmdGet, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual('doc-id', args.doc_id) - self.assertEqual(None, args.outfile) - - def test_get_dash(self): - args = self.parse_args(['get', 'test.db', 'doc-id', '-']) - self.assertEqual(client.CmdGet, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual('doc-id', args.doc_id) - self.assertEqual(sys.stdout, args.outfile) - - def test_init_db(self): - args = self.parse_args( - ['init-db', 'test.db', '--replica-uid=replica-uid']) - self.assertEqual(client.CmdInitDB, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual('replica-uid', args.replica_uid) - - def test_init_db_no_replica(self): - args = self.parse_args(['init-db', 'test.db']) - self.assertEqual(client.CmdInitDB, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertIs(None, args.replica_uid) - - def test_put(self): - args = self.parse_args(['put', 'test.db', 'doc-id', 'old-doc-rev']) - self.assertEqual(client.CmdPut, args.subcommand) - self.assertEqual('test.db', args.database) - self.assertEqual('doc-id', args.doc_id) - self.assertEqual('old-doc-rev', args.doc_rev) - self.assertEqual(None, args.infile) - - def test_sync(self): - args = self.parse_args(['sync', 'source', 'target']) - self.assertEqual(client.CmdSync, args.subcommand) - self.assertEqual('source', args.source) - self.assertEqual('target', args.target) - - def test_create_index(self): - args = self.parse_args(['create-index', 'db', 'index', 'expression']) - self.assertEqual(client.CmdCreateIndex, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('index', args.index) - self.assertEqual(['expression'], args.expression) - - def test_create_index_multi_expression(self): - args = self.parse_args(['create-index', 'db', 'index', 'e1', 'e2']) - self.assertEqual(client.CmdCreateIndex, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('index', args.index) - self.assertEqual(['e1', 'e2'], args.expression) - - def test_list_indexes(self): - args = self.parse_args(['list-indexes', 'db']) - self.assertEqual(client.CmdListIndexes, args.subcommand) - self.assertEqual('db', args.database) - - def test_delete_index(self): - args = self.parse_args(['delete-index', 'db', 'index']) - self.assertEqual(client.CmdDeleteIndex, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('index', args.index) - - def test_get_index_keys(self): - args = self.parse_args(['get-index-keys', 'db', 'index']) - self.assertEqual(client.CmdGetIndexKeys, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('index', args.index) - - def test_get_from_index(self): - args = self.parse_args(['get-from-index', 'db', 'index', 'foo']) - self.assertEqual(client.CmdGetFromIndex, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('index', args.index) - self.assertEqual(['foo'], args.values) - - def test_get_doc_conflicts(self): - args = self.parse_args(['get-doc-conflicts', 'db', 'doc-id']) - self.assertEqual(client.CmdGetDocConflicts, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('doc-id', args.doc_id) - - def test_resolve(self): - args = self.parse_args( - ['resolve-doc', 'db', 'doc-id', 'rev:1', 'other:1']) - self.assertEqual(client.CmdResolve, args.subcommand) - self.assertEqual('db', args.database) - self.assertEqual('doc-id', args.doc_id) - self.assertEqual(['rev:1', 'other:1'], args.doc_revs) - self.assertEqual(None, args.infile) - - -class TestCaseWithDB(tests.TestCase): - """These next tests are meant to have one class per Command. - - It is meant to test the inner workings of each command. The detailed - testing should happen in these classes. Stuff like how it handles errors, - etc. should be done here. - """ - - def setUp(self): - super(TestCaseWithDB, self).setUp() - self.working_dir = self.createTempDir() - self.db_path = self.working_dir + '/test.db' - self.db = u1db_open(self.db_path, create=True) - self.db._set_replica_uid('test') - self.addCleanup(self.db.close) - - def make_command(self, cls, stdin_content=''): - inf = cStringIO.StringIO(stdin_content) - out = cStringIO.StringIO() - err = cStringIO.StringIO() - return cls(inf, out, err) - - -class TestCmdCreate(TestCaseWithDB): - - def test_create(self): - cmd = self.make_command(client.CmdCreate) - inf = cStringIO.StringIO(tests.simple_doc) - cmd.run(self.db_path, inf, 'test-id') - doc = self.db.get_doc('test-id') - self.assertEqual(tests.simple_doc, doc.get_json()) - self.assertFalse(doc.has_conflicts) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('id: test-id\nrev: %s\n' % (doc.rev,), - cmd.stderr.getvalue()) - - -class TestCmdDelete(TestCaseWithDB): - - def test_delete(self): - doc = self.db.create_doc_from_json(tests.simple_doc) - cmd = self.make_command(client.CmdDelete) - cmd.run(self.db_path, doc.doc_id, doc.rev) - doc2 = self.db.get_doc(doc.doc_id, include_deleted=True) - self.assertEqual(doc.doc_id, doc2.doc_id) - self.assertNotEqual(doc.rev, doc2.rev) - self.assertIs(None, doc2.get_json()) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('rev: %s\n' % (doc2.rev,), cmd.stderr.getvalue()) - - def test_delete_fails_if_nonexistent(self): - doc = self.db.create_doc_from_json(tests.simple_doc) - db2_path = self.db_path + '.typo' - cmd = self.make_command(client.CmdDelete) - # TODO: We should really not be showing a traceback here. But we need - # to teach the commandline infrastructure how to handle - # exceptions. - # However, we *do* want to test that the db doesn't get created - # by accident. - self.assertRaises(errors.DatabaseDoesNotExist, - cmd.run, db2_path, doc.doc_id, doc.rev) - self.assertFalse(os.path.exists(db2_path)) - - def test_delete_no_such_doc(self): - cmd = self.make_command(client.CmdDelete) - # TODO: We should really not be showing a traceback here. But we need - # to teach the commandline infrastructure how to handle - # exceptions. - self.assertRaises(errors.DocumentDoesNotExist, - cmd.run, self.db_path, 'no-doc-id', 'no-rev') - - def test_delete_bad_rev(self): - doc = self.db.create_doc_from_json(tests.simple_doc) - cmd = self.make_command(client.CmdDelete) - self.assertRaises(errors.RevisionConflict, - cmd.run, self.db_path, doc.doc_id, 'not-the-actual-doc-rev:1') - # TODO: Test that we get a pretty output. - - -class TestCmdGet(TestCaseWithDB): - - def setUp(self): - super(TestCmdGet, self).setUp() - self.doc = self.db.create_doc_from_json( - tests.simple_doc, doc_id='my-test-doc') - - def test_get_simple(self): - cmd = self.make_command(client.CmdGet) - cmd.run(self.db_path, 'my-test-doc', None) - self.assertEqual(tests.simple_doc + "\n", cmd.stdout.getvalue()) - self.assertEqual('rev: %s\n' % (self.doc.rev,), - cmd.stderr.getvalue()) - - def test_get_conflict(self): - doc = self.make_document('my-test-doc', 'other:1', '{}', False) - self.db._put_doc_if_newer( - doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - cmd = self.make_command(client.CmdGet) - cmd.run(self.db_path, 'my-test-doc', None) - self.assertEqual('{}\n', cmd.stdout.getvalue()) - self.assertEqual('rev: %s\nDocument has conflicts.\n' % (doc.rev,), - cmd.stderr.getvalue()) - - def test_get_fail(self): - cmd = self.make_command(client.CmdGet) - result = cmd.run(self.db_path, 'doc-not-there', None) - self.assertEqual(1, result) - self.assertEqual("", cmd.stdout.getvalue()) - self.assertTrue("not found" in cmd.stderr.getvalue()) - - def test_get_no_database(self): - cmd = self.make_command(client.CmdGet) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", None) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - -class TestCmdGetDocConflicts(TestCaseWithDB): - - def setUp(self): - super(TestCmdGetDocConflicts, self).setUp() - self.doc1 = self.db.create_doc_from_json( - tests.simple_doc, doc_id='my-doc') - self.doc2 = self.make_document('my-doc', 'other:1', '{}', False) - self.db._put_doc_if_newer( - self.doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - - def test_get_doc_conflicts_none(self): - self.db.create_doc_from_json(tests.simple_doc, doc_id='a-doc') - cmd = self.make_command(client.CmdGetDocConflicts) - cmd.run(self.db_path, 'a-doc') - self.assertEqual([], json.loads(cmd.stdout.getvalue())) - self.assertEqual('', cmd.stderr.getvalue()) - - def test_get_doc_conflicts_simple(self): - cmd = self.make_command(client.CmdGetDocConflicts) - cmd.run(self.db_path, 'my-doc') - self.assertEqual( - [dict(rev=self.doc2.rev, content=self.doc2.content), - dict(rev=self.doc1.rev, content=self.doc1.content)], - json.loads(cmd.stdout.getvalue())) - self.assertEqual('', cmd.stderr.getvalue()) - - def test_get_doc_conflicts_no_db(self): - cmd = self.make_command(client.CmdGetDocConflicts) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc") - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - def test_get_doc_conflicts_no_doc(self): - cmd = self.make_command(client.CmdGetDocConflicts) - retval = cmd.run(self.db_path, "some-doc") - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n') - - -class TestCmdInit(TestCaseWithDB): - - def test_init_new(self): - path = self.working_dir + '/test2.db' - self.assertFalse(os.path.exists(path)) - cmd = self.make_command(client.CmdInitDB) - cmd.run(path, 'test-uid') - self.assertTrue(os.path.exists(path)) - db = u1db_open(path, create=False) - self.assertEqual('test-uid', db._replica_uid) - - def test_init_no_uid(self): - path = self.working_dir + '/test2.db' - cmd = self.make_command(client.CmdInitDB) - cmd.run(path, None) - self.assertTrue(os.path.exists(path)) - db = u1db_open(path, create=False) - self.assertIsNot(None, db._replica_uid) - - -class TestCmdPut(TestCaseWithDB): - - def setUp(self): - super(TestCmdPut, self).setUp() - self.doc = self.db.create_doc_from_json( - tests.simple_doc, doc_id='my-test-doc') - - def test_put_simple(self): - cmd = self.make_command(client.CmdPut) - inf = cStringIO.StringIO(tests.nested_doc) - cmd.run(self.db_path, 'my-test-doc', self.doc.rev, inf) - doc = self.db.get_doc('my-test-doc') - self.assertNotEqual(self.doc.rev, doc.rev) - self.assertGetDoc(self.db, 'my-test-doc', doc.rev, - tests.nested_doc, False) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('rev: %s\n' % (doc.rev,), - cmd.stderr.getvalue()) - - def test_put_no_db(self): - cmd = self.make_command(client.CmdPut) - inf = cStringIO.StringIO(tests.nested_doc) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", - 'my-test-doc', self.doc.rev, inf) - self.assertEqual(retval, 1) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('Database does not exist.\n', cmd.stderr.getvalue()) - - def test_put_no_doc(self): - cmd = self.make_command(client.CmdPut) - inf = cStringIO.StringIO(tests.nested_doc) - retval = cmd.run(self.db_path, 'no-such-doc', 'wut:1', inf) - self.assertEqual(1, retval) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('Document does not exist.\n', cmd.stderr.getvalue()) - - def test_put_doc_old_rev(self): - rev = self.doc.rev - doc = self.make_document('my-test-doc', rev, '{}', False) - self.db.put_doc(doc) - cmd = self.make_command(client.CmdPut) - inf = cStringIO.StringIO(tests.nested_doc) - retval = cmd.run(self.db_path, 'my-test-doc', rev, inf) - self.assertEqual(1, retval) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('Given revision is not current.\n', - cmd.stderr.getvalue()) - - def test_put_doc_w_conflicts(self): - doc = self.make_document('my-test-doc', 'other:1', '{}', False) - self.db._put_doc_if_newer( - doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - cmd = self.make_command(client.CmdPut) - inf = cStringIO.StringIO(tests.nested_doc) - retval = cmd.run(self.db_path, 'my-test-doc', 'other:1', inf) - self.assertEqual(1, retval) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('Document has conflicts.\n' - 'Inspect with get-doc-conflicts, then resolve.\n', - cmd.stderr.getvalue()) - - -class TestCmdResolve(TestCaseWithDB): - - def setUp(self): - super(TestCmdResolve, self).setUp() - self.doc1 = self.db.create_doc_from_json( - tests.simple_doc, doc_id='my-doc') - self.doc2 = self.make_document('my-doc', 'other:1', '{}', False) - self.db._put_doc_if_newer( - self.doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - - def test_resolve_simple(self): - self.assertTrue(self.db.get_doc('my-doc').has_conflicts) - cmd = self.make_command(client.CmdResolve) - inf = cStringIO.StringIO(tests.nested_doc) - cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf) - doc = self.db.get_doc('my-doc') - vec = vectorclock.VectorClockRev(doc.rev) - self.assertTrue( - vec.is_newer(vectorclock.VectorClockRev(self.doc1.rev))) - self.assertTrue( - vec.is_newer(vectorclock.VectorClockRev(self.doc2.rev))) - self.assertGetDoc(self.db, 'my-doc', doc.rev, tests.nested_doc, False) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual('rev: %s\n' % (doc.rev,), - cmd.stderr.getvalue()) - - def test_resolve_double(self): - moar = '{"x": 42}' - doc3 = self.make_document('my-doc', 'third:1', moar, False) - self.db._put_doc_if_newer( - doc3, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - cmd = self.make_command(client.CmdResolve) - inf = cStringIO.StringIO(tests.nested_doc) - cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf) - doc = self.db.get_doc('my-doc') - self.assertGetDoc(self.db, 'my-doc', doc.rev, moar, True) - self.assertEqual('', cmd.stdout.getvalue()) - self.assertEqual( - 'rev: %s\nDocument still has conflicts.\n' % (doc.rev,), - cmd.stderr.getvalue()) - - def test_resolve_no_db(self): - cmd = self.make_command(client.CmdResolve) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", [], None) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - def test_resolve_no_doc(self): - cmd = self.make_command(client.CmdResolve) - retval = cmd.run(self.db_path, "foo", [], None) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n') - - -class TestCmdSync(TestCaseWithDB): - - def setUp(self): - super(TestCmdSync, self).setUp() - self.db2_path = self.working_dir + '/test2.db' - self.db2 = u1db_open(self.db2_path, create=True) - self.addCleanup(self.db2.close) - self.db2._set_replica_uid('test2') - self.doc = self.db.create_doc_from_json( - tests.simple_doc, doc_id='test-id') - self.doc2 = self.db2.create_doc_from_json( - tests.nested_doc, doc_id='my-test-id') - - def test_sync(self): - cmd = self.make_command(client.CmdSync) - cmd.run(self.db_path, self.db2_path) - self.assertGetDoc(self.db2, 'test-id', self.doc.rev, tests.simple_doc, - False) - self.assertGetDoc(self.db, 'my-test-id', self.doc2.rev, - tests.nested_doc, False) - - -class TestCmdSyncRemote(tests.TestCaseWithServer, TestCaseWithDB): - - make_app_with_state = \ - staticmethod(test_remote_sync_target.make_http_app) - - def setUp(self): - super(TestCmdSyncRemote, self).setUp() - self.startServer() - self.db2 = self.request_state._create_database('test2.db') - - def test_sync_remote(self): - doc1 = self.db.create_doc_from_json(tests.simple_doc) - doc2 = self.db2.create_doc_from_json(tests.nested_doc) - db2_url = self.getURL('test2.db') - self.assertTrue(db2_url.startswith('http://')) - self.assertTrue(db2_url.endswith('/test2.db')) - cmd = self.make_command(client.CmdSync) - cmd.run(self.db_path, db2_url) - self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc, - False) - self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, - False) - - -class TestCmdCreateIndex(TestCaseWithDB): - - def test_create_index(self): - cmd = self.make_command(client.CmdCreateIndex) - retval = cmd.run(self.db_path, "foo", ["bar", "baz"]) - self.assertEqual(self.db.list_indexes(), [('foo', ['bar', "baz"])]) - self.assertEqual(retval, None) # conveniently mapped to 0 - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_create_index_no_db(self): - cmd = self.make_command(client.CmdCreateIndex) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", ["bar"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - def test_create_dupe_index(self): - self.db.create_index("foo", "bar") - cmd = self.make_command(client.CmdCreateIndex) - retval = cmd.run(self.db_path, "foo", ["bar"]) - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_create_dupe_index_different_expression(self): - self.db.create_index("foo", "bar") - cmd = self.make_command(client.CmdCreateIndex) - retval = cmd.run(self.db_path, "foo", ["baz"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), - "There is already a different index named 'foo'.\n") - - def test_create_index_bad_expression(self): - cmd = self.make_command(client.CmdCreateIndex) - retval = cmd.run(self.db_path, "foo", ["WAT()"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), - 'Bad index expression.\n') - - -class TestCmdListIndexes(TestCaseWithDB): - - def test_list_no_indexes(self): - cmd = self.make_command(client.CmdListIndexes) - retval = cmd.run(self.db_path) - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_list_indexes(self): - self.db.create_index("foo", "bar", "baz") - cmd = self.make_command(client.CmdListIndexes) - retval = cmd.run(self.db_path) - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), 'foo: bar, baz\n') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_list_several_indexes(self): - self.db.create_index("foo", "bar", "baz") - self.db.create_index("bar", "baz", "foo") - self.db.create_index("baz", "foo", "bar") - cmd = self.make_command(client.CmdListIndexes) - retval = cmd.run(self.db_path) - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), - 'bar: baz, foo\n' - 'baz: foo, bar\n' - 'foo: bar, baz\n' - ) - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_list_indexes_no_db(self): - cmd = self.make_command(client.CmdListIndexes) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST") - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - -class TestCmdDeleteIndex(TestCaseWithDB): - - def test_delete_index(self): - self.db.create_index("foo", "bar", "baz") - cmd = self.make_command(client.CmdDeleteIndex) - retval = cmd.run(self.db_path, "foo") - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), '') - self.assertEqual([], self.db.list_indexes()) - - def test_delete_index_no_db(self): - cmd = self.make_command(client.CmdDeleteIndex) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo") - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - def test_delete_index_no_index(self): - cmd = self.make_command(client.CmdDeleteIndex) - retval = cmd.run(self.db_path, "foo") - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), '') - - -class TestCmdGetIndexKeys(TestCaseWithDB): - - def test_get_index_keys(self): - self.db.create_index("foo", "bar") - self.db.create_doc_from_json('{"bar": 42}') - cmd = self.make_command(client.CmdGetIndexKeys) - retval = cmd.run(self.db_path, "foo") - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '42\n') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_get_index_keys_nonascii(self): - self.db.create_index("foo", "bar") - self.db.create_doc_from_json('{"bar": "\u00a4"}') - cmd = self.make_command(client.CmdGetIndexKeys) - retval = cmd.run(self.db_path, "foo") - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '\xc2\xa4\n') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_get_index_keys_empty(self): - self.db.create_index("foo", "bar") - cmd = self.make_command(client.CmdGetIndexKeys) - retval = cmd.run(self.db_path, "foo") - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_get_index_keys_no_db(self): - cmd = self.make_command(client.CmdGetIndexKeys) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo") - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - def test_get_index_keys_no_index(self): - cmd = self.make_command(client.CmdGetIndexKeys) - retval = cmd.run(self.db_path, "foo") - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n') - - -class TestCmdGetFromIndex(TestCaseWithDB): - - def test_get_from_index(self): - self.db.create_index("index", "key") - doc1 = self.db.create_doc_from_json(tests.simple_doc) - doc2 = self.db.create_doc_from_json(tests.nested_doc) - cmd = self.make_command(client.CmdGetFromIndex) - retval = cmd.run(self.db_path, "index", ["value"]) - self.assertEqual(retval, None) - self.assertEqual(sorted(json.loads(cmd.stdout.getvalue())), - sorted([dict(id=doc1.doc_id, - rev=doc1.rev, - content=doc1.content), - dict(id=doc2.doc_id, - rev=doc2.rev, - content=doc2.content), - ])) - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_get_from_index_empty(self): - self.db.create_index("index", "key") - cmd = self.make_command(client.CmdGetFromIndex) - retval = cmd.run(self.db_path, "index", ["value"]) - self.assertEqual(retval, None) - self.assertEqual(cmd.stdout.getvalue(), '[]\n') - self.assertEqual(cmd.stderr.getvalue(), '') - - def test_get_from_index_no_db(self): - cmd = self.make_command(client.CmdGetFromIndex) - retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", []) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') - - def test_get_from_index_no_index(self): - cmd = self.make_command(client.CmdGetFromIndex) - retval = cmd.run(self.db_path, "foo", []) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n') - - def test_get_from_index_two_expr_instead_of_one(self): - self.db.create_index("index", "key1") - cmd = self.make_command(client.CmdGetFromIndex) - cmd.argv = ["XX", "YY"] - retval = cmd.run(self.db_path, "index", ["value1", "value2"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual("Invalid query: index 'index' requires" - " 1 query expression, not 2.\n" - "For example, the following would be valid:\n" - " XX YY %r 'index' 'value1'\n" - % self.db_path, cmd.stderr.getvalue()) - - def test_get_from_index_three_expr_instead_of_two(self): - self.db.create_index("index", "key1", "key2") - cmd = self.make_command(client.CmdGetFromIndex) - cmd.argv = ["XX", "YY"] - retval = cmd.run(self.db_path, "index", ["value1", "value2", "value3"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual("Invalid query: index 'index' requires" - " 2 query expressions, not 3.\n" - "For example, the following would be valid:\n" - " XX YY %r 'index' 'value1' 'value2'\n" - % self.db_path, cmd.stderr.getvalue()) - - def test_get_from_index_one_expr_instead_of_two(self): - self.db.create_index("index", "key1", "key2") - cmd = self.make_command(client.CmdGetFromIndex) - cmd.argv = ["XX", "YY"] - retval = cmd.run(self.db_path, "index", ["value1"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual("Invalid query: index 'index' requires" - " 2 query expressions, not 1.\n" - "For example, the following would be valid:\n" - " XX YY %r 'index' 'value1' '*'\n" - % self.db_path, cmd.stderr.getvalue()) - - def test_get_from_index_cant_bad_glob(self): - self.db.create_index("index", "key1", "key2") - cmd = self.make_command(client.CmdGetFromIndex) - cmd.argv = ["XX", "YY"] - retval = cmd.run(self.db_path, "index", ["value1*", "value2"]) - self.assertEqual(retval, 1) - self.assertEqual(cmd.stdout.getvalue(), '') - self.assertEqual("Invalid query:" - " a star can only be followed by stars.\n" - "For example, the following would be valid:\n" - " XX YY %r 'index' 'value1*' '*'\n" - % self.db_path, cmd.stderr.getvalue()) - - -class RunMainHelper(object): - - def run_main(self, args, stdin=None): - if stdin is not None: - self.patch(sys, 'stdin', cStringIO.StringIO(stdin)) - stdout = cStringIO.StringIO() - stderr = cStringIO.StringIO() - self.patch(sys, 'stdout', stdout) - self.patch(sys, 'stderr', stderr) - try: - ret = client.main(args) - except SystemExit, e: - self.fail("Intercepted SystemExit: %s" % (e,)) - if ret is None: - ret = 0 - return ret, stdout.getvalue(), stderr.getvalue() - - -class TestCommandLine(TestCaseWithDB, RunMainHelper): - """These are meant to test that the infrastructure is fully connected. - - Each command is likely to only have one test here. Something that ensures - 'main()' knows about and can run the command correctly. Most logic-level - testing of the Command should go into its own test class above. - """ - - def _get_u1db_client_path(self): - from u1db import __path__ as u1db_path - u1db_parent_dir = os.path.dirname(u1db_path[0]) - return os.path.join(u1db_parent_dir, 'u1db-client') - - def runU1DBClient(self, args): - command = [sys.executable, self._get_u1db_client_path()] - command.extend(args) - p = subprocess.Popen(command, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - self.addCleanup(safe_close, p) - return p - - def test_create_subprocess(self): - p = self.runU1DBClient(['create', '--id', 'test-id', self.db_path]) - stdout, stderr = p.communicate(tests.simple_doc) - self.assertEqual(0, p.returncode) - self.assertEqual('', stdout) - doc = self.db.get_doc('test-id') - self.assertEqual(tests.simple_doc, doc.get_json()) - self.assertFalse(doc.has_conflicts) - expected = 'id: test-id\nrev: %s\n' % (doc.rev,) - stripped = stderr.replace('\r\n', '\n') - if expected != stripped: - # When run under python-dbg, it prints out the refs after the - # actual content, so match it if we need to. - expected_re = expected + '\[\d+ refs\]\n' - self.assertRegexpMatches(stripped, expected_re) - - def test_get(self): - doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') - ret, stdout, stderr = self.run_main(['get', self.db_path, 'test-id']) - self.assertEqual(0, ret) - self.assertEqual(tests.simple_doc + "\n", stdout) - self.assertEqual('rev: %s\n' % (doc.rev,), stderr) - ret, stdout, stderr = self.run_main(['get', self.db_path, 'not-there']) - self.assertEqual(1, ret) - - def test_delete(self): - doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') - ret, stdout, stderr = self.run_main( - ['delete', self.db_path, 'test-id', doc.rev]) - doc = self.db.get_doc('test-id', include_deleted=True) - self.assertEqual(0, ret) - self.assertEqual('', stdout) - self.assertEqual('rev: %s\n' % (doc.rev,), stderr) - - def test_init_db(self): - path = self.working_dir + '/test2.db' - ret, stdout, stderr = self.run_main(['init-db', path]) - u1db_open(path, create=False) - - def test_put(self): - doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') - ret, stdout, stderr = self.run_main( - ['put', self.db_path, 'test-id', doc.rev], - stdin=tests.nested_doc) - doc = self.db.get_doc('test-id') - self.assertFalse(doc.has_conflicts) - self.assertEqual(tests.nested_doc, doc.get_json()) - self.assertEqual(0, ret) - self.assertEqual('', stdout) - self.assertEqual('rev: %s\n' % (doc.rev,), stderr) - - def test_sync(self): - doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') - self.db2_path = self.working_dir + '/test2.db' - self.db2 = u1db_open(self.db2_path, create=True) - self.addCleanup(self.db2.close) - ret, stdout, stderr = self.run_main( - ['sync', self.db_path, self.db2_path]) - self.assertEqual(0, ret) - self.assertEqual('', stdout) - self.assertEqual('', stderr) - self.assertGetDoc( - self.db2, 'test-id', doc.rev, tests.simple_doc, False) - - -class TestHTTPIntegration(tests.TestCaseWithServer, RunMainHelper): - """Meant to test the cases where commands operate over http.""" - - def server_def(self): - def make_server(host_port, _application): - return serve.make_server(host_port[0], host_port[1], - self.working_dir) - return make_server, "shutdown", "http" - - def setUp(self): - super(TestHTTPIntegration, self).setUp() - self.working_dir = self.createTempDir(prefix='u1db-http-server-') - self.startServer() - - def getPath(self, dbname): - return os.path.join(self.working_dir, dbname) - - def test_init_db(self): - url = self.getURL('new.db') - ret, stdout, stderr = self.run_main(['init-db', url]) - u1db_open(self.getPath('new.db'), create=False) - - def test_create_get_put_delete(self): - db = u1db_open(self.getPath('test.db'), create=True) - url = self.getURL('test.db') - doc_id = '%abcd' - ret, stdout, stderr = self.run_main(['create', url, '--id', doc_id], - stdin=tests.simple_doc) - self.assertEqual(0, ret) - ret, stdout, stderr = self.run_main(['get', url, doc_id]) - self.assertEqual(0, ret) - self.assertTrue(stderr.startswith('rev: ')) - doc_rev = stderr[len('rev: '):].rstrip() - ret, stdout, stderr = self.run_main(['put', url, doc_id, doc_rev], - stdin=tests.nested_doc) - self.assertEqual(0, ret) - self.assertTrue(stderr.startswith('rev: ')) - doc_rev1 = stderr[len('rev: '):].rstrip() - self.assertGetDoc(db, doc_id, doc_rev1, tests.nested_doc, False) - ret, stdout, stderr = self.run_main(['delete', url, doc_id, doc_rev1]) - self.assertEqual(0, ret) - self.assertTrue(stderr.startswith('rev: ')) - doc_rev2 = stderr[len('rev: '):].rstrip() - self.assertGetDocIncludeDeleted(db, doc_id, doc_rev2, None, False) diff --git a/src/leap/soledad/u1db/tests/commandline/test_command.py b/src/leap/soledad/u1db/tests/commandline/test_command.py deleted file mode 100644 index 43580f23..00000000 --- a/src/leap/soledad/u1db/tests/commandline/test_command.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -import cStringIO -import argparse - -from u1db import ( - tests, - ) -from u1db.commandline import ( - command, - ) - - -class MyTestCommand(command.Command): - """Help String""" - - name = 'mycmd' - - @classmethod - def _populate_subparser(cls, parser): - parser.add_argument('foo') - parser.add_argument('--bar', dest='nbar', type=int) - - def run(self, foo, nbar): - self.stdout.write('foo: %s nbar: %d' % (foo, nbar)) - return 0 - - -def make_stdin_out_err(): - return cStringIO.StringIO(), cStringIO.StringIO(), cStringIO.StringIO() - - -class TestCommandGroup(tests.TestCase): - - def trap_system_exit(self, func, *args, **kwargs): - try: - return func(*args, **kwargs) - except SystemExit, e: - self.fail('Got SystemExit trying to run: %s' % (func,)) - - def parse_args(self, parser, args): - return self.trap_system_exit(parser.parse_args, args) - - def test_register(self): - group = command.CommandGroup() - self.assertEqual({}, group.commands) - group.register(MyTestCommand) - self.assertEqual({'mycmd': MyTestCommand}, - group.commands) - - def test_make_argparser(self): - group = command.CommandGroup(description='test-foo') - parser = group.make_argparser() - self.assertIsInstance(parser, argparse.ArgumentParser) - - def test_make_argparser_with_command(self): - group = command.CommandGroup(description='test-foo') - group.register(MyTestCommand) - parser = group.make_argparser() - args = self.parse_args(parser, ['mycmd', 'foozizle', '--bar=10']) - self.assertEqual('foozizle', args.foo) - self.assertEqual(10, args.nbar) - self.assertEqual(MyTestCommand, args.subcommand) - - def test_run_argv(self): - group = command.CommandGroup() - group.register(MyTestCommand) - stdin, stdout, stderr = make_stdin_out_err() - ret = self.trap_system_exit(group.run_argv, - ['mycmd', 'foozizle', '--bar=10'], - stdin, stdout, stderr) - self.assertEqual(0, ret) - - -class TestCommand(tests.TestCase): - - def make_command(self): - stdin, stdout, stderr = make_stdin_out_err() - return command.Command(stdin, stdout, stderr) - - def test__init__(self): - cmd = self.make_command() - self.assertIsNot(None, cmd.stdin) - self.assertIsNot(None, cmd.stdout) - self.assertIsNot(None, cmd.stderr) - - def test_run_args(self): - stdin, stdout, stderr = make_stdin_out_err() - cmd = MyTestCommand(stdin, stdout, stderr) - res = cmd.run(foo='foozizle', nbar=10) - self.assertEqual('foo: foozizle nbar: 10', stdout.getvalue()) diff --git a/src/leap/soledad/u1db/tests/commandline/test_serve.py b/src/leap/soledad/u1db/tests/commandline/test_serve.py deleted file mode 100644 index 6397eabe..00000000 --- a/src/leap/soledad/u1db/tests/commandline/test_serve.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -import os -import socket -import subprocess -import sys - -from u1db import ( - __version__ as _u1db_version, - open as u1db_open, - tests, - ) -from u1db.remote import http_client -from u1db.tests.commandline import safe_close - - -class TestU1DBServe(tests.TestCase): - - def _get_u1db_serve_path(self): - from u1db import __path__ as u1db_path - u1db_parent_dir = os.path.dirname(u1db_path[0]) - return os.path.join(u1db_parent_dir, 'u1db-serve') - - def startU1DBServe(self, args): - command = [sys.executable, self._get_u1db_serve_path()] - command.extend(args) - p = subprocess.Popen(command, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - self.addCleanup(safe_close, p) - return p - - def test_help(self): - p = self.startU1DBServe(['--help']) - stdout, stderr = p.communicate() - if stderr != '': - # stderr should normally be empty, but if we are running under - # python-dbg, it contains the following string - self.assertRegexpMatches(stderr, r'\[\d+ refs\]') - self.assertEqual(0, p.returncode) - self.assertIn('Run the U1DB server', stdout) - - def test_bind_to_port(self): - p = self.startU1DBServe([]) - starts = 'listening on:' - x = p.stdout.readline() - self.assertTrue(x.startswith(starts)) - port = int(x[len(starts):].split(":")[1]) - url = "http://127.0.0.1:%s/" % port - c = http_client.HTTPClientBase(url) - self.addCleanup(c.close) - res, _ = c._request_json('GET', []) - self.assertEqual({'version': _u1db_version}, res) - - def test_supply_port(self): - s = socket.socket() - s.bind(('127.0.0.1', 0)) - host, port = s.getsockname() - s.close() - p = self.startU1DBServe(['--port', str(port)]) - x = p.stdout.readline().strip() - self.assertEqual('listening on: 127.0.0.1:%s' % (port,), x) - url = "http://127.0.0.1:%s/" % port - c = http_client.HTTPClientBase(url) - self.addCleanup(c.close) - res, _ = c._request_json('GET', []) - self.assertEqual({'version': _u1db_version}, res) - - def test_bind_to_host(self): - p = self.startU1DBServe(["--host", "localhost"]) - starts = 'listening on: 127.0.0.1:' - x = p.stdout.readline() - self.assertTrue(x.startswith(starts)) - - def test_supply_working_dir(self): - tmp_dir = self.createTempDir('u1db-serve-test') - db = u1db_open(os.path.join(tmp_dir, 'landmark.db'), create=True) - db.close() - p = self.startU1DBServe(['--working-dir', tmp_dir]) - starts = 'listening on:' - x = p.stdout.readline() - self.assertTrue(x.startswith(starts)) - port = int(x[len(starts):].split(":")[1]) - url = "http://127.0.0.1:%s/landmark.db" % port - c = http_client.HTTPClientBase(url) - self.addCleanup(c.close) - res, _ = c._request_json('GET', []) - self.assertEqual({}, res) diff --git a/src/leap/soledad/u1db/tests/test_auth_middleware.py b/src/leap/soledad/u1db/tests/test_auth_middleware.py deleted file mode 100644 index e765f8a7..00000000 --- a/src/leap/soledad/u1db/tests/test_auth_middleware.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test OAuth wsgi middleware""" -import paste.fixture -from oauth import oauth -try: - import simplejson as json -except ImportError: - import json # noqa -import time - -from u1db import tests - -from u1db.remote.oauth_middleware import OAuthMiddleware -from u1db.remote.basic_auth_middleware import BasicAuthMiddleware, Unauthorized - - -BASE_URL = 'https://example.net' - - -class TestBasicAuthMiddleware(tests.TestCase): - - def setUp(self): - super(TestBasicAuthMiddleware, self).setUp() - self.got = [] - - def witness_app(environ, start_response): - start_response("200 OK", [("content-type", "text/plain")]) - self.got.append(( - environ['user_id'], environ['PATH_INFO'], - environ['QUERY_STRING'])) - return ["ok"] - - class MyAuthMiddleware(BasicAuthMiddleware): - - def verify_user(self, environ, user, password): - if user != "correct_user": - raise Unauthorized - if password != "correct_password": - raise Unauthorized - environ['user_id'] = user - - self.auth_midw = MyAuthMiddleware(witness_app, prefix="/pfx/") - self.app = paste.fixture.TestApp(self.auth_midw) - - def test_expect_prefix(self): - url = BASE_URL + '/foo/doc/doc-id' - resp = self.app.delete(url, expect_errors=True) - self.assertEqual(400, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual('{"error": "bad request"}', resp.body) - - def test_missing_auth(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - resp = self.app.delete(url, expect_errors=True) - self.assertEqual(401, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "unauthorized", - "message": "Missing Basic Authentication."}, - json.loads(resp.body)) - - def test_correct_auth(self): - user = "correct_user" - password = "correct_password" - params = {'old_rev': 'old-rev'} - url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % ( - '&'.join("%s=%s" % (k, v) for k, v in params.items())) - auth = '%s:%s' % (user, password) - headers = { - 'Authorization': 'Basic %s' % (auth.encode('base64'),)} - resp = self.app.delete(url, headers=headers) - self.assertEqual(200, resp.status) - self.assertEqual( - [('correct_user', '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) - - def test_incorrect_auth(self): - user = "correct_user" - password = "incorrect_password" - params = {'old_rev': 'old-rev'} - url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % ( - '&'.join("%s=%s" % (k, v) for k, v in params.items())) - auth = '%s:%s' % (user, password) - headers = { - 'Authorization': 'Basic %s' % (auth.encode('base64'),)} - resp = self.app.delete(url, headers=headers, expect_errors=True) - self.assertEqual(401, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "unauthorized", - "message": "Incorrect password or login."}, - json.loads(resp.body)) - - -class TestOAuthMiddlewareDefaultPrefix(tests.TestCase): - def setUp(self): - - super(TestOAuthMiddlewareDefaultPrefix, self).setUp() - self.got = [] - - def witness_app(environ, start_response): - start_response("200 OK", [("content-type", "text/plain")]) - self.got.append((environ['token_key'], environ['PATH_INFO'], - environ['QUERY_STRING'])) - return ["ok"] - - class MyOAuthMiddleware(OAuthMiddleware): - get_oauth_data_store = lambda self: tests.testingOAuthStore - - def verify(self, environ, oauth_req): - consumer, token = super(MyOAuthMiddleware, self).verify( - environ, oauth_req) - environ['token_key'] = token.key - - self.oauth_midw = MyOAuthMiddleware(witness_app, BASE_URL) - self.app = paste.fixture.TestApp(self.oauth_midw) - - def test_expect_tilde(self): - url = BASE_URL + '/foo/doc/doc-id' - resp = self.app.delete(url, expect_errors=True) - self.assertEqual(400, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual('{"error": "bad request"}', resp.body) - - def test_oauth_in_header(self): - url = BASE_URL + '/~/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer2, - tests.token2, - parameters=params, - http_url=url, - http_method='DELETE' - ) - url = oauth_req.get_normalized_http_url() + '?' + ( - '&'.join("%s=%s" % (k, v) for k, v in params.items())) - oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, - tests.consumer2, tests.token2) - resp = self.app.delete(url, headers=oauth_req.to_header()) - self.assertEqual(200, resp.status) - self.assertEqual([(tests.token2.key, - '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) - - def test_oauth_in_query_string(self): - url = BASE_URL + '/~/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer1, - tests.token1, - parameters=params, - http_url=url, - http_method='DELETE' - ) - oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, - tests.consumer1, tests.token1) - resp = self.app.delete(oauth_req.to_url()) - self.assertEqual(200, resp.status) - self.assertEqual([(tests.token1.key, - '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) - - -class TestOAuthMiddleware(tests.TestCase): - - def setUp(self): - super(TestOAuthMiddleware, self).setUp() - self.got = [] - - def witness_app(environ, start_response): - start_response("200 OK", [("content-type", "text/plain")]) - self.got.append((environ['token_key'], environ['PATH_INFO'], - environ['QUERY_STRING'])) - return ["ok"] - - class MyOAuthMiddleware(OAuthMiddleware): - get_oauth_data_store = lambda self: tests.testingOAuthStore - - def verify(self, environ, oauth_req): - consumer, token = super(MyOAuthMiddleware, self).verify( - environ, oauth_req) - environ['token_key'] = token.key - - self.oauth_midw = MyOAuthMiddleware( - witness_app, BASE_URL, prefix='/pfx/') - self.app = paste.fixture.TestApp(self.oauth_midw) - - def test_expect_prefix(self): - url = BASE_URL + '/foo/doc/doc-id' - resp = self.app.delete(url, expect_errors=True) - self.assertEqual(400, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual('{"error": "bad request"}', resp.body) - - def test_missing_oauth(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - resp = self.app.delete(url, expect_errors=True) - self.assertEqual(401, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "unauthorized", "message": "Missing OAuth."}, - json.loads(resp.body)) - - def test_oauth_in_query_string(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer1, - tests.token1, - parameters=params, - http_url=url, - http_method='DELETE' - ) - oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, - tests.consumer1, tests.token1) - resp = self.app.delete(oauth_req.to_url()) - self.assertEqual(200, resp.status) - self.assertEqual([(tests.token1.key, - '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) - - def test_oauth_invalid(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer1, - tests.token3, - parameters=params, - http_url=url, - http_method='DELETE' - ) - oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, - tests.consumer1, tests.token3) - resp = self.app.delete(oauth_req.to_url(), - expect_errors=True) - self.assertEqual(401, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - err = json.loads(resp.body) - self.assertEqual({"error": "unauthorized", - "message": err['message']}, - err) - - def test_oauth_in_header(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer2, - tests.token2, - parameters=params, - http_url=url, - http_method='DELETE' - ) - url = oauth_req.get_normalized_http_url() + '?' + ( - '&'.join("%s=%s" % (k, v) for k, v in params.items())) - oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, - tests.consumer2, tests.token2) - resp = self.app.delete(url, headers=oauth_req.to_header()) - self.assertEqual(200, resp.status) - self.assertEqual([(tests.token2.key, - '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) - - def test_oauth_plain_text(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer1, - tests.token1, - parameters=params, - http_url=url, - http_method='DELETE' - ) - oauth_req.sign_request(tests.sign_meth_PLAINTEXT, - tests.consumer1, tests.token1) - resp = self.app.delete(oauth_req.to_url()) - self.assertEqual(200, resp.status) - self.assertEqual([(tests.token1.key, - '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) - - def test_oauth_timestamp_threshold(self): - url = BASE_URL + '/pfx/foo/doc/doc-id' - params = {'old_rev': 'old-rev'} - oauth_req = oauth.OAuthRequest.from_consumer_and_token( - tests.consumer1, - tests.token1, - parameters=params, - http_url=url, - http_method='DELETE' - ) - oauth_req.set_parameter('oauth_timestamp', int(time.time()) - 5) - oauth_req.sign_request(tests.sign_meth_PLAINTEXT, - tests.consumer1, tests.token1) - # tweak threshold - self.oauth_midw.timestamp_threshold = 1 - resp = self.app.delete(oauth_req.to_url(), expect_errors=True) - self.assertEqual(401, resp.status) - err = json.loads(resp.body) - self.assertIn('Expired timestamp', err['message']) - self.assertIn('threshold 1', err['message']) diff --git a/src/leap/soledad/u1db/tests/test_backends.py b/src/leap/soledad/u1db/tests/test_backends.py deleted file mode 100644 index 7a3c9e5c..00000000 --- a/src/leap/soledad/u1db/tests/test_backends.py +++ /dev/null @@ -1,1895 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""The backend class for U1DB. This deals with hiding storage details.""" - -try: - import simplejson as json -except ImportError: - import json # noqa -from u1db import ( - DocumentBase, - errors, - tests, - vectorclock, - ) - -simple_doc = tests.simple_doc -nested_doc = tests.nested_doc - -from u1db.tests.test_remote_sync_target import ( - make_http_app, - make_oauth_http_app, -) - -from u1db.remote import ( - http_database, - ) - -try: - from u1db.tests import c_backend_wrapper -except ImportError: - c_backend_wrapper = None # noqa - - -def make_http_database_for_test(test, replica_uid, path='test'): - test.startServer() - test.request_state._create_database(replica_uid) - return http_database.HTTPDatabase(test.getURL(path)) - - -def copy_http_database_for_test(test, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR - # HOUSE. - return test.request_state._copy_database(db) - - -def make_oauth_http_database_for_test(test, replica_uid): - http_db = make_http_database_for_test(test, replica_uid, '~/test') - http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - return http_db - - -def copy_oauth_http_database_for_test(test, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR - # HOUSE. - http_db = test.request_state._copy_database(db) - http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - return http_db - - -class TestAlternativeDocument(DocumentBase): - """A (not very) alternative implementation of Document.""" - - -class AllDatabaseTests(tests.DatabaseBaseTests, tests.TestCaseWithServer): - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + [ - ('http', {'make_database_for_test': make_http_database_for_test, - 'copy_database_for_test': copy_http_database_for_test, - 'make_document_for_test': tests.make_document_for_test, - 'make_app_with_state': make_http_app}), - ('oauth_http', {'make_database_for_test': - make_oauth_http_database_for_test, - 'copy_database_for_test': - copy_oauth_http_database_for_test, - 'make_document_for_test': tests.make_document_for_test, - 'make_app_with_state': make_oauth_http_app}) - ] + tests.C_DATABASE_SCENARIOS - - def test_close(self): - self.db.close() - - def test_create_doc_allocating_doc_id(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertNotEqual(None, doc.doc_id) - self.assertNotEqual(None, doc.rev) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - - def test_create_doc_different_ids_same_db(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertNotEqual(doc1.doc_id, doc2.doc_id) - - def test_create_doc_with_id(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id') - self.assertEqual('my-id', doc.doc_id) - self.assertNotEqual(None, doc.rev) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - - def test_create_doc_existing_id(self): - doc = self.db.create_doc_from_json(simple_doc) - new_content = '{"something": "else"}' - self.assertRaises( - errors.RevisionConflict, self.db.create_doc_from_json, - new_content, doc.doc_id) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - - def test_put_doc_creating_initial(self): - doc = self.make_document('my_doc_id', None, simple_doc) - new_rev = self.db.put_doc(doc) - self.assertIsNot(None, new_rev) - self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False) - - def test_put_doc_space_in_id(self): - doc = self.make_document('my doc id', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_doc_update(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - orig_rev = doc.rev - doc.set_json('{"updated": "stuff"}') - new_rev = self.db.put_doc(doc) - self.assertNotEqual(new_rev, orig_rev) - self.assertGetDoc(self.db, 'my_doc_id', new_rev, - '{"updated": "stuff"}', False) - self.assertEqual(doc.rev, new_rev) - - def test_put_non_ascii_key(self): - content = json.dumps({u'key\xe5': u'val'}) - doc = self.db.create_doc_from_json(content, doc_id='my_doc') - self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) - - def test_put_non_ascii_value(self): - content = json.dumps({'key': u'\xe5'}) - doc = self.db.create_doc_from_json(content, doc_id='my_doc') - self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) - - def test_put_doc_refuses_no_id(self): - doc = self.make_document(None, None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - doc = self.make_document("", None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_doc_refuses_slashes(self): - doc = self.make_document('a/b', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - doc = self.make_document(r'\b', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_doc_url_quoting_is_fine(self): - doc_id = "%2F%2Ffoo%2Fbar" - doc = self.make_document(doc_id, None, simple_doc) - new_rev = self.db.put_doc(doc) - self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False) - - def test_put_doc_refuses_non_existing_old_rev(self): - doc = self.make_document('doc-id', 'test:4', simple_doc) - self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc) - - def test_put_doc_refuses_non_ascii_doc_id(self): - doc = self.make_document('d\xc3\xa5c-id', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_fails_with_bad_old_rev(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - old_rev = doc.rev - bad_doc = self.make_document(doc.doc_id, 'other:1', - '{"something": "else"}') - self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc) - self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False) - - def test_create_succeeds_after_delete(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) - deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) - new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False) - new_vc = vectorclock.VectorClockRev(new_doc.rev) - self.assertTrue( - new_vc.is_newer(deleted_vc), - "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev)) - - def test_put_succeeds_after_delete(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) - deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) - doc2 = self.make_document('my_doc_id', None, simple_doc) - self.db.put_doc(doc2) - self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False) - new_vc = vectorclock.VectorClockRev(doc2.rev) - self.assertTrue( - new_vc.is_newer(deleted_vc), - "%s does not supersede %s" % (doc2.rev, deleted_doc.rev)) - - def test_get_doc_after_put(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False) - - def test_get_doc_nonexisting(self): - self.assertIs(None, self.db.get_doc('non-existing')) - - def test_get_doc_deleted(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - self.assertIs(None, self.db.get_doc('my_doc_id')) - - def test_get_doc_include_deleted(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - - def test_get_docs(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertEqual([doc1, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) - - def test_get_docs_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.db.delete_doc(doc1) - self.assertEqual([doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) - - def test_get_docs_include_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.db.delete_doc(doc1) - self.assertEqual( - [doc1, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id], - include_deleted=True))) - - def test_get_docs_request_ordered(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertEqual([doc1, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) - self.assertEqual([doc2, doc1], - list(self.db.get_docs([doc2.doc_id, doc1.doc_id]))) - - def test_get_docs_empty_list(self): - self.assertEqual([], list(self.db.get_docs([]))) - - def test_handles_nested_content(self): - doc = self.db.create_doc_from_json(nested_doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) - - def test_handles_doc_with_null(self): - doc = self.db.create_doc_from_json('{"key": null}') - self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False) - - def test_delete_doc(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - orig_rev = doc.rev - self.db.delete_doc(doc) - self.assertNotEqual(orig_rev, doc.rev) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - self.assertIs(None, self.db.get_doc(doc.doc_id)) - - def test_delete_doc_non_existent(self): - doc = self.make_document('non-existing', 'other:1', simple_doc) - self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc) - - def test_delete_doc_already_deleted(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertRaises(errors.DocumentAlreadyDeleted, - self.db.delete_doc, doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - - def test_delete_doc_bad_rev(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) - doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc) - self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2) - self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) - - def test_delete_doc_sets_content_to_None(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertIs(None, doc.get_json()) - - def test_delete_doc_rev_supersedes(self): - doc = self.db.create_doc_from_json(simple_doc) - doc.set_json(nested_doc) - self.db.put_doc(doc) - doc.set_json('{"fishy": "content"}') - self.db.put_doc(doc) - old_rev = doc.rev - self.db.delete_doc(doc) - cur_vc = vectorclock.VectorClockRev(old_rev) - deleted_vc = vectorclock.VectorClockRev(doc.rev) - self.assertTrue(deleted_vc.is_newer(cur_vc), - "%s does not supersede %s" % (doc.rev, old_rev)) - - def test_delete_then_put(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - doc.set_json(nested_doc) - self.db.put_doc(doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) - - -class DocumentSizeTests(tests.DatabaseBaseTests): - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS - - def test_put_doc_refuses_oversized_documents(self): - self.db.set_document_size_limit(1) - doc = self.make_document('doc-id', None, simple_doc) - self.assertRaises(errors.DocumentTooBig, self.db.put_doc, doc) - - def test_create_doc_refuses_oversized_documents(self): - self.db.set_document_size_limit(1) - self.assertRaises( - errors.DocumentTooBig, self.db.create_doc_from_json, simple_doc, - doc_id='my_doc_id') - - def test_set_document_size_limit_zero(self): - self.db.set_document_size_limit(0) - self.assertEqual(0, self.db.document_size_limit) - - def test_set_document_size_limit(self): - self.db.set_document_size_limit(1000000) - self.assertEqual(1000000, self.db.document_size_limit) - - -class LocalDatabaseTests(tests.DatabaseBaseTests): - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS - - def test_create_doc_different_ids_diff_db(self): - doc1 = self.db.create_doc_from_json(simple_doc) - db2 = self.create_database('other-uid') - doc2 = db2.create_doc_from_json(simple_doc) - self.assertNotEqual(doc1.doc_id, doc2.doc_id) - - def test_put_doc_refuses_slashes_picky(self): - doc = self.make_document('/a', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_get_all_docs_empty(self): - self.assertEqual([], list(self.db.get_all_docs()[1])) - - def test_get_all_docs(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertEqual( - sorted([doc1, doc2]), sorted(list(self.db.get_all_docs()[1]))) - - def test_get_all_docs_exclude_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.db.delete_doc(doc2) - self.assertEqual([doc1], list(self.db.get_all_docs()[1])) - - def test_get_all_docs_include_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.db.delete_doc(doc2) - self.assertEqual( - sorted([doc1, doc2]), - sorted(list(self.db.get_all_docs(include_deleted=True)[1]))) - - def test_get_all_docs_generation(self): - self.db.create_doc_from_json(simple_doc) - self.db.create_doc_from_json(nested_doc) - self.assertEqual(2, self.db.get_all_docs()[0]) - - def test_simple_put_doc_if_newer(self): - doc = self.make_document('my-doc-id', 'test:1', simple_doc) - state_at_gen = self.db._put_doc_if_newer( - doc, save_conflict=False, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual(('inserted', 1), state_at_gen) - self.assertGetDoc(self.db, 'my-doc-id', 'test:1', simple_doc, False) - - def test_simple_put_doc_if_newer_deleted(self): - self.db.create_doc_from_json('{}', doc_id='my-doc-id') - doc = self.make_document('my-doc-id', 'test:2', None) - state_at_gen = self.db._put_doc_if_newer( - doc, save_conflict=False, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual(('inserted', 2), state_at_gen) - self.assertGetDocIncludeDeleted( - self.db, 'my-doc-id', 'test:2', None, False) - - def test_put_doc_if_newer_already_superseded(self): - orig_doc = '{"new": "doc"}' - doc1 = self.db.create_doc_from_json(orig_doc) - doc1_rev1 = doc1.rev - doc1.set_json(simple_doc) - self.db.put_doc(doc1) - doc1_rev2 = doc1.rev - # Nothing is inserted, because the document is already superseded - doc = self.make_document(doc1.doc_id, doc1_rev1, orig_doc) - state, _ = self.db._put_doc_if_newer( - doc, save_conflict=False, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual('superseded', state) - self.assertGetDoc(self.db, doc1.doc_id, doc1_rev2, simple_doc, False) - - def test_put_doc_if_newer_autoresolve(self): - doc1 = self.db.create_doc_from_json(simple_doc) - rev = doc1.rev - doc = self.make_document(doc1.doc_id, "whatever:1", doc1.get_json()) - state, _ = self.db._put_doc_if_newer( - doc, save_conflict=False, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual('superseded', state) - doc2 = self.db.get_doc(doc1.doc_id) - v2 = vectorclock.VectorClockRev(doc2.rev) - self.assertTrue(v2.is_newer(vectorclock.VectorClockRev("whatever:1"))) - self.assertTrue(v2.is_newer(vectorclock.VectorClockRev(rev))) - # strictly newer locally - self.assertTrue(rev not in doc2.rev) - - def test_put_doc_if_newer_already_converged(self): - orig_doc = '{"new": "doc"}' - doc1 = self.db.create_doc_from_json(orig_doc) - state_at_gen = self.db._put_doc_if_newer( - doc1, save_conflict=False, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual(('converged', 1), state_at_gen) - - def test_put_doc_if_newer_conflicted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - # Nothing is inserted, the document id is returned as would-conflict - alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - state, _ = self.db._put_doc_if_newer( - alt_doc, save_conflict=False, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual('conflicted', state) - # The database wasn't altered - self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) - - def test_put_doc_if_newer_newer_generation(self): - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - doc = self.make_document('doc_id', 'other:2', simple_doc) - state, _ = self.db._put_doc_if_newer( - doc, save_conflict=False, replica_uid='other', replica_gen=2, - replica_trans_id='T-irrelevant') - self.assertEqual('inserted', state) - - def test_put_doc_if_newer_same_generation_same_txid(self): - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - doc = self.db.create_doc_from_json(simple_doc) - self.make_document(doc.doc_id, 'other:1', simple_doc) - state, _ = self.db._put_doc_if_newer( - doc, save_conflict=False, replica_uid='other', replica_gen=1, - replica_trans_id='T-sid') - self.assertEqual('converged', state) - - def test_put_doc_if_newer_wrong_transaction_id(self): - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - doc = self.make_document('doc_id', 'other:1', simple_doc) - self.assertRaises( - errors.InvalidTransactionId, - self.db._put_doc_if_newer, doc, save_conflict=False, - replica_uid='other', replica_gen=1, replica_trans_id='T-sad') - - def test_put_doc_if_newer_old_generation_older_doc(self): - orig_doc = '{"new": "doc"}' - doc = self.db.create_doc_from_json(orig_doc) - doc_rev1 = doc.rev - doc.set_json(simple_doc) - self.db.put_doc(doc) - self.db._set_replica_gen_and_trans_id('other', 3, 'T-sid') - older_doc = self.make_document(doc.doc_id, doc_rev1, simple_doc) - state, _ = self.db._put_doc_if_newer( - older_doc, save_conflict=False, replica_uid='other', replica_gen=8, - replica_trans_id='T-irrelevant') - self.assertEqual('superseded', state) - - def test_put_doc_if_newer_old_generation_newer_doc(self): - self.db._set_replica_gen_and_trans_id('other', 5, 'T-sid') - doc = self.make_document('doc_id', 'other:1', simple_doc) - self.assertRaises( - errors.InvalidGeneration, - self.db._put_doc_if_newer, doc, save_conflict=False, - replica_uid='other', replica_gen=1, replica_trans_id='T-sad') - - def test_put_doc_if_newer_replica_uid(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1', - nested_doc) - self.assertEqual('inserted', - self.db._put_doc_if_newer(doc2, save_conflict=False, - replica_uid='other', replica_gen=2, - replica_trans_id='T-id2')[0]) - self.assertEqual((2, 'T-id2'), self.db._get_replica_gen_and_trans_id( - 'other')) - # Compare to the old rev, should be superseded - doc2 = self.make_document(doc1.doc_id, doc1.rev, nested_doc) - self.assertEqual('superseded', - self.db._put_doc_if_newer(doc2, save_conflict=False, - replica_uid='other', replica_gen=3, - replica_trans_id='T-id3')[0]) - self.assertEqual( - (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other')) - # A conflict that isn't saved still records the sync gen, because we - # don't need to see it again - doc2 = self.make_document(doc1.doc_id, doc1.rev + '|fourth:1', - '{}') - self.assertEqual('conflicted', - self.db._put_doc_if_newer(doc2, save_conflict=False, - replica_uid='other', replica_gen=4, - replica_trans_id='T-id4')[0]) - self.assertEqual( - (4, 'T-id4'), self.db._get_replica_gen_and_trans_id('other')) - - def test__get_replica_gen_and_trans_id(self): - self.assertEqual( - (0, ''), self.db._get_replica_gen_and_trans_id('other-db')) - self.db._set_replica_gen_and_trans_id('other-db', 2, 'T-transaction') - self.assertEqual( - (2, 'T-transaction'), - self.db._get_replica_gen_and_trans_id('other-db')) - - def test_put_updates_transaction_log(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - doc.set_json('{"something": "else"}') - self.db.put_doc(doc) - self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), - self.db.whats_changed()) - - def test_delete_updates_transaction_log(self): - doc = self.db.create_doc_from_json(simple_doc) - db_gen, _, _ = self.db.whats_changed() - self.db.delete_doc(doc) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), - self.db.whats_changed(db_gen)) - - def test_whats_changed_initial_database(self): - self.assertEqual((0, '', []), self.db.whats_changed()) - - def test_whats_changed_returns_one_id_for_multiple_changes(self): - doc = self.db.create_doc_from_json(simple_doc) - doc.set_json('{"new": "contents"}') - self.db.put_doc(doc) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), - self.db.whats_changed()) - self.assertEqual((2, last_trans_id, []), self.db.whats_changed(2)) - - def test_whats_changed_returns_last_edits_ascending(self): - doc = self.db.create_doc_from_json(simple_doc) - doc1 = self.db.create_doc_from_json(simple_doc) - doc.set_json('{"new": "contents"}') - self.db.delete_doc(doc1) - delete_trans_id = self.getLastTransId(self.db) - self.db.put_doc(doc) - put_trans_id = self.getLastTransId(self.db) - self.assertEqual((4, put_trans_id, - [(doc1.doc_id, 3, delete_trans_id), - (doc.doc_id, 4, put_trans_id)]), - self.db.whats_changed()) - - def test_whats_changed_doesnt_include_old_gen(self): - self.db.create_doc_from_json(simple_doc) - self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(simple_doc) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual((3, last_trans_id, [(doc2.doc_id, 3, last_trans_id)]), - self.db.whats_changed(2)) - - -class LocalDatabaseValidateGenNTransIdTests(tests.DatabaseBaseTests): - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS - - def test_validate_gen_and_trans_id(self): - self.db.create_doc_from_json(simple_doc) - gen, trans_id = self.db._get_generation_info() - self.db.validate_gen_and_trans_id(gen, trans_id) - - def test_validate_gen_and_trans_id_invalid_txid(self): - self.db.create_doc_from_json(simple_doc) - gen, _ = self.db._get_generation_info() - self.assertRaises( - errors.InvalidTransactionId, - self.db.validate_gen_and_trans_id, gen, 'wrong') - - def test_validate_gen_and_trans_id_invalid_gen(self): - self.db.create_doc_from_json(simple_doc) - gen, trans_id = self.db._get_generation_info() - self.assertRaises( - errors.InvalidGeneration, - self.db.validate_gen_and_trans_id, gen + 1, trans_id) - - -class LocalDatabaseValidateSourceGenTests(tests.DatabaseBaseTests): - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS - - def test_validate_source_gen_and_trans_id_same(self): - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - self.db._validate_source('other', 1, 'T-sid') - - def test_validate_source_gen_newer(self): - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - self.db._validate_source('other', 2, 'T-whatevs') - - def test_validate_source_wrong_txid(self): - self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') - self.assertRaises( - errors.InvalidTransactionId, - self.db._validate_source, 'other', 1, 'T-sad') - - -class LocalDatabaseWithConflictsTests(tests.DatabaseBaseTests): - # test supporting/functionality around storing conflicts - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS - - def test_get_docs_conflicted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual([doc2], list(self.db.get_docs([doc1.doc_id]))) - - def test_get_docs_conflicts_ignored(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - no_conflict_doc = self.make_document(doc1.doc_id, 'alternate:1', - nested_doc) - self.assertEqual([no_conflict_doc, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id], - check_for_conflicts=False))) - - def test_get_doc_conflicts(self): - doc = self.db.create_doc_from_json(simple_doc) - alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual([alt_doc, doc], - self.db.get_doc_conflicts(doc.doc_id)) - - def test_get_all_docs_sees_conflicts(self): - doc = self.db.create_doc_from_json(simple_doc) - alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - _, docs = self.db.get_all_docs() - self.assertTrue(list(docs)[0].has_conflicts) - - def test_get_doc_conflicts_unconflicted(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertEqual([], self.db.get_doc_conflicts(doc.doc_id)) - - def test_get_doc_conflicts_no_such_id(self): - self.assertEqual([], self.db.get_doc_conflicts('doc-id')) - - def test_resolve_doc(self): - doc = self.db.create_doc_from_json(simple_doc) - alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDocConflicts(self.db, doc.doc_id, - [('alternate:1', nested_doc), (doc.rev, simple_doc)]) - orig_rev = doc.rev - self.db.resolve_doc(doc, [alt_doc.rev, doc.rev]) - self.assertNotEqual(orig_rev, doc.rev) - self.assertFalse(doc.has_conflicts) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - self.assertGetDocConflicts(self.db, doc.doc_id, []) - - def test_resolve_doc_picks_biggest_vcr(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc2.rev, nested_doc), - (doc1.rev, simple_doc)]) - orig_doc1_rev = doc1.rev - self.db.resolve_doc(doc1, [doc2.rev, doc1.rev]) - self.assertFalse(doc1.has_conflicts) - self.assertNotEqual(orig_doc1_rev, doc1.rev) - self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) - self.assertGetDocConflicts(self.db, doc1.doc_id, []) - vcr_1 = vectorclock.VectorClockRev(orig_doc1_rev) - vcr_2 = vectorclock.VectorClockRev(doc2.rev) - vcr_new = vectorclock.VectorClockRev(doc1.rev) - self.assertTrue(vcr_new.is_newer(vcr_1)) - self.assertTrue(vcr_new.is_newer(vcr_2)) - - def test_resolve_doc_partial_not_winning(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc2.rev, nested_doc), - (doc1.rev, simple_doc)]) - content3 = '{"key": "valin3"}' - doc3 = self.make_document(doc1.doc_id, 'third:1', content3) - self.db._put_doc_if_newer( - doc3, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='bar') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc3.rev, content3), - (doc1.rev, simple_doc), - (doc2.rev, nested_doc)]) - self.db.resolve_doc(doc1, [doc2.rev, doc1.rev]) - self.assertTrue(doc1.has_conflicts) - self.assertGetDoc(self.db, doc1.doc_id, doc3.rev, content3, True) - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc3.rev, content3), - (doc1.rev, simple_doc)]) - - def test_resolve_doc_partial_winning(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - content3 = '{"key": "valin3"}' - doc3 = self.make_document(doc1.doc_id, 'third:1', content3) - self.db._put_doc_if_newer( - doc3, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='bar') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc3.rev, content3), - (doc1.rev, simple_doc), - (doc2.rev, nested_doc)]) - self.db.resolve_doc(doc1, [doc3.rev, doc1.rev]) - self.assertTrue(doc1.has_conflicts) - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc1.rev, simple_doc), - (doc2.rev, nested_doc)]) - - def test_resolve_doc_with_delete_conflict(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc1) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc2.rev, nested_doc), - (doc1.rev, None)]) - self.db.resolve_doc(doc2, [doc1.rev, doc2.rev]) - self.assertGetDocConflicts(self.db, doc1.doc_id, []) - self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, False) - - def test_resolve_doc_with_delete_to_delete(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc1) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [(doc2.rev, nested_doc), - (doc1.rev, None)]) - self.db.resolve_doc(doc1, [doc1.rev, doc2.rev]) - self.assertGetDocConflicts(self.db, doc1.doc_id, []) - self.assertGetDocIncludeDeleted( - self.db, doc1.doc_id, doc1.rev, None, False) - - def test_put_doc_if_newer_save_conflicted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - # Document is inserted as a conflict - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - state, _ = self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual('conflicted', state) - # The database was updated - self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, nested_doc, True) - - def test_force_doc_conflict_supersedes_properly(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', '{"b": 1}') - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - doc3 = self.make_document(doc1.doc_id, 'altalt:1', '{"c": 1}') - self.db._put_doc_if_newer( - doc3, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='bar') - doc22 = self.make_document(doc1.doc_id, 'alternate:2', '{"b": 2}') - self.db._put_doc_if_newer( - doc22, save_conflict=True, replica_uid='r', replica_gen=3, - replica_trans_id='zed') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [('alternate:2', doc22.get_json()), - ('altalt:1', doc3.get_json()), - (doc1.rev, simple_doc)]) - - def test_put_doc_if_newer_save_conflict_was_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc1) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertTrue(doc2.has_conflicts) - self.assertGetDoc( - self.db, doc1.doc_id, 'alternate:1', nested_doc, True) - self.assertGetDocConflicts(self.db, doc1.doc_id, - [('alternate:1', nested_doc), (doc1.rev, None)]) - - def test_put_doc_if_newer_propagates_full_resolution(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - resolved_vcr = vectorclock.VectorClockRev(doc1.rev) - vcr_2 = vectorclock.VectorClockRev(doc2.rev) - resolved_vcr.maximize(vcr_2) - resolved_vcr.increment('alternate') - doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(), - '{"good": 1}') - state, _ = self.db._put_doc_if_newer( - doc_resolved, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='foo2') - self.assertEqual('inserted', state) - self.assertFalse(doc_resolved.has_conflicts) - self.assertGetDocConflicts(self.db, doc1.doc_id, []) - doc3 = self.db.get_doc(doc1.doc_id) - self.assertFalse(doc3.has_conflicts) - - def test_put_doc_if_newer_propagates_partial_resolution(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'altalt:1', '{}') - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - doc3 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc3, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='foo2') - self.assertGetDocConflicts(self.db, doc1.doc_id, - [('alternate:1', nested_doc), ('test:1', simple_doc), - ('altalt:1', '{}')]) - resolved_vcr = vectorclock.VectorClockRev(doc1.rev) - vcr_3 = vectorclock.VectorClockRev(doc3.rev) - resolved_vcr.maximize(vcr_3) - resolved_vcr.increment('alternate') - doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(), - '{"good": 1}') - state, _ = self.db._put_doc_if_newer( - doc_resolved, save_conflict=True, replica_uid='r', replica_gen=3, - replica_trans_id='foo3') - self.assertEqual('inserted', state) - self.assertTrue(doc_resolved.has_conflicts) - doc4 = self.db.get_doc(doc1.doc_id) - self.assertTrue(doc4.has_conflicts) - self.assertGetDocConflicts(self.db, doc1.doc_id, - [('alternate:2|test:1', '{"good": 1}'), ('altalt:1', '{}')]) - - def test_put_doc_if_newer_replica_uid(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.db._set_replica_gen_and_trans_id('other', 1, 'T-id') - doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1', - nested_doc) - self.db._put_doc_if_newer(doc2, save_conflict=True, - replica_uid='other', replica_gen=2, - replica_trans_id='T-id2') - # Conflict vs the current update - doc2 = self.make_document(doc1.doc_id, doc1.rev + '|third:3', - '{}') - self.assertEqual('conflicted', - self.db._put_doc_if_newer(doc2, save_conflict=True, - replica_uid='other', replica_gen=3, - replica_trans_id='T-id3')[0]) - self.assertEqual( - (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other')) - - def test_put_doc_if_newer_autoresolve_2(self): - # this is an ordering variant of _3, but that already works - # adding the test explicitly to catch the regression easily - doc_a1 = self.db.create_doc_from_json(simple_doc) - doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', "{}") - doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', - '{"a":"42"}') - doc_a3 = self.make_document(doc_a1.doc_id, 'test:2|other:1', "{}") - state, _ = self.db._put_doc_if_newer( - doc_a2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual(state, 'inserted') - state, _ = self.db._put_doc_if_newer( - doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='foo2') - self.assertEqual(state, 'conflicted') - state, _ = self.db._put_doc_if_newer( - doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, - replica_trans_id='foo3') - self.assertEqual(state, 'inserted') - self.assertFalse(self.db.get_doc(doc_a1.doc_id).has_conflicts) - - def test_put_doc_if_newer_autoresolve_3(self): - doc_a1 = self.db.create_doc_from_json(simple_doc) - doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', "{}") - doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}') - doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', "{}") - state, _ = self.db._put_doc_if_newer( - doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual(state, 'inserted') - state, _ = self.db._put_doc_if_newer( - doc_a2, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='foo2') - self.assertEqual(state, 'conflicted') - state, _ = self.db._put_doc_if_newer( - doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, - replica_trans_id='foo3') - self.assertEqual(state, 'superseded') - doc = self.db.get_doc(doc_a1.doc_id, True) - self.assertFalse(doc.has_conflicts) - rev = vectorclock.VectorClockRev(doc.rev) - rev_a3 = vectorclock.VectorClockRev('test:3') - rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1') - self.assertTrue(rev.is_newer(rev_a3)) - self.assertTrue('test:4' in doc.rev) # locally increased - self.assertTrue(rev.is_newer(rev_a1b1)) - - def test_put_doc_if_newer_autoresolve_4(self): - doc_a1 = self.db.create_doc_from_json(simple_doc) - doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', None) - doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}') - doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', None) - state, _ = self.db._put_doc_if_newer( - doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertEqual(state, 'inserted') - state, _ = self.db._put_doc_if_newer( - doc_a2, save_conflict=True, replica_uid='r', replica_gen=2, - replica_trans_id='foo2') - self.assertEqual(state, 'conflicted') - state, _ = self.db._put_doc_if_newer( - doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, - replica_trans_id='foo3') - self.assertEqual(state, 'superseded') - doc = self.db.get_doc(doc_a1.doc_id, True) - self.assertFalse(doc.has_conflicts) - rev = vectorclock.VectorClockRev(doc.rev) - rev_a3 = vectorclock.VectorClockRev('test:3') - rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1') - self.assertTrue(rev.is_newer(rev_a3)) - self.assertTrue('test:4' in doc.rev) # locally increased - self.assertTrue(rev.is_newer(rev_a1b1)) - - def test_put_refuses_to_update_conflicted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - content2 = '{"key": "altval"}' - doc2 = self.make_document(doc1.doc_id, 'altrev:1', content2) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, content2, True) - content3 = '{"key": "local"}' - doc2.set_json(content3) - self.assertRaises(errors.ConflictedDoc, self.db.put_doc, doc2) - - def test_delete_refuses_for_conflicted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.make_document(doc1.doc_id, 'altrev:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, True) - self.assertRaises(errors.ConflictedDoc, self.db.delete_doc, doc2) - - -class DatabaseIndexTests(tests.DatabaseBaseTests): - - scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS - - def assertParseError(self, definition): - self.db.create_doc_from_json(nested_doc) - self.assertRaises( - errors.IndexDefinitionParseError, self.db.create_index, 'idx', - definition) - - def assertIndexCreatable(self, definition): - name = "idx" - self.db.create_doc_from_json(nested_doc) - self.db.create_index(name, definition) - self.assertEqual( - [(name, [definition])], self.db.list_indexes()) - - def test_create_index(self): - self.db.create_index('test-idx', 'name') - self.assertEqual([('test-idx', ['name'])], - self.db.list_indexes()) - - def test_create_index_on_non_ascii_field_name(self): - doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'})) - self.db.create_index('test-idx', u'\xe5') - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_list_indexes_with_non_ascii_field_names(self): - self.db.create_index('test-idx', u'\xe5') - self.assertEqual( - [('test-idx', [u'\xe5'])], self.db.list_indexes()) - - def test_create_index_evaluates_it(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_wildcard_matches_unicode_value(self): - doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) - self.db.create_index('test-idx', 'key') - self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) - - def test_retrieve_unicode_value_from_index(self): - doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc], self.db.get_from_index('test-idx', u"valu\xe5")) - - def test_create_index_fails_if_name_taken(self): - self.db.create_index('test-idx', 'key') - self.assertRaises(errors.IndexNameTakenError, - self.db.create_index, - 'test-idx', 'stuff') - - def test_create_index_does_not_fail_if_name_taken_with_same_index(self): - self.db.create_index('test-idx', 'key') - self.db.create_index('test-idx', 'key') - self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) - - def test_create_index_does_not_duplicate_indexed_fields(self): - self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.db.delete_index('test-idx') - self.db.create_index('test-idx', 'key') - self.assertEqual(1, len(self.db.get_from_index('test-idx', 'value'))) - - def test_delete_index_does_not_remove_fields_from_other_indexes(self): - self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.db.create_index('test-idx2', 'key') - self.db.delete_index('test-idx') - self.assertEqual(1, len(self.db.get_from_index('test-idx2', 'value'))) - - def test_create_index_after_deleting_document(self): - doc = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc2) - self.db.create_index('test-idx', 'key') - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_delete_index(self): - self.db.create_index('test-idx', 'key') - self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) - self.db.delete_index('test-idx') - self.assertEqual([], self.db.list_indexes()) - - def test_create_adds_to_index(self): - self.db.create_index('test-idx', 'key') - doc = self.db.create_doc_from_json(simple_doc) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_get_from_index_unmatched(self): - self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.assertEqual([], self.db.get_from_index('test-idx', 'novalue')) - - def test_create_index_multiple_exact_matches(self): - doc = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.assertEqual( - sorted([doc, doc2]), - sorted(self.db.get_from_index('test-idx', 'value'))) - - def test_get_from_index(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_get_from_index_multi(self): - content = '{"key": "value", "key2": "value2"}' - doc = self.db.create_doc_from_json(content) - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc], self.db.get_from_index('test-idx', 'value', 'value2')) - - def test_get_from_index_multi_list(self): - doc = self.db.create_doc_from_json( - '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc], self.db.get_from_index('test-idx', 'value', 'value2-1')) - self.assertEqual( - [doc], self.db.get_from_index('test-idx', 'value', 'value2-2')) - self.assertEqual( - [doc], self.db.get_from_index('test-idx', 'value', 'value2-3')) - self.assertEqual( - [('value', 'value2-1'), ('value', 'value2-2'), - ('value', 'value2-3')], - sorted(self.db.get_index_keys('test-idx'))) - - def test_get_from_index_sees_conflicts(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key', 'key2') - alt_doc = self.make_document( - doc.doc_id, 'alternate:1', - '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}') - self.db._put_doc_if_newer( - alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - docs = self.db.get_from_index('test-idx', 'value', 'value2-1') - self.assertTrue(docs[0].has_conflicts) - - def test_get_index_keys_multi_list_list(self): - self.db.create_doc_from_json( - '{"key": "value1-1 value1-2 value1-3", ' - '"key2": ["value2-1", "value2-2", "value2-3"]}') - self.db.create_index('test-idx', 'split_words(key)', 'key2') - self.assertEqual( - [(u'value1-1', u'value2-1'), (u'value1-1', u'value2-2'), - (u'value1-1', u'value2-3'), (u'value1-2', u'value2-1'), - (u'value1-2', u'value2-2'), (u'value1-2', u'value2-3'), - (u'value1-3', u'value2-1'), (u'value1-3', u'value2-2'), - (u'value1-3', u'value2-3')], - sorted(self.db.get_index_keys('test-idx'))) - - def test_get_from_index_multi_ordered(self): - doc1 = self.db.create_doc_from_json( - '{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value3"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value2"}') - doc4 = self.db.create_doc_from_json( - '{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc4, doc3, doc2, doc1], - self.db.get_from_index('test-idx', 'v*', '*')) - - def test_get_range_from_index_start_end(self): - doc1 = self.db.create_doc_from_json('{"key": "value3"}') - doc2 = self.db.create_doc_from_json('{"key": "value2"}') - self.db.create_doc_from_json('{"key": "value4"}') - self.db.create_doc_from_json('{"key": "value1"}') - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc2, doc1], - self.db.get_range_from_index('test-idx', 'value2', 'value3')) - - def test_get_range_from_index_start(self): - doc1 = self.db.create_doc_from_json('{"key": "value3"}') - doc2 = self.db.create_doc_from_json('{"key": "value2"}') - doc3 = self.db.create_doc_from_json('{"key": "value4"}') - self.db.create_doc_from_json('{"key": "value1"}') - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc2, doc1, doc3], - self.db.get_range_from_index('test-idx', 'value2')) - - def test_get_range_from_index_sees_conflicts(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - alt_doc = self.make_document( - doc.doc_id, 'alternate:1', '{"key": "valuedepalue"}') - self.db._put_doc_if_newer( - alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - docs = self.db.get_range_from_index('test-idx', 'a') - self.assertTrue(docs[0].has_conflicts) - - def test_get_range_from_index_end(self): - self.db.create_doc_from_json('{"key": "value3"}') - doc2 = self.db.create_doc_from_json('{"key": "value2"}') - self.db.create_doc_from_json('{"key": "value4"}') - doc4 = self.db.create_doc_from_json('{"key": "value1"}') - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc4, doc2], - self.db.get_range_from_index('test-idx', None, 'value2')) - - def test_get_wildcard_range_from_index_start(self): - doc1 = self.db.create_doc_from_json('{"key": "value4"}') - doc2 = self.db.create_doc_from_json('{"key": "value23"}') - doc3 = self.db.create_doc_from_json('{"key": "value2"}') - doc4 = self.db.create_doc_from_json('{"key": "value22"}') - self.db.create_doc_from_json('{"key": "value1"}') - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc3, doc4, doc2, doc1], - self.db.get_range_from_index('test-idx', 'value2*')) - - def test_get_wildcard_range_from_index_end(self): - self.db.create_doc_from_json('{"key": "value4"}') - doc2 = self.db.create_doc_from_json('{"key": "value23"}') - doc3 = self.db.create_doc_from_json('{"key": "value2"}') - doc4 = self.db.create_doc_from_json('{"key": "value22"}') - doc5 = self.db.create_doc_from_json('{"key": "value1"}') - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc5, doc3, doc4, doc2], - self.db.get_range_from_index('test-idx', None, 'value2*')) - - def test_get_wildcard_range_from_index_start_end(self): - self.db.create_doc_from_json('{"key": "a"}') - self.db.create_doc_from_json('{"key": "boo3"}') - doc3 = self.db.create_doc_from_json('{"key": "catalyst"}') - doc4 = self.db.create_doc_from_json('{"key": "whaever"}') - self.db.create_doc_from_json('{"key": "zerg"}') - self.db.create_index('test-idx', 'key') - self.assertEqual( - [doc3, doc4], - self.db.get_range_from_index('test-idx', 'cat*', 'zap*')) - - def test_get_range_from_index_multi_column_start_end(self): - self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value3"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value2"}') - self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc3, doc2], - self.db.get_range_from_index( - 'test-idx', ('value2', 'value2'), ('value2', 'value3'))) - - def test_get_range_from_index_multi_column_start(self): - doc1 = self.db.create_doc_from_json( - '{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value3"}') - self.db.create_doc_from_json('{"key": "value2", "key2": "value2"}') - self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc2, doc1], - self.db.get_range_from_index('test-idx', ('value2', 'value3'))) - - def test_get_range_from_index_multi_column_end(self): - self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value3"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value2"}') - doc4 = self.db.create_doc_from_json( - '{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc4, doc3, doc2], - self.db.get_range_from_index( - 'test-idx', None, ('value2', 'value3'))) - - def test_get_wildcard_range_from_index_multi_column_start(self): - doc1 = self.db.create_doc_from_json( - '{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value23"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value2"}') - self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc3, doc2, doc1], - self.db.get_range_from_index('test-idx', ('value2', 'value2*'))) - - def test_get_wildcard_range_from_index_multi_column_end(self): - self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value23"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value2"}') - doc4 = self.db.create_doc_from_json( - '{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc4, doc3, doc2], - self.db.get_range_from_index( - 'test-idx', None, ('value2', 'value2*'))) - - def test_get_glob_range_from_index_multi_column_start(self): - doc1 = self.db.create_doc_from_json( - '{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value23"}') - self.db.create_doc_from_json('{"key": "value1", "key2": "value2"}') - self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc2, doc1], - self.db.get_range_from_index('test-idx', ('value2', '*'))) - - def test_get_glob_range_from_index_multi_column_end(self): - self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value23"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value1", "key2": "value2"}') - doc4 = self.db.create_doc_from_json( - '{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc4, doc3, doc2], - self.db.get_range_from_index('test-idx', None, ('value2', '*'))) - - def test_get_range_from_index_illegal_wildcard_order(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_range_from_index, 'test-idx', ('*', 'v2')) - - def test_get_range_from_index_illegal_glob_after_wildcard(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_range_from_index, 'test-idx', ('*', 'v*')) - - def test_get_range_from_index_illegal_wildcard_order_end(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_range_from_index, 'test-idx', None, ('*', 'v2')) - - def test_get_range_from_index_illegal_glob_after_wildcard_end(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_range_from_index, 'test-idx', None, ('*', 'v*')) - - def test_get_from_index_fails_if_no_index(self): - self.assertRaises( - errors.IndexDoesNotExist, self.db.get_from_index, 'foo') - - def test_get_index_keys_fails_if_no_index(self): - self.assertRaises(errors.IndexDoesNotExist, - self.db.get_index_keys, - 'foo') - - def test_get_index_keys_works_if_no_docs(self): - self.db.create_index('test-idx', 'key') - self.assertEqual([], self.db.get_index_keys('test-idx')) - - def test_put_updates_index(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - new_content = '{"key": "altval"}' - doc.set_json(new_content) - self.db.put_doc(doc) - self.assertEqual([], self.db.get_from_index('test-idx', 'value')) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'altval')) - - def test_delete_updates_index(self): - doc = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(simple_doc) - self.db.create_index('test-idx', 'key') - self.assertEqual( - sorted([doc, doc2]), - sorted(self.db.get_from_index('test-idx', 'value'))) - self.db.delete_doc(doc) - self.assertEqual([doc2], self.db.get_from_index('test-idx', 'value')) - - def test_get_from_index_illegal_number_of_entries(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidValueForIndex, self.db.get_from_index, 'test-idx') - self.assertRaises( - errors.InvalidValueForIndex, - self.db.get_from_index, 'test-idx', 'v1') - self.assertRaises( - errors.InvalidValueForIndex, - self.db.get_from_index, 'test-idx', 'v1', 'v2', 'v3') - - def test_get_from_index_illegal_wildcard_order(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_from_index, 'test-idx', '*', 'v2') - - def test_get_from_index_illegal_glob_after_wildcard(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_from_index, 'test-idx', '*', 'v*') - - def test_get_all_from_index(self): - self.db.create_index('test-idx', 'key') - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - # This one should not be in the index - self.db.create_doc_from_json('{"no": "key"}') - diff_value_doc = '{"key": "diff value"}' - doc4 = self.db.create_doc_from_json(diff_value_doc) - # This is essentially a 'prefix' match, but we match every entry. - self.assertEqual( - sorted([doc1, doc2, doc4]), - sorted(self.db.get_from_index('test-idx', '*'))) - - def test_get_all_from_index_ordered(self): - self.db.create_index('test-idx', 'key') - doc1 = self.db.create_doc_from_json('{"key": "value x"}') - doc2 = self.db.create_doc_from_json('{"key": "value b"}') - doc3 = self.db.create_doc_from_json('{"key": "value a"}') - doc4 = self.db.create_doc_from_json('{"key": "value m"}') - # This is essentially a 'prefix' match, but we match every entry. - self.assertEqual( - [doc3, doc2, doc4, doc1], self.db.get_from_index('test-idx', '*')) - - def test_put_updates_when_adding_key(self): - doc = self.db.create_doc_from_json("{}") - self.db.create_index('test-idx', 'key') - self.assertEqual([], self.db.get_from_index('test-idx', '*')) - doc.set_json(simple_doc) - self.db.put_doc(doc) - self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) - - def test_get_from_index_empty_string(self): - self.db.create_index('test-idx', 'key') - doc1 = self.db.create_doc_from_json(simple_doc) - content2 = '{"key": ""}' - doc2 = self.db.create_doc_from_json(content2) - self.assertEqual([doc2], self.db.get_from_index('test-idx', '')) - # Empty string matches the wildcard. - self.assertEqual( - sorted([doc1, doc2]), - sorted(self.db.get_from_index('test-idx', '*'))) - - def test_get_from_index_not_null(self): - self.db.create_index('test-idx', 'key') - doc1 = self.db.create_doc_from_json(simple_doc) - self.db.create_doc_from_json('{"key": null}') - self.assertEqual([doc1], self.db.get_from_index('test-idx', '*')) - - def test_get_partial_from_index(self): - content1 = '{"k1": "v1", "k2": "v2"}' - content2 = '{"k1": "v1", "k2": "x2"}' - content3 = '{"k1": "v1", "k2": "y2"}' - # doc4 has a different k1 value, so it doesn't match the prefix. - content4 = '{"k1": "NN", "k2": "v2"}' - doc1 = self.db.create_doc_from_json(content1) - doc2 = self.db.create_doc_from_json(content2) - doc3 = self.db.create_doc_from_json(content3) - self.db.create_doc_from_json(content4) - self.db.create_index('test-idx', 'k1', 'k2') - self.assertEqual( - sorted([doc1, doc2, doc3]), - sorted(self.db.get_from_index('test-idx', "v1", "*"))) - - def test_get_glob_match(self): - # Note: the exact glob syntax is probably subject to change - content1 = '{"k1": "v1", "k2": "v1"}' - content2 = '{"k1": "v1", "k2": "v2"}' - content3 = '{"k1": "v1", "k2": "v3"}' - # doc4 has a different k2 prefix value, so it doesn't match - content4 = '{"k1": "v1", "k2": "ZZ"}' - self.db.create_index('test-idx', 'k1', 'k2') - doc1 = self.db.create_doc_from_json(content1) - doc2 = self.db.create_doc_from_json(content2) - doc3 = self.db.create_doc_from_json(content3) - self.db.create_doc_from_json(content4) - self.assertEqual( - sorted([doc1, doc2, doc3]), - sorted(self.db.get_from_index('test-idx', "v1", "v*"))) - - def test_nested_index(self): - doc = self.db.create_doc_from_json(nested_doc) - self.db.create_index('test-idx', 'sub.doc') - self.assertEqual( - [doc], self.db.get_from_index('test-idx', 'underneath')) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertEqual( - sorted([doc, doc2]), - sorted(self.db.get_from_index('test-idx', 'underneath'))) - - def test_nested_nonexistent(self): - self.db.create_doc_from_json(nested_doc) - # sub exists, but sub.foo does not: - self.db.create_index('test-idx', 'sub.foo') - self.assertEqual([], self.db.get_from_index('test-idx', '*')) - - def test_nested_nonexistent2(self): - self.db.create_doc_from_json(nested_doc) - self.db.create_index('test-idx', 'sub.foo.bar.baz.qux.fnord') - self.assertEqual([], self.db.get_from_index('test-idx', '*')) - - def test_nested_traverses_lists(self): - # subpath finds dicts in list - doc = self.db.create_doc_from_json( - '{"foo": [{"zap": "bar"}, {"zap": "baz"}]}') - # subpath only finds dicts in list - self.db.create_doc_from_json('{"foo": ["zap", "baz"]}') - self.db.create_index('test-idx', 'foo.zap') - self.assertEqual([doc], self.db.get_from_index('test-idx', 'bar')) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'baz')) - - def test_nested_list_traversal(self): - # subpath finds dicts in list - doc = self.db.create_doc_from_json( - '{"foo": [{"zap": [{"qux": "fnord"}, {"qux": "zombo"}]},' - '{"zap": "baz"}]}') - # subpath only finds dicts in list - self.db.create_index('test-idx', 'foo.zap.qux') - self.assertEqual([doc], self.db.get_from_index('test-idx', 'fnord')) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'zombo')) - - def test_index_list1(self): - self.db.create_index("index", "name") - content = '{"name": ["foo", "bar"]}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "bar") - self.assertEqual([doc], rows) - - def test_index_list2(self): - self.db.create_index("index", "name") - content = '{"name": ["foo", "bar"]}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "foo") - self.assertEqual([doc], rows) - - def test_get_from_index_case_sensitive(self): - self.db.create_index('test-idx', 'key') - doc1 = self.db.create_doc_from_json(simple_doc) - self.assertEqual([], self.db.get_from_index('test-idx', 'V*')) - self.assertEqual([doc1], self.db.get_from_index('test-idx', 'v*')) - - def test_get_from_index_illegal_glob_before_value(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_from_index, 'test-idx', 'v*', 'v2') - - def test_get_from_index_illegal_glob_after_glob(self): - self.db.create_index('test-idx', 'k1', 'k2') - self.assertRaises( - errors.InvalidGlobbing, - self.db.get_from_index, 'test-idx', 'v*', 'v*') - - def test_get_from_index_with_sql_wildcards(self): - self.db.create_index('test-idx', 'key') - content1 = '{"key": "va%lue"}' - content2 = '{"key": "value"}' - content3 = '{"key": "va_lue"}' - doc1 = self.db.create_doc_from_json(content1) - self.db.create_doc_from_json(content2) - doc3 = self.db.create_doc_from_json(content3) - # The '%' in the search should be treated literally, not as a sql - # globbing character. - self.assertEqual([doc1], self.db.get_from_index('test-idx', 'va%*')) - # Same for '_' - self.assertEqual([doc3], self.db.get_from_index('test-idx', 'va_*')) - - def test_get_from_index_with_lower(self): - self.db.create_index("index", "lower(name)") - content = '{"name": "Foo"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "foo") - self.assertEqual([doc], rows) - - def test_get_from_index_with_lower_matches_same_case(self): - self.db.create_index("index", "lower(name)") - content = '{"name": "foo"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "foo") - self.assertEqual([doc], rows) - - def test_index_lower_doesnt_match_different_case(self): - self.db.create_index("index", "lower(name)") - content = '{"name": "Foo"}' - self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "Foo") - self.assertEqual([], rows) - - def test_index_lower_doesnt_match_other_index(self): - self.db.create_index("index", "lower(name)") - self.db.create_index("other_index", "name") - content = '{"name": "Foo"}' - self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "Foo") - self.assertEqual(0, len(rows)) - - def test_index_split_words_match_first(self): - self.db.create_index("index", "split_words(name)") - content = '{"name": "foo bar"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "foo") - self.assertEqual([doc], rows) - - def test_index_split_words_match_second(self): - self.db.create_index("index", "split_words(name)") - content = '{"name": "foo bar"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "bar") - self.assertEqual([doc], rows) - - def test_index_split_words_match_both(self): - self.db.create_index("index", "split_words(name)") - content = '{"name": "foo foo"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "foo") - self.assertEqual([doc], rows) - - def test_index_split_words_double_space(self): - self.db.create_index("index", "split_words(name)") - content = '{"name": "foo bar"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "bar") - self.assertEqual([doc], rows) - - def test_index_split_words_leading_space(self): - self.db.create_index("index", "split_words(name)") - content = '{"name": " foo bar"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "foo") - self.assertEqual([doc], rows) - - def test_index_split_words_trailing_space(self): - self.db.create_index("index", "split_words(name)") - content = '{"name": "foo bar "}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "bar") - self.assertEqual([doc], rows) - - def test_get_from_index_with_number(self): - self.db.create_index("index", "number(foo, 5)") - content = '{"foo": 12}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "00012") - self.assertEqual([doc], rows) - - def test_get_from_index_with_number_bigger_than_padding(self): - self.db.create_index("index", "number(foo, 5)") - content = '{"foo": 123456}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "123456") - self.assertEqual([doc], rows) - - def test_number_mapping_ignores_non_numbers(self): - self.db.create_index("index", "number(foo, 5)") - content = '{"foo": 56}' - doc1 = self.db.create_doc_from_json(content) - content = '{"foo": "this is not a maigret painting"}' - self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "*") - self.assertEqual([doc1], rows) - - def test_get_from_index_with_bool(self): - self.db.create_index("index", "bool(foo)") - content = '{"foo": true}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "1") - self.assertEqual([doc], rows) - - def test_get_from_index_with_bool_false(self): - self.db.create_index("index", "bool(foo)") - content = '{"foo": false}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "0") - self.assertEqual([doc], rows) - - def test_get_from_index_with_non_bool(self): - self.db.create_index("index", "bool(foo)") - content = '{"foo": 42}' - self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "*") - self.assertEqual([], rows) - - def test_get_from_index_with_combine(self): - self.db.create_index("index", "combine(foo, bar)") - content = '{"foo": "value1", "bar": "value2"}' - doc = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "value1") - self.assertEqual([doc], rows) - rows = self.db.get_from_index("index", "value2") - self.assertEqual([doc], rows) - - def test_get_complex_combine(self): - self.db.create_index( - "index", "combine(number(foo, 5), lower(bar), split_words(baz))") - content = '{"foo": 12, "bar": "ALLCAPS", "baz": "qux nox"}' - doc = self.db.create_doc_from_json(content) - content = '{"foo": "not a number", "bar": "something"}' - doc2 = self.db.create_doc_from_json(content) - rows = self.db.get_from_index("index", "00012") - self.assertEqual([doc], rows) - rows = self.db.get_from_index("index", "allcaps") - self.assertEqual([doc], rows) - rows = self.db.get_from_index("index", "nox") - self.assertEqual([doc], rows) - rows = self.db.get_from_index("index", "something") - self.assertEqual([doc2], rows) - - def test_get_index_keys_from_index(self): - self.db.create_index('test-idx', 'key') - content1 = '{"key": "value1"}' - content2 = '{"key": "value2"}' - content3 = '{"key": "value2"}' - self.db.create_doc_from_json(content1) - self.db.create_doc_from_json(content2) - self.db.create_doc_from_json(content3) - self.assertEqual( - [('value1',), ('value2',)], - sorted(self.db.get_index_keys('test-idx'))) - - def test_get_index_keys_from_multicolumn_index(self): - self.db.create_index('test-idx', 'key1', 'key2') - content1 = '{"key1": "value1", "key2": "val2-1"}' - content2 = '{"key1": "value2", "key2": "val2-2"}' - content3 = '{"key1": "value2", "key2": "val2-2"}' - content4 = '{"key1": "value2", "key2": "val3"}' - self.db.create_doc_from_json(content1) - self.db.create_doc_from_json(content2) - self.db.create_doc_from_json(content3) - self.db.create_doc_from_json(content4) - self.assertEqual([ - ('value1', 'val2-1'), - ('value2', 'val2-2'), - ('value2', 'val3')], - sorted(self.db.get_index_keys('test-idx'))) - - def test_empty_expr(self): - self.assertParseError('') - - def test_nested_unknown_operation(self): - self.assertParseError('unknown_operation(field1)') - - def test_parse_missing_close_paren(self): - self.assertParseError("lower(a") - - def test_parse_trailing_close_paren(self): - self.assertParseError("lower(ab))") - - def test_parse_trailing_chars(self): - self.assertParseError("lower(ab)adsf") - - def test_parse_empty_op(self): - self.assertParseError("(ab)") - - def test_parse_top_level_commas(self): - self.assertParseError("a, b") - - def test_invalid_field_name(self): - self.assertParseError("a.") - - def test_invalid_inner_field_name(self): - self.assertParseError("lower(a.)") - - def test_gobbledigook(self): - self.assertParseError("(@#@cc @#!*DFJSXV(()jccd") - - def test_leading_space(self): - self.assertIndexCreatable(" lower(a)") - - def test_trailing_space(self): - self.assertIndexCreatable("lower(a) ") - - def test_spaces_before_open_paren(self): - self.assertIndexCreatable("lower (a)") - - def test_spaces_after_open_paren(self): - self.assertIndexCreatable("lower( a)") - - def test_spaces_before_close_paren(self): - self.assertIndexCreatable("lower(a )") - - def test_spaces_before_comma(self): - self.assertIndexCreatable("combine(a , b , c)") - - def test_spaces_after_comma(self): - self.assertIndexCreatable("combine(a, b, c)") - - def test_all_together_now(self): - self.assertParseError(' (a) ') - - def test_all_together_now2(self): - self.assertParseError('combine(lower(x)x,foo)') - - -class PythonBackendTests(tests.DatabaseBaseTests): - - def setUp(self): - super(PythonBackendTests, self).setUp() - self.simple_doc = json.loads(simple_doc) - - def test_create_doc_with_factory(self): - self.db.set_document_factory(TestAlternativeDocument) - doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id') - self.assertTrue(isinstance(doc, TestAlternativeDocument)) - - def test_get_doc_after_put_with_factory(self): - doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id') - self.db.set_document_factory(TestAlternativeDocument) - result = self.db.get_doc('my_doc_id') - self.assertTrue(isinstance(result, TestAlternativeDocument)) - self.assertEqual(doc.doc_id, result.doc_id) - self.assertEqual(doc.rev, result.rev) - self.assertEqual(doc.get_json(), result.get_json()) - self.assertEqual(False, result.has_conflicts) - - def test_get_doc_nonexisting_with_factory(self): - self.db.set_document_factory(TestAlternativeDocument) - self.assertIs(None, self.db.get_doc('non-existing')) - - def test_get_all_docs_with_factory(self): - self.db.set_document_factory(TestAlternativeDocument) - self.db.create_doc(self.simple_doc) - self.assertTrue(isinstance( - list(self.db.get_all_docs()[1])[0], TestAlternativeDocument)) - - def test_get_docs_conflicted_with_factory(self): - self.db.set_document_factory(TestAlternativeDocument) - doc1 = self.db.create_doc(self.simple_doc) - doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) - self.db._put_doc_if_newer( - doc2, save_conflict=True, replica_uid='r', replica_gen=1, - replica_trans_id='foo') - self.assertTrue( - isinstance( - list(self.db.get_docs([doc1.doc_id]))[0], - TestAlternativeDocument)) - - def test_get_from_index_with_factory(self): - self.db.set_document_factory(TestAlternativeDocument) - self.db.create_doc(self.simple_doc) - self.db.create_index('test-idx', 'key') - self.assertTrue( - isinstance( - self.db.get_from_index('test-idx', 'value')[0], - TestAlternativeDocument)) - - def test_sync_exchange_updates_indexes(self): - doc = self.db.create_doc(self.simple_doc) - self.db.create_index('test-idx', 'key') - new_content = '{"key": "altval"}' - other_rev = 'test:1|z:2' - st = self.db.get_sync_target() - - def ignore(doc_id, doc_rev, doc): - pass - - doc_other = self.make_document(doc.doc_id, other_rev, new_content) - docs_by_gen = [(doc_other, 10, 'T-sid')] - st.sync_exchange( - docs_by_gen, 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=ignore) - self.assertGetDoc(self.db, doc.doc_id, other_rev, new_content, False) - self.assertEqual( - [doc_other], self.db.get_from_index('test-idx', 'altval')) - self.assertEqual([], self.db.get_from_index('test-idx', 'value')) - - -# Use a custom loader to apply the scenarios at load time. -load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_c_backend.py b/src/leap/soledad/u1db/tests/test_c_backend.py deleted file mode 100644 index bdd2aec7..00000000 --- a/src/leap/soledad/u1db/tests/test_c_backend.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -try: - import simplejson as json -except ImportError: - import json # noqa -from u1db import ( - Document, - errors, - tests, - ) -from u1db.tests import c_backend_wrapper, c_backend_error -from u1db.tests.test_remote_sync_target import ( - make_http_app, - make_oauth_http_app - ) - - -class TestCDatabaseExists(tests.TestCase): - - def test_c_backend_compiled(self): - if c_backend_wrapper is None: - self.fail("Could not import the c_backend_wrapper module." - " Was it compiled properly?\n%s" % (c_backend_error,)) - - -# Rather than lots of failing tests, we have the above check to test that the -# module exists, and all these tests just get skipped -class BackendTests(tests.TestCase): - - def setUp(self): - super(BackendTests, self).setUp() - if c_backend_wrapper is None: - self.skipTest("The c_backend_wrapper could not be imported") - - -class TestCDatabase(BackendTests): - - def test_exists(self): - if c_backend_wrapper is None: - self.fail("Could not import the c_backend_wrapper module." - " Was it compiled properly?") - db = c_backend_wrapper.CDatabase(':memory:') - self.assertEqual(':memory:', db._filename) - - def test__is_closed(self): - db = c_backend_wrapper.CDatabase(':memory:') - self.assertTrue(db._sql_is_open()) - db.close() - self.assertFalse(db._sql_is_open()) - - def test__run_sql(self): - db = c_backend_wrapper.CDatabase(':memory:') - self.assertTrue(db._sql_is_open()) - self.assertEqual([], db._run_sql('CREATE TABLE test (id INTEGER)')) - self.assertEqual([], db._run_sql('INSERT INTO test VALUES (1)')) - self.assertEqual([('1',)], db._run_sql('SELECT * FROM test')) - - def test__get_generation(self): - db = c_backend_wrapper.CDatabase(':memory:') - self.assertEqual(0, db._get_generation()) - db.create_doc_from_json(tests.simple_doc) - self.assertEqual(1, db._get_generation()) - - def test__get_generation_info(self): - db = c_backend_wrapper.CDatabase(':memory:') - self.assertEqual((0, ''), db._get_generation_info()) - db.create_doc_from_json(tests.simple_doc) - info = db._get_generation_info() - self.assertEqual(1, info[0]) - self.assertTrue(info[1].startswith('T-')) - - def test__set_replica_uid(self): - db = c_backend_wrapper.CDatabase(':memory:') - self.assertIsNot(None, db._replica_uid) - db._set_replica_uid('foo') - self.assertEqual([('foo',)], db._run_sql( - "SELECT value FROM u1db_config WHERE name='replica_uid'")) - - def test_default_replica_uid(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.assertIsNot(None, self.db._replica_uid) - self.assertEqual(32, len(self.db._replica_uid)) - # casting to an int from the uid *is* the check for correct behavior. - int(self.db._replica_uid, 16) - - def test_get_conflicts_with_borked_data(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - # We add an entry to conflicts, but not to documents, which is an - # invalid situation - self.db._run_sql("INSERT INTO conflicts" - " VALUES ('doc-id', 'doc-rev', '{}')") - self.assertRaises(Exception, self.db.get_doc_conflicts, 'doc-id') - - def test_create_index_list(self): - # We manually poke data into the DB, so that we test just the "get_doc" - # code, rather than also testing the index management code. - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(tests.simple_doc) - self.db.create_index_list("key-idx", ["key"]) - docs = self.db.get_from_index('key-idx', 'value') - self.assertEqual([doc], docs) - - def test_create_index_list_on_non_ascii_field_name(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'})) - self.db.create_index_list('test-idx', [u'\xe5']) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_list_indexes_with_non_ascii_field_names(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.db.create_index_list('test-idx', [u'\xe5']) - self.assertEqual( - [('test-idx', [u'\xe5'])], self.db.list_indexes()) - - def test_create_index_evaluates_it(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(tests.simple_doc) - self.db.create_index_list('test-idx', ['key']) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_wildcard_matches_unicode_value(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) - self.db.create_index_list('test-idx', ['key']) - self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) - - def test_create_index_fails_if_name_taken(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.db.create_index_list('test-idx', ['key']) - self.assertRaises(errors.IndexNameTakenError, - self.db.create_index_list, - 'test-idx', ['stuff']) - - def test_create_index_does_not_fail_if_name_taken_with_same_index(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.db.create_index_list('test-idx', ['key']) - self.db.create_index_list('test-idx', ['key']) - self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) - - def test_create_index_after_deleting_document(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(tests.simple_doc) - doc2 = self.db.create_doc_from_json(tests.simple_doc) - self.db.delete_doc(doc2) - self.db.create_index_list('test-idx', ['key']) - self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) - - def test_get_from_index(self): - # We manually poke data into the DB, so that we test just the "get_doc" - # code, rather than also testing the index management code. - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(tests.simple_doc) - self.db.create_index("key-idx", "key") - docs = self.db.get_from_index('key-idx', 'value') - self.assertEqual([doc], docs) - - def test_get_from_index_list(self): - # We manually poke data into the DB, so that we test just the "get_doc" - # code, rather than also testing the index management code. - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(tests.simple_doc) - self.db.create_index("key-idx", "key") - docs = self.db.get_from_index_list('key-idx', ['value']) - self.assertEqual([doc], docs) - - def test_get_from_index_list_multi(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - content = '{"key": "value", "key2": "value2"}' - doc = self.db.create_doc_from_json(content) - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc], - self.db.get_from_index_list('test-idx', ['value', 'value2'])) - - def test_get_from_index_list_multi_ordered(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - doc1 = self.db.create_doc_from_json( - '{"key": "value3", "key2": "value4"}') - doc2 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value3"}') - doc3 = self.db.create_doc_from_json( - '{"key": "value2", "key2": "value2"}') - doc4 = self.db.create_doc_from_json( - '{"key": "value1", "key2": "value1"}') - self.db.create_index('test-idx', 'key', 'key2') - self.assertEqual( - [doc4, doc3, doc2, doc1], - self.db.get_from_index_list('test-idx', ['v*', '*'])) - - def test_get_from_index_2(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - doc = self.db.create_doc_from_json(tests.nested_doc) - self.db.create_index("multi-idx", "key", "sub.doc") - docs = self.db.get_from_index('multi-idx', 'value', 'underneath') - self.assertEqual([doc], docs) - - def test_get_index_keys(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.db.create_doc_from_json(tests.simple_doc) - self.db.create_index("key-idx", "key") - keys = self.db.get_index_keys('key-idx') - self.assertEqual([("value",)], keys) - - def test__query_init_one_field(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.db.create_index("key-idx", "key") - query = self.db._query_init("key-idx") - self.assertEqual("key-idx", query.index_name) - self.assertEqual(1, query.num_fields) - self.assertEqual(["key"], query.fields) - - def test__query_init_two_fields(self): - self.db = c_backend_wrapper.CDatabase(':memory:') - self.db.create_index("two-idx", "key", "key2") - query = self.db._query_init("two-idx") - self.assertEqual("two-idx", query.index_name) - self.assertEqual(2, query.num_fields) - self.assertEqual(["key", "key2"], query.fields) - - def assertFormatQueryEquals(self, expected, wildcards, fields): - val, w = c_backend_wrapper._format_query(fields) - self.assertEqual(expected, val) - self.assertEqual(wildcards, w) - - def test__format_query(self): - self.assertFormatQueryEquals( - "SELECT d0.doc_id FROM document_fields d0" - " WHERE d0.field_name = ? AND d0.value = ? ORDER BY d0.value", - [0], ["1"]) - self.assertFormatQueryEquals( - "SELECT d0.doc_id" - " FROM document_fields d0, document_fields d1" - " WHERE d0.field_name = ? AND d0.value = ?" - " AND d0.doc_id = d1.doc_id" - " AND d1.field_name = ? AND d1.value = ?" - " ORDER BY d0.value, d1.value", - [0, 0], ["1", "2"]) - self.assertFormatQueryEquals( - "SELECT d0.doc_id" - " FROM document_fields d0, document_fields d1, document_fields d2" - " WHERE d0.field_name = ? AND d0.value = ?" - " AND d0.doc_id = d1.doc_id" - " AND d1.field_name = ? AND d1.value = ?" - " AND d0.doc_id = d2.doc_id" - " AND d2.field_name = ? AND d2.value = ?" - " ORDER BY d0.value, d1.value, d2.value", - [0, 0, 0], ["1", "2", "3"]) - - def test__format_query_wildcard(self): - self.assertFormatQueryEquals( - "SELECT d0.doc_id FROM document_fields d0" - " WHERE d0.field_name = ? AND d0.value NOT NULL ORDER BY d0.value", - [1], ["*"]) - self.assertFormatQueryEquals( - "SELECT d0.doc_id" - " FROM document_fields d0, document_fields d1" - " WHERE d0.field_name = ? AND d0.value = ?" - " AND d0.doc_id = d1.doc_id" - " AND d1.field_name = ? AND d1.value NOT NULL" - " ORDER BY d0.value, d1.value", - [0, 1], ["1", "*"]) - - def test__format_query_glob(self): - self.assertFormatQueryEquals( - "SELECT d0.doc_id FROM document_fields d0" - " WHERE d0.field_name = ? AND d0.value GLOB ? ORDER BY d0.value", - [2], ["1*"]) - - -class TestCSyncTarget(BackendTests): - - def setUp(self): - super(TestCSyncTarget, self).setUp() - self.db = c_backend_wrapper.CDatabase(':memory:') - self.st = self.db.get_sync_target() - - def test_attached_to_db(self): - self.assertEqual( - self.db._replica_uid, self.st.get_sync_info("misc")[0]) - - def test_get_sync_exchange(self): - exc = self.st._get_sync_exchange("source-uid", 10) - self.assertIsNot(None, exc) - - def test_sync_exchange_insert_doc_from_source(self): - exc = self.st._get_sync_exchange("source-uid", 5) - doc = c_backend_wrapper.make_document('doc-id', 'replica:1', - tests.simple_doc) - self.assertEqual([], exc.get_seen_ids()) - exc.insert_doc_from_source(doc, 10, 'T-sid') - self.assertGetDoc(self.db, 'doc-id', 'replica:1', tests.simple_doc, - False) - self.assertEqual( - (10, 'T-sid'), self.db._get_replica_gen_and_trans_id('source-uid')) - self.assertEqual(['doc-id'], exc.get_seen_ids()) - - def test_sync_exchange_conflicted_doc(self): - doc = self.db.create_doc_from_json(tests.simple_doc) - exc = self.st._get_sync_exchange("source-uid", 5) - doc2 = c_backend_wrapper.make_document(doc.doc_id, 'replica:1', - tests.nested_doc) - self.assertEqual([], exc.get_seen_ids()) - # The insert should be rejected and the doc_id not considered 'seen' - exc.insert_doc_from_source(doc2, 10, 'T-sid') - self.assertGetDoc( - self.db, doc.doc_id, doc.rev, tests.simple_doc, False) - self.assertEqual([], exc.get_seen_ids()) - - def test_sync_exchange_find_doc_ids(self): - doc = self.db.create_doc_from_json(tests.simple_doc) - exc = self.st._get_sync_exchange("source-uid", 0) - self.assertEqual(0, exc.target_gen) - exc.find_doc_ids_to_return() - doc_id = exc.get_doc_ids_to_return()[0] - self.assertEqual( - (doc.doc_id, 1), doc_id[:-1]) - self.assertTrue(doc_id[-1].startswith('T-')) - self.assertEqual(1, exc.target_gen) - - def test_sync_exchange_find_doc_ids_not_including_recently_inserted(self): - doc1 = self.db.create_doc_from_json(tests.simple_doc) - doc2 = self.db.create_doc_from_json(tests.nested_doc) - exc = self.st._get_sync_exchange("source-uid", 0) - doc3 = c_backend_wrapper.make_document(doc1.doc_id, - doc1.rev + "|zreplica:2", tests.simple_doc) - exc.insert_doc_from_source(doc3, 10, 'T-sid') - exc.find_doc_ids_to_return() - self.assertEqual( - (doc2.doc_id, 2), exc.get_doc_ids_to_return()[0][:-1]) - self.assertEqual(3, exc.target_gen) - - def test_sync_exchange_return_docs(self): - returned = [] - - def return_doc_cb(doc, gen, trans_id): - returned.append((doc, gen, trans_id)) - - doc1 = self.db.create_doc_from_json(tests.simple_doc) - exc = self.st._get_sync_exchange("source-uid", 0) - exc.find_doc_ids_to_return() - exc.return_docs(return_doc_cb) - self.assertEqual((doc1, 1), returned[0][:-1]) - - def test_sync_exchange_doc_ids(self): - doc1 = self.db.create_doc_from_json(tests.simple_doc, doc_id='doc-1') - db2 = c_backend_wrapper.CDatabase(':memory:') - doc2 = db2.create_doc_from_json(tests.nested_doc, doc_id='doc-2') - returned = [] - - def return_doc_cb(doc, gen, trans_id): - returned.append((doc, gen, trans_id)) - - val = self.st.sync_exchange_doc_ids( - db2, [(doc2.doc_id, 1, 'T-sid')], 0, None, return_doc_cb) - last_trans_id = self.db._get_transaction_log()[-1][1] - self.assertEqual(2, self.db._get_generation()) - self.assertEqual((2, last_trans_id), val) - self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, - False) - self.assertEqual((doc1, 1), returned[0][:-1]) - - -class TestCHTTPSyncTarget(BackendTests): - - def test_format_sync_url(self): - target = c_backend_wrapper.create_http_sync_target("http://base_url") - self.assertEqual("http://base_url/sync-from/replica-uid", - c_backend_wrapper._format_sync_url(target, "replica-uid")) - - def test_format_sync_url_escapes(self): - # The base_url should not get munged (we assume it is already a - # properly formed URL), but the replica-uid should get properly escaped - target = c_backend_wrapper.create_http_sync_target( - "http://host/base%2Ctest/") - self.assertEqual("http://host/base%2Ctest/sync-from/replica%2Cuid", - c_backend_wrapper._format_sync_url(target, "replica,uid")) - - def test_format_refuses_non_http(self): - db = c_backend_wrapper.CDatabase(':memory:') - target = db.get_sync_target() - self.assertRaises(RuntimeError, - c_backend_wrapper._format_sync_url, target, 'replica,uid') - - def test_oauth_credentials(self): - target = c_backend_wrapper.create_oauth_http_sync_target( - "http://host/base%2Ctest/", - 'consumer-key', 'consumer-secret', 'token-key', 'token-secret') - auth = c_backend_wrapper._get_oauth_authorization(target, - "GET", "http://host/base%2Ctest/sync-from/abcd-efg") - self.assertIsNot(None, auth) - self.assertTrue(auth.startswith('Authorization: OAuth realm="", ')) - self.assertNotIn('http://host/base', auth) - self.assertIn('oauth_nonce="', auth) - self.assertIn('oauth_timestamp="', auth) - self.assertIn('oauth_consumer_key="consumer-key"', auth) - self.assertIn('oauth_signature_method="HMAC-SHA1"', auth) - self.assertIn('oauth_version="1.0"', auth) - self.assertIn('oauth_token="token-key"', auth) - self.assertIn('oauth_signature="', auth) - - -class TestSyncCtoHTTPViaC(tests.TestCaseWithServer): - - make_app_with_state = staticmethod(make_http_app) - - def setUp(self): - super(TestSyncCtoHTTPViaC, self).setUp() - if c_backend_wrapper is None: - self.skipTest("The c_backend_wrapper could not be imported") - self.startServer() - - def test_trivial_sync(self): - mem_db = self.request_state._create_database('test.db') - mem_doc = mem_db.create_doc_from_json(tests.nested_doc) - url = self.getURL('test.db') - target = c_backend_wrapper.create_http_sync_target(url) - db = c_backend_wrapper.CDatabase(':memory:') - doc = db.create_doc_from_json(tests.simple_doc) - c_backend_wrapper.sync_db_to_target(db, target) - self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) - self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), - False) - - def test_unavailable(self): - mem_db = self.request_state._create_database('test.db') - mem_db.create_doc_from_json(tests.nested_doc) - tries = [] - - def wrapper(instance, *args, **kwargs): - tries.append(None) - raise errors.Unavailable - - mem_db.whats_changed = wrapper - url = self.getURL('test.db') - target = c_backend_wrapper.create_http_sync_target(url) - db = c_backend_wrapper.CDatabase(':memory:') - db.create_doc_from_json(tests.simple_doc) - self.assertRaises( - errors.Unavailable, c_backend_wrapper.sync_db_to_target, db, - target) - self.assertEqual(5, len(tries)) - - def test_unavailable_then_available(self): - mem_db = self.request_state._create_database('test.db') - mem_doc = mem_db.create_doc_from_json(tests.nested_doc) - orig_whatschanged = mem_db.whats_changed - tries = [] - - def wrapper(instance, *args, **kwargs): - if len(tries) < 1: - tries.append(None) - raise errors.Unavailable - return orig_whatschanged(instance, *args, **kwargs) - - mem_db.whats_changed = wrapper - url = self.getURL('test.db') - target = c_backend_wrapper.create_http_sync_target(url) - db = c_backend_wrapper.CDatabase(':memory:') - doc = db.create_doc_from_json(tests.simple_doc) - c_backend_wrapper.sync_db_to_target(db, target) - self.assertEqual(1, len(tries)) - self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) - self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), - False) - - def test_db_sync(self): - mem_db = self.request_state._create_database('test.db') - mem_doc = mem_db.create_doc_from_json(tests.nested_doc) - url = self.getURL('test.db') - db = c_backend_wrapper.CDatabase(':memory:') - doc = db.create_doc_from_json(tests.simple_doc) - local_gen_before_sync = db.sync(url) - gen, _, changes = db.whats_changed(local_gen_before_sync) - self.assertEqual(1, len(changes)) - self.assertEqual(mem_doc.doc_id, changes[0][0]) - self.assertEqual(1, gen - local_gen_before_sync) - self.assertEqual(1, local_gen_before_sync) - self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) - self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), - False) - - -class TestSyncCtoOAuthHTTPViaC(tests.TestCaseWithServer): - - make_app_with_state = staticmethod(make_oauth_http_app) - - def setUp(self): - super(TestSyncCtoOAuthHTTPViaC, self).setUp() - if c_backend_wrapper is None: - self.skipTest("The c_backend_wrapper could not be imported") - self.startServer() - - def test_trivial_sync(self): - mem_db = self.request_state._create_database('test.db') - mem_doc = mem_db.create_doc_from_json(tests.nested_doc) - url = self.getURL('~/test.db') - target = c_backend_wrapper.create_oauth_http_sync_target(url, - tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - db = c_backend_wrapper.CDatabase(':memory:') - doc = db.create_doc_from_json(tests.simple_doc) - c_backend_wrapper.sync_db_to_target(db, target) - self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) - self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), - False) - - -class TestVectorClock(BackendTests): - - def create_vcr(self, rev): - return c_backend_wrapper.VectorClockRev(rev) - - def test_parse_empty(self): - self.assertEqual('VectorClockRev()', - repr(self.create_vcr(''))) - - def test_parse_invalid(self): - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('x'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('x:a'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|x:a'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('x:a|y:1'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|x:2a'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1||'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|x:2|'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|x:2|:'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|x:2|m:'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|x:|m:3'))) - self.assertEqual('VectorClockRev(None)', - repr(self.create_vcr('y:1|:|m:3'))) - - def test_parse_single(self): - self.assertEqual('VectorClockRev(test:1)', - repr(self.create_vcr('test:1'))) - - def test_parse_multi(self): - self.assertEqual('VectorClockRev(test:1|z:2)', - repr(self.create_vcr('test:1|z:2'))) - self.assertEqual('VectorClockRev(ab:1|bc:2|cd:3|de:4|ef:5)', - repr(self.create_vcr('ab:1|bc:2|cd:3|de:4|ef:5'))) - self.assertEqual('VectorClockRev(a:2|b:1)', - repr(self.create_vcr('b:1|a:2'))) - - -class TestCDocument(BackendTests): - - def make_document(self, *args, **kwargs): - return c_backend_wrapper.make_document(*args, **kwargs) - - def test_create(self): - self.make_document('doc-id', 'uid:1', tests.simple_doc) - - def assertPyDocEqualCDoc(self, *args, **kwargs): - cdoc = self.make_document(*args, **kwargs) - pydoc = Document(*args, **kwargs) - self.assertEqual(pydoc, cdoc) - self.assertEqual(cdoc, pydoc) - - def test_cmp_to_pydoc_equal(self): - self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc) - self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc, - has_conflicts=False) - self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc, - has_conflicts=True) - - def test_cmp_to_pydoc_not_equal_conflicts(self): - cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) - pydoc = Document('doc-id', 'uid:1', tests.simple_doc, - has_conflicts=True) - self.assertNotEqual(cdoc, pydoc) - self.assertNotEqual(pydoc, cdoc) - - def test_cmp_to_pydoc_not_equal_doc_id(self): - cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) - pydoc = Document('doc2-id', 'uid:1', tests.simple_doc) - self.assertNotEqual(cdoc, pydoc) - self.assertNotEqual(pydoc, cdoc) - - def test_cmp_to_pydoc_not_equal_doc_rev(self): - cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) - pydoc = Document('doc-id', 'uid:2', tests.simple_doc) - self.assertNotEqual(cdoc, pydoc) - self.assertNotEqual(pydoc, cdoc) - - def test_cmp_to_pydoc_not_equal_content(self): - cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) - pydoc = Document('doc-id', 'uid:1', tests.nested_doc) - self.assertNotEqual(cdoc, pydoc) - self.assertNotEqual(pydoc, cdoc) - - -class TestUUID(BackendTests): - - def test_uuid4_conformance(self): - uuids = set() - for i in range(20): - uuid = c_backend_wrapper.generate_hex_uuid() - self.assertIsInstance(uuid, str) - self.assertEqual(32, len(uuid)) - # This will raise ValueError if it isn't a valid hex string - long(uuid, 16) - # Version 4 uuids have 2 other requirements, the high 4 bits of the - # seventh byte are always '0x4', and the middle bits of byte 9 are - # always set - self.assertEqual('4', uuid[12]) - self.assertTrue(uuid[16] in '89ab') - self.assertTrue(uuid not in uuids) - uuids.add(uuid) diff --git a/src/leap/soledad/u1db/tests/test_common_backend.py b/src/leap/soledad/u1db/tests/test_common_backend.py deleted file mode 100644 index 8c7c7ed9..00000000 --- a/src/leap/soledad/u1db/tests/test_common_backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test common backend bits.""" - -from u1db import ( - backends, - tests, - ) - - -class TestCommonBackendImpl(tests.TestCase): - - def test__allocate_doc_id(self): - db = backends.CommonBackend() - doc_id1 = db._allocate_doc_id() - self.assertTrue(doc_id1.startswith('D-')) - self.assertEqual(34, len(doc_id1)) - int(doc_id1[len('D-'):], 16) - self.assertNotEqual(doc_id1, db._allocate_doc_id()) diff --git a/src/leap/soledad/u1db/tests/test_document.py b/src/leap/soledad/u1db/tests/test_document.py deleted file mode 100644 index 20f254b9..00000000 --- a/src/leap/soledad/u1db/tests/test_document.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - - -from u1db import errors, tests - - -class TestDocument(tests.TestCase): - - scenarios = ([( - 'py', {'make_document_for_test': tests.make_document_for_test})] + - tests.C_DATABASE_SCENARIOS) - - def test_create_doc(self): - doc = self.make_document('doc-id', 'uid:1', tests.simple_doc) - self.assertEqual('doc-id', doc.doc_id) - self.assertEqual('uid:1', doc.rev) - self.assertEqual(tests.simple_doc, doc.get_json()) - self.assertFalse(doc.has_conflicts) - - def test__repr__(self): - doc = self.make_document('doc-id', 'uid:1', tests.simple_doc) - self.assertEqual( - '%s(doc-id, uid:1, \'{"key": "value"}\')' - % (doc.__class__.__name__,), - repr(doc)) - - def test__repr__conflicted(self): - doc = self.make_document('doc-id', 'uid:1', tests.simple_doc, - has_conflicts=True) - self.assertEqual( - '%s(doc-id, uid:1, conflicted, \'{"key": "value"}\')' - % (doc.__class__.__name__,), - repr(doc)) - - def test__lt__(self): - doc_a = self.make_document('a', 'b', '{}') - doc_b = self.make_document('b', 'b', '{}') - self.assertTrue(doc_a < doc_b) - self.assertTrue(doc_b > doc_a) - doc_aa = self.make_document('a', 'a', '{}') - self.assertTrue(doc_aa < doc_a) - - def test__eq__(self): - doc_a = self.make_document('a', 'b', '{}') - doc_b = self.make_document('a', 'b', '{}') - self.assertTrue(doc_a == doc_b) - doc_b = self.make_document('a', 'b', '{}', has_conflicts=True) - self.assertFalse(doc_a == doc_b) - - def test_non_json_dict(self): - self.assertRaises( - errors.InvalidJSON, self.make_document, 'id', 'uid:1', - '"not a json dictionary"') - - def test_non_json(self): - self.assertRaises( - errors.InvalidJSON, self.make_document, 'id', 'uid:1', - 'not a json dictionary') - - def test_get_size(self): - doc_a = self.make_document('a', 'b', '{"some": "content"}') - self.assertEqual( - len('a' + 'b' + '{"some": "content"}'), doc_a.get_size()) - - def test_get_size_empty_document(self): - doc_a = self.make_document('a', 'b', None) - self.assertEqual(len('a' + 'b'), doc_a.get_size()) - - -class TestPyDocument(tests.TestCase): - - scenarios = ([( - 'py', {'make_document_for_test': tests.make_document_for_test})]) - - def test_get_content(self): - doc = self.make_document('id', 'rev', '{"content":""}') - self.assertEqual({"content": ""}, doc.content) - doc.set_json('{"content": "new"}') - self.assertEqual({"content": "new"}, doc.content) - - def test_set_content(self): - doc = self.make_document('id', 'rev', '{"content":""}') - doc.content = {"content": "new"} - self.assertEqual('{"content": "new"}', doc.get_json()) - - def test_set_bad_content(self): - doc = self.make_document('id', 'rev', '{"content":""}') - self.assertRaises( - errors.InvalidContent, setattr, doc, 'content', - '{"content": "new"}') - - def test_is_tombstone(self): - doc_a = self.make_document('a', 'b', '{}') - self.assertFalse(doc_a.is_tombstone()) - doc_a.set_json(None) - self.assertTrue(doc_a.is_tombstone()) - - def test_make_tombstone(self): - doc_a = self.make_document('a', 'b', '{}') - self.assertFalse(doc_a.is_tombstone()) - doc_a.make_tombstone() - self.assertTrue(doc_a.is_tombstone()) - - def test_same_content_as(self): - doc_a = self.make_document('a', 'b', '{}') - doc_b = self.make_document('d', 'e', '{}') - self.assertTrue(doc_a.same_content_as(doc_b)) - doc_b = self.make_document('p', 'q', '{}', has_conflicts=True) - self.assertTrue(doc_a.same_content_as(doc_b)) - doc_b.content['key'] = 'value' - self.assertFalse(doc_a.same_content_as(doc_b)) - - def test_same_content_as_json_order(self): - doc_a = self.make_document( - 'a', 'b', '{"key1": "val1", "key2": "val2"}') - doc_b = self.make_document( - 'c', 'd', '{"key2": "val2", "key1": "val1"}') - self.assertTrue(doc_a.same_content_as(doc_b)) - - def test_set_json(self): - doc = self.make_document('id', 'rev', '{"content":""}') - doc.set_json('{"content": "new"}') - self.assertEqual('{"content": "new"}', doc.get_json()) - - def test_set_json_non_dict(self): - doc = self.make_document('id', 'rev', '{"content":""}') - self.assertRaises(errors.InvalidJSON, doc.set_json, '"is not a dict"') - - def test_set_json_error(self): - doc = self.make_document('id', 'rev', '{"content":""}') - self.assertRaises(errors.InvalidJSON, doc.set_json, 'is not json') - - -load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_errors.py b/src/leap/soledad/u1db/tests/test_errors.py deleted file mode 100644 index 0e089ede..00000000 --- a/src/leap/soledad/u1db/tests/test_errors.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests error infrastructure.""" - -from u1db import ( - errors, - tests, - ) - - -class TestError(tests.TestCase): - - def test_error_base(self): - err = errors.U1DBError() - self.assertEqual("error", err.wire_description) - self.assertIs(None, err.message) - - err = errors.U1DBError("Message.") - self.assertEqual("error", err.wire_description) - self.assertEqual("Message.", err.message) - - def test_HTTPError(self): - err = errors.HTTPError(500) - self.assertEqual(500, err.status) - self.assertIs(None, err.wire_description) - self.assertIs(None, err.message) - - err = errors.HTTPError(500, "Crash.") - self.assertEqual(500, err.status) - self.assertIs(None, err.wire_description) - self.assertEqual("Crash.", err.message) - - def test_HTTPError_str(self): - err = errors.HTTPError(500) - self.assertEqual("HTTPError(500)", str(err)) - - err = errors.HTTPError(500, "ERROR") - self.assertEqual("HTTPError(500, 'ERROR')", str(err)) - - def test_Unvailable(self): - err = errors.Unavailable() - self.assertEqual(503, err.status) - self.assertEqual("Unavailable()", str(err)) - - err = errors.Unavailable("DOWN") - self.assertEqual("DOWN", err.message) - self.assertEqual("Unavailable('DOWN')", str(err)) diff --git a/src/leap/soledad/u1db/tests/test_http_app.py b/src/leap/soledad/u1db/tests/test_http_app.py deleted file mode 100644 index 13522693..00000000 --- a/src/leap/soledad/u1db/tests/test_http_app.py +++ /dev/null @@ -1,1133 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test the WSGI app.""" - -import paste.fixture -import sys -try: - import simplejson as json -except ImportError: - import json # noqa -import StringIO - -from u1db import ( - __version__ as _u1db_version, - errors, - sync, - tests, - ) - -from u1db.remote import ( - http_app, - http_errors, - ) - - -class TestFencedReader(tests.TestCase): - - def test_init(self): - reader = http_app._FencedReader(StringIO.StringIO(""), 25, 100) - self.assertEqual(25, reader.remaining) - - def test_read_chunk(self): - inp = StringIO.StringIO("abcdef") - reader = http_app._FencedReader(inp, 5, 10) - data = reader.read_chunk(2) - self.assertEqual("ab", data) - self.assertEqual(2, inp.tell()) - self.assertEqual(3, reader.remaining) - - def test_read_chunk_remaining(self): - inp = StringIO.StringIO("abcdef") - reader = http_app._FencedReader(inp, 4, 10) - data = reader.read_chunk(9999) - self.assertEqual("abcd", data) - self.assertEqual(4, inp.tell()) - self.assertEqual(0, reader.remaining) - - def test_read_chunk_nothing_left(self): - inp = StringIO.StringIO("abc") - reader = http_app._FencedReader(inp, 2, 10) - reader.read_chunk(2) - self.assertEqual(2, inp.tell()) - self.assertEqual(0, reader.remaining) - data = reader.read_chunk(2) - self.assertEqual("", data) - self.assertEqual(2, inp.tell()) - self.assertEqual(0, reader.remaining) - - def test_read_chunk_kept(self): - inp = StringIO.StringIO("abcde") - reader = http_app._FencedReader(inp, 4, 10) - reader._kept = "xyz" - data = reader.read_chunk(2) # atmost ignored - self.assertEqual("xyz", data) - self.assertEqual(0, inp.tell()) - self.assertEqual(4, reader.remaining) - self.assertIsNone(reader._kept) - - def test_getline(self): - inp = StringIO.StringIO("abc\r\nde") - reader = http_app._FencedReader(inp, 6, 10) - reader.MAXCHUNK = 6 - line = reader.getline() - self.assertEqual("abc\r\n", line) - self.assertEqual("d", reader._kept) - - def test_getline_exact(self): - inp = StringIO.StringIO("abcd\r\nef") - reader = http_app._FencedReader(inp, 6, 10) - reader.MAXCHUNK = 6 - line = reader.getline() - self.assertEqual("abcd\r\n", line) - self.assertIs(None, reader._kept) - - def test_getline_no_newline(self): - inp = StringIO.StringIO("abcd") - reader = http_app._FencedReader(inp, 4, 10) - reader.MAXCHUNK = 6 - line = reader.getline() - self.assertEqual("abcd", line) - - def test_getline_many_chunks(self): - inp = StringIO.StringIO("abcde\r\nf") - reader = http_app._FencedReader(inp, 8, 10) - reader.MAXCHUNK = 4 - line = reader.getline() - self.assertEqual("abcde\r\n", line) - self.assertEqual("f", reader._kept) - line = reader.getline() - self.assertEqual("f", line) - - def test_getline_empty(self): - inp = StringIO.StringIO("") - reader = http_app._FencedReader(inp, 0, 10) - reader.MAXCHUNK = 4 - line = reader.getline() - self.assertEqual("", line) - line = reader.getline() - self.assertEqual("", line) - - def test_getline_just_newline(self): - inp = StringIO.StringIO("\r\n") - reader = http_app._FencedReader(inp, 2, 10) - reader.MAXCHUNK = 4 - line = reader.getline() - self.assertEqual("\r\n", line) - line = reader.getline() - self.assertEqual("", line) - - def test_getline_too_large(self): - inp = StringIO.StringIO("x" * 50) - reader = http_app._FencedReader(inp, 50, 25) - reader.MAXCHUNK = 4 - self.assertRaises(http_app.BadRequest, reader.getline) - - def test_getline_too_large_complete(self): - inp = StringIO.StringIO("x" * 25 + "\r\n") - reader = http_app._FencedReader(inp, 50, 25) - reader.MAXCHUNK = 4 - self.assertRaises(http_app.BadRequest, reader.getline) - - -class TestHTTPMethodDecorator(tests.TestCase): - - def test_args(self): - @http_app.http_method() - def f(self, a, b): - return self, a, b - res = f("self", {"a": "x", "b": "y"}, None) - self.assertEqual(("self", "x", "y"), res) - - def test_args_missing(self): - @http_app.http_method() - def f(self, a, b): - return a, b - self.assertRaises(http_app.BadRequest, f, "self", {"a": "x"}, None) - - def test_args_unexpected(self): - @http_app.http_method() - def f(self, a): - return a - self.assertRaises(http_app.BadRequest, f, "self", - {"a": "x", "c": "z"}, None) - - def test_args_default(self): - @http_app.http_method() - def f(self, a, b="z"): - return a, b - res = f("self", {"a": "x"}, None) - self.assertEqual(("x", "z"), res) - - def test_args_conversion(self): - @http_app.http_method(b=int) - def f(self, a, b): - return self, a, b - res = f("self", {"a": "x", "b": "2"}, None) - self.assertEqual(("self", "x", 2), res) - - self.assertRaises(http_app.BadRequest, f, "self", - {"a": "x", "b": "foo"}, None) - - def test_args_conversion_with_default(self): - @http_app.http_method(b=str) - def f(self, a, b=None): - return self, a, b - res = f("self", {"a": "x"}, None) - self.assertEqual(("self", "x", None), res) - - def test_args_content(self): - @http_app.http_method() - def f(self, a, content): - return a, content - res = f(self, {"a": "x"}, "CONTENT") - self.assertEqual(("x", "CONTENT"), res) - - def test_args_content_as_args(self): - @http_app.http_method(b=int, content_as_args=True) - def f(self, a, b): - return self, a, b - res = f("self", {"a": "x"}, '{"b": "2"}') - self.assertEqual(("self", "x", 2), res) - - self.assertRaises(http_app.BadRequest, f, "self", {}, 'not-json') - - def test_args_content_no_query(self): - @http_app.http_method(no_query=True, - content_as_args=True) - def f(self, a='a', b='b'): - return a, b - res = f("self", {}, '{"b": "y"}') - self.assertEqual(('a', 'y'), res) - - self.assertRaises(http_app.BadRequest, f, "self", {'a': 'x'}, - '{"b": "y"}') - - -class TestResource(object): - - @http_app.http_method() - def get(self, a, b): - self.args = dict(a=a, b=b) - return 'Get' - - @http_app.http_method() - def put(self, a, content): - self.args = dict(a=a) - self.content = content - return 'Put' - - @http_app.http_method(content_as_args=True) - def put_args(self, a, b): - self.args = dict(a=a, b=b) - self.order = ['a'] - self.entries = [] - - @http_app.http_method() - def put_stream_entry(self, content): - self.entries.append(content) - self.order.append('s') - - def put_end(self): - self.order.append('e') - return "Put/end" - - -class parameters: - max_request_size = 200000 - max_entry_size = 100000 - - -class TestHTTPInvocationByMethodWithBody(tests.TestCase): - - def test_get(self): - resource = TestResource() - environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'GET'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - res = invoke() - self.assertEqual('Get', res) - self.assertEqual({'a': '1', 'b': '2'}, resource.args) - - def test_put_json(self): - resource = TestResource() - body = '{"body": true}' - environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO(body), - 'CONTENT_LENGTH': str(len(body)), - 'CONTENT_TYPE': 'application/json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - res = invoke() - self.assertEqual('Put', res) - self.assertEqual({'a': '1'}, resource.args) - self.assertEqual('{"body": true}', resource.content) - - def test_put_sync_stream(self): - resource = TestResource() - body = ( - '[\r\n' - '{"b": 2},\r\n' # args - '{"entry": "x"},\r\n' # stream entry - '{"entry": "y"}\r\n' # stream entry - ']' - ) - environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO(body), - 'CONTENT_LENGTH': str(len(body)), - 'CONTENT_TYPE': 'application/x-u1db-sync-stream'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - res = invoke() - self.assertEqual('Put/end', res) - self.assertEqual({'a': '1', 'b': 2}, resource.args) - self.assertEqual( - ['{"entry": "x"}', '{"entry": "y"}'], resource.entries) - self.assertEqual(['a', 's', 's', 'e'], resource.order) - - def _put_sync_stream(self, body): - resource = TestResource() - environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO(body), - 'CONTENT_LENGTH': str(len(body)), - 'CONTENT_TYPE': 'application/x-u1db-sync-stream'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - invoke() - - def test_put_sync_stream_wrong_start(self): - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "{}\r\n]") - - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "\r\n{}\r\n]") - - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "") - - def test_put_sync_stream_wrong_end(self): - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "[\r\n{}") - - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "[\r\n") - - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "[\r\n{}\r\n]\r\n...") - - def test_put_sync_stream_missing_comma(self): - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "[\r\n{}\r\n{}\r\n]") - - def test_put_sync_stream_extra_comma(self): - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "[\r\n{},\r\n]") - - self.assertRaises(http_app.BadRequest, - self._put_sync_stream, "[\r\n{},\r\n{},\r\n]") - - def test_bad_request_decode_failure(self): - resource = TestResource() - environ = {'QUERY_STRING': 'a=\xff', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO('{}'), - 'CONTENT_LENGTH': '2', - 'CONTENT_TYPE': 'application/json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_unsupported_content_type(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO('{}'), - 'CONTENT_LENGTH': '2', - 'CONTENT_TYPE': 'text/plain'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_content_length_too_large(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO('{}'), - 'CONTENT_LENGTH': '10000', - 'CONTENT_TYPE': 'text/plain'} - - resource.max_request_size = 5000 - resource.max_entry_size = sys.maxint # we don't get to use this - - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_no_content_length(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO('a'), - 'CONTENT_TYPE': 'application/json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_invalid_content_length(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO('abc'), - 'CONTENT_LENGTH': '1unk', - 'CONTENT_TYPE': 'application/json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_empty_body(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO(''), - 'CONTENT_LENGTH': '0', - 'CONTENT_TYPE': 'application/json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_unsupported_method_get_like(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'DELETE'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_unsupported_method_put_like(self): - resource = TestResource() - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', - 'wsgi.input': StringIO.StringIO('{}'), - 'CONTENT_LENGTH': '2', - 'CONTENT_TYPE': 'application/json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - def test_bad_request_unsupported_method_put_like_multi_json(self): - resource = TestResource() - body = '{}\r\n{}\r\n' - environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'POST', - 'wsgi.input': StringIO.StringIO(body), - 'CONTENT_LENGTH': str(len(body)), - 'CONTENT_TYPE': 'application/x-u1db-multi-json'} - invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, - parameters) - self.assertRaises(http_app.BadRequest, invoke) - - -class TestHTTPResponder(tests.TestCase): - - def start_response(self, status, headers): - self.status = status - self.headers = dict(headers) - self.response_body = [] - - def write(data): - self.response_body.append(data) - - return write - - def test_send_response_content_w_headers(self): - responder = http_app.HTTPResponder(self.start_response) - responder.send_response_content('foo', headers={'x-a': '1'}) - self.assertEqual('200 OK', self.status) - self.assertEqual({'content-type': 'application/json', - 'cache-control': 'no-cache', - 'x-a': '1', 'content-length': '3'}, self.headers) - self.assertEqual([], self.response_body) - self.assertEqual(['foo'], responder.content) - - def test_send_response_json(self): - responder = http_app.HTTPResponder(self.start_response) - responder.send_response_json(value='success') - self.assertEqual('200 OK', self.status) - expected_body = '{"value": "success"}\r\n' - self.assertEqual({'content-type': 'application/json', - 'content-length': str(len(expected_body)), - 'cache-control': 'no-cache'}, self.headers) - self.assertEqual([], self.response_body) - self.assertEqual([expected_body], responder.content) - - def test_send_response_json_status_fail(self): - responder = http_app.HTTPResponder(self.start_response) - responder.send_response_json(400) - self.assertEqual('400 Bad Request', self.status) - expected_body = '{}\r\n' - self.assertEqual({'content-type': 'application/json', - 'content-length': str(len(expected_body)), - 'cache-control': 'no-cache'}, self.headers) - self.assertEqual([], self.response_body) - self.assertEqual([expected_body], responder.content) - - def test_start_finish_response_status_fail(self): - responder = http_app.HTTPResponder(self.start_response) - responder.start_response(404, {'error': 'not found'}) - responder.finish_response() - self.assertEqual('404 Not Found', self.status) - self.assertEqual({'content-type': 'application/json', - 'cache-control': 'no-cache'}, self.headers) - self.assertEqual(['{"error": "not found"}\r\n'], self.response_body) - self.assertEqual([], responder.content) - - def test_send_stream_entry(self): - responder = http_app.HTTPResponder(self.start_response) - responder.content_type = "application/x-u1db-multi-json" - responder.start_response(200) - responder.start_stream() - responder.stream_entry({'entry': 1}) - responder.stream_entry({'entry': 2}) - responder.end_stream() - responder.finish_response() - self.assertEqual('200 OK', self.status) - self.assertEqual({'content-type': 'application/x-u1db-multi-json', - 'cache-control': 'no-cache'}, self.headers) - self.assertEqual(['[', - '\r\n', '{"entry": 1}', - ',\r\n', '{"entry": 2}', - '\r\n]\r\n'], self.response_body) - self.assertEqual([], responder.content) - - def test_send_stream_w_error(self): - responder = http_app.HTTPResponder(self.start_response) - responder.content_type = "application/x-u1db-multi-json" - responder.start_response(200) - responder.start_stream() - responder.stream_entry({'entry': 1}) - responder.send_response_json(503, error="unavailable") - self.assertEqual('200 OK', self.status) - self.assertEqual({'content-type': 'application/x-u1db-multi-json', - 'cache-control': 'no-cache'}, self.headers) - self.assertEqual(['[', - '\r\n', '{"entry": 1}'], self.response_body) - self.assertEqual([',\r\n', '{"error": "unavailable"}\r\n'], - responder.content) - - -class TestHTTPApp(tests.TestCase): - - def setUp(self): - super(TestHTTPApp, self).setUp() - self.state = tests.ServerStateForTests() - self.http_app = http_app.HTTPApp(self.state) - self.app = paste.fixture.TestApp(self.http_app) - self.db0 = self.state._create_database('db0') - - def test_bad_request_broken(self): - resp = self.app.put('/db0/doc/doc1', params='{"x": 1}', - headers={'content-type': 'application/foo'}, - expect_errors=True) - self.assertEqual(400, resp.status) - - def test_bad_request_dispatch(self): - resp = self.app.put('/db0/foo/doc1', params='{"x": 1}', - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(400, resp.status) - - def test_version(self): - resp = self.app.get('/') - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({"version": _u1db_version}, json.loads(resp.body)) - - def test_create_database(self): - resp = self.app.put('/db1', params='{}', - headers={'content-type': 'application/json'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({'ok': True}, json.loads(resp.body)) - - resp = self.app.put('/db1', params='{}', - headers={'content-type': 'application/json'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({'ok': True}, json.loads(resp.body)) - - def test_delete_database(self): - resp = self.app.delete('/db0') - self.assertEqual(200, resp.status) - self.assertRaises(errors.DatabaseDoesNotExist, - self.state.check_database, 'db0') - - def test_get_database(self): - resp = self.app.get('/db0') - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({}, json.loads(resp.body)) - - def test_valid_database_names(self): - resp = self.app.get('/a-database', expect_errors=True) - self.assertEqual(404, resp.status) - - resp = self.app.get('/db1', expect_errors=True) - self.assertEqual(404, resp.status) - - resp = self.app.get('/0', expect_errors=True) - self.assertEqual(404, resp.status) - - resp = self.app.get('/0-0', expect_errors=True) - self.assertEqual(404, resp.status) - - resp = self.app.get('/org.future', expect_errors=True) - self.assertEqual(404, resp.status) - - def test_invalid_database_names(self): - resp = self.app.get('/.a', expect_errors=True) - self.assertEqual(400, resp.status) - - resp = self.app.get('/-a', expect_errors=True) - self.assertEqual(400, resp.status) - - resp = self.app.get('/_a', expect_errors=True) - self.assertEqual(400, resp.status) - - def test_put_doc_create(self): - resp = self.app.put('/db0/doc/doc1', params='{"x": 1}', - headers={'content-type': 'application/json'}) - doc = self.db0.get_doc('doc1') - self.assertEqual(201, resp.status) # created - self.assertEqual('{"x": 1}', doc.get_json()) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) - - def test_put_doc(self): - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev, - params='{"x": 2}', - headers={'content-type': 'application/json'}) - doc = self.db0.get_doc('doc1') - self.assertEqual(200, resp.status) - self.assertEqual('{"x": 2}', doc.get_json()) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) - - def test_put_doc_too_large(self): - self.http_app.max_request_size = 15000 - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev, - params='{"%s": 2}' % ('z' * 16000), - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(400, resp.status) - - def test_delete_doc(self): - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - resp = self.app.delete('/db0/doc/doc1?old_rev=%s' % doc.rev) - doc = self.db0.get_doc('doc1', include_deleted=True) - self.assertEqual(None, doc.content) - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) - - def test_get_doc(self): - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - resp = self.app.get('/db0/doc/%s' % doc.doc_id) - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual('{"x": 1}', resp.body) - self.assertEqual(doc.rev, resp.header('x-u1db-rev')) - self.assertEqual('false', resp.header('x-u1db-has-conflicts')) - - def test_get_doc_non_existing(self): - resp = self.app.get('/db0/doc/not-there', expect_errors=True) - self.assertEqual(404, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "document does not exist"}, json.loads(resp.body)) - self.assertEqual('', resp.header('x-u1db-rev')) - self.assertEqual('false', resp.header('x-u1db-has-conflicts')) - - def test_get_doc_deleted(self): - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - self.db0.delete_doc(doc) - resp = self.app.get('/db0/doc/doc1', expect_errors=True) - self.assertEqual(404, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": errors.DocumentDoesNotExist.wire_description}, - json.loads(resp.body)) - - def test_get_doc_deleted_explicit_exclude(self): - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - self.db0.delete_doc(doc) - resp = self.app.get( - '/db0/doc/doc1?include_deleted=false', expect_errors=True) - self.assertEqual(404, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": errors.DocumentDoesNotExist.wire_description}, - json.loads(resp.body)) - - def test_get_deleted_doc(self): - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - self.db0.delete_doc(doc) - resp = self.app.get( - '/db0/doc/doc1?include_deleted=true', expect_errors=True) - self.assertEqual(404, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": errors.DOCUMENT_DELETED}, json.loads(resp.body)) - self.assertEqual(doc.rev, resp.header('x-u1db-rev')) - self.assertEqual('false', resp.header('x-u1db-has-conflicts')) - - def test_get_doc_non_existing_dabase(self): - resp = self.app.get('/not-there/doc/doc1', expect_errors=True) - self.assertEqual(404, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "database does not exist"}, json.loads(resp.body)) - - def test_get_docs(self): - doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') - ids = ','.join([doc1.doc_id, doc2.doc_id]) - resp = self.app.get('/db0/docs?doc_ids=%s' % ids) - self.assertEqual(200, resp.status) - self.assertEqual( - 'application/json', resp.header('content-type')) - expected = [ - {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", - "has_conflicts": False}, - {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2", - "has_conflicts": False}] - self.assertEqual(expected, json.loads(resp.body)) - - def test_get_docs_missing_doc_ids(self): - resp = self.app.get('/db0/docs', expect_errors=True) - self.assertEqual(400, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "missing document ids"}, json.loads(resp.body)) - - def test_get_docs_empty_doc_ids(self): - resp = self.app.get('/db0/docs?doc_ids=', expect_errors=True) - self.assertEqual(400, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual( - {"error": "missing document ids"}, json.loads(resp.body)) - - def test_get_docs_percent(self): - doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc%1') - doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') - ids = ','.join([doc1.doc_id, doc2.doc_id]) - resp = self.app.get('/db0/docs?doc_ids=%s' % ids) - self.assertEqual(200, resp.status) - self.assertEqual( - 'application/json', resp.header('content-type')) - expected = [ - {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc%1", - "has_conflicts": False}, - {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2", - "has_conflicts": False}] - self.assertEqual(expected, json.loads(resp.body)) - - def test_get_docs_deleted(self): - doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') - self.db0.delete_doc(doc2) - ids = ','.join([doc1.doc_id, doc2.doc_id]) - resp = self.app.get('/db0/docs?doc_ids=%s' % ids) - self.assertEqual(200, resp.status) - self.assertEqual( - 'application/json', resp.header('content-type')) - expected = [ - {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", - "has_conflicts": False}] - self.assertEqual(expected, json.loads(resp.body)) - - def test_get_docs_include_deleted(self): - doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') - self.db0.delete_doc(doc2) - ids = ','.join([doc1.doc_id, doc2.doc_id]) - resp = self.app.get('/db0/docs?doc_ids=%s&include_deleted=true' % ids) - self.assertEqual(200, resp.status) - self.assertEqual( - 'application/json', resp.header('content-type')) - expected = [ - {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", - "has_conflicts": False}, - {"content": None, "doc_rev": "db0:2", "doc_id": "doc2", - "has_conflicts": False}] - self.assertEqual(expected, json.loads(resp.body)) - - def test_get_sync_info(self): - self.db0._set_replica_gen_and_trans_id('other-id', 1, 'T-transid') - resp = self.app.get('/db0/sync-from/other-id') - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual(dict(target_replica_uid='db0', - target_replica_generation=0, - target_replica_transaction_id='', - source_replica_uid='other-id', - source_replica_generation=1, - source_transaction_id='T-transid'), - json.loads(resp.body)) - - def test_record_sync_info(self): - resp = self.app.put('/db0/sync-from/other-id', - params='{"generation": 2, "transaction_id": "T-transid"}', - headers={'content-type': 'application/json'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({'ok': True}, json.loads(resp.body)) - self.assertEqual( - (2, 'T-transid'), - self.db0._get_replica_gen_and_trans_id('other-id')) - - def test_sync_exchange_send(self): - entries = { - 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': - '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'}, - 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content': - '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'} - } - - gens = [] - _do_set_replica_gen_and_trans_id = \ - self.db0._do_set_replica_gen_and_trans_id - - def set_sync_generation_witness(other_uid, other_gen, other_trans_id): - gens.append((other_uid, other_gen)) - _do_set_replica_gen_and_trans_id( - other_uid, other_gen, other_trans_id) - self.assertGetDoc(self.db0, entries[other_gen]['id'], - entries[other_gen]['rev'], - entries[other_gen]['content'], False) - - self.patch( - self.db0, '_do_set_replica_gen_and_trans_id', - set_sync_generation_witness) - - args = dict(last_known_generation=0) - body = ("[\r\n" + - "%s,\r\n" % json.dumps(args) + - "%s,\r\n" % json.dumps(entries[10]) + - "%s\r\n" % json.dumps(entries[11]) + - "]\r\n") - resp = self.app.post('/db0/sync-from/replica', - params=body, - headers={'content-type': - 'application/x-u1db-sync-stream'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/x-u1db-sync-stream', - resp.header('content-type')) - bits = resp.body.split('\r\n') - self.assertEqual('[', bits[0]) - last_trans_id = self.db0._get_transaction_log()[-1][1] - self.assertEqual({'new_generation': 2, - 'new_transaction_id': last_trans_id}, - json.loads(bits[1])) - self.assertEqual(']', bits[2]) - self.assertEqual('', bits[3]) - self.assertEqual([('replica', 10), ('replica', 11)], gens) - - def test_sync_exchange_send_ensure(self): - entries = { - 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': - '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'}, - 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content': - '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'} - } - - args = dict(last_known_generation=0, ensure=True) - body = ("[\r\n" + - "%s,\r\n" % json.dumps(args) + - "%s,\r\n" % json.dumps(entries[10]) + - "%s\r\n" % json.dumps(entries[11]) + - "]\r\n") - resp = self.app.post('/dbnew/sync-from/replica', - params=body, - headers={'content-type': - 'application/x-u1db-sync-stream'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/x-u1db-sync-stream', - resp.header('content-type')) - bits = resp.body.split('\r\n') - self.assertEqual('[', bits[0]) - dbnew = self.state.open_database("dbnew") - last_trans_id = dbnew._get_transaction_log()[-1][1] - self.assertEqual({'new_generation': 2, - 'new_transaction_id': last_trans_id, - 'replica_uid': dbnew._replica_uid}, - json.loads(bits[1])) - self.assertEqual(']', bits[2]) - self.assertEqual('', bits[3]) - - def test_sync_exchange_send_entry_too_large(self): - self.patch(http_app.SyncResource, 'max_request_size', 20000) - self.patch(http_app.SyncResource, 'max_entry_size', 10000) - entries = { - 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': - '{"value": "%s"}' % ('H' * 11000), 'gen': 10}, - } - args = dict(last_known_generation=0) - body = ("[\r\n" + - "%s,\r\n" % json.dumps(args) + - "%s\r\n" % json.dumps(entries[10]) + - "]\r\n") - resp = self.app.post('/db0/sync-from/replica', - params=body, - headers={'content-type': - 'application/x-u1db-sync-stream'}, - expect_errors=True) - self.assertEqual(400, resp.status) - - def test_sync_exchange_receive(self): - doc = self.db0.create_doc_from_json('{"value": "there"}') - doc2 = self.db0.create_doc_from_json('{"value": "there2"}') - args = dict(last_known_generation=0) - body = "[\r\n%s\r\n]" % json.dumps(args) - resp = self.app.post('/db0/sync-from/replica', - params=body, - headers={'content-type': - 'application/x-u1db-sync-stream'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/x-u1db-sync-stream', - resp.header('content-type')) - parts = resp.body.splitlines() - self.assertEqual(5, len(parts)) - self.assertEqual('[', parts[0]) - last_trans_id = self.db0._get_transaction_log()[-1][1] - self.assertEqual({'new_generation': 2, - 'new_transaction_id': last_trans_id}, - json.loads(parts[1].rstrip(","))) - part2 = json.loads(parts[2].rstrip(",")) - self.assertTrue(part2['trans_id'].startswith('T-')) - self.assertEqual('{"value": "there"}', part2['content']) - self.assertEqual(doc.rev, part2['rev']) - self.assertEqual(doc.doc_id, part2['id']) - self.assertEqual(1, part2['gen']) - part3 = json.loads(parts[3].rstrip(",")) - self.assertTrue(part3['trans_id'].startswith('T-')) - self.assertEqual('{"value": "there2"}', part3['content']) - self.assertEqual(doc2.rev, part3['rev']) - self.assertEqual(doc2.doc_id, part3['id']) - self.assertEqual(2, part3['gen']) - self.assertEqual(']', parts[4]) - - def test_sync_exchange_error_in_stream(self): - args = dict(last_known_generation=0) - body = "[\r\n%s\r\n]" % json.dumps(args) - - def boom(self, return_doc_cb): - raise errors.Unavailable - - self.patch(sync.SyncExchange, 'return_docs', - boom) - resp = self.app.post('/db0/sync-from/replica', - params=body, - headers={'content-type': - 'application/x-u1db-sync-stream'}) - self.assertEqual(200, resp.status) - self.assertEqual('application/x-u1db-sync-stream', - resp.header('content-type')) - parts = resp.body.splitlines() - self.assertEqual(3, len(parts)) - self.assertEqual('[', parts[0]) - self.assertEqual({'new_generation': 0, 'new_transaction_id': ''}, - json.loads(parts[1].rstrip(","))) - self.assertEqual({'error': 'unavailable'}, json.loads(parts[2])) - - -class TestRequestHooks(tests.TestCase): - - def setUp(self): - super(TestRequestHooks, self).setUp() - self.state = tests.ServerStateForTests() - self.http_app = http_app.HTTPApp(self.state) - self.app = paste.fixture.TestApp(self.http_app) - self.db0 = self.state._create_database('db0') - - def test_begin_and_done(self): - calls = [] - - def begin(environ): - self.assertTrue('PATH_INFO' in environ) - calls.append('begin') - - def done(environ): - self.assertTrue('PATH_INFO' in environ) - calls.append('done') - - self.http_app.request_begin = begin - self.http_app.request_done = done - - doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') - self.app.get('/db0/doc/%s' % doc.doc_id) - - self.assertEqual(['begin', 'done'], calls) - - def test_bad_request(self): - calls = [] - - def begin(environ): - self.assertTrue('PATH_INFO' in environ) - calls.append('begin') - - def bad_request(environ): - self.assertTrue('PATH_INFO' in environ) - calls.append('bad-request') - - self.http_app.request_begin = begin - self.http_app.request_bad_request = bad_request - # shouldn't be called - self.http_app.request_done = lambda env: 1 / 0 - - resp = self.app.put('/db0/foo/doc1', params='{"x": 1}', - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(400, resp.status) - self.assertEqual(['begin', 'bad-request'], calls) - - -class TestHTTPErrors(tests.TestCase): - - def test_wire_description_to_status(self): - self.assertNotIn("error", http_errors.wire_description_to_status) - - -class TestHTTPAppErrorHandling(tests.TestCase): - - def setUp(self): - super(TestHTTPAppErrorHandling, self).setUp() - self.exc = None - self.state = tests.ServerStateForTests() - - class ErroringResource(object): - - def post(_, args, content): - raise self.exc - - def lookup_resource(environ, responder): - return ErroringResource() - - self.http_app = http_app.HTTPApp(self.state) - self.http_app._lookup_resource = lookup_resource - self.app = paste.fixture.TestApp(self.http_app) - - def test_RevisionConflict_etc(self): - self.exc = errors.RevisionConflict() - resp = self.app.post('/req', params='{}', - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(409, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({"error": "revision conflict"}, - json.loads(resp.body)) - - def test_Unavailable(self): - self.exc = errors.Unavailable - resp = self.app.post('/req', params='{}', - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(503, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({"error": "unavailable"}, - json.loads(resp.body)) - - def test_generic_u1db_errors(self): - self.exc = errors.U1DBError() - resp = self.app.post('/req', params='{}', - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(500, resp.status) - self.assertEqual('application/json', resp.header('content-type')) - self.assertEqual({"error": "error"}, - json.loads(resp.body)) - - def test_generic_u1db_errors_hooks(self): - calls = [] - - def begin(environ): - self.assertTrue('PATH_INFO' in environ) - calls.append('begin') - - def u1db_error(environ, exc): - self.assertTrue('PATH_INFO' in environ) - calls.append(('error', exc)) - - self.http_app.request_begin = begin - self.http_app.request_u1db_error = u1db_error - # shouldn't be called - self.http_app.request_done = lambda env: 1 / 0 - - self.exc = errors.U1DBError() - resp = self.app.post('/req', params='{}', - headers={'content-type': 'application/json'}, - expect_errors=True) - self.assertEqual(500, resp.status) - self.assertEqual(['begin', ('error', self.exc)], calls) - - def test_failure(self): - class Failure(Exception): - pass - self.exc = Failure() - self.assertRaises(Failure, self.app.post, '/req', params='{}', - headers={'content-type': 'application/json'}) - - def test_failure_hooks(self): - class Failure(Exception): - pass - calls = [] - - def begin(environ): - calls.append('begin') - - def failed(environ): - self.assertTrue('PATH_INFO' in environ) - calls.append(('failed', sys.exc_info())) - - self.http_app.request_begin = begin - self.http_app.request_failed = failed - # shouldn't be called - self.http_app.request_done = lambda env: 1 / 0 - - self.exc = Failure() - self.assertRaises(Failure, self.app.post, '/req', params='{}', - headers={'content-type': 'application/json'}) - - self.assertEqual(2, len(calls)) - self.assertEqual('begin', calls[0]) - marker, (exc_type, exc, tb) = calls[1] - self.assertEqual('failed', marker) - self.assertEqual(self.exc, exc) - - -class TestPluggableSyncExchange(tests.TestCase): - - def setUp(self): - super(TestPluggableSyncExchange, self).setUp() - self.state = tests.ServerStateForTests() - self.state.ensure_database('foo') - - def test_plugging(self): - - class MySyncExchange(object): - def __init__(self, db, source_replica_uid, last_known_generation): - pass - - class MySyncResource(http_app.SyncResource): - sync_exchange_class = MySyncExchange - - sync_res = MySyncResource('foo', 'src', self.state, None) - sync_res.post_args( - {'last_known_generation': 0, 'last_known_trans_id': None}, '{}') - self.assertIsInstance(sync_res.sync_exch, MySyncExchange) diff --git a/src/leap/soledad/u1db/tests/test_http_client.py b/src/leap/soledad/u1db/tests/test_http_client.py deleted file mode 100644 index 115c8aaa..00000000 --- a/src/leap/soledad/u1db/tests/test_http_client.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests for HTTPDatabase""" - -from oauth import oauth -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db import ( - errors, - tests, - ) -from u1db.remote import ( - http_client, - ) - - -class TestEncoder(tests.TestCase): - - def test_encode_string(self): - self.assertEqual("foo", http_client._encode_query_parameter("foo")) - - def test_encode_true(self): - self.assertEqual("true", http_client._encode_query_parameter(True)) - - def test_encode_false(self): - self.assertEqual("false", http_client._encode_query_parameter(False)) - - -class TestHTTPClientBase(tests.TestCaseWithServer): - - def setUp(self): - super(TestHTTPClientBase, self).setUp() - self.errors = 0 - - def app(self, environ, start_response): - if environ['PATH_INFO'].endswith('echo'): - start_response("200 OK", [('Content-Type', 'application/json')]) - ret = {} - for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'): - ret[name] = environ[name] - if environ['REQUEST_METHOD'] in ('PUT', 'POST'): - ret['CONTENT_TYPE'] = environ['CONTENT_TYPE'] - content_length = int(environ['CONTENT_LENGTH']) - ret['body'] = environ['wsgi.input'].read(content_length) - return [json.dumps(ret)] - elif environ['PATH_INFO'].endswith('error_then_accept'): - if self.errors >= 3: - start_response( - "200 OK", [('Content-Type', 'application/json')]) - ret = {} - for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'): - ret[name] = environ[name] - if environ['REQUEST_METHOD'] in ('PUT', 'POST'): - ret['CONTENT_TYPE'] = environ['CONTENT_TYPE'] - content_length = int(environ['CONTENT_LENGTH']) - ret['body'] = '{"oki": "doki"}' - return [json.dumps(ret)] - self.errors += 1 - content_length = int(environ['CONTENT_LENGTH']) - error = json.loads( - environ['wsgi.input'].read(content_length)) - response = error['response'] - # In debug mode, wsgiref has an assertion that the status parameter - # is a 'str' object. However error['status'] returns a unicode - # object. - status = str(error['status']) - if isinstance(response, unicode): - response = str(response) - if isinstance(response, str): - start_response(status, [('Content-Type', 'text/plain')]) - return [str(response)] - else: - start_response(status, [('Content-Type', 'application/json')]) - return [json.dumps(response)] - elif environ['PATH_INFO'].endswith('error'): - self.errors += 1 - content_length = int(environ['CONTENT_LENGTH']) - error = json.loads( - environ['wsgi.input'].read(content_length)) - response = error['response'] - # In debug mode, wsgiref has an assertion that the status parameter - # is a 'str' object. However error['status'] returns a unicode - # object. - status = str(error['status']) - if isinstance(response, unicode): - response = str(response) - if isinstance(response, str): - start_response(status, [('Content-Type', 'text/plain')]) - return [str(response)] - else: - start_response(status, [('Content-Type', 'application/json')]) - return [json.dumps(response)] - elif '/oauth' in environ['PATH_INFO']: - base_url = self.getURL('').rstrip('/') - oauth_req = oauth.OAuthRequest.from_request( - http_method=environ['REQUEST_METHOD'], - http_url=base_url + environ['PATH_INFO'], - headers={'Authorization': environ['HTTP_AUTHORIZATION']}, - query_string=environ['QUERY_STRING'] - ) - oauth_server = oauth.OAuthServer(tests.testingOAuthStore) - oauth_server.add_signature_method(tests.sign_meth_HMAC_SHA1) - try: - consumer, token, params = oauth_server.verify_request( - oauth_req) - except oauth.OAuthError, e: - start_response("401 Unauthorized", - [('Content-Type', 'application/json')]) - return [json.dumps({"error": "unauthorized", - "message": e.message})] - start_response("200 OK", [('Content-Type', 'application/json')]) - return [json.dumps([environ['PATH_INFO'], token.key, params])] - - def make_app(self): - return self.app - - def getClient(self, **kwds): - self.startServer() - return http_client.HTTPClientBase(self.getURL('dbase'), **kwds) - - def test_construct(self): - self.startServer() - url = self.getURL() - cli = http_client.HTTPClientBase(url) - self.assertEqual(url, cli._url.geturl()) - self.assertIs(None, cli._conn) - - def test_parse_url(self): - cli = http_client.HTTPClientBase( - '%s://127.0.0.1:12345/' % self.url_scheme) - self.assertEqual(self.url_scheme, cli._url.scheme) - self.assertEqual('127.0.0.1', cli._url.hostname) - self.assertEqual(12345, cli._url.port) - self.assertEqual('/', cli._url.path) - - def test__ensure_connection(self): - cli = self.getClient() - self.assertIs(None, cli._conn) - cli._ensure_connection() - self.assertIsNot(None, cli._conn) - conn = cli._conn - cli._ensure_connection() - self.assertIs(conn, cli._conn) - - def test_close(self): - cli = self.getClient() - cli._ensure_connection() - cli.close() - self.assertIs(None, cli._conn) - - def test__request(self): - cli = self.getClient() - res, headers = cli._request('PUT', ['echo'], {}, {}) - self.assertEqual({'CONTENT_TYPE': 'application/json', - 'PATH_INFO': '/dbase/echo', - 'QUERY_STRING': '', - 'body': '{}', - 'REQUEST_METHOD': 'PUT'}, json.loads(res)) - - res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1}) - self.assertEqual({'PATH_INFO': '/dbase/doc/echo', - 'QUERY_STRING': 'a=1', - 'REQUEST_METHOD': 'GET'}, json.loads(res)) - - res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1}) - self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo', - 'QUERY_STRING': 'a=1', - 'REQUEST_METHOD': 'GET'}, json.loads(res)) - - res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body', - 'application/x-test') - self.assertEqual({'CONTENT_TYPE': 'application/x-test', - 'PATH_INFO': '/dbase/echo', - 'QUERY_STRING': 'b=2', - 'body': 'Body', - 'REQUEST_METHOD': 'POST'}, json.loads(res)) - - def test__request_json(self): - cli = self.getClient() - res, headers = cli._request_json( - 'POST', ['echo'], {'b': 2}, {'a': 'x'}) - self.assertEqual('application/json', headers['content-type']) - self.assertEqual({'CONTENT_TYPE': 'application/json', - 'PATH_INFO': '/dbase/echo', - 'QUERY_STRING': 'b=2', - 'body': '{"a": "x"}', - 'REQUEST_METHOD': 'POST'}, res) - - def test_unspecified_http_error(self): - cli = self.getClient() - self.assertRaises(errors.HTTPError, - cli._request_json, 'POST', ['error'], {}, - {'status': "500 Internal Error", - 'response': "Crash."}) - try: - cli._request_json('POST', ['error'], {}, - {'status': "500 Internal Error", - 'response': "Fail."}) - except errors.HTTPError, e: - pass - - self.assertEqual(500, e.status) - self.assertEqual("Fail.", e.message) - self.assertTrue("content-type" in e.headers) - - def test_revision_conflict(self): - cli = self.getClient() - self.assertRaises(errors.RevisionConflict, - cli._request_json, 'POST', ['error'], {}, - {'status': "409 Conflict", - 'response': {"error": "revision conflict"}}) - - def test_unavailable_proper(self): - cli = self.getClient() - cli._delays = (0, 0, 0, 0, 0) - self.assertRaises(errors.Unavailable, - cli._request_json, 'POST', ['error'], {}, - {'status': "503 Service Unavailable", - 'response': {"error": "unavailable"}}) - self.assertEqual(5, self.errors) - - def test_unavailable_then_available(self): - cli = self.getClient() - cli._delays = (0, 0, 0, 0, 0) - res, headers = cli._request_json( - 'POST', ['error_then_accept'], {'b': 2}, - {'status': "503 Service Unavailable", - 'response': {"error": "unavailable"}}) - self.assertEqual('application/json', headers['content-type']) - self.assertEqual({'CONTENT_TYPE': 'application/json', - 'PATH_INFO': '/dbase/error_then_accept', - 'QUERY_STRING': 'b=2', - 'body': '{"oki": "doki"}', - 'REQUEST_METHOD': 'POST'}, res) - self.assertEqual(3, self.errors) - - def test_unavailable_random_source(self): - cli = self.getClient() - cli._delays = (0, 0, 0, 0, 0) - try: - cli._request_json('POST', ['error'], {}, - {'status': "503 Service Unavailable", - 'response': "random unavailable."}) - except errors.Unavailable, e: - pass - - self.assertEqual(503, e.status) - self.assertEqual("random unavailable.", e.message) - self.assertTrue("content-type" in e.headers) - self.assertEqual(5, self.errors) - - def test_document_too_big(self): - cli = self.getClient() - self.assertRaises(errors.DocumentTooBig, - cli._request_json, 'POST', ['error'], {}, - {'status': "403 Forbidden", - 'response': {"error": "document too big"}}) - - def test_user_quota_exceeded(self): - cli = self.getClient() - self.assertRaises(errors.UserQuotaExceeded, - cli._request_json, 'POST', ['error'], {}, - {'status': "403 Forbidden", - 'response': {"error": "user quota exceeded"}}) - - def test_user_needs_subscription(self): - cli = self.getClient() - self.assertRaises(errors.SubscriptionNeeded, - cli._request_json, 'POST', ['error'], {}, - {'status': "403 Forbidden", - 'response': {"error": "user needs subscription"}}) - - def test_generic_u1db_error(self): - cli = self.getClient() - self.assertRaises(errors.U1DBError, - cli._request_json, 'POST', ['error'], {}, - {'status': "400 Bad Request", - 'response': {"error": "error"}}) - try: - cli._request_json('POST', ['error'], {}, - {'status': "400 Bad Request", - 'response': {"error": "error"}}) - except errors.U1DBError, e: - pass - self.assertIs(e.__class__, errors.U1DBError) - - def test_unspecified_bad_request(self): - cli = self.getClient() - self.assertRaises(errors.HTTPError, - cli._request_json, 'POST', ['error'], {}, - {'status': "400 Bad Request", - 'response': ""}) - try: - cli._request_json('POST', ['error'], {}, - {'status': "400 Bad Request", - 'response': ""}) - except errors.HTTPError, e: - pass - - self.assertEqual(400, e.status) - self.assertEqual("", e.message) - self.assertTrue("content-type" in e.headers) - - def test_oauth(self): - cli = self.getClient() - cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - params = {'x': u'\xf0', 'y': "foo"} - res, headers = cli._request('GET', ['doc', 'oauth'], params) - self.assertEqual( - ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res)) - - # oauth does its own internal quoting - params = {'x': u'\xf0', 'y': "foo"} - res, headers = cli._request('GET', ['doc', 'oauth', 'foo bar'], params) - self.assertEqual( - ['/dbase/doc/oauth/foo bar', tests.token1.key, params], - json.loads(res)) - - def test_oauth_ctr_creds(self): - cli = self.getClient(creds={'oauth': { - 'consumer_key': tests.consumer1.key, - 'consumer_secret': tests.consumer1.secret, - 'token_key': tests.token1.key, - 'token_secret': tests.token1.secret, - }}) - params = {'x': u'\xf0', 'y': "foo"} - res, headers = cli._request('GET', ['doc', 'oauth'], params) - self.assertEqual( - ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res)) - - def test_unknown_creds(self): - self.assertRaises(errors.UnknownAuthMethod, - self.getClient, creds={'foo': {}}) - self.assertRaises(errors.UnknownAuthMethod, - self.getClient, creds={}) - - def test_oauth_Unauthorized(self): - cli = self.getClient() - cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, "WRONG") - params = {'y': 'foo'} - self.assertRaises(errors.Unauthorized, cli._request, 'GET', - ['doc', 'oauth'], params) diff --git a/src/leap/soledad/u1db/tests/test_http_database.py b/src/leap/soledad/u1db/tests/test_http_database.py deleted file mode 100644 index c8e7eb76..00000000 --- a/src/leap/soledad/u1db/tests/test_http_database.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests for HTTPDatabase""" - -import inspect -try: - import simplejson as json -except ImportError: - import json # noqa - -from u1db import ( - errors, - Document, - tests, - ) -from u1db.remote import ( - http_database, - http_target, - ) -from u1db.tests.test_remote_sync_target import ( - make_http_app, -) - - -class TestHTTPDatabaseSimpleOperations(tests.TestCase): - - def setUp(self): - super(TestHTTPDatabaseSimpleOperations, self).setUp() - self.db = http_database.HTTPDatabase('dbase') - self.db._conn = object() # crash if used - self.got = None - self.response_val = None - - def _request(method, url_parts, params=None, body=None, - content_type=None): - self.got = method, url_parts, params, body, content_type - if isinstance(self.response_val, Exception): - raise self.response_val - return self.response_val - - def _request_json(method, url_parts, params=None, body=None, - content_type=None): - self.got = method, url_parts, params, body, content_type - if isinstance(self.response_val, Exception): - raise self.response_val - return self.response_val - - self.db._request = _request - self.db._request_json = _request_json - - def test__sanity_same_signature(self): - my_request_sig = inspect.getargspec(self.db._request) - my_request_sig = (['self'] + my_request_sig[0],) + my_request_sig[1:] - self.assertEqual(my_request_sig, - inspect.getargspec(http_database.HTTPDatabase._request)) - my_request_json_sig = inspect.getargspec(self.db._request_json) - my_request_json_sig = ((['self'] + my_request_json_sig[0],) + - my_request_json_sig[1:]) - self.assertEqual(my_request_json_sig, - inspect.getargspec(http_database.HTTPDatabase._request_json)) - - def test__ensure(self): - self.response_val = {'ok': True}, {} - self.db._ensure() - self.assertEqual(('PUT', [], {}, {}, None), self.got) - - def test__delete(self): - self.response_val = {'ok': True}, {} - self.db._delete() - self.assertEqual(('DELETE', [], {}, {}, None), self.got) - - def test__check(self): - self.response_val = {}, {} - res = self.db._check() - self.assertEqual({}, res) - self.assertEqual(('GET', [], None, None, None), self.got) - - def test_put_doc(self): - self.response_val = {'rev': 'doc-rev'}, {} - doc = Document('doc-id', None, '{"v": 1}') - res = self.db.put_doc(doc) - self.assertEqual('doc-rev', res) - self.assertEqual('doc-rev', doc.rev) - self.assertEqual(('PUT', ['doc', 'doc-id'], {}, - '{"v": 1}', 'application/json'), self.got) - - self.response_val = {'rev': 'doc-rev-2'}, {} - doc.content = {"v": 2} - res = self.db.put_doc(doc) - self.assertEqual('doc-rev-2', res) - self.assertEqual('doc-rev-2', doc.rev) - self.assertEqual(('PUT', ['doc', 'doc-id'], {'old_rev': 'doc-rev'}, - '{"v": 2}', 'application/json'), self.got) - - def test_get_doc(self): - self.response_val = '{"v": 2}', {'x-u1db-rev': 'doc-rev', - 'x-u1db-has-conflicts': 'false'} - self.assertGetDoc(self.db, 'doc-id', 'doc-rev', '{"v": 2}', False) - self.assertEqual( - ('GET', ['doc', 'doc-id'], {'include_deleted': False}, None, None), - self.got) - - def test_get_doc_non_existing(self): - self.response_val = errors.DocumentDoesNotExist() - self.assertIs(None, self.db.get_doc('not-there')) - self.assertEqual( - ('GET', ['doc', 'not-there'], {'include_deleted': False}, None, - None), self.got) - - def test_get_doc_deleted(self): - self.response_val = errors.DocumentDoesNotExist() - self.assertIs(None, self.db.get_doc('deleted')) - self.assertEqual( - ('GET', ['doc', 'deleted'], {'include_deleted': False}, None, - None), self.got) - - def test_get_doc_deleted_include_deleted(self): - self.response_val = errors.HTTPError(404, - json.dumps( - {"error": errors.DOCUMENT_DELETED} - ), - {'x-u1db-rev': 'doc-rev-gone', - 'x-u1db-has-conflicts': 'false'}) - doc = self.db.get_doc('deleted', include_deleted=True) - self.assertEqual('deleted', doc.doc_id) - self.assertEqual('doc-rev-gone', doc.rev) - self.assertIs(None, doc.content) - self.assertEqual( - ('GET', ['doc', 'deleted'], {'include_deleted': True}, None, None), - self.got) - - def test_get_doc_pass_through_errors(self): - self.response_val = errors.HTTPError(500, 'Crash.') - self.assertRaises(errors.HTTPError, - self.db.get_doc, 'something-something') - - def test_create_doc_with_id(self): - self.response_val = {'rev': 'doc-rev'}, {} - new_doc = self.db.create_doc_from_json('{"v": 1}', doc_id='doc-id') - self.assertEqual('doc-rev', new_doc.rev) - self.assertEqual('doc-id', new_doc.doc_id) - self.assertEqual('{"v": 1}', new_doc.get_json()) - self.assertEqual(('PUT', ['doc', 'doc-id'], {}, - '{"v": 1}', 'application/json'), self.got) - - def test_create_doc_without_id(self): - self.response_val = {'rev': 'doc-rev-2'}, {} - new_doc = self.db.create_doc_from_json('{"v": 3}') - self.assertEqual('D-', new_doc.doc_id[:2]) - self.assertEqual('doc-rev-2', new_doc.rev) - self.assertEqual('{"v": 3}', new_doc.get_json()) - self.assertEqual(('PUT', ['doc', new_doc.doc_id], {}, - '{"v": 3}', 'application/json'), self.got) - - def test_delete_doc(self): - self.response_val = {'rev': 'doc-rev-gone'}, {} - doc = Document('doc-id', 'doc-rev', None) - self.db.delete_doc(doc) - self.assertEqual('doc-rev-gone', doc.rev) - self.assertEqual(('DELETE', ['doc', 'doc-id'], {'old_rev': 'doc-rev'}, - None, None), self.got) - - def test_get_sync_target(self): - st = self.db.get_sync_target() - self.assertIsInstance(st, http_target.HTTPSyncTarget) - self.assertEqual(st._url, self.db._url) - - def test_get_sync_target_inherits_oauth_credentials(self): - self.db.set_oauth_credentials(tests.consumer1.key, - tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - st = self.db.get_sync_target() - self.assertEqual(self.db._creds, st._creds) - - -class TestHTTPDatabaseCtrWithCreds(tests.TestCase): - - def test_ctr_with_creds(self): - db1 = http_database.HTTPDatabase('http://dbs/db', creds={'oauth': { - 'consumer_key': tests.consumer1.key, - 'consumer_secret': tests.consumer1.secret, - 'token_key': tests.token1.key, - 'token_secret': tests.token1.secret - }}) - self.assertIn('oauth', db1._creds) - - -class TestHTTPDatabaseIntegration(tests.TestCaseWithServer): - - make_app_with_state = staticmethod(make_http_app) - - def setUp(self): - super(TestHTTPDatabaseIntegration, self).setUp() - self.startServer() - - def test_non_existing_db(self): - db = http_database.HTTPDatabase(self.getURL('not-there')) - self.assertRaises(errors.DatabaseDoesNotExist, db.get_doc, 'doc1') - - def test__ensure(self): - db = http_database.HTTPDatabase(self.getURL('new')) - db._ensure() - self.assertIs(None, db.get_doc('doc1')) - - def test__delete(self): - self.request_state._create_database('db0') - db = http_database.HTTPDatabase(self.getURL('db0')) - db._delete() - self.assertRaises(errors.DatabaseDoesNotExist, - self.request_state.check_database, 'db0') - - def test_open_database_existing(self): - self.request_state._create_database('db0') - db = http_database.HTTPDatabase.open_database(self.getURL('db0'), - create=False) - self.assertIs(None, db.get_doc('doc1')) - - def test_open_database_non_existing(self): - self.assertRaises(errors.DatabaseDoesNotExist, - http_database.HTTPDatabase.open_database, - self.getURL('not-there'), - create=False) - - def test_open_database_create(self): - db = http_database.HTTPDatabase.open_database(self.getURL('new'), - create=True) - self.assertIs(None, db.get_doc('doc1')) - - def test_delete_database_existing(self): - self.request_state._create_database('db0') - http_database.HTTPDatabase.delete_database(self.getURL('db0')) - self.assertRaises(errors.DatabaseDoesNotExist, - self.request_state.check_database, 'db0') - - def test_doc_ids_needing_quoting(self): - db0 = self.request_state._create_database('db0') - db = http_database.HTTPDatabase.open_database(self.getURL('db0'), - create=False) - doc = Document('%fff', None, '{}') - db.put_doc(doc) - self.assertGetDoc(db0, '%fff', doc.rev, '{}', False) - self.assertGetDoc(db, '%fff', doc.rev, '{}', False) diff --git a/src/leap/soledad/u1db/tests/test_https.py b/src/leap/soledad/u1db/tests/test_https.py deleted file mode 100644 index 67681c8a..00000000 --- a/src/leap/soledad/u1db/tests/test_https.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Test support for client-side https support.""" - -import os -import ssl -import sys - -from paste import httpserver - -from u1db import ( - tests, - ) -from u1db.remote import ( - http_client, - http_target, - ) - -from u1db.tests.test_remote_sync_target import ( - make_oauth_http_app, - ) - - -def https_server_def(): - def make_server(host_port, application): - from OpenSSL import SSL - cert_file = os.path.join(os.path.dirname(__file__), 'testing-certs', - 'testing.cert') - key_file = os.path.join(os.path.dirname(__file__), 'testing-certs', - 'testing.key') - ssl_context = SSL.Context(SSL.SSLv23_METHOD) - ssl_context.use_privatekey_file(key_file) - ssl_context.use_certificate_chain_file(cert_file) - srv = httpserver.WSGIServerBase(application, host_port, - httpserver.WSGIHandler, - ssl_context=ssl_context - ) - - def shutdown_request(req): - req.shutdown() - srv.close_request(req) - - srv.shutdown_request = shutdown_request - application.base_url = "https://localhost:%s" % srv.server_address[1] - return srv - return make_server, "shutdown", "https" - - -def oauth_https_sync_target(test, host, path): - _, port = test.server.server_address - st = http_target.HTTPSyncTarget('https://%s:%d/~/%s' % (host, port, path)) - st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - return st - - -class TestHttpSyncTargetHttpsSupport(tests.TestCaseWithServer): - - scenarios = [ - ('oauth_https', {'server_def': https_server_def, - 'make_app_with_state': make_oauth_http_app, - 'make_document_for_test': tests.make_document_for_test, - 'sync_target': oauth_https_sync_target - }), - ] - - def setUp(self): - try: - import OpenSSL # noqa - except ImportError: - self.skipTest("Requires pyOpenSSL") - self.cacert_pem = os.path.join(os.path.dirname(__file__), - 'testing-certs', 'cacert.pem') - super(TestHttpSyncTargetHttpsSupport, self).setUp() - - def getSyncTarget(self, host, path=None): - if self.server is None: - self.startServer() - return self.sync_target(self, host, path) - - def test_working(self): - self.startServer() - db = self.request_state._create_database('test') - self.patch(http_client, 'CA_CERTS', self.cacert_pem) - remote_target = self.getSyncTarget('localhost', 'test') - remote_target.record_sync_info('other-id', 2, 'T-id') - self.assertEqual( - (2, 'T-id'), db._get_replica_gen_and_trans_id('other-id')) - - def test_cannot_verify_cert(self): - if not sys.platform.startswith('linux'): - self.skipTest( - "XXX certificate verification happens on linux only for now") - self.startServer() - # don't print expected traceback server-side - self.server.handle_error = lambda req, cli_addr: None - self.request_state._create_database('test') - remote_target = self.getSyncTarget('localhost', 'test') - try: - remote_target.record_sync_info('other-id', 2, 'T-id') - except ssl.SSLError, e: - self.assertIn("certificate verify failed", str(e)) - else: - self.fail("certificate verification should have failed.") - - def test_host_mismatch(self): - if not sys.platform.startswith('linux'): - self.skipTest( - "XXX certificate verification happens on linux only for now") - self.startServer() - self.request_state._create_database('test') - self.patch(http_client, 'CA_CERTS', self.cacert_pem) - remote_target = self.getSyncTarget('127.0.0.1', 'test') - self.assertRaises( - http_client.CertificateError, remote_target.record_sync_info, - 'other-id', 2, 'T-id') - - -load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_inmemory.py b/src/leap/soledad/u1db/tests/test_inmemory.py deleted file mode 100644 index 255a1e08..00000000 --- a/src/leap/soledad/u1db/tests/test_inmemory.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test in-memory backend internals.""" - -from u1db import ( - errors, - tests, - ) -from u1db.backends import inmemory - - -simple_doc = '{"key": "value"}' - - -class TestInMemoryDatabaseInternals(tests.TestCase): - - def setUp(self): - super(TestInMemoryDatabaseInternals, self).setUp() - self.db = inmemory.InMemoryDatabase('test') - - def test__allocate_doc_rev_from_None(self): - self.assertEqual('test:1', self.db._allocate_doc_rev(None)) - - def test__allocate_doc_rev_incremental(self): - self.assertEqual('test:2', self.db._allocate_doc_rev('test:1')) - - def test__allocate_doc_rev_other(self): - self.assertEqual('replica:1|test:1', - self.db._allocate_doc_rev('replica:1')) - - def test__get_replica_uid(self): - self.assertEqual('test', self.db._replica_uid) - - -class TestInMemoryIndex(tests.TestCase): - - def test_has_name_and_definition(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - self.assertEqual('idx-name', idx._name) - self.assertEqual(['key'], idx._definition) - - def test_evaluate_json(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - self.assertEqual(['value'], idx.evaluate_json(simple_doc)) - - def test_evaluate_json_field_None(self): - idx = inmemory.InMemoryIndex('idx-name', ['missing']) - self.assertEqual([], idx.evaluate_json(simple_doc)) - - def test_evaluate_json_subfield_None(self): - idx = inmemory.InMemoryIndex('idx-name', ['key', 'missing']) - self.assertEqual([], idx.evaluate_json(simple_doc)) - - def test_evaluate_multi_index(self): - doc = '{"key": "value", "key2": "value2"}' - idx = inmemory.InMemoryIndex('idx-name', ['key', 'key2']) - self.assertEqual(['value\x01value2'], - idx.evaluate_json(doc)) - - def test_update_ignores_None(self): - idx = inmemory.InMemoryIndex('idx-name', ['nokey']) - idx.add_json('doc-id', simple_doc) - self.assertEqual({}, idx._values) - - def test_update_adds_entry(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - idx.add_json('doc-id', simple_doc) - self.assertEqual({'value': ['doc-id']}, idx._values) - - def test_remove_json(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - idx.add_json('doc-id', simple_doc) - self.assertEqual({'value': ['doc-id']}, idx._values) - idx.remove_json('doc-id', simple_doc) - self.assertEqual({}, idx._values) - - def test_remove_json_multiple(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - idx.add_json('doc-id', simple_doc) - idx.add_json('doc2-id', simple_doc) - self.assertEqual({'value': ['doc-id', 'doc2-id']}, idx._values) - idx.remove_json('doc-id', simple_doc) - self.assertEqual({'value': ['doc2-id']}, idx._values) - - def test_keys(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - idx.add_json('doc-id', simple_doc) - self.assertEqual(['value'], idx.keys()) - - def test_lookup(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - idx.add_json('doc-id', simple_doc) - self.assertEqual(['doc-id'], idx.lookup(['value'])) - - def test_lookup_multi(self): - idx = inmemory.InMemoryIndex('idx-name', ['key']) - idx.add_json('doc-id', simple_doc) - idx.add_json('doc2-id', simple_doc) - self.assertEqual(['doc-id', 'doc2-id'], idx.lookup(['value'])) - - def test__find_non_wildcards(self): - idx = inmemory.InMemoryIndex('idx-name', ['k1', 'k2', 'k3']) - self.assertEqual(-1, idx._find_non_wildcards(('a', 'b', 'c'))) - self.assertEqual(2, idx._find_non_wildcards(('a', 'b', '*'))) - self.assertEqual(3, idx._find_non_wildcards(('a', 'b', 'c*'))) - self.assertEqual(2, idx._find_non_wildcards(('a', 'b*', '*'))) - self.assertEqual(0, idx._find_non_wildcards(('*', '*', '*'))) - self.assertEqual(1, idx._find_non_wildcards(('a*', '*', '*'))) - self.assertRaises(errors.InvalidValueForIndex, - idx._find_non_wildcards, ('a', 'b')) - self.assertRaises(errors.InvalidValueForIndex, - idx._find_non_wildcards, ('a', 'b', 'c', 'd')) - self.assertRaises(errors.InvalidGlobbing, - idx._find_non_wildcards, ('*', 'b', 'c')) diff --git a/src/leap/soledad/u1db/tests/test_open.py b/src/leap/soledad/u1db/tests/test_open.py deleted file mode 100644 index fbeb0cfd..00000000 --- a/src/leap/soledad/u1db/tests/test_open.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test u1db.open""" - -import os - -from u1db import ( - errors, - open as u1db_open, - tests, - ) -from u1db.backends import sqlite_backend -from u1db.tests.test_backends import TestAlternativeDocument - - -class TestU1DBOpen(tests.TestCase): - - def setUp(self): - super(TestU1DBOpen, self).setUp() - tmpdir = self.createTempDir() - self.db_path = tmpdir + '/test.db' - - def test_open_no_create(self): - self.assertRaises(errors.DatabaseDoesNotExist, - u1db_open, self.db_path, create=False) - self.assertFalse(os.path.exists(self.db_path)) - - def test_open_create(self): - db = u1db_open(self.db_path, create=True) - self.addCleanup(db.close) - self.assertTrue(os.path.exists(self.db_path)) - self.assertIsInstance(db, sqlite_backend.SQLiteDatabase) - - def test_open_with_factory(self): - db = u1db_open(self.db_path, create=True, - document_factory=TestAlternativeDocument) - self.addCleanup(db.close) - self.assertEqual(TestAlternativeDocument, db._factory) - - def test_open_existing(self): - db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) - self.addCleanup(db.close) - doc = db.create_doc_from_json(tests.simple_doc) - # Even though create=True, we shouldn't wipe the db - db2 = u1db_open(self.db_path, create=True) - self.addCleanup(db2.close) - doc2 = db2.get_doc(doc.doc_id) - self.assertEqual(doc, doc2) - - def test_open_existing_no_create(self): - db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) - self.addCleanup(db.close) - db2 = u1db_open(self.db_path, create=False) - self.addCleanup(db2.close) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/u1db/tests/test_query_parser.py b/src/leap/soledad/u1db/tests/test_query_parser.py deleted file mode 100644 index ee374267..00000000 --- a/src/leap/soledad/u1db/tests/test_query_parser.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -from u1db import ( - errors, - query_parser, - tests, - ) - - -trivial_raw_doc = {} - - -class TestFieldName(tests.TestCase): - - def test_check_fieldname_valid(self): - self.assertIsNone(query_parser.check_fieldname("foo")) - - def test_check_fieldname_invalid(self): - self.assertRaises( - errors.IndexDefinitionParseError, query_parser.check_fieldname, - "foo.") - - -class TestMakeTree(tests.TestCase): - - def setUp(self): - super(TestMakeTree, self).setUp() - self.parser = query_parser.Parser() - - def assertParseError(self, definition): - self.assertRaises( - errors.IndexDefinitionParseError, self.parser.parse, - definition) - - def test_single_field(self): - self.assertIsInstance( - self.parser.parse('f'), query_parser.ExtractField) - - def test_single_mapping(self): - self.assertIsInstance( - self.parser.parse('bool(field1)'), query_parser.Bool) - - def test_nested_mapping(self): - self.assertIsInstance( - self.parser.parse('lower(split_words(field1))'), - query_parser.Lower) - - def test_nested_branching_mapping(self): - self.assertIsInstance( - self.parser.parse( - 'combine(lower(field1), split_words(field2), ' - 'number(field3, 5))'), query_parser.Combine) - - def test_single_mapping_multiple_fields(self): - self.assertIsInstance( - self.parser.parse('number(field1, 5)'), query_parser.Number) - - def test_unknown_mapping(self): - self.assertParseError('mapping(whatever)') - - def test_parse_missing_close_paren(self): - self.assertParseError("lower(a") - - def test_parse_trailing_chars(self): - self.assertParseError("lower(ab))") - - def test_parse_empty_op(self): - self.assertParseError("(ab)") - - def test_parse_top_level_commas(self): - self.assertParseError("a, b") - - def test_invalid_field_name(self): - self.assertParseError("a.") - - def test_invalid_inner_field_name(self): - self.assertParseError("lower(a.)") - - def test_gobbledigook(self): - self.assertParseError("(@#@cc @#!*DFJSXV(()jccd") - - def test_leading_space(self): - self.assertIsInstance( - self.parser.parse(" lower(a)"), query_parser.Lower) - - def test_trailing_space(self): - self.assertIsInstance( - self.parser.parse("lower(a) "), query_parser.Lower) - - def test_spaces_before_open_paren(self): - self.assertIsInstance( - self.parser.parse("lower (a)"), query_parser.Lower) - - def test_spaces_after_open_paren(self): - self.assertIsInstance( - self.parser.parse("lower( a)"), query_parser.Lower) - - def test_spaces_before_close_paren(self): - self.assertIsInstance( - self.parser.parse("lower(a )"), query_parser.Lower) - - def test_spaces_before_comma(self): - self.assertIsInstance( - self.parser.parse("number(a , 5)"), query_parser.Number) - - def test_spaces_after_comma(self): - self.assertIsInstance( - self.parser.parse("number(a, 5)"), query_parser.Number) - - -class TestStaticGetter(tests.TestCase): - - def test_returns_string(self): - getter = query_parser.StaticGetter('foo') - self.assertEqual(['foo'], getter.get(trivial_raw_doc)) - - def test_returns_int(self): - getter = query_parser.StaticGetter(9) - self.assertEqual([9], getter.get(trivial_raw_doc)) - - def test_returns_float(self): - getter = query_parser.StaticGetter(9.2) - self.assertEqual([9.2], getter.get(trivial_raw_doc)) - - def test_returns_None(self): - getter = query_parser.StaticGetter(None) - self.assertEqual([], getter.get(trivial_raw_doc)) - - def test_returns_list(self): - getter = query_parser.StaticGetter(['a', 'b']) - self.assertEqual(['a', 'b'], getter.get(trivial_raw_doc)) - - -class TestExtractField(tests.TestCase): - - def assertExtractField(self, expected, field_name, raw_doc): - getter = query_parser.ExtractField(field_name) - self.assertEqual(expected, getter.get(raw_doc)) - - def test_get_value(self): - self.assertExtractField(['bar'], 'foo', {'foo': 'bar'}) - - def test_get_value_None(self): - self.assertExtractField([], 'foo', {'foo': None}) - - def test_get_value_missing_key(self): - self.assertExtractField([], 'foo', {}) - - def test_get_value_subfield(self): - self.assertExtractField(['bar'], 'foo.baz', {'foo': {'baz': 'bar'}}) - - def test_get_value_subfield_missing(self): - self.assertExtractField([], 'foo.baz', {'foo': 'bar'}) - - def test_get_value_dict(self): - self.assertExtractField([], 'foo', {'foo': {'baz': 'bar'}}) - - def test_get_value_list(self): - self.assertExtractField(['bar', 'zap'], 'foo', {'foo': ['bar', 'zap']}) - - def test_get_value_mixed_list(self): - self.assertExtractField(['bar', 'zap'], 'foo', - {'foo': ['bar', ['baa'], 'zap', {'bing': 9}]}) - - def test_get_value_list_of_dicts(self): - self.assertExtractField([], 'foo', {'foo': [{'zap': 'bar'}]}) - - def test_get_value_list_of_dicts2(self): - self.assertExtractField( - ['bar', 'baz'], 'foo.zap', - {'foo': [{'zap': 'bar'}, {'zap': 'baz'}]}) - - def test_get_value_int(self): - self.assertExtractField([9], 'foo', {'foo': 9}) - - def test_get_value_float(self): - self.assertExtractField([9.2], 'foo', {'foo': 9.2}) - - def test_get_value_bool(self): - self.assertExtractField([True], 'foo', {'foo': True}) - self.assertExtractField([False], 'foo', {'foo': False}) - - -class TestLower(tests.TestCase): - - def assertLowerGets(self, expected, input_val): - getter = query_parser.Lower(query_parser.StaticGetter(input_val)) - out_val = getter.get(trivial_raw_doc) - self.assertEqual(sorted(expected), sorted(out_val)) - - def test_inner_returns_None(self): - self.assertLowerGets([], None) - - def test_inner_returns_string(self): - self.assertLowerGets(['foo'], 'fOo') - - def test_inner_returns_list(self): - self.assertLowerGets(['foo', 'bar'], ['fOo', 'bAr']) - - def test_inner_returns_int(self): - self.assertLowerGets([], 9) - - def test_inner_returns_float(self): - self.assertLowerGets([], 9.0) - - def test_inner_returns_bool(self): - self.assertLowerGets([], True) - - def test_inner_returns_list_containing_int(self): - self.assertLowerGets(['foo', 'bar'], ['fOo', 9, 'bAr']) - - def test_inner_returns_list_containing_float(self): - self.assertLowerGets(['foo', 'bar'], ['fOo', 9.2, 'bAr']) - - def test_inner_returns_list_containing_bool(self): - self.assertLowerGets(['foo', 'bar'], ['fOo', True, 'bAr']) - - def test_inner_returns_list_containing_list(self): - # TODO: Should this be unfolding the inner list? - self.assertLowerGets(['foo', 'bar'], ['fOo', ['bAa'], 'bAr']) - - def test_inner_returns_list_containing_dict(self): - self.assertLowerGets(['foo', 'bar'], ['fOo', {'baa': 'xam'}, 'bAr']) - - -class TestSplitWords(tests.TestCase): - - def assertSplitWords(self, expected, value): - getter = query_parser.SplitWords(query_parser.StaticGetter(value)) - self.assertEqual(sorted(expected), sorted(getter.get(trivial_raw_doc))) - - def test_inner_returns_None(self): - self.assertSplitWords([], None) - - def test_inner_returns_string(self): - self.assertSplitWords(['foo', 'bar'], 'foo bar') - - def test_inner_returns_list(self): - self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], - ['foo baz', 'bar sux']) - - def test_deduplicates(self): - self.assertSplitWords(['bar'], ['bar', 'bar', 'bar']) - - def test_inner_returns_int(self): - self.assertSplitWords([], 9) - - def test_inner_returns_float(self): - self.assertSplitWords([], 9.2) - - def test_inner_returns_bool(self): - self.assertSplitWords([], True) - - def test_inner_returns_list_containing_int(self): - self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], - ['foo baz', 9, 'bar sux']) - - def test_inner_returns_list_containing_float(self): - self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], - ['foo baz', 9.2, 'bar sux']) - - def test_inner_returns_list_containing_bool(self): - self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], - ['foo baz', True, 'bar sux']) - - def test_inner_returns_list_containing_list(self): - # TODO: Expand sub-lists? - self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], - ['foo baz', ['baa'], 'bar sux']) - - def test_inner_returns_list_containing_dict(self): - self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], - ['foo baz', {'baa': 'xam'}, 'bar sux']) - - -class TestNumber(tests.TestCase): - - def assertNumber(self, expected, value, padding=5): - """Assert number transformation produced expected values.""" - getter = query_parser.Number(query_parser.StaticGetter(value), padding) - self.assertEqual(expected, getter.get(trivial_raw_doc)) - - def test_inner_returns_None(self): - """None is thrown away.""" - self.assertNumber([], None) - - def test_inner_returns_int(self): - """A single integer is converted to zero padded strings.""" - self.assertNumber(['00009'], 9) - - def test_inner_returns_list(self): - """Integers are converted to zero padded strings.""" - self.assertNumber(['00009', '00235'], [9, 235]) - - def test_inner_returns_string(self): - """A string is thrown away.""" - self.assertNumber([], 'foo bar') - - def test_inner_returns_float(self): - """A float is thrown away.""" - self.assertNumber([], 9.2) - - def test_inner_returns_bool(self): - """A boolean is thrown away.""" - self.assertNumber([], True) - - def test_inner_returns_list_containing_strings(self): - """Strings in a list are thrown away.""" - self.assertNumber(['00009'], ['foo baz', 9, 'bar sux']) - - def test_inner_returns_list_containing_float(self): - """Floats in a list are thrown away.""" - self.assertNumber( - ['00083', '00073'], [83, 9.2, 73]) - - def test_inner_returns_list_containing_bool(self): - """Booleans in a list are thrown away.""" - self.assertNumber( - ['00083', '00073'], [83, True, 73]) - - def test_inner_returns_list_containing_list(self): - """Lists in a list are thrown away.""" - # TODO: Expand sub-lists? - self.assertNumber( - ['00012', '03333'], [12, [29], 3333]) - - def test_inner_returns_list_containing_dict(self): - """Dicts in a list are thrown away.""" - self.assertNumber( - ['00012', '00001'], [12, {54: 89}, 1]) - - -class TestIsNull(tests.TestCase): - - def assertIsNull(self, value): - getter = query_parser.IsNull(query_parser.StaticGetter(value)) - self.assertEqual([True], getter.get(trivial_raw_doc)) - - def assertIsNotNull(self, value): - getter = query_parser.IsNull(query_parser.StaticGetter(value)) - self.assertEqual([False], getter.get(trivial_raw_doc)) - - def test_inner_returns_None(self): - self.assertIsNull(None) - - def test_inner_returns_string(self): - self.assertIsNotNull('foo') - - def test_inner_returns_list(self): - self.assertIsNotNull(['foo', 'bar']) - - def test_inner_returns_empty_list(self): - # TODO: is this the behavior we want? - self.assertIsNull([]) - - def test_inner_returns_int(self): - self.assertIsNotNull(9) - - def test_inner_returns_float(self): - self.assertIsNotNull(9.2) - - def test_inner_returns_bool(self): - self.assertIsNotNull(True) - - # TODO: What about a dict? Inner is likely to return None, even though the - # attribute does exist... - - -class TestParser(tests.TestCase): - - def parse(self, spec): - parser = query_parser.Parser() - return parser.parse(spec) - - def parse_all(self, specs): - parser = query_parser.Parser() - return parser.parse_all(specs) - - def assertParseError(self, definition): - self.assertRaises(errors.IndexDefinitionParseError, self.parse, - definition) - - def test_parse_empty_string(self): - self.assertRaises(errors.IndexDefinitionParseError, self.parse, "") - - def test_parse_field(self): - getter = self.parse("a") - self.assertIsInstance(getter, query_parser.ExtractField) - self.assertEqual(["a"], getter.field) - - def test_parse_dotted_field(self): - getter = self.parse("a.b") - self.assertIsInstance(getter, query_parser.ExtractField) - self.assertEqual(["a", "b"], getter.field) - - def test_parse_dotted_field_nothing_after_dot(self): - self.assertParseError("a.") - - def test_parse_missing_close_on_transformation(self): - self.assertParseError("lower(a") - - def test_parse_missing_field_in_transformation(self): - self.assertParseError("lower()") - - def test_parse_trailing_chars(self): - self.assertParseError("lower(ab))") - - def test_parse_empty_op(self): - self.assertParseError("(ab)") - - def test_parse_unknown_op(self): - self.assertParseError("no_such_operation(field)") - - def test_parse_wrong_arg_type(self): - self.assertParseError("number(field, fnord)") - - def test_parse_transformation(self): - getter = self.parse("lower(a)") - self.assertIsInstance(getter, query_parser.Lower) - self.assertIsInstance(getter.inner, query_parser.ExtractField) - self.assertEqual(["a"], getter.inner.field) - - def test_parse_all(self): - getters = self.parse_all(["a", "b"]) - self.assertEqual(2, len(getters)) - self.assertIsInstance(getters[0], query_parser.ExtractField) - self.assertEqual(["a"], getters[0].field) - self.assertIsInstance(getters[1], query_parser.ExtractField) - self.assertEqual(["b"], getters[1].field) diff --git a/src/leap/soledad/u1db/tests/test_remote_sync_target.py b/src/leap/soledad/u1db/tests/test_remote_sync_target.py deleted file mode 100644 index 3e0d8995..00000000 --- a/src/leap/soledad/u1db/tests/test_remote_sync_target.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests for the remote sync targets""" - -import cStringIO - -from u1db import ( - errors, - tests, - ) -from u1db.remote import ( - http_app, - http_target, - oauth_middleware, - ) - - -class TestHTTPSyncTargetBasics(tests.TestCase): - - def test_parse_url(self): - remote_target = http_target.HTTPSyncTarget('http://127.0.0.1:12345/') - self.assertEqual('http', remote_target._url.scheme) - self.assertEqual('127.0.0.1', remote_target._url.hostname) - self.assertEqual(12345, remote_target._url.port) - self.assertEqual('/', remote_target._url.path) - - -class TestParsingSyncStream(tests.TestCase): - - def test_wrong_start(self): - tgt = http_target.HTTPSyncTarget("http://foo/foo") - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "{}\r\n]", None) - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "\r\n{}\r\n]", None) - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "", None) - - def test_wrong_end(self): - tgt = http_target.HTTPSyncTarget("http://foo/foo") - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "[\r\n{}", None) - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "[\r\n", None) - - def test_missing_comma(self): - tgt = http_target.HTTPSyncTarget("http://foo/foo") - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, - '[\r\n{}\r\n{"id": "i", "rev": "r", ' - '"content": "c", "gen": 3}\r\n]', None) - - def test_no_entries(self): - tgt = http_target.HTTPSyncTarget("http://foo/foo") - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "[\r\n]", None) - - def test_extra_comma(self): - tgt = http_target.HTTPSyncTarget("http://foo/foo") - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, "[\r\n{},\r\n]", None) - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, - '[\r\n{},\r\n{"id": "i", "rev": "r", ' - '"content": "{}", "gen": 3, "trans_id": "T-sid"}' - ',\r\n]', - lambda doc, gen, trans_id: None) - - def test_error_in_stream(self): - tgt = http_target.HTTPSyncTarget("http://foo/foo") - - self.assertRaises(errors.Unavailable, - tgt._parse_sync_stream, - '[\r\n{"new_generation": 0},' - '\r\n{"error": "unavailable"}\r\n', None) - - self.assertRaises(errors.Unavailable, - tgt._parse_sync_stream, - '[\r\n{"error": "unavailable"}\r\n', None) - - self.assertRaises(errors.BrokenSyncStream, - tgt._parse_sync_stream, - '[\r\n{"error": "?"}\r\n', None) - - -def make_http_app(state): - return http_app.HTTPApp(state) - - -def http_sync_target(test, path): - return http_target.HTTPSyncTarget(test.getURL(path)) - - -def make_oauth_http_app(state): - app = http_app.HTTPApp(state) - application = oauth_middleware.OAuthMiddleware(app, None, prefix='/~/') - application.get_oauth_data_store = lambda: tests.testingOAuthStore - return application - - -def oauth_http_sync_target(test, path): - st = http_sync_target(test, '~/' + path) - st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - return st - - -class TestRemoteSyncTargets(tests.TestCaseWithServer): - - scenarios = [ - ('http', {'make_app_with_state': make_http_app, - 'make_document_for_test': tests.make_document_for_test, - 'sync_target': http_sync_target}), - ('oauth_http', {'make_app_with_state': make_oauth_http_app, - 'make_document_for_test': tests.make_document_for_test, - 'sync_target': oauth_http_sync_target}), - ] - - def getSyncTarget(self, path=None): - if self.server is None: - self.startServer() - return self.sync_target(self, path) - - def test_get_sync_info(self): - self.startServer() - db = self.request_state._create_database('test') - db._set_replica_gen_and_trans_id('other-id', 1, 'T-transid') - remote_target = self.getSyncTarget('test') - self.assertEqual(('test', 0, '', 1, 'T-transid'), - remote_target.get_sync_info('other-id')) - - def test_record_sync_info(self): - self.startServer() - db = self.request_state._create_database('test') - remote_target = self.getSyncTarget('test') - remote_target.record_sync_info('other-id', 2, 'T-transid') - self.assertEqual( - (2, 'T-transid'), db._get_replica_gen_and_trans_id('other-id')) - - def test_sync_exchange_send(self): - self.startServer() - db = self.request_state._create_database('test') - remote_target = self.getSyncTarget('test') - other_docs = [] - - def receive_doc(doc): - other_docs.append((doc.doc_id, doc.rev, doc.get_json())) - - doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}') - new_gen, trans_id = remote_target.sync_exchange( - [(doc, 10, 'T-sid')], 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=receive_doc) - self.assertEqual(1, new_gen) - self.assertGetDoc( - db, 'doc-here', 'replica:1', '{"value": "here"}', False) - - def test_sync_exchange_send_failure_and_retry_scenario(self): - self.startServer() - - def blackhole_getstderr(inst): - return cStringIO.StringIO() - - self.patch(self.server.RequestHandlerClass, 'get_stderr', - blackhole_getstderr) - db = self.request_state._create_database('test') - _put_doc_if_newer = db._put_doc_if_newer - trigger_ids = ['doc-here2'] - - def bomb_put_doc_if_newer(doc, save_conflict, - replica_uid=None, replica_gen=None, - replica_trans_id=None): - if doc.doc_id in trigger_ids: - raise Exception - return _put_doc_if_newer(doc, save_conflict=save_conflict, - replica_uid=replica_uid, replica_gen=replica_gen, - replica_trans_id=replica_trans_id) - self.patch(db, '_put_doc_if_newer', bomb_put_doc_if_newer) - remote_target = self.getSyncTarget('test') - other_changes = [] - - def receive_doc(doc, gen, trans_id): - other_changes.append( - (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) - - doc1 = self.make_document('doc-here', 'replica:1', '{"value": "here"}') - doc2 = self.make_document('doc-here2', 'replica:1', - '{"value": "here2"}') - self.assertRaises( - errors.HTTPError, - remote_target.sync_exchange, - [(doc1, 10, 'T-sid'), (doc2, 11, 'T-sud')], - 'replica', last_known_generation=0, last_known_trans_id=None, - return_doc_cb=receive_doc) - self.assertGetDoc(db, 'doc-here', 'replica:1', '{"value": "here"}', - False) - self.assertEqual( - (10, 'T-sid'), db._get_replica_gen_and_trans_id('replica')) - self.assertEqual([], other_changes) - # retry - trigger_ids = [] - new_gen, trans_id = remote_target.sync_exchange( - [(doc2, 11, 'T-sud')], 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=receive_doc) - self.assertGetDoc(db, 'doc-here2', 'replica:1', '{"value": "here2"}', - False) - self.assertEqual( - (11, 'T-sud'), db._get_replica_gen_and_trans_id('replica')) - self.assertEqual(2, new_gen) - # bounced back to us - self.assertEqual( - ('doc-here', 'replica:1', '{"value": "here"}', 1), - other_changes[0][:-1]) - - def test_sync_exchange_in_stream_error(self): - self.startServer() - - def blackhole_getstderr(inst): - return cStringIO.StringIO() - - self.patch(self.server.RequestHandlerClass, 'get_stderr', - blackhole_getstderr) - db = self.request_state._create_database('test') - doc = db.create_doc_from_json('{"value": "there"}') - - def bomb_get_docs(doc_ids, check_for_conflicts=None, - include_deleted=False): - yield doc - # delayed failure case - raise errors.Unavailable - - self.patch(db, 'get_docs', bomb_get_docs) - remote_target = self.getSyncTarget('test') - other_changes = [] - - def receive_doc(doc, gen, trans_id): - other_changes.append( - (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) - - self.assertRaises( - errors.Unavailable, remote_target.sync_exchange, [], 'replica', - last_known_generation=0, last_known_trans_id=None, - return_doc_cb=receive_doc) - self.assertEqual( - (doc.doc_id, doc.rev, '{"value": "there"}', 1), - other_changes[0][:-1]) - - def test_sync_exchange_receive(self): - self.startServer() - db = self.request_state._create_database('test') - doc = db.create_doc_from_json('{"value": "there"}') - remote_target = self.getSyncTarget('test') - other_changes = [] - - def receive_doc(doc, gen, trans_id): - other_changes.append( - (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) - - new_gen, trans_id = remote_target.sync_exchange( - [], 'replica', last_known_generation=0, last_known_trans_id=None, - return_doc_cb=receive_doc) - self.assertEqual(1, new_gen) - self.assertEqual( - (doc.doc_id, doc.rev, '{"value": "there"}', 1), - other_changes[0][:-1]) - - def test_sync_exchange_send_ensure_callback(self): - self.startServer() - remote_target = self.getSyncTarget('test') - other_docs = [] - replica_uid_box = [] - - def receive_doc(doc): - other_docs.append((doc.doc_id, doc.rev, doc.get_json())) - - def ensure_cb(replica_uid): - replica_uid_box.append(replica_uid) - - doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}') - new_gen, trans_id = remote_target.sync_exchange( - [(doc, 10, 'T-sid')], 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=receive_doc, - ensure_callback=ensure_cb) - self.assertEqual(1, new_gen) - db = self.request_state.open_database('test') - self.assertEqual(1, len(replica_uid_box)) - self.assertEqual(db._replica_uid, replica_uid_box[0]) - self.assertGetDoc( - db, 'doc-here', 'replica:1', '{"value": "here"}', False) - - -load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_remote_utils.py b/src/leap/soledad/u1db/tests/test_remote_utils.py deleted file mode 100644 index 959cd882..00000000 --- a/src/leap/soledad/u1db/tests/test_remote_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests for protocol details utils.""" - -from u1db.tests import TestCase -from u1db.remote import utils - - -class TestUtils(TestCase): - - def test_check_and_strip_comma(self): - line, comma = utils.check_and_strip_comma("abc,") - self.assertTrue(comma) - self.assertEqual("abc", line) - - line, comma = utils.check_and_strip_comma("abc") - self.assertFalse(comma) - self.assertEqual("abc", line) - - line, comma = utils.check_and_strip_comma("") - self.assertFalse(comma) - self.assertEqual("", line) diff --git a/src/leap/soledad/u1db/tests/test_server_state.py b/src/leap/soledad/u1db/tests/test_server_state.py deleted file mode 100644 index fc3f1282..00000000 --- a/src/leap/soledad/u1db/tests/test_server_state.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests for server state object.""" - -import os - -from u1db import ( - errors, - tests, - ) -from u1db.remote import ( - server_state, - ) -from u1db.backends import sqlite_backend - - -class TestServerState(tests.TestCase): - - def setUp(self): - super(TestServerState, self).setUp() - self.state = server_state.ServerState() - - def test_set_workingdir(self): - tempdir = self.createTempDir() - self.state.set_workingdir(tempdir) - self.assertTrue(self.state._relpath('path').startswith(tempdir)) - - def test_open_database(self): - tempdir = self.createTempDir() - self.state.set_workingdir(tempdir) - path = tempdir + '/test.db' - self.assertFalse(os.path.exists(path)) - # Create the db, but don't do anything with it - sqlite_backend.SQLitePartialExpandDatabase(path) - db = self.state.open_database('test.db') - self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase) - - def test_check_database(self): - tempdir = self.createTempDir() - self.state.set_workingdir(tempdir) - path = tempdir + '/test.db' - self.assertFalse(os.path.exists(path)) - - # doesn't exist => raises - self.assertRaises(errors.DatabaseDoesNotExist, - self.state.check_database, 'test.db') - - # Create the db, but don't do anything with it - sqlite_backend.SQLitePartialExpandDatabase(path) - # exists => returns - res = self.state.check_database('test.db') - self.assertIsNone(res) - - def test_ensure_database(self): - tempdir = self.createTempDir() - self.state.set_workingdir(tempdir) - path = tempdir + '/test.db' - self.assertFalse(os.path.exists(path)) - db, replica_uid = self.state.ensure_database('test.db') - self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase) - self.assertEqual(db._replica_uid, replica_uid) - self.assertTrue(os.path.exists(path)) - db2 = self.state.open_database('test.db') - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) - - def test_delete_database(self): - tempdir = self.createTempDir() - self.state.set_workingdir(tempdir) - path = tempdir + '/test.db' - db, _ = self.state.ensure_database('test.db') - db.close() - self.state.delete_database('test.db') - self.assertFalse(os.path.exists(path)) - - def test_delete_database_DoesNotExist(self): - tempdir = self.createTempDir() - self.state.set_workingdir(tempdir) - self.assertRaises(errors.DatabaseDoesNotExist, - self.state.delete_database, 'test.db') diff --git a/src/leap/soledad/u1db/tests/test_sqlite_backend.py b/src/leap/soledad/u1db/tests/test_sqlite_backend.py deleted file mode 100644 index 73330789..00000000 --- a/src/leap/soledad/u1db/tests/test_sqlite_backend.py +++ /dev/null @@ -1,493 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Test sqlite backend internals.""" - -import os -import time -import threading - -from sqlite3 import dbapi2 - -from u1db import ( - errors, - tests, - query_parser, - ) -from u1db.backends import sqlite_backend -from u1db.tests.test_backends import TestAlternativeDocument - - -simple_doc = '{"key": "value"}' -nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' - - -class TestSQLiteDatabase(tests.TestCase): - - def test_atomic_initialize(self): - tmpdir = self.createTempDir() - dbname = os.path.join(tmpdir, 'atomic.db') - - t2 = None # will be a thread - - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): - _index_storage_value = "testing" - - def __init__(self, dbname, ntry): - self._try = ntry - self._is_initialized_invocations = 0 - super(SQLiteDatabaseTesting, self).__init__(dbname) - - def _is_initialized(self, c): - res = super(SQLiteDatabaseTesting, self)._is_initialized(c) - if self._try == 1: - self._is_initialized_invocations += 1 - if self._is_initialized_invocations == 2: - t2.start() - # hard to do better and have a generic test - time.sleep(0.05) - return res - - outcome2 = [] - - def second_try(): - try: - db2 = SQLiteDatabaseTesting(dbname, 2) - except Exception, e: - outcome2.append(e) - else: - outcome2.append(db2) - - t2 = threading.Thread(target=second_try) - db1 = SQLiteDatabaseTesting(dbname, 1) - t2.join() - - self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting) - db2 = outcome2[0] - self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor())) - - -class TestSQLitePartialExpandDatabase(tests.TestCase): - - def setUp(self): - super(TestSQLitePartialExpandDatabase, self).setUp() - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - self.db._set_replica_uid('test') - - def test_create_database(self): - raw_db = self.db._get_sqlite_handle() - self.assertNotEqual(None, raw_db) - - def test_default_replica_uid(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - self.assertIsNot(None, self.db._replica_uid) - self.assertEqual(32, len(self.db._replica_uid)) - int(self.db._replica_uid, 16) - - def test__close_sqlite_handle(self): - raw_db = self.db._get_sqlite_handle() - self.db._close_sqlite_handle() - self.assertRaises(dbapi2.ProgrammingError, - raw_db.cursor) - - def test_create_database_initializes_schema(self): - raw_db = self.db._get_sqlite_handle() - c = raw_db.cursor() - c.execute("SELECT * FROM u1db_config") - config = dict([(r[0], r[1]) for r in c.fetchall()]) - self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', - 'index_storage': 'expand referenced'}, config) - - # These tables must exist, though we don't care what is in them yet - c.execute("SELECT * FROM transaction_log") - c.execute("SELECT * FROM document") - c.execute("SELECT * FROM document_fields") - c.execute("SELECT * FROM sync_log") - c.execute("SELECT * FROM conflicts") - c.execute("SELECT * FROM index_definitions") - - def test__parse_index(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - g = self.db._parse_index_definition('fieldname') - self.assertIsInstance(g, query_parser.ExtractField) - self.assertEqual(['fieldname'], g.field) - - def test__update_indexes(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - g = self.db._parse_index_definition('fieldname') - c = self.db._get_sqlite_handle().cursor() - self.db._update_indexes('doc-id', {'fieldname': 'val'}, - [('fieldname', g)], c) - c.execute('SELECT doc_id, field_name, value FROM document_fields') - self.assertEqual([('doc-id', 'fieldname', 'val')], - c.fetchall()) - - def test__set_replica_uid(self): - # Start from scratch, so that replica_uid isn't set. - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') - self.assertIsNot(None, self.db._real_replica_uid) - self.assertIsNot(None, self.db._replica_uid) - self.db._set_replica_uid('foo') - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'") - self.assertEqual(('foo',), c.fetchone()) - self.assertEqual('foo', self.db._real_replica_uid) - self.assertEqual('foo', self.db._replica_uid) - self.db._close_sqlite_handle() - self.assertEqual('foo', self.db._replica_uid) - - def test__get_generation(self): - self.assertEqual(0, self.db._get_generation()) - - def test__get_generation_info(self): - self.assertEqual((0, ''), self.db._get_generation_info()) - - def test_create_index(self): - self.db.create_index('test-idx', "key") - self.assertEqual([('test-idx', ["key"])], self.db.list_indexes()) - - def test_create_index_multiple_fields(self): - self.db.create_index('test-idx', "key", "key2") - self.assertEqual([('test-idx', ["key", "key2"])], - self.db.list_indexes()) - - def test__get_index_definition(self): - self.db.create_index('test-idx', "key", "key2") - # TODO: How would you test that an index is getting used for an SQL - # request? - self.assertEqual(["key", "key2"], - self.db._get_index_definition('test-idx')) - - def test_list_index_mixed(self): - # Make sure that we properly order the output - c = self.db._get_sqlite_handle().cursor() - # We intentionally insert the data in weird ordering, to make sure the - # query still gets it back correctly. - c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", - [('idx-1', 0, 'key10'), - ('idx-2', 2, 'key22'), - ('idx-1', 1, 'key11'), - ('idx-2', 0, 'key20'), - ('idx-2', 1, 'key21')]) - self.assertEqual([('idx-1', ['key10', 'key11']), - ('idx-2', ['key20', 'key21', 'key22'])], - self.db.list_indexes()) - - def test_no_indexes_no_document_fields(self): - self.db.create_doc_from_json( - '{"key1": "val1", "key2": "val2"}') - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual([], c.fetchall()) - - def test_create_extracts_fields(self): - doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') - doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}') - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual([], c.fetchall()) - self.db.create_index('test', 'key1', 'key2') - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual(sorted( - [(doc1.doc_id, "key1", "val1"), - (doc1.doc_id, "key2", "val2"), - (doc2.doc_id, "key1", "valx"), - (doc2.doc_id, "key2", "valy"), - ]), sorted(c.fetchall())) - - def test_put_updates_fields(self): - self.db.create_index('test', 'key1', 'key2') - doc1 = self.db.create_doc_from_json( - '{"key1": "val1", "key2": "val2"}') - doc1.content = {"key1": "val1", "key2": "valy"} - self.db.put_doc(doc1) - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual([(doc1.doc_id, "key1", "val1"), - (doc1.doc_id, "key2", "valy"), - ], c.fetchall()) - - def test_put_updates_nested_fields(self): - self.db.create_index('test', 'key', 'sub.doc') - doc1 = self.db.create_doc_from_json(nested_doc) - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual([(doc1.doc_id, "key", "value"), - (doc1.doc_id, "sub.doc", "underneath"), - ], c.fetchall()) - - def test__ensure_schema_rollback(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/rollback.db' - - class SQLitePartialExpandDbTesting( - sqlite_backend.SQLitePartialExpandDatabase): - - def _set_replica_uid_in_transaction(self, uid): - super(SQLitePartialExpandDbTesting, - self)._set_replica_uid_in_transaction(uid) - if fail: - raise Exception() - - db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting) - db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed - fail = True - self.assertRaises(Exception, db._ensure_schema) - fail = False - db._initialize(db._db_handle.cursor()) - - def test__open_database(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/test.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase._open_database(path) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) - - def test__open_database_with_factory(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/test.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase._open_database( - path, document_factory=TestAlternativeDocument) - self.assertEqual(TestAlternativeDocument, db2._factory) - - def test__open_database_non_existent(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/non-existent.sqlite' - self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase._open_database, path) - - def test__open_database_during_init(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/initialised.db' - db = sqlite_backend.SQLitePartialExpandDatabase.__new__( - sqlite_backend.SQLitePartialExpandDatabase) - db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed - self.addCleanup(db.close) - observed = [] - - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): - WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 - - @classmethod - def _which_index_storage(cls, c): - res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c) - db._ensure_schema() # init db - observed.append(res[0]) - return res - - db2 = SQLiteDatabaseTesting._open_database(path) - self.addCleanup(db2.close) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) - self.assertEqual([None, - sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], - observed) - - def test__open_database_invalid(self): - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): - WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 - temp_dir = self.createTempDir(prefix='u1db-test-') - path1 = temp_dir + '/invalid1.db' - with open(path1, 'wb') as f: - f.write("") - self.assertRaises(dbapi2.OperationalError, - SQLiteDatabaseTesting._open_database, path1) - with open(path1, 'wb') as f: - f.write("invalid") - self.assertRaises(dbapi2.DatabaseError, - SQLiteDatabaseTesting._open_database, path1) - - def test_open_database_existing(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/existing.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) - - def test_open_database_with_factory(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/existing.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase.open_database( - path, create=False, document_factory=TestAlternativeDocument) - self.assertEqual(TestAlternativeDocument, db2._factory) - - def test_open_database_create(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/new.sqlite' - sqlite_backend.SQLiteDatabase.open_database(path, create=True) - db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) - - def test_open_database_non_existent(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/non-existent.sqlite' - self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.open_database, path, - create=False) - - def test_delete_database_existent(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/new.sqlite' - db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) - db.close() - sqlite_backend.SQLiteDatabase.delete_database(path) - self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.open_database, path, - create=False) - - def test_delete_database_nonexistent(self): - temp_dir = self.createTempDir(prefix='u1db-test-') - path = temp_dir + '/non-existent.sqlite' - self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.delete_database, path) - - def test__get_indexed_fields(self): - self.db.create_index('idx1', 'a', 'b') - self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields()) - self.db.create_index('idx2', 'b', 'c') - self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields()) - - def test_indexed_fields_expanded(self): - self.db.create_index('idx1', 'key1') - doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') - self.assertEqual(set(['key1']), self.db._get_indexed_fields()) - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) - - def test_create_index_updates_fields(self): - doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') - self.db.create_index('idx1', 'key1') - self.assertEqual(set(['key1']), self.db._get_indexed_fields()) - c = self.db._get_sqlite_handle().cursor() - c.execute("SELECT doc_id, field_name, value FROM document_fields" - " ORDER BY doc_id, field_name, value") - self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) - - def assertFormatQueryEquals(self, exp_statement, exp_args, definition, - values): - statement, args = self.db._format_query(definition, values) - self.assertEqual(exp_statement, statement) - self.assertEqual(exp_args, args) - - def test__format_query(self): - self.assertFormatQueryEquals( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON " - "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name " - "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content " - "ORDER BY d0.value;", ["key1", "a"], - ["key1"], ["a"]) - - def test__format_query2(self): - self.assertFormatQueryEquals( - 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' - 'document d, document_fields d0, document_fields d1, ' - 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' - 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' - 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' - 'd0.value, d1.value, d2.value;', - ["key1", "a", "key2", "b", "key3", "c"], - ["key1", "key2", "key3"], ["a", "b", "c"]) - - def test__format_query_wildcard(self): - self.assertFormatQueryEquals( - 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' - 'document d, document_fields d0, document_fields d1, ' - 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' - 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? ' - 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content ' - 'ORDER BY d0.value, d1.value, d2.value;', - ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"], - ["a", "b*", "*"]) - - def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition, - start_value, end_value): - statement, args = self.db._format_range_query( - definition, start_value, end_value) - self.assertEqual(exp_statement, statement) - self.assertEqual(exp_args, args) - - def test__format_range_query(self): - self.assertFormatRangeQueryEquals( - 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' - 'document d, document_fields d0, document_fields d1, ' - 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' - 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' - 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' - 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' - 'd0.value, d1.value, d2.value;', - ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q', - 'key3', 'r'], - ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"]) - - def test__format_range_query_no_start(self): - self.assertFormatRangeQueryEquals( - 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' - 'document d, document_fields d0, document_fields d1, ' - 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' - 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' - 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' - 'd0.value, d1.value, d2.value;', - ['key1', 'a', 'key2', 'b', 'key3', 'c'], - ["key1", "key2", "key3"], None, ["a", "b", "c"]) - - def test__format_range_query_no_end(self): - self.assertFormatRangeQueryEquals( - 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' - 'document d, document_fields d0, document_fields d1, ' - 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' - 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' - 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' - 'd0.value, d1.value, d2.value;', - ['key1', 'a', 'key2', 'b', 'key3', 'c'], - ["key1", "key2", "key3"], ["a", "b", "c"], None) - - def test__format_range_query_wildcard(self): - self.assertFormatRangeQueryEquals( - 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' - 'document d, document_fields d0, document_fields d1, ' - 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' - 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' - 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' - 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' - 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? ' - 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? ' - 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id ' - 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, ' - 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;', - ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*', - 'key3'], - ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) diff --git a/src/leap/soledad/u1db/tests/test_sync.py b/src/leap/soledad/u1db/tests/test_sync.py deleted file mode 100644 index f2a925f0..00000000 --- a/src/leap/soledad/u1db/tests/test_sync.py +++ /dev/null @@ -1,1285 +0,0 @@ -# Copyright 2011-2012 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""The Synchronization class for U1DB.""" - -import os -from wsgiref import simple_server - -from u1db import ( - errors, - sync, - tests, - vectorclock, - SyncTarget, - ) -from u1db.backends import ( - inmemory, - ) -from u1db.remote import ( - http_target, - ) - -from u1db.tests.test_remote_sync_target import ( - make_http_app, - make_oauth_http_app, - ) - -simple_doc = tests.simple_doc -nested_doc = tests.nested_doc - - -def _make_local_db_and_target(test): - db = test.create_database('test') - st = db.get_sync_target() - return db, st - - -def _make_local_db_and_http_target(test, path='test'): - test.startServer() - db = test.request_state._create_database(os.path.basename(path)) - st = http_target.HTTPSyncTarget.connect(test.getURL(path)) - return db, st - - -def _make_c_db_and_c_http_target(test, path='test'): - test.startServer() - db = test.request_state._create_database(os.path.basename(path)) - url = test.getURL(path) - st = tests.c_backend_wrapper.create_http_sync_target(url) - return db, st - - -def _make_local_db_and_oauth_http_target(test): - db, st = _make_local_db_and_http_target(test, '~/test') - st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - return db, st - - -def _make_c_db_and_oauth_http_target(test, path='~/test'): - test.startServer() - db = test.request_state._create_database(os.path.basename(path)) - url = test.getURL(path) - st = tests.c_backend_wrapper.create_oauth_http_sync_target(url, - tests.consumer1.key, tests.consumer1.secret, - tests.token1.key, tests.token1.secret) - return db, st - - -target_scenarios = [ - ('local', {'create_db_and_target': _make_local_db_and_target}), - ('http', {'create_db_and_target': _make_local_db_and_http_target, - 'make_app_with_state': make_http_app}), - ('oauth_http', {'create_db_and_target': - _make_local_db_and_oauth_http_target, - 'make_app_with_state': make_oauth_http_app}), - ] - -c_db_scenarios = [ - ('local,c', {'create_db_and_target': _make_local_db_and_target, - 'make_database_for_test': tests.make_c_database_for_test, - 'copy_database_for_test': tests.copy_c_database_for_test, - 'make_document_for_test': tests.make_c_document_for_test, - 'whitebox': False}), - ('http,c', {'create_db_and_target': _make_c_db_and_c_http_target, - 'make_database_for_test': tests.make_c_database_for_test, - 'copy_database_for_test': tests.copy_c_database_for_test, - 'make_document_for_test': tests.make_c_document_for_test, - 'make_app_with_state': make_http_app, - 'whitebox': False}), - ('oauth_http,c', {'create_db_and_target': _make_c_db_and_oauth_http_target, - 'make_database_for_test': tests.make_c_database_for_test, - 'copy_database_for_test': tests.copy_c_database_for_test, - 'make_document_for_test': tests.make_c_document_for_test, - 'make_app_with_state': make_oauth_http_app, - 'whitebox': False}), - ] - - -class DatabaseSyncTargetTests(tests.DatabaseBaseTests, - tests.TestCaseWithServer): - - scenarios = (tests.multiply_scenarios(tests.DatabaseBaseTests.scenarios, - target_scenarios) - + c_db_scenarios) - # whitebox true means self.db is the actual local db object - # against which the sync is performed - whitebox = True - - def setUp(self): - super(DatabaseSyncTargetTests, self).setUp() - self.db, self.st = self.create_db_and_target(self) - self.other_changes = [] - - def tearDown(self): - # We delete them explicitly, so that connections are cleanly closed - del self.st - self.db.close() - del self.db - super(DatabaseSyncTargetTests, self).tearDown() - - def receive_doc(self, doc, gen, trans_id): - self.other_changes.append( - (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) - - def set_trace_hook(self, callback, shallow=False): - setter = (self.st._set_trace_hook if not shallow else - self.st._set_trace_hook_shallow) - try: - setter(callback) - except NotImplementedError: - self.skipTest("%s does not implement _set_trace_hook" - % (self.st.__class__.__name__,)) - - def test_get_sync_target(self): - self.assertIsNot(None, self.st) - - def test_get_sync_info(self): - self.assertEqual( - ('test', 0, '', 0, ''), self.st.get_sync_info('other')) - - def test_create_doc_updates_sync_info(self): - self.assertEqual( - ('test', 0, '', 0, ''), self.st.get_sync_info('other')) - self.db.create_doc_from_json(simple_doc) - self.assertEqual(1, self.st.get_sync_info('other')[1]) - - def test_record_sync_info(self): - self.st.record_sync_info('replica', 10, 'T-transid') - self.assertEqual( - ('test', 0, '', 10, 'T-transid'), self.st.get_sync_info('replica')) - - def test_sync_exchange(self): - docs_by_gen = [ - (self.make_document('doc-id', 'replica:1', simple_doc), 10, - 'T-sid')] - new_gen, trans_id = self.st.sync_exchange( - docs_by_gen, 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False) - self.assertTransactionLog(['doc-id'], self.db) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual(([], 1, last_trans_id), - (self.other_changes, new_gen, last_trans_id)) - self.assertEqual(10, self.st.get_sync_info('replica')[3]) - - def test_sync_exchange_deleted(self): - doc = self.db.create_doc_from_json('{}') - edit_rev = 'replica:1|' + doc.rev - docs_by_gen = [ - (self.make_document(doc.doc_id, edit_rev, None), 10, 'T-sid')] - new_gen, trans_id = self.st.sync_exchange( - docs_by_gen, 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, edit_rev, None, False) - self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual(([], 2, last_trans_id), - (self.other_changes, new_gen, trans_id)) - self.assertEqual(10, self.st.get_sync_info('replica')[3]) - - def test_sync_exchange_push_many(self): - docs_by_gen = [ - (self.make_document('doc-id', 'replica:1', simple_doc), 10, 'T-1'), - (self.make_document('doc-id2', 'replica:1', nested_doc), 11, - 'T-2')] - new_gen, trans_id = self.st.sync_exchange( - docs_by_gen, 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False) - self.assertGetDoc(self.db, 'doc-id2', 'replica:1', nested_doc, False) - self.assertTransactionLog(['doc-id', 'doc-id2'], self.db) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual(([], 2, last_trans_id), - (self.other_changes, new_gen, trans_id)) - self.assertEqual(11, self.st.get_sync_info('replica')[3]) - - def test_sync_exchange_refuses_conflicts(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - new_doc = '{"key": "altval"}' - docs_by_gen = [ - (self.make_document(doc.doc_id, 'replica:1', new_doc), 10, - 'T-sid')] - new_gen, _ = self.st.sync_exchange( - docs_by_gen, 'replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertTransactionLog([doc.doc_id], self.db) - self.assertEqual( - (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1]) - self.assertEqual(1, new_gen) - if self.whitebox: - self.assertEqual(self.db._last_exchange_log['return'], - {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]}) - - def test_sync_exchange_ignores_convergence(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - gen, txid = self.db._get_generation_info() - docs_by_gen = [ - (self.make_document(doc.doc_id, doc.rev, simple_doc), 10, 'T-sid')] - new_gen, _ = self.st.sync_exchange( - docs_by_gen, 'replica', last_known_generation=gen, - last_known_trans_id=txid, return_doc_cb=self.receive_doc) - self.assertTransactionLog([doc.doc_id], self.db) - self.assertEqual(([], 1), (self.other_changes, new_gen)) - - def test_sync_exchange_returns_new_docs(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - new_gen, _ = self.st.sync_exchange( - [], 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertTransactionLog([doc.doc_id], self.db) - self.assertEqual( - (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1]) - self.assertEqual(1, new_gen) - if self.whitebox: - self.assertEqual(self.db._last_exchange_log['return'], - {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]}) - - def test_sync_exchange_returns_deleted_docs(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) - new_gen, _ = self.st.sync_exchange( - [], 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) - self.assertEqual( - (doc.doc_id, doc.rev, None, 2), self.other_changes[0][:-1]) - self.assertEqual(2, new_gen) - if self.whitebox: - self.assertEqual(self.db._last_exchange_log['return'], - {'last_gen': 2, 'docs': [(doc.doc_id, doc.rev)]}) - - def test_sync_exchange_returns_many_new_docs(self): - doc = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db) - new_gen, _ = self.st.sync_exchange( - [], 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db) - self.assertEqual(2, new_gen) - self.assertEqual( - [(doc.doc_id, doc.rev, simple_doc, 1), - (doc2.doc_id, doc2.rev, nested_doc, 2)], - [c[:-1] for c in self.other_changes]) - if self.whitebox: - self.assertEqual( - self.db._last_exchange_log['return'], - {'last_gen': 2, 'docs': - [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]}) - - def test_sync_exchange_getting_newer_docs(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - new_doc = '{"key": "altval"}' - docs_by_gen = [ - (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, - 'T-sid')] - new_gen, _ = self.st.sync_exchange( - docs_by_gen, 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) - self.assertEqual(([], 2), (self.other_changes, new_gen)) - - def test_sync_exchange_with_concurrent_updates_of_synced_doc(self): - expected = [] - - def before_whatschanged_cb(state): - if state != 'before whats_changed': - return - cont = '{"key": "cuncurrent"}' - conc_rev = self.db.put_doc( - self.make_document(doc.doc_id, 'test:1|z:2', cont)) - expected.append((doc.doc_id, conc_rev, cont, 3)) - - self.set_trace_hook(before_whatschanged_cb) - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - new_doc = '{"key": "altval"}' - docs_by_gen = [ - (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, - 'T-sid')] - new_gen, _ = self.st.sync_exchange( - docs_by_gen, 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertEqual(expected, [c[:-1] for c in self.other_changes]) - self.assertEqual(3, new_gen) - - def test_sync_exchange_with_concurrent_updates(self): - - def after_whatschanged_cb(state): - if state != 'after whats_changed': - return - self.db.create_doc_from_json('{"new": "doc"}') - - self.set_trace_hook(after_whatschanged_cb) - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - new_doc = '{"key": "altval"}' - docs_by_gen = [ - (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, - 'T-sid')] - new_gen, _ = self.st.sync_exchange( - docs_by_gen, 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertEqual(([], 2), (self.other_changes, new_gen)) - - def test_sync_exchange_converged_handling(self): - doc = self.db.create_doc_from_json(simple_doc) - docs_by_gen = [ - (self.make_document('new', 'other:1', '{}'), 4, 'T-foo'), - (self.make_document(doc.doc_id, doc.rev, doc.get_json()), 5, - 'T-bar')] - new_gen, _ = self.st.sync_exchange( - docs_by_gen, 'other-replica', last_known_generation=0, - last_known_trans_id=None, return_doc_cb=self.receive_doc) - self.assertEqual(([], 2), (self.other_changes, new_gen)) - - def test_sync_exchange_detect_incomplete_exchange(self): - def before_get_docs_explode(state): - if state != 'before get_docs': - return - raise errors.U1DBError("fail") - self.set_trace_hook(before_get_docs_explode) - # suppress traceback printing in the wsgiref server - self.patch(simple_server.ServerHandler, - 'log_exception', lambda h, exc_info: None) - doc = self.db.create_doc_from_json(simple_doc) - self.assertTransactionLog([doc.doc_id], self.db) - self.assertRaises( - (errors.U1DBError, errors.BrokenSyncStream), - self.st.sync_exchange, [], 'other-replica', - last_known_generation=0, last_known_trans_id=None, - return_doc_cb=self.receive_doc) - - def test_sync_exchange_doc_ids(self): - sync_exchange_doc_ids = getattr(self.st, 'sync_exchange_doc_ids', None) - if sync_exchange_doc_ids is None: - self.skipTest("sync_exchange_doc_ids not implemented") - db2 = self.create_database('test2') - doc = db2.create_doc_from_json(simple_doc) - new_gen, trans_id = sync_exchange_doc_ids( - db2, [(doc.doc_id, 10, 'T-sid')], 0, None, - return_doc_cb=self.receive_doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - self.assertTransactionLog([doc.doc_id], self.db) - last_trans_id = self.getLastTransId(self.db) - self.assertEqual(([], 1, last_trans_id), - (self.other_changes, new_gen, trans_id)) - self.assertEqual(10, self.st.get_sync_info(db2._replica_uid)[3]) - - def test__set_trace_hook(self): - called = [] - - def cb(state): - called.append(state) - - self.set_trace_hook(cb) - self.st.sync_exchange([], 'replica', 0, None, self.receive_doc) - self.st.record_sync_info('replica', 0, 'T-sid') - self.assertEqual(['before whats_changed', - 'after whats_changed', - 'before get_docs', - 'record_sync_info', - ], - called) - - def test__set_trace_hook_shallow(self): - if (self.st._set_trace_hook_shallow == self.st._set_trace_hook - or self.st._set_trace_hook_shallow.im_func == - SyncTarget._set_trace_hook_shallow.im_func): - # shallow same as full - expected = ['before whats_changed', - 'after whats_changed', - 'before get_docs', - 'record_sync_info', - ] - else: - expected = ['sync_exchange', 'record_sync_info'] - - called = [] - - def cb(state): - called.append(state) - - self.set_trace_hook(cb, shallow=True) - self.st.sync_exchange([], 'replica', 0, None, self.receive_doc) - self.st.record_sync_info('replica', 0, 'T-sid') - self.assertEqual(expected, called) - - -def sync_via_synchronizer(test, db_source, db_target, trace_hook=None, - trace_hook_shallow=None): - target = db_target.get_sync_target() - trace_hook = trace_hook or trace_hook_shallow - if trace_hook: - target._set_trace_hook(trace_hook) - return sync.Synchronizer(db_source, target).sync() - - -sync_scenarios = [] -for name, scenario in tests.LOCAL_DATABASES_SCENARIOS: - scenario = dict(scenario) - scenario['do_sync'] = sync_via_synchronizer - sync_scenarios.append((name, scenario)) - scenario = dict(scenario) - - -def make_database_for_http_test(test, replica_uid): - if test.server is None: - test.startServer() - db = test.request_state._create_database(replica_uid) - try: - http_at = test._http_at - except AttributeError: - http_at = test._http_at = {} - http_at[db] = replica_uid - return db - - -def copy_database_for_http_test(test, db): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS - # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE - # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN - # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR HOUSE. - if test.server is None: - test.startServer() - new_db = test.request_state._copy_database(db) - try: - http_at = test._http_at - except AttributeError: - http_at = test._http_at = {} - path = db._replica_uid - while path in http_at.values(): - path += 'copy' - http_at[new_db] = path - return new_db - - -def sync_via_synchronizer_and_http(test, db_source, db_target, - trace_hook=None, trace_hook_shallow=None): - if trace_hook: - test.skipTest("full trace hook unsupported over http") - path = test._http_at[db_target] - target = http_target.HTTPSyncTarget.connect(test.getURL(path)) - if trace_hook_shallow: - target._set_trace_hook_shallow(trace_hook_shallow) - return sync.Synchronizer(db_source, target).sync() - - -sync_scenarios.append(('pyhttp', { - 'make_database_for_test': make_database_for_http_test, - 'copy_database_for_test': copy_database_for_http_test, - 'make_document_for_test': tests.make_document_for_test, - 'make_app_with_state': make_http_app, - 'do_sync': sync_via_synchronizer_and_http - })) - - -if tests.c_backend_wrapper is not None: - # TODO: We should hook up sync tests with an HTTP target - def sync_via_c_sync(test, db_source, db_target, trace_hook=None, - trace_hook_shallow=None): - target = db_target.get_sync_target() - trace_hook = trace_hook or trace_hook_shallow - if trace_hook: - target._set_trace_hook(trace_hook) - return tests.c_backend_wrapper.sync_db_to_target(db_source, target) - - for name, scenario in tests.C_DATABASE_SCENARIOS: - scenario = dict(scenario) - scenario['do_sync'] = sync_via_synchronizer - sync_scenarios.append((name + ',pysync', scenario)) - scenario = dict(scenario) - scenario['do_sync'] = sync_via_c_sync - sync_scenarios.append((name + ',csync', scenario)) - - -class DatabaseSyncTests(tests.DatabaseBaseTests, - tests.TestCaseWithServer): - - scenarios = sync_scenarios - do_sync = None # set by scenarios - - def create_database(self, replica_uid, sync_role=None): - if replica_uid == 'test' and sync_role is None: - # created up the chain by base class but unused - return None - db = self.create_database_for_role(replica_uid, sync_role) - if sync_role: - self._use_tracking[db] = (replica_uid, sync_role) - return db - - def create_database_for_role(self, replica_uid, sync_role): - # hook point for reuse - return super(DatabaseSyncTests, self).create_database(replica_uid) - - def copy_database(self, db, sync_role=None): - # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES - # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST - # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS - # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND - # NINJA TO YOUR HOUSE. - db_copy = super(DatabaseSyncTests, self).copy_database(db) - name, orig_sync_role = self._use_tracking[db] - self._use_tracking[db_copy] = (name + '(copy)', sync_role - or orig_sync_role) - return db_copy - - def sync(self, db_from, db_to, trace_hook=None, - trace_hook_shallow=None): - from_name, from_sync_role = self._use_tracking[db_from] - to_name, to_sync_role = self._use_tracking[db_to] - if from_sync_role not in ('source', 'both'): - raise Exception("%s marked for %s use but used as source" % - (from_name, from_sync_role)) - if to_sync_role not in ('target', 'both'): - raise Exception("%s marked for %s use but used as target" % - (to_name, to_sync_role)) - return self.do_sync(self, db_from, db_to, trace_hook, - trace_hook_shallow) - - def setUp(self): - self._use_tracking = {} - super(DatabaseSyncTests, self).setUp() - - def assertLastExchangeLog(self, db, expected): - log = getattr(db, '_last_exchange_log', None) - if log is None: - return - self.assertEqual(expected, log) - - def test_sync_tracks_db_generation_of_other(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.assertEqual(0, self.sync(self.db1, self.db2)) - self.assertEqual( - (0, ''), self.db1._get_replica_gen_and_trans_id('test2')) - self.assertEqual( - (0, ''), self.db2._get_replica_gen_and_trans_id('test1')) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [], 'last_known_gen': 0}, - 'return': {'docs': [], 'last_gen': 0}}) - - def test_sync_autoresolves(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc1 = self.db1.create_doc_from_json(simple_doc, doc_id='doc') - rev1 = doc1.rev - doc2 = self.db2.create_doc_from_json(simple_doc, doc_id='doc') - rev2 = doc2.rev - self.sync(self.db1, self.db2) - doc = self.db1.get_doc('doc') - self.assertFalse(doc.has_conflicts) - self.assertEqual(doc.rev, self.db2.get_doc('doc').rev) - v = vectorclock.VectorClockRev(doc.rev) - self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev1))) - self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev2))) - - def test_sync_autoresolves_moar(self): - # here we test that when a database that has a conflicted document is - # the source of a sync, and the target database has a revision of the - # conflicted document that is newer than the source database's, and - # that target's database's document's content is the same as the - # source's document's conflict's, the source's document's conflict gets - # autoresolved, and the source's document's revision bumped. - # - # idea is as follows: - # A B - # a1 - - # `-------> - # a1 a1 - # v v - # a2 a1b1 - # `-------> - # a1b1+a2 a1b1 - # v - # a1b1+a2 a1b2 (a1b2 has same content as a2) - # `-------> - # a3b2 a1b2 (autoresolved) - # `-------> - # a3b2 a3b2 - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(simple_doc, doc_id='doc') - self.sync(self.db1, self.db2) - for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]: - doc = db.get_doc('doc') - doc.set_json(content) - db.put_doc(doc) - self.sync(self.db1, self.db2) - # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict - doc = self.db1.get_doc('doc') - rev1 = doc.rev - self.assertTrue(doc.has_conflicts) - # set db2 to have a doc of {} (same as db1 before the conflict) - doc = self.db2.get_doc('doc') - doc.set_json('{}') - self.db2.put_doc(doc) - rev2 = doc.rev - # sync it across - self.sync(self.db1, self.db2) - # tadaa! - doc = self.db1.get_doc('doc') - self.assertFalse(doc.has_conflicts) - vec1 = vectorclock.VectorClockRev(rev1) - vec2 = vectorclock.VectorClockRev(rev2) - vec3 = vectorclock.VectorClockRev(doc.rev) - self.assertTrue(vec3.is_newer(vec1)) - self.assertTrue(vec3.is_newer(vec2)) - # because the conflict is on the source, sync it another time - self.sync(self.db1, self.db2) - # make sure db2 now has the exact same thing - self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) - - def test_sync_autoresolves_moar_backwards(self): - # here we test that when a database that has a conflicted document is - # the target of a sync, and the source database has a revision of the - # conflicted document that is newer than the target database's, and - # that source's database's document's content is the same as the - # target's document's conflict's, the target's document's conflict gets - # autoresolved, and the document's revision bumped. - # - # idea is as follows: - # A B - # a1 - - # `-------> - # a1 a1 - # v v - # a2 a1b1 - # `-------> - # a1b1+a2 a1b1 - # v - # a1b1+a2 a1b2 (a1b2 has same content as a2) - # <-------' - # a3b2 a3b2 (autoresolved and propagated) - self.db1 = self.create_database('test1', 'both') - self.db2 = self.create_database('test2', 'both') - self.db1.create_doc_from_json(simple_doc, doc_id='doc') - self.sync(self.db1, self.db2) - for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]: - doc = db.get_doc('doc') - doc.set_json(content) - db.put_doc(doc) - self.sync(self.db1, self.db2) - # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict - doc = self.db1.get_doc('doc') - rev1 = doc.rev - self.assertTrue(doc.has_conflicts) - revc = self.db1.get_doc_conflicts('doc')[-1].rev - # set db2 to have a doc of {} (same as db1 before the conflict) - doc = self.db2.get_doc('doc') - doc.set_json('{}') - self.db2.put_doc(doc) - rev2 = doc.rev - # sync it across - self.sync(self.db2, self.db1) - # tadaa! - doc = self.db1.get_doc('doc') - self.assertFalse(doc.has_conflicts) - vec1 = vectorclock.VectorClockRev(rev1) - vec2 = vectorclock.VectorClockRev(rev2) - vec3 = vectorclock.VectorClockRev(doc.rev) - vecc = vectorclock.VectorClockRev(revc) - self.assertTrue(vec3.is_newer(vec1)) - self.assertTrue(vec3.is_newer(vec2)) - self.assertTrue(vec3.is_newer(vecc)) - # make sure db2 now has the exact same thing - self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) - - def test_sync_autoresolves_moar_backwards_three(self): - # same as autoresolves_moar_backwards, but with three databases (note - # all the syncs go in the same direction -- this is a more natural - # scenario): - # - # A B C - # a1 - - - # `-------> - # a1 a1 - - # `-------> - # a1 a1 a1 - # v v - # a2 a1b1 a1 - # `-------------------> - # a2 a1b1 a2 - # `-------> - # a2+a1b1 a2 - # v - # a2 a2+a1b1 a2c1 (same as a1b1) - # `-------------------> - # a2c1 a2+a1b1 a2c1 - # `-------> - # a2b2c1 a2b2c1 a2c1 - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'both') - self.db3 = self.create_database('test3', 'target') - self.db1.create_doc_from_json(simple_doc, doc_id='doc') - self.sync(self.db1, self.db2) - self.sync(self.db2, self.db3) - for db, content in [(self.db2, '{"hi": 42}'), - (self.db1, '{}'), - ]: - doc = db.get_doc('doc') - doc.set_json(content) - db.put_doc(doc) - self.sync(self.db1, self.db3) - self.sync(self.db2, self.db3) - # db2 and db3 now both have a doc of {}, but db2 has a - # conflict - doc = self.db2.get_doc('doc') - self.assertTrue(doc.has_conflicts) - revc = self.db2.get_doc_conflicts('doc')[-1].rev - self.assertEqual('{}', doc.get_json()) - self.assertEqual(self.db3.get_doc('doc').get_json(), doc.get_json()) - self.assertEqual(self.db3.get_doc('doc').rev, doc.rev) - # set db3 to have a doc of {hi:42} (same as db2 before the conflict) - doc = self.db3.get_doc('doc') - doc.set_json('{"hi": 42}') - self.db3.put_doc(doc) - rev3 = doc.rev - # sync it across to db1 - self.sync(self.db1, self.db3) - # db1 now has hi:42, with a rev that is newer than db2's doc - doc = self.db1.get_doc('doc') - rev1 = doc.rev - self.assertFalse(doc.has_conflicts) - self.assertEqual('{"hi": 42}', doc.get_json()) - VCR = vectorclock.VectorClockRev - self.assertTrue(VCR(rev1).is_newer(VCR(self.db2.get_doc('doc').rev))) - # so sync it to db2 - self.sync(self.db1, self.db2) - # tadaa! - doc = self.db2.get_doc('doc') - self.assertFalse(doc.has_conflicts) - # db2's revision of the document is strictly newer than db1's before - # the sync, and db3's before that sync way back when - self.assertTrue(VCR(doc.rev).is_newer(VCR(rev1))) - self.assertTrue(VCR(doc.rev).is_newer(VCR(rev3))) - self.assertTrue(VCR(doc.rev).is_newer(VCR(revc))) - # make sure both dbs now have the exact same thing - self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) - - def test_sync_puts_changes(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc = self.db1.create_doc_from_json(simple_doc) - self.assertEqual(1, self.sync(self.db1, self.db2)) - self.assertGetDoc(self.db2, doc.doc_id, doc.rev, simple_doc, False) - self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) - self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0]) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [(doc.doc_id, doc.rev)], - 'source_uid': 'test1', - 'source_gen': 1, 'last_known_gen': 0}, - 'return': {'docs': [], 'last_gen': 1}}) - - def test_sync_pulls_changes(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc = self.db2.create_doc_from_json(simple_doc) - self.db1.create_index('test-idx', 'key') - self.assertEqual(0, self.sync(self.db1, self.db2)) - self.assertGetDoc(self.db1, doc.doc_id, doc.rev, simple_doc, False) - self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) - self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0]) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [], 'last_known_gen': 0}, - 'return': {'docs': [(doc.doc_id, doc.rev)], - 'last_gen': 1}}) - self.assertEqual([doc], self.db1.get_from_index('test-idx', 'value')) - - def test_sync_pulling_doesnt_update_other_if_changed(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc = self.db2.create_doc_from_json(simple_doc) - # After the local side has sent its list of docs, before we start - # receiving the "targets" response, we update the local database with a - # new record. - # When we finish synchronizing, we can notice that something locally - # was updated, and we cannot tell c2 our new updated generation - - def before_get_docs(state): - if state != 'before get_docs': - return - self.db1.create_doc_from_json(simple_doc) - - self.assertEqual(0, self.sync(self.db1, self.db2, - trace_hook=before_get_docs)) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [], 'last_known_gen': 0}, - 'return': {'docs': [(doc.doc_id, doc.rev)], - 'last_gen': 1}}) - self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) - # c2 should not have gotten a '_record_sync_info' call, because the - # local database had been updated more than just by the messages - # returned from c2. - self.assertEqual( - (0, ''), self.db2._get_replica_gen_and_trans_id('test1')) - - def test_sync_doesnt_update_other_if_nothing_pulled(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(simple_doc) - - def no_record_sync_info(state): - if state != 'record_sync_info': - return - self.fail('SyncTarget.record_sync_info was called') - self.assertEqual(1, self.sync(self.db1, self.db2, - trace_hook_shallow=no_record_sync_info)) - self.assertEqual( - 1, - self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)[0]) - - def test_sync_ignores_convergence(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'both') - doc = self.db1.create_doc_from_json(simple_doc) - self.db3 = self.create_database('test3', 'target') - self.assertEqual(1, self.sync(self.db1, self.db3)) - self.assertEqual(0, self.sync(self.db2, self.db3)) - self.assertEqual(1, self.sync(self.db1, self.db2)) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [(doc.doc_id, doc.rev)], - 'source_uid': 'test1', - 'source_gen': 1, 'last_known_gen': 0}, - 'return': {'docs': [], 'last_gen': 1}}) - - def test_sync_ignores_superseded(self): - self.db1 = self.create_database('test1', 'both') - self.db2 = self.create_database('test2', 'both') - doc = self.db1.create_doc_from_json(simple_doc) - doc_rev1 = doc.rev - self.db3 = self.create_database('test3', 'target') - self.sync(self.db1, self.db3) - self.sync(self.db2, self.db3) - new_content = '{"key": "altval"}' - doc.set_json(new_content) - self.db1.put_doc(doc) - doc_rev2 = doc.rev - self.sync(self.db2, self.db1) - self.assertLastExchangeLog(self.db1, - {'receive': {'docs': [(doc.doc_id, doc_rev1)], - 'source_uid': 'test2', - 'source_gen': 1, 'last_known_gen': 0}, - 'return': {'docs': [(doc.doc_id, doc_rev2)], - 'last_gen': 2}}) - self.assertGetDoc(self.db1, doc.doc_id, doc_rev2, new_content, False) - - def test_sync_sees_remote_conflicted(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc1 = self.db1.create_doc_from_json(simple_doc) - doc_id = doc1.doc_id - doc1_rev = doc1.rev - self.db1.create_index('test-idx', 'key') - new_doc = '{"key": "altval"}' - doc2 = self.db2.create_doc_from_json(new_doc, doc_id=doc_id) - doc2_rev = doc2.rev - self.assertTransactionLog([doc1.doc_id], self.db1) - self.sync(self.db1, self.db2) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [(doc_id, doc1_rev)], - 'source_uid': 'test1', - 'source_gen': 1, 'last_known_gen': 0}, - 'return': {'docs': [(doc_id, doc2_rev)], - 'last_gen': 1}}) - self.assertTransactionLog([doc_id, doc_id], self.db1) - self.assertGetDoc(self.db1, doc_id, doc2_rev, new_doc, True) - self.assertGetDoc(self.db2, doc_id, doc2_rev, new_doc, False) - from_idx = self.db1.get_from_index('test-idx', 'altval')[0] - self.assertEqual(doc2.doc_id, from_idx.doc_id) - self.assertEqual(doc2.rev, from_idx.rev) - self.assertTrue(from_idx.has_conflicts) - self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) - - def test_sync_sees_remote_delete_conflicted(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc1 = self.db1.create_doc_from_json(simple_doc) - doc_id = doc1.doc_id - self.db1.create_index('test-idx', 'key') - self.sync(self.db1, self.db2) - doc2 = self.make_document(doc1.doc_id, doc1.rev, doc1.get_json()) - new_doc = '{"key": "altval"}' - doc1.set_json(new_doc) - self.db1.put_doc(doc1) - self.db2.delete_doc(doc2) - self.assertTransactionLog([doc_id, doc_id], self.db1) - self.sync(self.db1, self.db2) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [(doc_id, doc1.rev)], - 'source_uid': 'test1', - 'source_gen': 2, 'last_known_gen': 1}, - 'return': {'docs': [(doc_id, doc2.rev)], - 'last_gen': 2}}) - self.assertTransactionLog([doc_id, doc_id, doc_id], self.db1) - self.assertGetDocIncludeDeleted(self.db1, doc_id, doc2.rev, None, True) - self.assertGetDocIncludeDeleted( - self.db2, doc_id, doc2.rev, None, False) - self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) - - def test_sync_local_race_conflicted(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - doc = self.db1.create_doc_from_json(simple_doc) - doc_id = doc.doc_id - doc1_rev = doc.rev - self.db1.create_index('test-idx', 'key') - self.sync(self.db1, self.db2) - content1 = '{"key": "localval"}' - content2 = '{"key": "altval"}' - doc.set_json(content2) - self.db2.put_doc(doc) - doc2_rev2 = doc.rev - triggered = [] - - def after_whatschanged(state): - if state != 'after whats_changed': - return - triggered.append(True) - doc = self.make_document(doc_id, doc1_rev, content1) - self.db1.put_doc(doc) - - self.sync(self.db1, self.db2, trace_hook=after_whatschanged) - self.assertEqual([True], triggered) - self.assertGetDoc(self.db1, doc_id, doc2_rev2, content2, True) - from_idx = self.db1.get_from_index('test-idx', 'altval')[0] - self.assertEqual(doc.doc_id, from_idx.doc_id) - self.assertEqual(doc.rev, from_idx.rev) - self.assertTrue(from_idx.has_conflicts) - self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) - self.assertEqual([], self.db1.get_from_index('test-idx', 'localval')) - - def test_sync_propagates_deletes(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'both') - doc1 = self.db1.create_doc_from_json(simple_doc) - doc_id = doc1.doc_id - self.db1.create_index('test-idx', 'key') - self.sync(self.db1, self.db2) - self.db2.create_index('test-idx', 'key') - self.db3 = self.create_database('test3', 'target') - self.sync(self.db1, self.db3) - self.db1.delete_doc(doc1) - deleted_rev = doc1.rev - self.sync(self.db1, self.db2) - self.assertLastExchangeLog(self.db2, - {'receive': {'docs': [(doc_id, deleted_rev)], - 'source_uid': 'test1', - 'source_gen': 2, 'last_known_gen': 1}, - 'return': {'docs': [], 'last_gen': 2}}) - self.assertGetDocIncludeDeleted( - self.db1, doc_id, deleted_rev, None, False) - self.assertGetDocIncludeDeleted( - self.db2, doc_id, deleted_rev, None, False) - self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) - self.assertEqual([], self.db2.get_from_index('test-idx', 'value')) - self.sync(self.db2, self.db3) - self.assertLastExchangeLog(self.db3, - {'receive': {'docs': [(doc_id, deleted_rev)], - 'source_uid': 'test2', - 'source_gen': 2, 'last_known_gen': 0}, - 'return': {'docs': [], 'last_gen': 2}}) - self.assertGetDocIncludeDeleted( - self.db3, doc_id, deleted_rev, None, False) - - def test_sync_propagates_resolution(self): - self.db1 = self.create_database('test1', 'both') - self.db2 = self.create_database('test2', 'both') - doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc') - db3 = self.create_database('test3', 'both') - self.sync(self.db2, self.db1) - self.assertEqual( - self.db1._get_generation_info(), - self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)) - self.assertEqual( - self.db2._get_generation_info(), - self.db1._get_replica_gen_and_trans_id(self.db2._replica_uid)) - self.sync(db3, self.db1) - # update on 2 - doc2 = self.make_document('the-doc', doc1.rev, '{"a": 2}') - self.db2.put_doc(doc2) - self.sync(self.db2, db3) - self.assertEqual(db3.get_doc('the-doc').rev, doc2.rev) - # update on 1 - doc1.set_json('{"a": 3}') - self.db1.put_doc(doc1) - # conflicts - self.sync(self.db2, self.db1) - self.sync(db3, self.db1) - self.assertTrue(self.db2.get_doc('the-doc').has_conflicts) - self.assertTrue(db3.get_doc('the-doc').has_conflicts) - # resolve - conflicts = self.db2.get_doc_conflicts('the-doc') - doc4 = self.make_document('the-doc', None, '{"a": 4}') - revs = [doc.rev for doc in conflicts] - self.db2.resolve_doc(doc4, revs) - doc2 = self.db2.get_doc('the-doc') - self.assertEqual(doc4.get_json(), doc2.get_json()) - self.assertFalse(doc2.has_conflicts) - self.sync(self.db2, db3) - doc3 = db3.get_doc('the-doc') - self.assertEqual(doc4.get_json(), doc3.get_json()) - self.assertFalse(doc3.has_conflicts) - - def test_sync_supersedes_conflicts(self): - self.db1 = self.create_database('test1', 'both') - self.db2 = self.create_database('test2', 'target') - db3 = self.create_database('test3', 'both') - doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc') - self.db2.create_doc_from_json('{"b": 1}', doc_id='the-doc') - db3.create_doc_from_json('{"c": 1}', doc_id='the-doc') - self.sync(db3, self.db1) - self.assertEqual( - self.db1._get_generation_info(), - db3._get_replica_gen_and_trans_id(self.db1._replica_uid)) - self.assertEqual( - db3._get_generation_info(), - self.db1._get_replica_gen_and_trans_id(db3._replica_uid)) - self.sync(db3, self.db2) - self.assertEqual( - self.db2._get_generation_info(), - db3._get_replica_gen_and_trans_id(self.db2._replica_uid)) - self.assertEqual( - db3._get_generation_info(), - self.db2._get_replica_gen_and_trans_id(db3._replica_uid)) - self.assertEqual(3, len(db3.get_doc_conflicts('the-doc'))) - doc1.set_json('{"a": 2}') - self.db1.put_doc(doc1) - self.sync(db3, self.db1) - # original doc1 should have been removed from conflicts - self.assertEqual(3, len(db3.get_doc_conflicts('the-doc'))) - - def test_sync_stops_after_get_sync_info(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(tests.simple_doc) - self.sync(self.db1, self.db2) - - def put_hook(state): - self.fail("Tracehook triggered for %s" % (state,)) - - self.sync(self.db1, self.db2, trace_hook_shallow=put_hook) - - def test_sync_detects_rollback_in_source(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1') - self.sync(self.db1, self.db2) - db1_copy = self.copy_database(self.db1) - self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2') - self.sync(self.db1, self.db2) - self.assertRaises( - errors.InvalidGeneration, self.sync, db1_copy, self.db2) - - def test_sync_detects_rollback_in_target(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") - self.sync(self.db1, self.db2) - db2_copy = self.copy_database(self.db2) - self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2') - self.sync(self.db1, self.db2) - self.assertRaises( - errors.InvalidGeneration, self.sync, self.db1, db2_copy) - - def test_sync_detects_diverged_source(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - db3 = self.copy_database(self.db1) - self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") - db3.create_doc_from_json(tests.simple_doc, doc_id="divergent") - self.sync(self.db1, self.db2) - self.assertRaises( - errors.InvalidTransactionId, self.sync, db3, self.db2) - - def test_sync_detects_diverged_target(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - db3 = self.copy_database(self.db2) - db3.create_doc_from_json(tests.nested_doc, doc_id="divergent") - self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") - self.sync(self.db1, self.db2) - self.assertRaises( - errors.InvalidTransactionId, self.sync, self.db1, db3) - - def test_sync_detects_rollback_and_divergence_in_source(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1') - self.sync(self.db1, self.db2) - db1_copy = self.copy_database(self.db1) - self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2') - self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc3') - self.sync(self.db1, self.db2) - db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2') - db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3') - self.assertRaises( - errors.InvalidTransactionId, self.sync, db1_copy, self.db2) - - def test_sync_detects_rollback_and_divergence_in_target(self): - self.db1 = self.create_database('test1', 'source') - self.db2 = self.create_database('test2', 'target') - self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") - self.sync(self.db1, self.db2) - db2_copy = self.copy_database(self.db2) - self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2') - self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc3') - self.sync(self.db1, self.db2) - db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2') - db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3') - self.assertRaises( - errors.InvalidTransactionId, self.sync, self.db1, db2_copy) - - -class TestDbSync(tests.TestCaseWithServer): - """Test db.sync remote sync shortcut""" - - scenarios = [ - ('py-http', { - 'make_app_with_state': make_http_app, - 'make_database_for_test': tests.make_memory_database_for_test, - }), - ('c-http', { - 'make_app_with_state': make_http_app, - 'make_database_for_test': tests.make_c_database_for_test - }), - ('py-oauth-http', { - 'make_app_with_state': make_oauth_http_app, - 'make_database_for_test': tests.make_memory_database_for_test, - 'oauth': True - }), - ('c-oauth-http', { - 'make_app_with_state': make_oauth_http_app, - 'make_database_for_test': tests.make_c_database_for_test, - 'oauth': True - }), - ] - - oauth = False - - def do_sync(self, target_name): - if self.oauth: - path = '~/' + target_name - extra = dict(creds={'oauth': { - 'consumer_key': tests.consumer1.key, - 'consumer_secret': tests.consumer1.secret, - 'token_key': tests.token1.key, - 'token_secret': tests.token1.secret - }}) - else: - path = target_name - extra = {} - target_url = self.getURL(path) - return self.db.sync(target_url, **extra) - - def setUp(self): - super(TestDbSync, self).setUp() - self.startServer() - self.db = self.make_database_for_test(self, 'test1') - self.db2 = self.request_state._create_database('test2.db') - - def test_db_sync(self): - doc1 = self.db.create_doc_from_json(tests.simple_doc) - doc2 = self.db2.create_doc_from_json(tests.nested_doc) - local_gen_before_sync = self.do_sync('test2.db') - gen, _, changes = self.db.whats_changed(local_gen_before_sync) - self.assertEqual(1, len(changes)) - self.assertEqual(doc2.doc_id, changes[0][0]) - self.assertEqual(1, gen - local_gen_before_sync) - self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc, - False) - self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, - False) - - def test_db_sync_autocreate(self): - doc1 = self.db.create_doc_from_json(tests.simple_doc) - local_gen_before_sync = self.do_sync('test3.db') - gen, _, changes = self.db.whats_changed(local_gen_before_sync) - self.assertEqual(0, gen - local_gen_before_sync) - db3 = self.request_state.open_database('test3.db') - gen, _, changes = db3.whats_changed() - self.assertEqual(1, len(changes)) - self.assertEqual(doc1.doc_id, changes[0][0]) - self.assertGetDoc(db3, doc1.doc_id, doc1.rev, tests.simple_doc, - False) - t_gen, _ = self.db._get_replica_gen_and_trans_id('test3.db') - s_gen, _ = db3._get_replica_gen_and_trans_id('test1') - self.assertEqual(1, t_gen) - self.assertEqual(1, s_gen) - - -class TestRemoteSyncIntegration(tests.TestCaseWithServer): - """Integration tests for the most common sync scenario local -> remote""" - - make_app_with_state = staticmethod(make_http_app) - - def setUp(self): - super(TestRemoteSyncIntegration, self).setUp() - self.startServer() - self.db1 = inmemory.InMemoryDatabase('test1') - self.db2 = self.request_state._create_database('test2') - - def test_sync_tracks_generations_incrementally(self): - doc11 = self.db1.create_doc_from_json('{"a": 1}') - doc12 = self.db1.create_doc_from_json('{"a": 2}') - doc21 = self.db2.create_doc_from_json('{"b": 1}') - doc22 = self.db2.create_doc_from_json('{"b": 2}') - #sanity - self.assertEqual(2, len(self.db1._get_transaction_log())) - self.assertEqual(2, len(self.db2._get_transaction_log())) - progress1 = [] - progress2 = [] - _do_set_replica_gen_and_trans_id = \ - self.db1._do_set_replica_gen_and_trans_id - - def set_sync_generation_witness1(other_uid, other_gen, trans_id): - progress1.append((other_uid, other_gen, - [d for d, t in self.db1._get_transaction_log()[2:]])) - _do_set_replica_gen_and_trans_id(other_uid, other_gen, trans_id) - self.patch(self.db1, '_do_set_replica_gen_and_trans_id', - set_sync_generation_witness1) - _do_set_replica_gen_and_trans_id2 = \ - self.db2._do_set_replica_gen_and_trans_id - - def set_sync_generation_witness2(other_uid, other_gen, trans_id): - progress2.append((other_uid, other_gen, - [d for d, t in self.db2._get_transaction_log()[2:]])) - _do_set_replica_gen_and_trans_id2(other_uid, other_gen, trans_id) - self.patch(self.db2, '_do_set_replica_gen_and_trans_id', - set_sync_generation_witness2) - - db2_url = self.getURL('test2') - self.db1.sync(db2_url) - - self.assertEqual([('test2', 1, [doc21.doc_id]), - ('test2', 2, [doc21.doc_id, doc22.doc_id]), - ('test2', 4, [doc21.doc_id, doc22.doc_id])], - progress1) - self.assertEqual([('test1', 1, [doc11.doc_id]), - ('test1', 2, [doc11.doc_id, doc12.doc_id]), - ('test1', 4, [doc11.doc_id, doc12.doc_id])], - progress2) - - -load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_test_infrastructure.py b/src/leap/soledad/u1db/tests/test_test_infrastructure.py deleted file mode 100644 index b79e0516..00000000 --- a/src/leap/soledad/u1db/tests/test_test_infrastructure.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""Tests for test infrastructure bits""" - -from wsgiref import simple_server - -from u1db import ( - tests, - ) - - -class TestTestCaseWithServer(tests.TestCaseWithServer): - - def make_app(self): - return "app" - - @staticmethod - def server_def(): - def make_server(host_port, application): - assert application == "app" - return simple_server.WSGIServer(host_port, None) - return (make_server, "shutdown", "http") - - def test_getURL(self): - self.startServer() - url = self.getURL() - self.assertTrue(url.startswith('http://127.0.0.1:')) diff --git a/src/leap/soledad/u1db/tests/test_vectorclock.py b/src/leap/soledad/u1db/tests/test_vectorclock.py deleted file mode 100644 index 72baf246..00000000 --- a/src/leap/soledad/u1db/tests/test_vectorclock.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""VectorClockRev helper class tests.""" - -from u1db import tests, vectorclock - -try: - from u1db.tests import c_backend_wrapper -except ImportError: - c_backend_wrapper = None - - -c_vectorclock_scenarios = [] -if c_backend_wrapper is not None: - c_vectorclock_scenarios.append( - ('c', {'create_vcr': c_backend_wrapper.VectorClockRev})) - - -class TestVectorClockRev(tests.TestCase): - - scenarios = [('py', {'create_vcr': vectorclock.VectorClockRev}) - ] + c_vectorclock_scenarios - - def assertIsNewer(self, newer_rev, older_rev): - new_vcr = self.create_vcr(newer_rev) - old_vcr = self.create_vcr(older_rev) - self.assertTrue(new_vcr.is_newer(old_vcr)) - self.assertFalse(old_vcr.is_newer(new_vcr)) - - def assertIsConflicted(self, rev_a, rev_b): - vcr_a = self.create_vcr(rev_a) - vcr_b = self.create_vcr(rev_b) - self.assertFalse(vcr_a.is_newer(vcr_b)) - self.assertFalse(vcr_b.is_newer(vcr_a)) - - def assertRoundTrips(self, rev): - self.assertEqual(rev, self.create_vcr(rev).as_str()) - - def test__is_newer_doc_rev(self): - self.assertIsNewer('test:1', None) - self.assertIsNewer('test:2', 'test:1') - self.assertIsNewer('other:2|test:1', 'other:1|test:1') - self.assertIsNewer('other:1|test:1', 'other:1') - self.assertIsNewer('a:2|b:1', 'b:1') - self.assertIsNewer('a:1|b:2', 'a:1') - self.assertIsConflicted('other:2|test:1', 'other:1|test:2') - self.assertIsConflicted('other:1|test:1', 'other:2') - self.assertIsConflicted('test:1', 'test:1') - - def test_None(self): - vcr = self.create_vcr(None) - self.assertEqual('', vcr.as_str()) - - def test_round_trips(self): - self.assertRoundTrips('test:1') - self.assertRoundTrips('a:1|b:2') - self.assertRoundTrips('alternate:2|test:1') - - def test_handles_sort_order(self): - self.assertEqual('a:1|b:2', self.create_vcr('b:2|a:1').as_str()) - # Last one out of place - self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6', - self.create_vcr('f:6|a:1|b:2|c:3|d:4|e:5').as_str()) - # Fully reversed - self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6', - self.create_vcr('f:6|e:5|d:4|c:3|b:2|a:1').as_str()) - - def assertIncrement(self, original, replica_uid, after_increment): - vcr = self.create_vcr(original) - vcr.increment(replica_uid) - self.assertEqual(after_increment, vcr.as_str()) - - def test_increment(self): - self.assertIncrement(None, 'test', 'test:1') - self.assertIncrement('test:1', 'test', 'test:2') - - def test_increment_adds_uid(self): - self.assertIncrement('other:1', 'test', 'other:1|test:1') - self.assertIncrement('a:1|ab:2', 'aa', 'a:1|aa:1|ab:2') - - def test_increment_update_partial(self): - self.assertIncrement('a:1|ab:2', 'a', 'a:2|ab:2') - self.assertIncrement('a:2|ab:2', 'ab', 'a:2|ab:3') - - def test_increment_appends_uid(self): - self.assertIncrement('b:2', 'c', 'b:2|c:1') - - def assertMaximize(self, rev1, rev2, maximized): - vcr1 = self.create_vcr(rev1) - vcr2 = self.create_vcr(rev2) - vcr1.maximize(vcr2) - self.assertEqual(maximized, vcr1.as_str()) - # reset vcr1 to maximize the other way - vcr1 = self.create_vcr(rev1) - vcr2.maximize(vcr1) - self.assertEqual(maximized, vcr2.as_str()) - - def test_maximize(self): - self.assertMaximize(None, None, '') - self.assertMaximize(None, 'x:1', 'x:1') - self.assertMaximize('x:1', 'y:1', 'x:1|y:1') - self.assertMaximize('x:2', 'x:1', 'x:2') - self.assertMaximize('x:2', 'x:1|y:2', 'x:2|y:2') - self.assertMaximize('a:1|c:2|e:3', 'b:3|d:4|f:5', - 'a:1|b:3|c:2|d:4|e:3|f:5') - -load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/testing-certs/Makefile b/src/leap/soledad/u1db/tests/testing-certs/Makefile deleted file mode 100644 index 2385e75b..00000000 --- a/src/leap/soledad/u1db/tests/testing-certs/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -CATOP=./demoCA -ORIG_CONF=/usr/lib/ssl/openssl.cnf -ELEVEN_YEARS=-days 4015 - -init: - cp $(ORIG_CONF) ca.conf - install -d $(CATOP) - install -d $(CATOP)/certs - install -d $(CATOP)/crl - install -d $(CATOP)/newcerts - install -d $(CATOP)/private - touch $(CATOP)/index.txt - echo 01>$(CATOP)/crlnumber - @echo '**** Making CA certificate ...' - openssl req -nodes -new \ - -newkey rsa -keyout $(CATOP)/private/cakey.pem \ - -out $(CATOP)/careq.pem \ - -multivalue-rdn \ - -subj "/C=UK/ST=-/O=u1db LOCAL TESTING ONLY, DO NO TRUST/CN=u1db testing CA" - openssl ca -config ./ca.conf -create_serial \ - -out $(CATOP)/cacert.pem $(ELEVEN_YEARS) -batch \ - -keyfile $(CATOP)/private/cakey.pem -selfsign \ - -extensions v3_ca -infiles $(CATOP)/careq.pem - -pems: - cp ./demoCA/cacert.pem . - openssl req -new -config ca.conf \ - -multivalue-rdn \ - -subj "/O=u1db LOCAL TESTING ONLY, DO NOT TRUST/CN=localhost" \ - -nodes -keyout testing.key -out newreq.pem $(ELEVEN_YEARS) - openssl ca -batch -config ./ca.conf $(ELEVEN_YEARS) \ - -policy policy_anything \ - -out testing.cert -infiles newreq.pem - -.PHONY: init pems diff --git a/src/leap/soledad/u1db/tests/testing-certs/cacert.pem b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem deleted file mode 100644 index c019a730..00000000 --- a/src/leap/soledad/u1db/tests/testing-certs/cacert.pem +++ /dev/null @@ -1,58 +0,0 @@ -Certificate: - Data: - Version: 3 (0x2) - Serial Number: - e4:de:01:76:c4:78:78:7e - Signature Algorithm: sha1WithRSAEncryption - Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA - Validity - Not Before: May 3 11:11:11 2012 GMT - Not After : May 1 11:11:11 2023 GMT - Subject: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - Public-Key: (1024 bit) - Modulus: - 00:bc:91:a5:7f:7d:37:f7:06:c7:db:5b:83:6a:6b: - 63:c3:8b:5c:f7:84:4d:97:6d:d4:be:bf:e7:79:a8: - c1:03:57:ec:90:d4:20:e7:02:95:d9:a6:49:e3:f9: - 9a:ea:37:b9:b2:02:62:ab:40:d3:42:bb:4a:4e:a2: - 47:71:0f:1d:a2:c5:94:a1:cf:35:d3:23:32:42:c0: - 1e:8d:cb:08:58:fb:8a:5c:3e:ea:eb:d5:2c:ed:d6: - aa:09:b4:b5:7d:e3:45:c9:ae:c2:82:b2:ae:c0:81: - bc:24:06:65:a9:e7:e0:61:ac:25:ee:53:d3:d7:be: - 22:f7:00:a2:ad:c6:0e:3a:39 - Exponent: 65537 (0x10001) - X509v3 extensions: - X509v3 Subject Key Identifier: - DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D - X509v3 Authority Key Identifier: - keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D - - X509v3 Basic Constraints: - CA:TRUE - Signature Algorithm: sha1WithRSAEncryption - 72:9b:c1:f7:07:65:83:36:25:4e:01:2f:b7:4a:f2:a4:00:28: - 80:c7:56:2c:32:39:90:13:61:4b:bb:12:c5:44:9d:42:57:85: - 28:19:70:69:e1:43:c8:bd:11:f6:94:df:91:2d:c3:ea:82:8d: - b4:8f:5d:47:a3:00:99:53:29:93:27:6c:c5:da:c1:20:6f:ab: - ec:4a:be:34:f3:8f:02:e5:0c:c0:03:ac:2b:33:41:71:4f:0a: - 72:5a:b4:26:1a:7f:81:bc:c0:95:8a:06:87:a8:11:9f:5c:73: - 38:df:5a:69:40:21:29:ad:46:23:56:75:e1:e9:8b:10:18:4c: - 7b:54 ------BEGIN CERTIFICATE----- -MIICkjCCAfugAwIBAgIJAOTeAXbEeHh+MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV -BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg -T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x -MjA1MDMxMTExMTFaFw0yMzA1MDExMTExMTFaMGIxCzAJBgNVBAYTAlVLMQowCAYD -VQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcgT05MWSwgRE8gTk8g -VFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTCBnzANBgkqhkiG9w0BAQEF -AAOBjQAwgYkCgYEAvJGlf3039wbH21uDamtjw4tc94RNl23Uvr/neajBA1fskNQg -5wKV2aZJ4/ma6je5sgJiq0DTQrtKTqJHcQ8dosWUoc810yMyQsAejcsIWPuKXD7q -69Us7daqCbS1feNFya7CgrKuwIG8JAZlqefgYawl7lPT174i9wCircYOOjkCAwEA -AaNQME4wHQYDVR0OBBYEFNs9k1FsMhVUjxBQ/ElPNhUou5VtMB8GA1UdIwQYMBaA -FNs9k1FsMhVUjxBQ/ElPNhUou5VtMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF -BQADgYEAcpvB9wdlgzYlTgEvt0rypAAogMdWLDI5kBNhS7sSxUSdQleFKBlwaeFD -yL0R9pTfkS3D6oKNtI9dR6MAmVMpkydsxdrBIG+r7Eq+NPOPAuUMwAOsKzNBcU8K -clq0Jhp/gbzAlYoGh6gRn1xzON9aaUAhKa1GI1Z14emLEBhMe1Q= ------END CERTIFICATE----- diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.cert b/src/leap/soledad/u1db/tests/testing-certs/testing.cert deleted file mode 100644 index 985684fb..00000000 --- a/src/leap/soledad/u1db/tests/testing-certs/testing.cert +++ /dev/null @@ -1,61 +0,0 @@ -Certificate: - Data: - Version: 3 (0x2) - Serial Number: - e4:de:01:76:c4:78:78:7f - Signature Algorithm: sha1WithRSAEncryption - Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA - Validity - Not Before: May 3 11:11:14 2012 GMT - Not After : May 1 11:11:14 2023 GMT - Subject: O=u1db LOCAL TESTING ONLY, DO NOT TRUST, CN=localhost - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - Public-Key: (1024 bit) - Modulus: - 00:c6:1d:72:d3:c5:e4:fc:d1:4c:d9:e4:08:3e:90: - 10:ce:3f:1f:87:4a:1d:4f:7f:2a:5a:52:c9:65:4f: - d9:2c:bf:69:75:18:1a:b5:c9:09:32:00:47:f5:60: - aa:c6:dd:3a:87:37:5f:16:be:de:29:b5:ea:fc:41: - 7e:eb:77:bb:df:63:c3:06:1e:ed:e9:a0:67:1a:f1: - ec:e1:9d:f7:9c:8f:1c:fa:c3:66:7b:39:dc:70:ae: - 09:1b:9c:c0:9a:c4:90:77:45:8e:39:95:a9:2f:92: - 43:bd:27:07:5a:99:51:6e:76:a0:af:dd:b1:2c:8f: - ca:8b:8c:47:0d:f6:6e:fc:69 - Exponent: 65537 (0x10001) - X509v3 extensions: - X509v3 Basic Constraints: - CA:FALSE - Netscape Comment: - OpenSSL Generated Certificate - X509v3 Subject Key Identifier: - 1C:63:85:E1:1D:F3:89:2E:6C:4E:3F:FB:D0:10:64:5A:C1:22:6A:2A - X509v3 Authority Key Identifier: - keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D - - Signature Algorithm: sha1WithRSAEncryption - 1d:6d:3e:bd:93:fd:bd:3e:17:b8:9f:f0:99:7f:db:50:5c:b2: - 01:42:03:b5:d5:94:05:d3:f6:8e:80:82:55:47:1f:58:f2:18: - 6c:ab:ef:43:2c:2f:10:e1:7c:c4:5c:cc:ac:50:50:22:42:aa: - 35:33:f5:b9:f3:a6:66:55:d9:36:f4:f2:e4:d4:d9:b5:2c:52: - 66:d4:21:17:97:22:b8:9b:d7:0e:7c:3d:ce:85:19:ca:c4:d2: - 58:62:31:c6:18:3e:44:fc:f4:30:b6:95:87:ee:21:4a:08:f0: - af:3c:8f:c4:ba:5e:a1:5c:37:1a:7d:7b:fe:66:ae:62:50:17: - 31:ca ------BEGIN CERTIFICATE----- -MIICnzCCAgigAwIBAgIJAOTeAXbEeHh/MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV -BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg -T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x -MjA1MDMxMTExMTRaFw0yMzA1MDExMTExMTRaMEQxLjAsBgNVBAoMJXUxZGIgTE9D -QUwgVEVTVElORyBPTkxZLCBETyBOT1QgVFJVU1QxEjAQBgNVBAMMCWxvY2FsaG9z -dDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAxh1y08Xk/NFM2eQIPpAQzj8f -h0odT38qWlLJZU/ZLL9pdRgatckJMgBH9WCqxt06hzdfFr7eKbXq/EF+63e732PD -Bh7t6aBnGvHs4Z33nI8c+sNmeznccK4JG5zAmsSQd0WOOZWpL5JDvScHWplRbnag -r92xLI/Ki4xHDfZu/GkCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0E -HxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFBxjheEd -84kubE4/+9AQZFrBImoqMB8GA1UdIwQYMBaAFNs9k1FsMhVUjxBQ/ElPNhUou5Vt -MA0GCSqGSIb3DQEBBQUAA4GBAB1tPr2T/b0+F7if8Jl/21BcsgFCA7XVlAXT9o6A -glVHH1jyGGyr70MsLxDhfMRczKxQUCJCqjUz9bnzpmZV2Tb08uTU2bUsUmbUIReX -Irib1w58Pc6FGcrE0lhiMcYYPkT89DC2lYfuIUoI8K88j8S6XqFcNxp9e/5mrmJQ -FzHK ------END CERTIFICATE----- diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.key b/src/leap/soledad/u1db/tests/testing-certs/testing.key deleted file mode 100644 index d83d4920..00000000 --- a/src/leap/soledad/u1db/tests/testing-certs/testing.key +++ /dev/null @@ -1,16 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMYdctPF5PzRTNnk -CD6QEM4/H4dKHU9/KlpSyWVP2Sy/aXUYGrXJCTIAR/VgqsbdOoc3Xxa+3im16vxB -fut3u99jwwYe7emgZxrx7OGd95yPHPrDZns53HCuCRucwJrEkHdFjjmVqS+SQ70n -B1qZUW52oK/dsSyPyouMRw32bvxpAgMBAAECgYBs3lXxhjg1rhabTjIxnx19GTcM -M3Az9V+izweZQu3HJ1CeZiaXauhAr+LbNsniCkRVddotN6oCJdQB10QVxXBZc9Jz -HPJ4zxtZfRZlNMTMmG7eLWrfxpgWnb/BUjDb40yy1nhr9yhDUnI/8RoHDRHnAEHZ -/CnHGUrqcVcrY5zJAQJBAPLhBJg9W88JVmcOKdWxRgs7dLHnZb999Kv1V5mczmAi -jvGvbUmucqOqke6pTUHNYyNHqU6pySzGUi2cH+BAkFECQQDQ0VoAOysg6FVoT15v -tGh57t5sTiCZZ7PS8jwvtThsgA+vcf6c16XWzXgjGXSap4r2QDOY2rI5lsWLaQ8T -+fyZAkAfyFJRmbXp4c7srW3MCOahkaYzoZQu+syJtBFCiMJ40gzik5I5khpuUGPI -V19EvRu8AiSlppIsycb3MPb64XgBAkEAy7DrUf5le5wmc7G4NM6OeyJ+5LbxJbL6 -vnJ8My1a9LuWkVVpQCU7J+UVo2dZTuLPspW9vwTVhUeFOxAoHRxlQQJAFem93f7m -el2BkB2EFqU3onPejkZ5UrDmfmeOQR1axMQNSXqSxcJxqa16Ru1BWV2gcWRbwajQ -oc+kuJThu/r/Ug== ------END PRIVATE KEY----- diff --git a/src/leap/soledad/u1db/vectorclock.py b/src/leap/soledad/u1db/vectorclock.py deleted file mode 100644 index 42bceaa8..00000000 --- a/src/leap/soledad/u1db/vectorclock.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2011 Canonical Ltd. -# -# This file is part of u1db. -# -# u1db is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# u1db is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with u1db. If not, see . - -"""VectorClockRev helper class.""" - - -class VectorClockRev(object): - """Track vector clocks for multiple replica ids. - - This allows simple comparison to determine if one VectorClockRev is - newer/older/in-conflict-with another VectorClockRev without having to - examine history. Every replica has a strictly increasing revision. When - creating a new revision, they include all revisions for all other replicas - which the new revision dominates, and increment their own revision to - something greater than the current value. - """ - - def __init__(self, value): - self._values = self._expand(value) - - def __repr__(self): - s = self.as_str() - return '%s(%s)' % (self.__class__.__name__, s) - - def as_str(self): - s = '|'.join(['%s:%d' % (m, r) for m, r - in sorted(self._values.items())]) - return s - - def _expand(self, value): - result = {} - if value is None: - return result - for replica_info in value.split('|'): - replica_uid, counter = replica_info.split(':') - counter = int(counter) - result[replica_uid] = counter - return result - - def is_newer(self, other): - """Is this VectorClockRev strictly newer than other. - """ - if not self._values: - return False - if not other._values: - return True - this_is_newer = False - other_expand = dict(other._values) - for key, value in self._values.iteritems(): - if key in other_expand: - other_value = other_expand.pop(key) - if other_value > value: - return False - elif other_value < value: - this_is_newer = True - else: - this_is_newer = True - if other_expand: - return False - return this_is_newer - - def increment(self, replica_uid): - """Increase the 'replica_uid' section of this vector clock. - - :return: A string representing the new vector clock value - """ - self._values[replica_uid] = self._values.get(replica_uid, 0) + 1 - - def maximize(self, other_vcr): - for replica_uid, counter in other_vcr._values.iteritems(): - if replica_uid not in self._values: - self._values[replica_uid] = counter - else: - this_counter = self._values[replica_uid] - if this_counter < counter: - self._values[replica_uid] = counter -- cgit v1.2.3 From b925c880a7d604e6f3ce437d17fdd8b1bb6cbae7 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:08:11 -0200 Subject: Add sqlcipher backend. --- src/leap/soledad/backends/sqlcipher.py | 954 +++++++++++++++++++++++++++++++++ 1 file changed, 954 insertions(+) create mode 100644 src/leap/soledad/backends/sqlcipher.py (limited to 'src/leap') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py new file mode 100644 index 00000000..24f47eed --- /dev/null +++ b/src/leap/soledad/backends/sqlcipher.py @@ -0,0 +1,954 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""A U1DB implementation that uses SQLCipher as its persistence layer.""" + +import errno +import os +try: + import simplejson as json +except ImportError: + import json # noqa +from sqlite3 import dbapi2 +import sys +import time +import uuid + +import pkg_resources + +from u1db.backends import CommonBackend, CommonSyncTarget +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) + + +def open(path, create, document_factory=None, password=None): + """Open a database at the given location. + + Will raise u1db.errors.DatabaseDoesNotExist if create=False and the + database does not already exist. + + :param path: The filesystem path for the database to open. + :param create: True/False, should the database be created if it doesn't + already exist? + :param document_factory: A function that will be called with the same + parameters as Document.__init__. + :return: An instance of Database. + """ + from u1db.backends import sqlite_backend + return sqlite_backend.SQLCipherDatabase.open_database( + path, create=create, document_factory=document_factory, password=password) + + +class SQLCipherDatabase(CommonBackend): + """A U1DB implementation that uses SQLCipher as its persistence layer.""" + + _sqlite_registry = {} + + @classmethod + def set_pragma_key(cls, db_handle, key): + db_handle.cursor().execute("PRAGMA key = '%s'" % key) + + def __init__(self, sqlite_file, document_factory=None, password=None): + """Create a new sqlite file.""" + self._db_handle = dbapi2.connect(sqlite_file) + if password: + SQLiteDatabase.set_pragma_key(self._db_handle, password) + self._real_replica_uid = None + self._ensure_schema() + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return SQLCipherSyncTarget(self) + + @classmethod + def _which_index_storage(cls, c): + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'index_storage'") + except dbapi2.OperationalError, e: + # The table does not exist yet + return None, e + else: + return c.fetchone()[0], None + + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + + @classmethod + def _open_database(cls, sqlite_file, document_factory=None, password=None): + if not os.path.isfile(sqlite_file): + raise errors.DatabaseDoesNotExist() + tries = 2 + while True: + # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + db_handle = dbapi2.connect(sqlite_file) + if password: + SQLiteDatabase.set_pragma_key(db_handle, password) + c = db_handle.cursor() + v, err = cls._which_index_storage(c) + db_handle.close() + if v is not None: + break + # possibly another process is initializing it, wait for it to be + # done + if tries == 0: + raise err # go for the richest error? + tries -= 1 + time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) + return SQLCipherDatabase._sqlite_registry[v]( + sqlite_file, document_factory=document_factory) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None, password=None): + try: + return cls._open_database(sqlite_file, + document_factory=document_factory, + password=password) + except errors.DatabaseDoesNotExist: + if not create: + raise + if backend_cls is None: + # default is SQLCipherPartialExpandDatabase + backend_cls = SQLCipherPartialExpandDatabase + return backend_cls(sqlite_file, document_factory=document_factory, + password=password) + + @staticmethod + def delete_database(sqlite_file): + try: + os.unlink(sqlite_file) + except OSError as ex: + if ex.errno == errno.ENOENT: + raise errors.DatabaseDoesNotExist() + raise + + @staticmethod + def register_implementation(klass): + """Register that we implement an SQLCipherDatabase. + + The attribute _index_storage_value will be used as the lookup key. + """ + SQLCipherDatabase._sqlite_registry[klass._index_storage_value] = klass + + def _get_sqlite_handle(self): + """Get access to the underlying sqlite database. + + This should only be used by the test suite, etc, for examining the + state of the underlying database. + """ + return self._db_handle + + def _close_sqlite_handle(self): + """Release access to the underlying sqlite database.""" + self._db_handle.close() + + def close(self): + self._close_sqlite_handle() + + def _is_initialized(self, c): + """Check if this database has been initialized.""" + c.execute("PRAGMA case_sensitive_like=ON") + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'sql_schema'") + except dbapi2.OperationalError: + # The table does not exist yet + val = None + else: + val = c.fetchone() + if val is not None: + return True + return False + + def _initialize(self, c): + """Create the schema in the database.""" + #read the script with sql commands + # TODO: Change how we set up the dependency. Most likely use something + # like lp:dirspec to grab the file from a common resource + # directory. Doesn't specifically need to be handled until we get + # to the point of packaging this. + schema_content = pkg_resources.resource_string( + __name__, 'dbschema.sql') + # Note: We'd like to use c.executescript() here, but it seems that + # executescript always commits, even if you set + # isolation_level = None, so if we want to properly handle + # exclusive locking and rollbacks between processes, we need + # to execute it line-by-line + for line in schema_content.split(';'): + if not line: + continue + c.execute(line) + #add extra fields + self._extra_schema_init(c) + # A unique identifier should be set for this replica. Implementations + # don't have to strictly use uuid here, but we do want the uid to be + # unique amongst all databases that will sync with each other. + # We might extend this to using something with hostname for easier + # debugging. + self._set_replica_uid_in_transaction(uuid.uuid4().hex) + c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", + (self._index_storage_value,)) + + def _ensure_schema(self): + """Ensure that the database schema has been created.""" + old_isolation_level = self._db_handle.isolation_level + c = self._db_handle.cursor() + if self._is_initialized(c): + return + try: + # autocommit/own mgmt of transactions + self._db_handle.isolation_level = None + with self._db_handle: + # only one execution path should initialize the db + c.execute("begin exclusive") + if self._is_initialized(c): + return + self._initialize(c) + finally: + self._db_handle.isolation_level = old_isolation_level + + def _extra_schema_init(self, c): + """Add any extra fields, etc to the basic table definitions.""" + + def _parse_index_definition(self, index_field): + """Parse a field definition for an index, returning a Getter.""" + # Note: We may want to keep a Parser object around, and cache the + # Getter objects for a greater length of time. Specifically, if + # you create a bunch of indexes, and then insert 50k docs, you'll + # re-parse the indexes between puts. The time to insert the docs + # is still likely to dominate put_doc time, though. + parser = query_parser.Parser() + getter = parser.parse(index_field) + return getter + + def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): + """Update document_fields for a single document. + + :param doc_id: Identifier for this document + :param raw_doc: The python dict representation of the document. + :param getters: A list of [(field_name, Getter)]. Getter.get will be + called to evaluate the index definition for this document, and the + results will be inserted into the db. + :param db_cursor: An sqlite Cursor. + :return: None + """ + values = [] + for field_name, getter in getters: + for idx_value in getter.get(raw_doc): + values.append((doc_id, field_name, idx_value)) + if values: + db_cursor.executemany( + "INSERT INTO document_fields VALUES (?, ?, ?)", values) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + with self._db_handle: + self._set_replica_uid_in_transaction(replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO u1db_config" + " VALUES ('replica_uid', ?)", + (replica_uid,)) + self._real_replica_uid = replica_uid + + def _get_replica_uid(self): + if self._real_replica_uid is not None: + return self._real_replica_uid + c = self._db_handle.cursor() + c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") + val = c.fetchone() + if val is None: + return None + self._real_replica_uid = val[0] + return self._real_replica_uid + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + c = self._db_handle.cursor() + c.execute('SELECT max(generation) FROM transaction_log') + val = c.fetchone()[0] + if val is None: + return 0 + return val + + def _get_generation_info(self): + c = self._db_handle.cursor() + c.execute( + 'SELECT max(generation), transaction_id FROM transaction_log ') + val = c.fetchone() + if val[0] is None: + return(0, '') + return val + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + c = self._db_handle.cursor() + c.execute( + 'SELECT transaction_id FROM transaction_log WHERE generation = ?', + (generation,)) + val = c.fetchone() + if val is None: + raise errors.InvalidGeneration + return val[0] + + def _get_transaction_log(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, transaction_id FROM transaction_log" + " ORDER BY generation") + return c.fetchall() + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + c = self._db_handle.cursor() + if check_for_conflicts: + c.execute( + "SELECT document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " + "conflicts ON conflicts.doc_id = document.doc_id WHERE " + "document.doc_id = ? GROUP BY document.doc_id, " + "document.doc_rev, document.content;", (doc_id,)) + else: + c.execute( + "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", + (doc_id,)) + val = c.fetchone() + if val is None: + return None + doc_rev, content, conflicts = val + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + return doc + + def _has_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", + (doc_id,)) + val = c.fetchone() + if val is None: + return False + else: + return True + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + c = self._db_handle.cursor() + c.execute( + "SELECT document.doc_id, document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " + "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " + "document.doc_rev, document.content;") + rows = c.fetchall() + for doc_id, doc_rev, content, conflicts in rows: + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): + """Convert a dict representation into named fields. + + So something like: {'key1': 'val1', 'key2': 'val2'} + gets converted into: [(doc_id, 'key1', 'val1', 0) + (doc_id, 'key2', 'val2', 0)] + :param doc_id: Just added to every record. + :param base_field: if set, these are nested keys, so each field should + be appropriately prefixed. + :param raw_doc: The python dictionary. + """ + # TODO: Handle lists + values = [] + for field_name, value in raw_doc.iteritems(): + if value is None and not save_none: + continue + if base_field: + full_name = base_field + '.' + field_name + else: + full_name = field_name + if value is None or isinstance(value, (int, float, basestring)): + values.append((doc_id, full_name, value, len(values))) + else: + subvalues = self._expand_to_fields(doc_id, full_name, value, + save_none) + for _, subfield_name, val, _ in subvalues: + values.append((doc_id, subfield_name, val, len(values))) + return values + + def _put_and_update_indexes(self, old_doc, doc): + """Actually insert a document into the database. + + This both updates the existing documents content, and any indexes that + refer to this document. + """ + raise NotImplementedError(self._put_and_update_indexes) + + def whats_changed(self, old_generation=0): + c = self._db_handle.cursor() + c.execute("SELECT generation, doc_id, transaction_id" + " FROM transaction_log" + " WHERE generation > ? ORDER BY generation DESC", + (old_generation,)) + results = c.fetchall() + cur_gen = old_generation + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + c.execute("SELECT generation, transaction_id" + " FROM transaction_log ORDER BY generation DESC LIMIT 1") + results = c.fetchone() + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, newest_trans_id = results + + return cur_gen, newest_trans_id, changes + + def delete_doc(self, doc): + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _get_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", + (doc_id,)) + return [self._factory(doc_id, doc_rev, content) + for doc_rev, content in c.fetchall()] + + def get_doc_conflicts(self, doc_id): + with self._db_handle: + conflict_docs = self._get_conflicts(doc_id) + if not conflict_docs: + return [] + this_doc = self._get_doc(doc_id) + this_doc.has_conflicts = True + return [this_doc] + conflict_docs + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + c = self._db_handle.cursor() + c.execute("SELECT known_generation, known_transaction_id FROM sync_log" + " WHERE replica_uid = ?", + (other_replica_uid,)) + val = c.fetchone() + if val is None: + other_gen = 0 + trans_id = '' + else: + other_gen = val[0] + trans_id = val[1] + return other_gen, trans_id + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + with self._db_handle: + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", + (other_replica_uid, other_generation, + other_transaction_id)) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + with self._db_handle: + return super(SQLCipherDatabase, self)._put_doc_if_newer(doc, + save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + + def _add_conflict(self, c, doc_id, my_doc_rev, my_content): + c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", + (doc_id, my_doc_rev, my_content)) + + def _delete_conflicts(self, c, doc, conflict_revs): + deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] + c.executemany("DELETE FROM conflicts" + " WHERE doc_id=? AND doc_rev=?", deleting) + doc.has_conflicts = self._has_conflicts(doc.doc_id) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + c_revs_to_prune = [] + for c_doc in self._get_conflicts(doc.doc_id): + c_vcr = vectorclock.VectorClockRev(c_doc.rev) + if doc_vcr.is_newer(c_vcr): + c_revs_to_prune.append(c_doc.rev) + elif doc.same_content_as(c_doc): + c_revs_to_prune.append(c_doc.rev) + doc_vcr.maximize(c_vcr) + autoresolved = True + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + c = self._db_handle.cursor() + self._delete_conflicts(c, doc, c_revs_to_prune) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + c = self._db_handle.cursor() + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + def resolve_doc(self, doc, conflicted_doc_revs): + with self._db_handle: + cur_doc = self._get_doc(doc.doc_id) + # TODO: https://bugs.launchpad.net/u1db/+bug/928274 + # I think we have a logic bug in resolve_doc + # Specifically, cur_doc.rev is always in the final vector + # clock of revisions that we supersede, even if it wasn't in + # conflicted_doc_revs. We still add it as a conflict, but the + # fact that _put_doc_if_newer propagates resolutions means I + # think that conflict could accidentally be resolved. We need + # to add a test for this case first. (create a rev, create a + # conflict, create another conflict, resolve the first rev + # and first conflict, then make sure that the resolved + # rev doesn't supersede the second conflict rev.) It *might* + # not matter, because the superseding rev is in as a + # conflict, but it does seem incorrect + new_rev = self._ensure_maximal_rev(cur_doc.rev, + conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + c = self._db_handle.cursor() + doc.rev = new_rev + if cur_doc.rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) + # TODO: Is there some way that we could construct a rev that would + # end up in superseded_revs, such that we add a conflict, and + # then immediately delete it? + self._delete_conflicts(c, doc, superseded_revs) + + def list_indexes(self): + """Return the list of indexes and their definitions.""" + c = self._db_handle.cursor() + # TODO: How do we test the ordering? + c.execute("SELECT name, field FROM index_definitions" + " ORDER BY name, offset") + definitions = [] + cur_name = None + for name, field in c.fetchall(): + if cur_name != name: + definitions.append((name, [])) + cur_name = name + definitions[-1][-1].append(field) + return definitions + + def _get_index_definition(self, index_name): + """Return the stored definition for a given index_name.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions" + " WHERE name = ? ORDER BY offset", (index_name,)) + fields = [x[0] for x in c.fetchall()] + if not fields: + raise errors.IndexDoesNotExist + return fields + + @staticmethod + def _strip_glob(value): + """Remove the trailing * from a value.""" + assert value[-1] == '*' + return value[:-1] + + def _format_query(self, definition, key_values): + # First, build the definition. We join the document_fields table + # against itself, as many times as the 'width' of our definition. + # We then do a query for each key_value, one-at-a-time. + # Note: All of these strings are static, we could cache them, etc. + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + wildcard_where = [novalue_where[i] + + (" AND d%d.value NOT NULL" % (i,)) + for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + like_where = [novalue_where[i] + + (" AND d%d.value GLOB ?" % (i,)) + for i in range(len(definition))] + is_wildcard = False + # Merge the lists together, so that: + # [field1, field2, field3], [val1, val2, val3] + # Becomes: + # (field1, val1, field2, val2, field3, val3) + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(exact_where[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_from_index(self, index_name, *key_values): + definition = self._get_index_definition(index_name) + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + statement, args = self._format_query(definition, key_values) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def _format_range_query(self, definition, start_value, end_value): + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + wildcard_where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + like_where = [ + novalue_where[i] + ( + " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in + range(len(definition))] + range_where_lower = [ + novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in + range(len(definition))] + range_where_upper = [ + novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in + range(len(definition))] + args = [] + where = [] + if start_value: + if isinstance(start_value, basestring): + start_value = (start_value,) + if len(start_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, start_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(self._strip_glob(value)) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(value) + if end_value: + if isinstance(end_value, basestring): + end_value = (end_value,) + if len(end_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, end_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(self._strip_glob(value)) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_upper[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + definition = self._get_index_definition(index_name) + statement, args = self._format_range_query( + definition, start_value, end_value) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def get_index_keys(self, index_name): + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + value_fields = ', '.join([ + 'd%d.value' % i for i in range(len(definition))]) + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + statement = ( + "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( + value_fields, ', '.join(tables), ' AND '.join(where), + value_fields)) + try: + c.execute(statement, tuple(definition)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) + return c.fetchall() + + def delete_index(self, index_name): + with self._db_handle: + c = self._db_handle.cursor() + c.execute("DELETE FROM index_definitions WHERE name = ?", + (index_name,)) + c.execute( + "DELETE FROM document_fields WHERE document_fields.field_name " + " NOT IN (SELECT field from index_definitions)") + + +class SQLCipherSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SQLCipherPartialExpandDatabase(SQLCipherDatabase): + """An SQLCipher Backend that expands documents into a document_field table. + + It stores the original document text in document.doc. For fields that are + indexed, the data goes into document_fields. + """ + + _index_storage_value = 'expand referenced' + + def _get_indexed_fields(self): + """Determine what fields are indexed.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions") + return set([x[0] for x in c.fetchall()]) + + def _evaluate_index(self, raw_doc, field): + parser = query_parser.Parser() + getter = parser.parse(field) + return getter.get(raw_doc) + + def _put_and_update_indexes(self, old_doc, doc): + c = self._db_handle.cursor() + if doc and not doc.is_tombstone(): + raw_doc = json.loads(doc.get_json()) + else: + raw_doc = {} + if old_doc is not None: + c.execute("UPDATE document SET doc_rev=?, content=?" + " WHERE doc_id = ?", + (doc.rev, doc.get_json(), doc.doc_id)) + c.execute("DELETE FROM document_fields WHERE doc_id = ?", + (doc.doc_id,)) + else: + c.execute("INSERT INTO document (doc_id, doc_rev, content)" + " VALUES (?, ?, ?)", + (doc.doc_id, doc.rev, doc.get_json())) + indexed_fields = self._get_indexed_fields() + if indexed_fields: + # It is expected that len(indexed_fields) is shorter than + # len(raw_doc) + getters = [(field, self._parse_index_definition(field)) + for field in indexed_fields] + self._update_indexes(doc.doc_id, raw_doc, getters, c) + trans_id = self._allocate_transaction_id() + c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" + " VALUES (?, ?)", (doc.doc_id, trans_id)) + + def create_index(self, index_name, *index_expressions): + with self._db_handle: + c = self._db_handle.cursor() + cur_fields = self._get_indexed_fields() + definition = [(index_name, idx, field) + for idx, field in enumerate(index_expressions)] + try: + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + definition) + except dbapi2.IntegrityError as e: + stored_def = self._get_index_definition(index_name) + if stored_def == [x[-1] for x in definition]: + return + raise errors.IndexNameTakenError, e, sys.exc_info()[2] + new_fields = set( + [f for f in index_expressions if f not in cur_fields]) + if new_fields: + self._update_all_indexes(new_fields) + + def _iter_all_docs(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, content FROM document") + while True: + next_rows = c.fetchmany() + if not next_rows: + break + for row in next_rows: + yield row + + def _update_all_indexes(self, new_fields): + """Iterate all the documents, and add content to document_fields. + + :param new_fields: The index definitions that need to be added. + """ + getters = [(field, self._parse_index_definition(field)) + for field in new_fields] + c = self._db_handle.cursor() + for doc_id, doc in self._iter_all_docs(): + if doc is None: + continue + raw_doc = json.loads(doc) + self._update_indexes(doc_id, raw_doc, getters, c) + +SQLCipherDatabase.register_implementation(SQLCipherPartialExpandDatabase) -- cgit v1.2.3 From 7cc7aee73fbf82b604988585e051da32b99dc70e Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:15:42 -0200 Subject: Move log classes so all backends can use them. --- src/leap/soledad/__init__.py | 131 +++++++++++++++++++++++++++++++++ src/leap/soledad/backends/openstack.py | 124 ------------------------------- src/leap/soledad/tests/__init__.py | 6 +- 3 files changed, 134 insertions(+), 127 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index b7082e53..7f742a89 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -41,3 +41,134 @@ class GPGWrapper(): def import_keys(self, data): return self.gpg.import_keys(data) + + +#---------------------------------------------------------------------------- +# u1db Transaction and Sync logs as JSON structures. +#---------------------------------------------------------------------------- + +class SimpleLog(object): + def __init__(self): + self._log = [] + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return self._log + + log = property( + _get_log, _set_log, doc="Log contents.") + + def append(self, msg): + self._log.append(msg) + + def reduce(self, func, initializer=None): + return reduce(func, self.log, initializer) + + def map(self, func): + return map(func, self.log) + + def filter(self, func): + return filter(func, self.log) + + +class TransactionLog(SimpleLog): + """ + An ordered list of (generation, doc_id, transaction_id) tuples. + """ + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return sorted(self._log, reverse=True) + + log = property( + _get_log, _set_log, doc="Log contents.") + + def get_generation(self): + """ + Return the current generation. + """ + gens = self.map(lambda x: x[0]) + if not gens: + return 0 + return max(gens) + + def get_generation_info(self): + """ + Return the current generation and transaction id. + """ + if not self._log: + return(0, '') + info = self.map(lambda x: (x[0], x[2])) + return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) + + def get_trans_id_for_gen(self, gen): + """ + Get the transaction id corresponding to a particular generation. + """ + log = self.reduce(lambda x, y: y if y[0] == gen else x) + if log is None: + return None + return log[2] + + def whats_changed(self, old_generation): + """ + Return a list of documents that have changed since old_generation. + """ + results = self.filter(lambda x: x[0] > old_generation) + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + results = self.log + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, _, newest_trans_id = results[0] + + return cur_gen, newest_trans_id, changes + + + +class SyncLog(SimpleLog): + """ + A list of (replica_id, generation, transaction_id) tuples. + """ + + def find_by_replica_uid(self, replica_uid): + if not self.log: + return () + return self.reduce(lambda x, y: y if y[0] == replica_uid else x) + + def get_replica_gen_and_trans_id(self, other_replica_uid): + """ + Return the last known generation and transaction id for the other db + replica. + """ + info = self.find_by_replica_uid(other_replica_uid) + if not info: + return (0, '') + return (info[1], info[2]) + + def set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """ + Set the last-known generation and transaction id for the other + database replica. + """ + self.log = self.filter(lambda x: x[0] != other_replica_uid) + self.append((other_replica_uid, other_generation, + other_transaction_id)) + diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index ec4609b4..6c971485 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -32,8 +32,6 @@ class OpenStackDatabase(CommonBackend): def whats_changed(self, old_generation=0): self._get_u1db_data() - # This method is implemented in TransactionLog because testing is - # easier like this for now, but it can be moved to here afterwards. return self._transaction_log.whats_changed(old_generation) def _get_doc(self, doc_id, check_for_conflicts=False): @@ -245,125 +243,3 @@ class OpenStackSyncTarget(HTTPSyncTarget): source_replica_transaction_id) -class SimpleLog(object): - def __init__(self): - self._log = [] - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return self._log - - log = property( - _get_log, _set_log, doc="Log contents.") - - def append(self, msg): - self._log.append(msg) - - def reduce(self, func, initializer=None): - return reduce(func, self.log, initializer) - - def map(self, func): - return map(func, self.log) - - def filter(self, func): - return filter(func, self.log) - - -class TransactionLog(SimpleLog): - """ - A list of (generation, doc_id, transaction_id) tuples. - """ - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return sorted(self._log, reverse=True) - - log = property( - _get_log, _set_log, doc="Log contents.") - - def get_generation(self): - """ - Return the current generation. - """ - gens = self.map(lambda x: x[0]) - if not gens: - return 0 - return max(gens) - - def get_generation_info(self): - """ - Return the current generation and transaction id. - """ - if not self._log: - return(0, '') - info = self.map(lambda x: (x[0], x[2])) - return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) - - def get_trans_id_for_gen(self, gen): - """ - Get the transaction id corresponding to a particular generation. - """ - log = self.reduce(lambda x, y: y if y[0] == gen else x) - if log is None: - return None - return log[2] - - def whats_changed(self, old_generation): - results = self.filter(lambda x: x[0] > old_generation) - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - results = self.log - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, _, newest_trans_id = results[0] - - return cur_gen, newest_trans_id, changes - - - -class SyncLog(SimpleLog): - """ - A list of (replica_id, generation, transaction_id) tuples. - """ - - def find_by_replica_uid(self, replica_uid): - if not self.log: - return () - return self.reduce(lambda x, y: y if y[0] == replica_uid else x) - - def get_replica_gen_and_trans_id(self, other_replica_uid): - """ - Return the last known generation and transaction id for the other db - replica. - """ - info = self.find_by_replica_uid(other_replica_uid) - if not info: - return (0, '') - return (info[1], info[2]) - - def set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - """ - Set the last-known generation and transaction id for the other - database replica. - """ - self.log = self.filter(lambda x: x[0] != other_replica_uid) - self.append((other_replica_uid, other_generation, - other_transaction_id)) - diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index 8e0a5c52..b6585755 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -7,13 +7,13 @@ import unittest import os import u1db -from soledad import GPGWrapper -from soledad.backends import leap -from soledad.backends.openstack import ( +from soledad import ( + GPGWrapper, SimpleLog, TransactionLog, SyncLog, ) +from soledad.backends import leap class EncryptedSyncTestCase(unittest.TestCase): -- cgit v1.2.3 From 722de6750a2a2de2b55ab30991447bb792de11cd Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 6 Dec 2012 11:19:12 -0200 Subject: Fix dependencies version info on README --- src/leap/soledad/README | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README index de524672..894ce6af 100644 --- a/src/leap/soledad/README +++ b/src/leap/soledad/README @@ -1,7 +1,7 @@ Soledad -- Synchronization Of Locally Encrypted Data Among Devices ================================================================== -This software is under development, many parts of the code are still untested. +This software is under development. Dependencies ------------ @@ -9,11 +9,9 @@ Dependencies Soledad depends on the following python libraries: * u1db 0.1.4 [1] - * python-swiftclient 1.1.1 [2] + * python-swiftclient 1.2.0 [2] * python-gnupg 0.3.1 [3] [1] http://pypi.python.org/pypi/u1db/0.1.4 -[2] https://launchpad.net/python-swiftclient -[3] http://packages.python.org/python-gnupg/index.html - -Right now, all these libs +[2] http://pypi.python.org/pypi/python-swiftclient/1.2.0 +[3] http://pypi.python.org/pypi/python-gnupg/0.3.1 -- cgit v1.2.3 From f89f2e0fe490899ecc4baf3395f3441111da328f Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 11:00:10 -0200 Subject: Refactor to add ObjectStore class. --- src/leap/soledad/__init__.py | 2 +- src/leap/soledad/backends/objectstore.py | 153 +++++++++++++++++++++++++++++++ src/leap/soledad/backends/openstack.py | 143 +---------------------------- 3 files changed, 157 insertions(+), 141 deletions(-) create mode 100644 src/leap/soledad/backends/objectstore.py (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 7f742a89..78f1f768 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -44,7 +44,7 @@ class GPGWrapper(): #---------------------------------------------------------------------------- -# u1db Transaction and Sync logs as JSON structures. +# u1db Transaction and Sync logs. #---------------------------------------------------------------------------- class SimpleLog(object): diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py new file mode 100644 index 00000000..e36df72d --- /dev/null +++ b/src/leap/soledad/backends/objectstore.py @@ -0,0 +1,153 @@ +from u1db.backends import CommonBackend + + +class ObjectStore(CommonBackend): + + def __init__(self): + self._sync_log = SyncLog() + self._transaction_log = TransactionLog() + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def set_document_factory(self, factory): + self._factory = factory + + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + self._get_u1db_data() + return self._transaction_log.whats_changed(old_generation) + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def delete_doc(self, doc): + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_doc(olddoc) + return new_rev + + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + + def get_doc_conflicts(self, doc_id): + return [] + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + self._get_u1db_data() + return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._get_u1db_data() + self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, + other_generation, + other_transaction_id) + self._set_u1db_data() + + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- + + def _get_generation(self): + self._get_u1db_data() + return self._transaction_log.get_generation() + + def _get_generation_info(self): + self._get_u1db_data() + return self._transaction_log.get_generation_info() + + def _has_conflicts(self, doc_id): + # Documents never have conflicts on server. + return False + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + self._get_u1db_data() + trans_id = self._transaction_log.get_trans_id_for_gen(generation) + if trans_id is None: + raise errors.InvalidGeneration + return trans_id + + def _ensure_u1db_data(self): + """ + Guarantee that u1db data exists in store. + """ + if not self._is_initialized(): + self._initialize() + u1db_data = self._get_doc('u1db_data') + self._sync_log.log = u1db_data.content['sync_log'] + self._transaction_log.log = u1db_data.content['transaction_log'] + + def _is_initialized(self): + """ + Verify if u1db data exists in store. + """ + if not self._get_doc('u1db_data'): + return False + return True + + def _initialize(self): + """ + Create u1db data object in store. + """ + content = { 'transaction_log' : [], + 'sync_log' : [] } + doc = self.create_doc('u1db_data', content) + + def _get_u1db_data(self): + data = self.get_doc('u1db_data').content + self._transaction_log = data['transaction_log'] + self._sync_log = data['sync_log'] + + def _set_u1db_data(self): + doc = self._factory('u1db_data') + doc.content = { 'transaction_log' : self._transaction_log, + 'sync_log' : self._sync_log } + self.put_doc(doc) + + diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index 6c971485..f8563d81 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -1,15 +1,16 @@ -from leap import * from u1db import errors from u1db.backends import CommonBackend from u1db.remote.http_target import HTTPSyncTarget from swiftclient import client +from soledad.backends.objectstore import ObjectStore -class OpenStackDatabase(CommonBackend): +class OpenStackDatabase(ObjectStore): """A U1DB implementation that uses OpenStack as its persistence layer.""" def __init__(self, auth_url, user, auth_key, container): """Create a new OpenStack data container.""" + super(OpenStackDatabase, self) self._auth_url = auth_url self._user = user self._auth_key = auth_key @@ -24,16 +25,6 @@ class OpenStackDatabase(CommonBackend): # implemented methods from Database #------------------------------------------------------------------------- - def set_document_factory(self, factory): - self._factory = factory - - def set_document_size_limit(self, limit): - raise NotImplementedError(self.set_document_size_limit) - - def whats_changed(self, old_generation=0): - self._get_u1db_data() - return self._transaction_log.whats_changed(old_generation) - def _get_doc(self, doc_id, check_for_conflicts=False): """Get just the document content, without fancy handling. @@ -47,14 +38,6 @@ class OpenStackDatabase(CommonBackend): except swiftclient.ClientException: return None - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - def get_all_docs(self, include_deleted=False): """Get all documents from the database.""" generation = self._get_generation() @@ -84,51 +67,6 @@ class OpenStackDatabase(CommonBackend): self._set_u1db_data() return new_rev - def delete_doc(self, doc): - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_doc(olddoc) - return new_rev - - # start of index-related methods: these are not supported by this backend. - - def create_index(self, index_name, *index_expressions): - return False - - def delete_index(self, index_name): - return False - - def list_indexes(self): - return [] - - def get_from_index(self, index_name, *key_values): - return [] - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - return [] - - def get_index_keys(self, index_name): - return [] - - # end of index-related methods: these are not supported by this backend. - - def get_doc_conflicts(self, doc_id): - return [] - - def resolve_doc(self, doc, conflicted_doc_revs): - raise NotImplementedError(self.resolve_doc) - def get_sync_target(self): return OpenStackSyncTarget(self) @@ -141,89 +79,14 @@ class OpenStackDatabase(CommonBackend): return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync( autocreate=autocreate) - def _get_replica_gen_and_trans_id(self, other_replica_uid): - self._get_u1db_data() - return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid) - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - self._get_u1db_data() - self._sync_log.set_replica_gen_and_trans_id(other_replica_uid, - other_generation, - other_transaction_id) - self._set_u1db_data() - - #------------------------------------------------------------------------- - # implemented methods from CommonBackend - #------------------------------------------------------------------------- - - def _get_generation(self): - self._get_u1db_data() - return self._transaction_log.get_generation() - - def _get_generation_info(self): - self._get_u1db_data() - return self._transaction_log.get_generation_info() - - def _has_conflicts(self, doc_id): - # Documents never have conflicts on server. - return False - - def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): - raise NotImplementedError(self._put_and_update_indexes) - - - def _get_trans_id_for_gen(self, generation): - self._get_u1db_data() - trans_id = self._transaction_log.get_trans_id_for_gen(generation) - if trans_id is None: - raise errors.InvalidGeneration - return trans_id - #------------------------------------------------------------------------- # OpenStack specific methods #------------------------------------------------------------------------- - def _ensure_u1db_data(self): - """ - Guarantee that u1db data exists in store. - """ - if self._is_initialized(): - return - self._initialize() - - def _is_initialized(self): - """ - Verify if u1db data exists in store. - """ - if not self._get_doc('u1db_data'): - return False - return True - - def _initialize(self): - """ - Create u1db data object in store. - """ - content = { 'transaction_log' : [], - 'sync_log' : [] } - doc = self.create_doc('u1db_data', content) - def _get_auth(self): self._url, self._auth_token = self._connection.get_auth() return self._url, self.auth_token - def _get_u1db_data(self): - data = self.get_doc('u1db_data').content - self._transaction_log = data['transaction_log'] - self._sync_log = data['sync_log'] - - def _set_u1db_data(self): - doc = self._factory('u1db_data') - doc.content = { 'transaction_log' : self._transaction_log, - 'sync_log' : self._sync_log } - self.put_doc(doc) - - class OpenStackSyncTarget(HTTPSyncTarget): def get_sync_info(self, source_replica_uid): -- cgit v1.2.3 From b3090f710e3777bad2a9f996444e5099883c9f03 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 12:05:31 -0200 Subject: Add CouchDB u1db backend. --- src/leap/soledad/README | 2 + src/leap/soledad/__init__.py | 5 +- src/leap/soledad/backends/couchdb.py | 97 ++++++++++++++++++++++++++++++++ src/leap/soledad/backends/objectstore.py | 26 +++++++++ src/leap/soledad/backends/openstack.py | 20 ++----- 5 files changed, 131 insertions(+), 19 deletions(-) create mode 100644 src/leap/soledad/backends/couchdb.py (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README index 894ce6af..97976b01 100644 --- a/src/leap/soledad/README +++ b/src/leap/soledad/README @@ -11,7 +11,9 @@ Soledad depends on the following python libraries: * u1db 0.1.4 [1] * python-swiftclient 1.2.0 [2] * python-gnupg 0.3.1 [3] + * CouchDB 0.8 [4] [1] http://pypi.python.org/pypi/u1db/0.1.4 [2] http://pypi.python.org/pypi/python-swiftclient/1.2.0 [3] http://pypi.python.org/pypi/python-gnupg/0.3.1 +[4] http://pypi.python.org/pypi/CouchDB/0.8 diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 78f1f768..d07567b5 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -1,9 +1,6 @@ # License? -"""A U1DB implementation that uses OpenStack Swift as its persistence layer.""" - -from backends.leap import * -from backends.openstack import * +"""A U1DB implementation for using Object Stores as its persistence layer.""" import gnupg diff --git a/src/leap/soledad/backends/couchdb.py b/src/leap/soledad/backends/couchdb.py new file mode 100644 index 00000000..89b713f9 --- /dev/null +++ b/src/leap/soledad/backends/couchdb.py @@ -0,0 +1,97 @@ +from u1db import errors +from u1db.remote.http_target import HTTPSyncTarget +from couchdb import * +from soledad.backends.objectstore import ObjectStore + + +class CouchDatabase(ObjectStore): + """A U1DB implementation that uses Couch as its persistence layer.""" + + def __init__(self, url, database, full_commit=True, session=None): + """Create a new Couch data container.""" + self._url = url + self._full_commit = full_commit + self._session = session + self._server = couchdb.Server(url=self._url, + full_commit=self._full_commit, + session=self._session) + # this will ensure that transaction and sync logs exist and are + # up-to-date. + super(CouchDatabase, self) + self._database = self._server[database] + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ + cdoc = self._database.get(doc_id) + if cdoc is not None: + content = {} + for key, value in content: + if not key in ['_id', '_rev', '_u1db_rev']: + content[key] = value + doc = self._factory(doc_id=doc_id, rev=cdoc['_u1db_rev']) + doc.content = content + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + for doc_id in self._database: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) + + def _put_doc(self, doc, new_rev): + # map u1db metadata to couch + content = doc.content + content['_id'] = doc.doc_id + content['_u1db_rev'] = new_rev + self._database.save(doc.content) + + def get_sync_target(self): + return CouchSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + from u1db.sync import Synchronizer + from u1db.remote.http_target import CouchSyncTarget + return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + #------------------------------------------------------------------------- + # Couch specific methods + #------------------------------------------------------------------------- + + # no specific methods so far. + +class CouchSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index e36df72d..456892b3 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,11 +1,17 @@ from u1db.backends import CommonBackend +from soledad import SyncLog, TransactionLog class ObjectStore(CommonBackend): def __init__(self): + # This initialization method should be called after the connection + # with the database is established, so it can ensure that u1db data is + # configured and up-to-date. + self.set_document_factory(LeapDocument) self._sync_log = SyncLog() self._transaction_log = TransactionLog() + self._ensure_u1db_data() #------------------------------------------------------------------------- # implemented methods from Database @@ -29,6 +35,26 @@ class ObjectStore(CommonBackend): return None return doc + def _put_doc(self, doc) + raise NotImplementedError(self._put_doc) + + def put_doc(self, doc) + # consistency check + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + # put the document + new_rev = self._allocate_doc_rev(doc.rev) + self._put_doc(doc, new_rev) + doc.rev = new_rev + # update u1db generation and logs + new_gen = self._get_generation() + 1 + trans_id = self._allocate_transaction_id() + self._transaction_log.append((new_gen, doc.doc_id, trans_id)) + self._set_u1db_data() + return new_rev + def delete_doc(self, doc): old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) if old_doc is None: diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index f8563d81..5f2a2771 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -1,5 +1,4 @@ from u1db import errors -from u1db.backends import CommonBackend from u1db.remote.http_target import HTTPSyncTarget from swiftclient import client from soledad.backends.objectstore import ObjectStore @@ -10,16 +9,15 @@ class OpenStackDatabase(ObjectStore): def __init__(self, auth_url, user, auth_key, container): """Create a new OpenStack data container.""" - super(OpenStackDatabase, self) self._auth_url = auth_url self._user = user self._auth_key = auth_key self._container = container - self.set_document_factory(LeapDocument) self._connection = swiftclient.Connection(self._auth_url, self._user, self._auth_key) self._get_auth() - self._ensure_u1db_data() + # this will ensure transaction and sync logs exist and are up-to-date. + super(OpenStackDatabase, self) #------------------------------------------------------------------------- # implemented methods from Database @@ -33,6 +31,7 @@ class OpenStackDatabase(ObjectStore): """ try: response, contents = self._connection.get_object(self._container, doc_id) + # TODO: change revision to be a dictionary element? rev = response['x-object-meta-rev'] return self._factory(doc_id, rev, contents) except swiftclient.ClientException: @@ -51,21 +50,12 @@ class OpenStackDatabase(ObjectStore): results.append(doc) return (generation, results) - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - # TODO: check for conflicts? + def _put_doc(self, doc, new_rev): new_rev = self._allocate_doc_rev(doc.rev) + # TODO: change revision to be a dictionary element? headers = { 'X-Object-Meta-Rev' : new_rev } self._connection.put_object(self._container, doc_id, doc.get_json(), headers=headers) - new_gen = self._get_generation() + 1 - trans_id = self._allocate_transaction_id() - self._transaction_log.append((new_gen, doc.doc_id, trans_id)) - self._set_u1db_data() - return new_rev def get_sync_target(self): return OpenStackSyncTarget(self) -- cgit v1.2.3 From 817d4a1dab5cfce6228593ad61951e1593777eeb Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 14:43:08 -0200 Subject: Fix lack of collons on some methods. --- src/leap/soledad/backends/objectstore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index 456892b3..d9ab7cbd 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -35,10 +35,10 @@ class ObjectStore(CommonBackend): return None return doc - def _put_doc(self, doc) + def _put_doc(self, doc): raise NotImplementedError(self._put_doc) - def put_doc(self, doc) + def put_doc(self, doc): # consistency check if doc.doc_id is None: raise errors.InvalidDocId() -- cgit v1.2.3 From 002d2bfdbc4ca62733478524ec588cf0aa9f9383 Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 10 Dec 2012 18:39:56 -0200 Subject: CouchDB backend can put and get objects. --- src/leap/soledad/backends/couch.py | 115 +++++++++++++++++++++++++++++++ src/leap/soledad/backends/couchdb.py | 97 -------------------------- src/leap/soledad/backends/leap.py | 1 + src/leap/soledad/backends/objectstore.py | 43 +++++++----- src/leap/soledad/backends/openstack.py | 2 +- src/leap/soledad/tests/test_couchdb.py | 19 +++++ 6 files changed, 160 insertions(+), 117 deletions(-) create mode 100644 src/leap/soledad/backends/couch.py delete mode 100644 src/leap/soledad/backends/couchdb.py create mode 100644 src/leap/soledad/tests/test_couchdb.py (limited to 'src/leap') diff --git a/src/leap/soledad/backends/couch.py b/src/leap/soledad/backends/couch.py new file mode 100644 index 00000000..5586ea9c --- /dev/null +++ b/src/leap/soledad/backends/couch.py @@ -0,0 +1,115 @@ +from u1db import errors +from u1db.remote.http_target import HTTPSyncTarget +from couchdb.client import Server, Document +from couchdb.http import ResourceNotFound +from soledad.backends.objectstore import ObjectStore +from soledad.backends.leap import LeapDocument + + +class CouchDatabase(ObjectStore): + """A U1DB implementation that uses Couch as its persistence layer.""" + + def __init__(self, url, database, full_commit=True, session=None): + """Create a new Couch data container.""" + self._url = url + self._full_commit = full_commit + self._session = session + self._server = Server(url=self._url, + full_commit=self._full_commit, + session=self._session) + # this will ensure that transaction and sync logs exist and are + # up-to-date. + self.set_document_factory(LeapDocument) + try: + self._database = self._server[database] + except ResourceNotFound: + self._server.create(database) + self._database = self._server[database] + super(CouchDatabase, self).__init__() + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling. + + Conflicts do not happen on server side, so there's no need to check + for them. + """ + cdoc = self._database.get(doc_id) + if cdoc is None: + return None + content = {} + for (key, value) in cdoc.items(): + if key not in ['_id', '_rev', 'u1db_rev']: + content[key] = value + doc = self._factory(doc_id=doc_id, rev=cdoc['u1db_rev']) + doc.content = content + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + for doc_id in self._database: + doc = self._get_doc(doc_id) + if doc.content is None and not include_deleted: + continue + results.append(doc) + return (generation, results) + + def _put_doc(self, doc): + # map u1db metadata to couch + content = doc.content + cdoc = Document() + cdoc['_id'] = doc.doc_id + cdoc['u1db_rev'] = doc.rev + for (key, value) in content.items(): + cdoc[key] = value + self._database.save(cdoc) + + def get_sync_target(self): + return CouchSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + from u1db.sync import Synchronizer + from u1db.remote.http_target import CouchSyncTarget + return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_u1db_data(self): + cdoc = self._database.get(self.U1DB_DATA_DOC_ID) + self._sync_log.log = cdoc['sync_log'] + self._transaction_log.log = cdoc['transaction_log'] + self._replica_uid = cdoc['replica_uid'] + self._couch_rev = cdoc['_rev'] + + #------------------------------------------------------------------------- + # Couch specific methods + #------------------------------------------------------------------------- + + # no specific methods so far. + +class CouchSyncTarget(HTTPSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + diff --git a/src/leap/soledad/backends/couchdb.py b/src/leap/soledad/backends/couchdb.py deleted file mode 100644 index 89b713f9..00000000 --- a/src/leap/soledad/backends/couchdb.py +++ /dev/null @@ -1,97 +0,0 @@ -from u1db import errors -from u1db.remote.http_target import HTTPSyncTarget -from couchdb import * -from soledad.backends.objectstore import ObjectStore - - -class CouchDatabase(ObjectStore): - """A U1DB implementation that uses Couch as its persistence layer.""" - - def __init__(self, url, database, full_commit=True, session=None): - """Create a new Couch data container.""" - self._url = url - self._full_commit = full_commit - self._session = session - self._server = couchdb.Server(url=self._url, - full_commit=self._full_commit, - session=self._session) - # this will ensure that transaction and sync logs exist and are - # up-to-date. - super(CouchDatabase, self) - self._database = self._server[database] - - #------------------------------------------------------------------------- - # implemented methods from Database - #------------------------------------------------------------------------- - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling. - - Conflicts do not happen on server side, so there's no need to check - for them. - """ - cdoc = self._database.get(doc_id) - if cdoc is not None: - content = {} - for key, value in content: - if not key in ['_id', '_rev', '_u1db_rev']: - content[key] = value - doc = self._factory(doc_id=doc_id, rev=cdoc['_u1db_rev']) - doc.content = content - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - for doc_id in self._database: - doc = self._get_doc(doc_id) - if doc.content is None and not include_deleted: - continue - results.append(doc) - return (generation, results) - - def _put_doc(self, doc, new_rev): - # map u1db metadata to couch - content = doc.content - content['_id'] = doc.doc_id - content['_u1db_rev'] = new_rev - self._database.save(doc.content) - - def get_sync_target(self): - return CouchSyncTarget(self) - - def close(self): - raise NotImplementedError(self.close) - - def sync(self, url, creds=None, autocreate=True): - from u1db.sync import Synchronizer - from u1db.remote.http_target import CouchSyncTarget - return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync( - autocreate=autocreate) - - #------------------------------------------------------------------------- - # Couch specific methods - #------------------------------------------------------------------------- - - # no specific methods so far. - -class CouchSyncTarget(HTTPSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index 2c815632..ce00c8f3 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -7,6 +7,7 @@ from u1db import Document from u1db.remote.http_target import HTTPSyncTarget from u1db.remote.http_database import HTTPDatabase import base64 +from soledad import GPGWrapper class NoDefaultKey(Exception): diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index d9ab7cbd..5bd864c8 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,5 +1,7 @@ +import uuid from u1db.backends import CommonBackend from soledad import SyncLog, TransactionLog +from soledad.backends.leap import LeapDocument class ObjectStore(CommonBackend): @@ -45,15 +47,14 @@ class ObjectStore(CommonBackend): self._check_doc_id(doc.doc_id) self._check_doc_size(doc) # put the document - new_rev = self._allocate_doc_rev(doc.rev) - self._put_doc(doc, new_rev) - doc.rev = new_rev + doc.rev = self._allocate_doc_rev(doc.rev) + self._put_doc(doc) # update u1db generation and logs new_gen = self._get_generation() + 1 trans_id = self._allocate_transaction_id() self._transaction_log.append((new_gen, doc.doc_id, trans_id)) self._set_u1db_data() - return new_rev + return doc.rev def delete_doc(self, doc): old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) @@ -145,15 +146,16 @@ class ObjectStore(CommonBackend): """ if not self._is_initialized(): self._initialize() - u1db_data = self._get_doc('u1db_data') - self._sync_log.log = u1db_data.content['sync_log'] - self._transaction_log.log = u1db_data.content['transaction_log'] + self._get_u1db_data() + + U1DB_DATA_DOC_ID = 'u1db_data' def _is_initialized(self): """ Verify if u1db data exists in store. """ - if not self._get_doc('u1db_data'): + doc = self._get_doc(self.U1DB_DATA_DOC_ID) + if not self._get_doc(self.U1DB_DATA_DOC_ID): return False return True @@ -161,19 +163,22 @@ class ObjectStore(CommonBackend): """ Create u1db data object in store. """ - content = { 'transaction_log' : [], - 'sync_log' : [] } - doc = self.create_doc('u1db_data', content) + self._replica_uid = uuid.uuid4().hex + doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID) + doc.content = { 'transaction_log' : [], + 'sync_log' : [], + 'replica_uid' : self._replica_uid } + self._put_doc(doc) - def _get_u1db_data(self): - data = self.get_doc('u1db_data').content - self._transaction_log = data['transaction_log'] - self._sync_log = data['sync_log'] + def _get_u1db_data(self, u1db_data_doc_id): + NotImplementedError(self._get_u1db_data) def _set_u1db_data(self): - doc = self._factory('u1db_data') - doc.content = { 'transaction_log' : self._transaction_log, - 'sync_log' : self._sync_log } - self.put_doc(doc) + doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID) + doc.content = { 'transaction_log' : self._transaction_log.log, + 'sync_log' : self._sync_log.log, + 'replica_uid' : self._replica_uid, + '_rev' : self._couch_rev} + self._put_doc(doc) diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py index 5f2a2771..c027231c 100644 --- a/src/leap/soledad/backends/openstack.py +++ b/src/leap/soledad/backends/openstack.py @@ -17,7 +17,7 @@ class OpenStackDatabase(ObjectStore): self._auth_key) self._get_auth() # this will ensure transaction and sync logs exist and are up-to-date. - super(OpenStackDatabase, self) + super(OpenStackDatabase, self).__init__() #------------------------------------------------------------------------- # implemented methods from Database diff --git a/src/leap/soledad/tests/test_couchdb.py b/src/leap/soledad/tests/test_couchdb.py new file mode 100644 index 00000000..58285086 --- /dev/null +++ b/src/leap/soledad/tests/test_couchdb.py @@ -0,0 +1,19 @@ +import unittest +from soledad.backends.couch import CouchDatabase + +class CouchTestCase(unittest.TestCase): + + def setUp(self): + self._db = CouchDatabase('http://localhost:5984', 'u1db_tests') + + def test_create_get(self): + doc1 = self._db.create_doc({"key": "value"}, doc_id="testdoc") + doc2 = self._db.get_doc('testdoc') + self.assertEqual(doc1, doc2, 'error storing/retrieving document.') + self.assertEqual(self._db._get_generation(), 1) + + def tearDown(self): + self._db._server.delete('u1db_tests') + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3 From d5816c05136c9c018b8984b5f8a104c164676e9f Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 11:47:16 -0200 Subject: Fix ObjectStore's put_doc. --- src/leap/soledad/backends/objectstore.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index 5bd864c8..298bdda3 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,5 +1,6 @@ import uuid from u1db.backends import CommonBackend +from u1db import errors from soledad import SyncLog, TransactionLog from soledad.backends.leap import LeapDocument @@ -46,8 +47,21 @@ class ObjectStore(CommonBackend): raise errors.InvalidDocId() self._check_doc_id(doc.doc_id) self._check_doc_size(doc) - # put the document - doc.rev = self._allocate_doc_rev(doc.rev) + # check if document exists + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev self._put_doc(doc) # update u1db generation and logs new_gen = self._get_generation() + 1 @@ -69,7 +83,7 @@ class ObjectStore(CommonBackend): new_rev = self._allocate_doc_rev(doc.rev) doc.rev = new_rev doc.make_tombstone() - self._put_doc(olddoc) + self._put_doc(doc) return new_rev # start of index-related methods: these are not supported by this backend. @@ -171,9 +185,15 @@ class ObjectStore(CommonBackend): self._put_doc(doc) def _get_u1db_data(self, u1db_data_doc_id): + """ + Fetch u1db configuration data from backend storage. + """ NotImplementedError(self._get_u1db_data) def _set_u1db_data(self): + """ + Save u1db configuration data on backend storage. + """ doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID) doc.content = { 'transaction_log' : self._transaction_log.log, 'sync_log' : self._sync_log.log, -- cgit v1.2.3 From 703224c26e868546d37e9850db75747df1f92348 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 11:47:38 -0200 Subject: Store u1db contents in couch as json string. --- src/leap/soledad/backends/couch.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/couch.py b/src/leap/soledad/backends/couch.py index 5586ea9c..ed356fdd 100644 --- a/src/leap/soledad/backends/couch.py +++ b/src/leap/soledad/backends/couch.py @@ -5,6 +5,11 @@ from couchdb.http import ResourceNotFound from soledad.backends.objectstore import ObjectStore from soledad.backends.leap import LeapDocument +try: + import simplejson as json +except ImportError: + import json # noqa + class CouchDatabase(ObjectStore): """A U1DB implementation that uses Couch as its persistence layer.""" @@ -40,12 +45,11 @@ class CouchDatabase(ObjectStore): cdoc = self._database.get(doc_id) if cdoc is None: return None - content = {} - for (key, value) in cdoc.items(): - if key not in ['_id', '_rev', 'u1db_rev']: - content[key] = value doc = self._factory(doc_id=doc_id, rev=cdoc['u1db_rev']) - doc.content = content + if cdoc['u1db_json'] is not None: + doc.content = json.loads(cdoc['u1db_json']) + else: + doc.make_tombstone() return doc def get_all_docs(self, include_deleted=False): @@ -60,13 +64,20 @@ class CouchDatabase(ObjectStore): return (generation, results) def _put_doc(self, doc): - # map u1db metadata to couch - content = doc.content + # prepare couch's Document cdoc = Document() cdoc['_id'] = doc.doc_id + # we have to guarantee that couch's _rev is cosistent + old_cdoc = self._database.get(doc.doc_id) + if old_cdoc is not None: + cdoc['_rev'] = old_cdoc['_rev'] + # store u1db's rev cdoc['u1db_rev'] = doc.rev - for (key, value) in content.items(): - cdoc[key] = value + # store u1db's content as json string + if not doc.is_tombstone(): + cdoc['u1db_json'] = doc.get_json() + else: + cdoc['u1db_json'] = None self._database.save(cdoc) def get_sync_target(self): @@ -83,9 +94,10 @@ class CouchDatabase(ObjectStore): def _get_u1db_data(self): cdoc = self._database.get(self.U1DB_DATA_DOC_ID) - self._sync_log.log = cdoc['sync_log'] - self._transaction_log.log = cdoc['transaction_log'] - self._replica_uid = cdoc['replica_uid'] + content = json.loads(cdoc['u1db_json']) + self._sync_log.log = content['sync_log'] + self._transaction_log.log = content['transaction_log'] + self._replica_uid = content['replica_uid'] self._couch_rev = cdoc['_rev'] #------------------------------------------------------------------------- -- cgit v1.2.3 From 45908d847d09336d685dd38b698441a92570861e Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 11:48:15 -0200 Subject: Add basic tests for Couch backend. --- src/leap/soledad/tests/test_couchdb.py | 281 +++++++++++++++++++++++++++++++-- 1 file changed, 271 insertions(+), 10 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/tests/test_couchdb.py b/src/leap/soledad/tests/test_couchdb.py index 58285086..4468ae04 100644 --- a/src/leap/soledad/tests/test_couchdb.py +++ b/src/leap/soledad/tests/test_couchdb.py @@ -1,19 +1,280 @@ -import unittest +import unittest2 from soledad.backends.couch import CouchDatabase +from soledad.backends.leap import LeapDocument +from u1db import errors, vectorclock -class CouchTestCase(unittest.TestCase): +try: + import simplejson as json +except ImportError: + import json # noqa + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + +def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): + return LeapDocument(doc_id, rev, content, has_conflicts=has_conflicts) + +class CouchTestCase(unittest2.TestCase): def setUp(self): - self._db = CouchDatabase('http://localhost:5984', 'u1db_tests') + self.db = CouchDatabase('http://localhost:5984', 'u1db_tests') + + def make_document(self, doc_id, doc_rev, content, has_conflicts=False): + return self.make_document_for_test( + self, doc_id, doc_rev, content, has_conflicts) + + def make_document_for_test(self, test, doc_id, doc_rev, content, + has_conflicts): + return make_document_for_test( + test, doc_id, doc_rev, content, has_conflicts) + + def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id)) + + def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, + has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + + + def test_create_doc_allocating_doc_id(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertNotEqual(None, doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_different_ids_same_db(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertNotEqual(doc1.doc_id, doc2.doc_id) + + def test_create_doc_with_id(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id') + self.assertEqual('my-id', doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_existing_id(self): + doc = self.db.create_doc_from_json(simple_doc) + new_content = '{"something": "else"}' + self.assertRaises( + errors.RevisionConflict, self.db.create_doc_from_json, + new_content, doc.doc_id) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_put_doc_creating_initial(self): + doc = self.make_document('my_doc_id', None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertIsNot(None, new_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False) + + def test_put_doc_space_in_id(self): + doc = self.make_document('my doc id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_update(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + orig_rev = doc.rev + doc.set_json('{"updated": "stuff"}') + new_rev = self.db.put_doc(doc) + self.assertNotEqual(new_rev, orig_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, + '{"updated": "stuff"}', False) + self.assertEqual(doc.rev, new_rev) + + def test_put_non_ascii_key(self): + content = json.dumps({u'key\xe5': u'val'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_non_ascii_value(self): + content = json.dumps({'key': u'\xe5'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_doc_refuses_no_id(self): + doc = self.make_document(None, None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document("", None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_refuses_slashes(self): + doc = self.make_document('a/b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document(r'\b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_url_quoting_is_fine(self): + doc_id = "%2F%2Ffoo%2Fbar" + doc = self.make_document(doc_id, None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False) + + def test_put_doc_refuses_non_existing_old_rev(self): + doc = self.make_document('doc-id', 'test:4', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc) + + def test_put_doc_refuses_non_ascii_doc_id(self): + doc = self.make_document('d\xc3\xa5c-id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_fails_with_bad_old_rev(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + old_rev = doc.rev + bad_doc = self.make_document(doc.doc_id, 'other:1', + '{"something": "else"}') + self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc) + self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False) + + def test_create_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(new_doc.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev)) + + def test_put_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + doc2 = self.make_document('my_doc_id', None, simple_doc) + self.db.put_doc(doc2) + self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(doc2.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (doc2.rev, deleted_doc.rev)) + + def test_get_doc_after_put(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False) + + def test_get_doc_nonexisting(self): + self.assertIs(None, self.db.get_doc('non-existing')) + + def test_get_doc_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertIs(None, self.db.get_doc('my_doc_id')) + + def test_get_doc_include_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_get_docs(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual([doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_include_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual( + [doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id], + include_deleted=True))) + + def test_get_docs_request_ordered(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + self.assertEqual([doc2, doc1], + list(self.db.get_docs([doc2.doc_id, doc1.doc_id]))) + + def test_get_docs_empty_list(self): + self.assertEqual([], list(self.db.get_docs([]))) + + def test_handles_nested_content(self): + doc = self.db.create_doc_from_json(nested_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + def test_handles_doc_with_null(self): + doc = self.db.create_doc_from_json('{"key": null}') + self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False) + + def test_delete_doc(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + orig_rev = doc.rev + self.db.delete_doc(doc) + self.assertNotEqual(orig_rev, doc.rev) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + self.assertIs(None, self.db.get_doc(doc.doc_id)) + + def test_delete_doc_non_existent(self): + doc = self.make_document('non-existing', 'other:1', simple_doc) + self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc) + + def test_delete_doc_already_deleted(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertRaises(errors.DocumentAlreadyDeleted, + self.db.delete_doc, doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_delete_doc_bad_rev(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + + def test_delete_doc_sets_content_to_None(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertIs(None, doc.get_json()) + + def test_delete_doc_rev_supersedes(self): + doc = self.db.create_doc_from_json(simple_doc) + doc.set_json(nested_doc) + self.db.put_doc(doc) + doc.set_json('{"fishy": "content"}') + self.db.put_doc(doc) + old_rev = doc.rev + self.db.delete_doc(doc) + cur_vc = vectorclock.VectorClockRev(old_rev) + deleted_vc = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(deleted_vc.is_newer(cur_vc), + "%s does not supersede %s" % (doc.rev, old_rev)) + + def test_delete_then_put(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + doc.set_json(nested_doc) + self.db.put_doc(doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + - def test_create_get(self): - doc1 = self._db.create_doc({"key": "value"}, doc_id="testdoc") - doc2 = self._db.get_doc('testdoc') - self.assertEqual(doc1, doc2, 'error storing/retrieving document.') - self.assertEqual(self._db._get_generation(), 1) def tearDown(self): - self._db._server.delete('u1db_tests') + self.db._server.delete('u1db_tests') if __name__ == '__main__': - unittest.main() + unittest2.main() -- cgit v1.2.3 From 4417d89bb9bdd59d717501c6db3f2215cdeb87fb Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 12:07:28 -0200 Subject: SQLCipherDatabase now extends SQLitePartialExpandDatabase. --- src/leap/soledad/backends/sqlcipher.py | 831 +-------------------------------- 1 file changed, 3 insertions(+), 828 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index 24f47eed..fcdab251 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -30,6 +30,7 @@ import uuid import pkg_resources from u1db.backends import CommonBackend, CommonSyncTarget +from u1db.backends.sqlite_backend import SQLitePartialExpandDatabase from u1db import ( Document, errors, @@ -56,7 +57,7 @@ def open(path, create, document_factory=None, password=None): path, create=create, document_factory=document_factory, password=password) -class SQLCipherDatabase(CommonBackend): +class SQLCipherDatabase(SQLitePartialExpandDatabase): """A U1DB implementation that uses SQLCipher as its persistence layer.""" _sqlite_registry = {} @@ -74,25 +75,6 @@ class SQLCipherDatabase(CommonBackend): self._ensure_schema() self._factory = document_factory or Document - def set_document_factory(self, factory): - self._factory = factory - - def get_sync_target(self): - return SQLCipherSyncTarget(self) - - @classmethod - def _which_index_storage(cls, c): - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'index_storage'") - except dbapi2.OperationalError, e: - # The table does not exist yet - return None, e - else: - return c.fetchone()[0], None - - WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 - @classmethod def _open_database(cls, sqlite_file, document_factory=None, password=None): if not os.path.isfile(sqlite_file): @@ -135,15 +117,6 @@ class SQLCipherDatabase(CommonBackend): return backend_cls(sqlite_file, document_factory=document_factory, password=password) - @staticmethod - def delete_database(sqlite_file): - try: - os.unlink(sqlite_file) - except OSError as ex: - if ex.errno == errno.ENOENT: - raise errors.DatabaseDoesNotExist() - raise - @staticmethod def register_implementation(klass): """Register that we implement an SQLCipherDatabase. @@ -152,803 +125,5 @@ class SQLCipherDatabase(CommonBackend): """ SQLCipherDatabase._sqlite_registry[klass._index_storage_value] = klass - def _get_sqlite_handle(self): - """Get access to the underlying sqlite database. - - This should only be used by the test suite, etc, for examining the - state of the underlying database. - """ - return self._db_handle - - def _close_sqlite_handle(self): - """Release access to the underlying sqlite database.""" - self._db_handle.close() - - def close(self): - self._close_sqlite_handle() - - def _is_initialized(self, c): - """Check if this database has been initialized.""" - c.execute("PRAGMA case_sensitive_like=ON") - try: - c.execute("SELECT value FROM u1db_config" - " WHERE name = 'sql_schema'") - except dbapi2.OperationalError: - # The table does not exist yet - val = None - else: - val = c.fetchone() - if val is not None: - return True - return False - - def _initialize(self, c): - """Create the schema in the database.""" - #read the script with sql commands - # TODO: Change how we set up the dependency. Most likely use something - # like lp:dirspec to grab the file from a common resource - # directory. Doesn't specifically need to be handled until we get - # to the point of packaging this. - schema_content = pkg_resources.resource_string( - __name__, 'dbschema.sql') - # Note: We'd like to use c.executescript() here, but it seems that - # executescript always commits, even if you set - # isolation_level = None, so if we want to properly handle - # exclusive locking and rollbacks between processes, we need - # to execute it line-by-line - for line in schema_content.split(';'): - if not line: - continue - c.execute(line) - #add extra fields - self._extra_schema_init(c) - # A unique identifier should be set for this replica. Implementations - # don't have to strictly use uuid here, but we do want the uid to be - # unique amongst all databases that will sync with each other. - # We might extend this to using something with hostname for easier - # debugging. - self._set_replica_uid_in_transaction(uuid.uuid4().hex) - c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", - (self._index_storage_value,)) - - def _ensure_schema(self): - """Ensure that the database schema has been created.""" - old_isolation_level = self._db_handle.isolation_level - c = self._db_handle.cursor() - if self._is_initialized(c): - return - try: - # autocommit/own mgmt of transactions - self._db_handle.isolation_level = None - with self._db_handle: - # only one execution path should initialize the db - c.execute("begin exclusive") - if self._is_initialized(c): - return - self._initialize(c) - finally: - self._db_handle.isolation_level = old_isolation_level - - def _extra_schema_init(self, c): - """Add any extra fields, etc to the basic table definitions.""" - - def _parse_index_definition(self, index_field): - """Parse a field definition for an index, returning a Getter.""" - # Note: We may want to keep a Parser object around, and cache the - # Getter objects for a greater length of time. Specifically, if - # you create a bunch of indexes, and then insert 50k docs, you'll - # re-parse the indexes between puts. The time to insert the docs - # is still likely to dominate put_doc time, though. - parser = query_parser.Parser() - getter = parser.parse(index_field) - return getter - - def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): - """Update document_fields for a single document. - - :param doc_id: Identifier for this document - :param raw_doc: The python dict representation of the document. - :param getters: A list of [(field_name, Getter)]. Getter.get will be - called to evaluate the index definition for this document, and the - results will be inserted into the db. - :param db_cursor: An sqlite Cursor. - :return: None - """ - values = [] - for field_name, getter in getters: - for idx_value in getter.get(raw_doc): - values.append((doc_id, field_name, idx_value)) - if values: - db_cursor.executemany( - "INSERT INTO document_fields VALUES (?, ?, ?)", values) - - def _set_replica_uid(self, replica_uid): - """Force the replica_uid to be set.""" - with self._db_handle: - self._set_replica_uid_in_transaction(replica_uid) - - def _set_replica_uid_in_transaction(self, replica_uid): - """Set the replica_uid. A transaction should already be held.""" - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO u1db_config" - " VALUES ('replica_uid', ?)", - (replica_uid,)) - self._real_replica_uid = replica_uid - - def _get_replica_uid(self): - if self._real_replica_uid is not None: - return self._real_replica_uid - c = self._db_handle.cursor() - c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") - val = c.fetchone() - if val is None: - return None - self._real_replica_uid = val[0] - return self._real_replica_uid - - _replica_uid = property(_get_replica_uid) - - def _get_generation(self): - c = self._db_handle.cursor() - c.execute('SELECT max(generation) FROM transaction_log') - val = c.fetchone()[0] - if val is None: - return 0 - return val - - def _get_generation_info(self): - c = self._db_handle.cursor() - c.execute( - 'SELECT max(generation), transaction_id FROM transaction_log ') - val = c.fetchone() - if val[0] is None: - return(0, '') - return val - - def _get_trans_id_for_gen(self, generation): - if generation == 0: - return '' - c = self._db_handle.cursor() - c.execute( - 'SELECT transaction_id FROM transaction_log WHERE generation = ?', - (generation,)) - val = c.fetchone() - if val is None: - raise errors.InvalidGeneration - return val[0] - - def _get_transaction_log(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, transaction_id FROM transaction_log" - " ORDER BY generation") - return c.fetchall() - - def _get_doc(self, doc_id, check_for_conflicts=False): - """Get just the document content, without fancy handling.""" - c = self._db_handle.cursor() - if check_for_conflicts: - c.execute( - "SELECT document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " - "conflicts ON conflicts.doc_id = document.doc_id WHERE " - "document.doc_id = ? GROUP BY document.doc_id, " - "document.doc_rev, document.content;", (doc_id,)) - else: - c.execute( - "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", - (doc_id,)) - val = c.fetchone() - if val is None: - return None - doc_rev, content, conflicts = val - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - return doc - - def _has_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", - (doc_id,)) - val = c.fetchone() - if val is None: - return False - else: - return True - - def get_doc(self, doc_id, include_deleted=False): - doc = self._get_doc(doc_id, check_for_conflicts=True) - if doc is None: - return None - if doc.is_tombstone() and not include_deleted: - return None - return doc - - def get_all_docs(self, include_deleted=False): - """Get all documents from the database.""" - generation = self._get_generation() - results = [] - c = self._db_handle.cursor() - c.execute( - "SELECT document.doc_id, document.doc_rev, document.content, " - "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " - "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " - "document.doc_rev, document.content;") - rows = c.fetchall() - for doc_id, doc_rev, content, conflicts in rows: - if content is None and not include_deleted: - continue - doc = self._factory(doc_id, doc_rev, content) - doc.has_conflicts = conflicts > 0 - results.append(doc) - return (generation, results) - - def put_doc(self, doc): - if doc.doc_id is None: - raise errors.InvalidDocId() - self._check_doc_id(doc.doc_id) - self._check_doc_size(doc) - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc and old_doc.has_conflicts: - raise errors.ConflictedDoc() - if old_doc and doc.rev is None and old_doc.is_tombstone(): - new_rev = self._allocate_doc_rev(old_doc.rev) - else: - if old_doc is not None: - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - else: - if doc.rev is not None: - raise errors.RevisionConflict() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): - """Convert a dict representation into named fields. - - So something like: {'key1': 'val1', 'key2': 'val2'} - gets converted into: [(doc_id, 'key1', 'val1', 0) - (doc_id, 'key2', 'val2', 0)] - :param doc_id: Just added to every record. - :param base_field: if set, these are nested keys, so each field should - be appropriately prefixed. - :param raw_doc: The python dictionary. - """ - # TODO: Handle lists - values = [] - for field_name, value in raw_doc.iteritems(): - if value is None and not save_none: - continue - if base_field: - full_name = base_field + '.' + field_name - else: - full_name = field_name - if value is None or isinstance(value, (int, float, basestring)): - values.append((doc_id, full_name, value, len(values))) - else: - subvalues = self._expand_to_fields(doc_id, full_name, value, - save_none) - for _, subfield_name, val, _ in subvalues: - values.append((doc_id, subfield_name, val, len(values))) - return values - - def _put_and_update_indexes(self, old_doc, doc): - """Actually insert a document into the database. - - This both updates the existing documents content, and any indexes that - refer to this document. - """ - raise NotImplementedError(self._put_and_update_indexes) - - def whats_changed(self, old_generation=0): - c = self._db_handle.cursor() - c.execute("SELECT generation, doc_id, transaction_id" - " FROM transaction_log" - " WHERE generation > ? ORDER BY generation DESC", - (old_generation,)) - results = c.fetchall() - cur_gen = old_generation - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - c.execute("SELECT generation, transaction_id" - " FROM transaction_log ORDER BY generation DESC LIMIT 1") - results = c.fetchone() - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, newest_trans_id = results - - return cur_gen, newest_trans_id, changes - - def delete_doc(self, doc): - with self._db_handle: - old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) - if old_doc is None: - raise errors.DocumentDoesNotExist - if old_doc.rev != doc.rev: - raise errors.RevisionConflict() - if old_doc.is_tombstone(): - raise errors.DocumentAlreadyDeleted - if old_doc.has_conflicts: - raise errors.ConflictedDoc() - new_rev = self._allocate_doc_rev(doc.rev) - doc.rev = new_rev - doc.make_tombstone() - self._put_and_update_indexes(old_doc, doc) - return new_rev - - def _get_conflicts(self, doc_id): - c = self._db_handle.cursor() - c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", - (doc_id,)) - return [self._factory(doc_id, doc_rev, content) - for doc_rev, content in c.fetchall()] - - def get_doc_conflicts(self, doc_id): - with self._db_handle: - conflict_docs = self._get_conflicts(doc_id) - if not conflict_docs: - return [] - this_doc = self._get_doc(doc_id) - this_doc.has_conflicts = True - return [this_doc] + conflict_docs - - def _get_replica_gen_and_trans_id(self, other_replica_uid): - c = self._db_handle.cursor() - c.execute("SELECT known_generation, known_transaction_id FROM sync_log" - " WHERE replica_uid = ?", - (other_replica_uid,)) - val = c.fetchone() - if val is None: - other_gen = 0 - trans_id = '' - else: - other_gen = val[0] - trans_id = val[1] - return other_gen, trans_id - - def _set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - with self._db_handle: - self._do_set_replica_gen_and_trans_id( - other_replica_uid, other_generation, other_transaction_id) - - def _do_set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, - other_transaction_id): - c = self._db_handle.cursor() - c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", - (other_replica_uid, other_generation, - other_transaction_id)) - - def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, - replica_gen=None, replica_trans_id=None): - with self._db_handle: - return super(SQLCipherDatabase, self)._put_doc_if_newer(doc, - save_conflict=save_conflict, - replica_uid=replica_uid, replica_gen=replica_gen, - replica_trans_id=replica_trans_id) - - def _add_conflict(self, c, doc_id, my_doc_rev, my_content): - c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", - (doc_id, my_doc_rev, my_content)) - - def _delete_conflicts(self, c, doc, conflict_revs): - deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] - c.executemany("DELETE FROM conflicts" - " WHERE doc_id=? AND doc_rev=?", deleting) - doc.has_conflicts = self._has_conflicts(doc.doc_id) - - def _prune_conflicts(self, doc, doc_vcr): - if self._has_conflicts(doc.doc_id): - autoresolved = False - c_revs_to_prune = [] - for c_doc in self._get_conflicts(doc.doc_id): - c_vcr = vectorclock.VectorClockRev(c_doc.rev) - if doc_vcr.is_newer(c_vcr): - c_revs_to_prune.append(c_doc.rev) - elif doc.same_content_as(c_doc): - c_revs_to_prune.append(c_doc.rev) - doc_vcr.maximize(c_vcr) - autoresolved = True - if autoresolved: - doc_vcr.increment(self._replica_uid) - doc.rev = doc_vcr.as_str() - c = self._db_handle.cursor() - self._delete_conflicts(c, doc, c_revs_to_prune) - - def _force_doc_sync_conflict(self, doc): - my_doc = self._get_doc(doc.doc_id) - c = self._db_handle.cursor() - self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) - self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) - doc.has_conflicts = True - self._put_and_update_indexes(my_doc, doc) - - def resolve_doc(self, doc, conflicted_doc_revs): - with self._db_handle: - cur_doc = self._get_doc(doc.doc_id) - # TODO: https://bugs.launchpad.net/u1db/+bug/928274 - # I think we have a logic bug in resolve_doc - # Specifically, cur_doc.rev is always in the final vector - # clock of revisions that we supersede, even if it wasn't in - # conflicted_doc_revs. We still add it as a conflict, but the - # fact that _put_doc_if_newer propagates resolutions means I - # think that conflict could accidentally be resolved. We need - # to add a test for this case first. (create a rev, create a - # conflict, create another conflict, resolve the first rev - # and first conflict, then make sure that the resolved - # rev doesn't supersede the second conflict rev.) It *might* - # not matter, because the superseding rev is in as a - # conflict, but it does seem incorrect - new_rev = self._ensure_maximal_rev(cur_doc.rev, - conflicted_doc_revs) - superseded_revs = set(conflicted_doc_revs) - c = self._db_handle.cursor() - doc.rev = new_rev - if cur_doc.rev in superseded_revs: - self._put_and_update_indexes(cur_doc, doc) - else: - self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) - # TODO: Is there some way that we could construct a rev that would - # end up in superseded_revs, such that we add a conflict, and - # then immediately delete it? - self._delete_conflicts(c, doc, superseded_revs) - - def list_indexes(self): - """Return the list of indexes and their definitions.""" - c = self._db_handle.cursor() - # TODO: How do we test the ordering? - c.execute("SELECT name, field FROM index_definitions" - " ORDER BY name, offset") - definitions = [] - cur_name = None - for name, field in c.fetchall(): - if cur_name != name: - definitions.append((name, [])) - cur_name = name - definitions[-1][-1].append(field) - return definitions - - def _get_index_definition(self, index_name): - """Return the stored definition for a given index_name.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions" - " WHERE name = ? ORDER BY offset", (index_name,)) - fields = [x[0] for x in c.fetchall()] - if not fields: - raise errors.IndexDoesNotExist - return fields - - @staticmethod - def _strip_glob(value): - """Remove the trailing * from a value.""" - assert value[-1] == '*' - return value[:-1] - - def _format_query(self, definition, key_values): - # First, build the definition. We join the document_fields table - # against itself, as many times as the 'width' of our definition. - # We then do a query for each key_value, one-at-a-time. - # Note: All of these strings are static, we could cache them, etc. - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = ["d.doc_id = d%d.doc_id" - " AND d%d.field_name = ?" - % (i, i) for i in range(len(definition))] - wildcard_where = [novalue_where[i] - + (" AND d%d.value NOT NULL" % (i,)) - for i in range(len(definition))] - exact_where = [novalue_where[i] - + (" AND d%d.value = ?" % (i,)) - for i in range(len(definition))] - like_where = [novalue_where[i] - + (" AND d%d.value GLOB ?" % (i,)) - for i in range(len(definition))] - is_wildcard = False - # Merge the lists together, so that: - # [field1, field2, field3], [val1, val2, val3] - # Becomes: - # (field1, val1, field2, val2, field3, val3) - args = [] - where = [] - for idx, (field, value) in enumerate(zip(definition, key_values)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(exact_where[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_from_index(self, index_name, *key_values): - definition = self._get_index_definition(index_name) - if len(key_values) != len(definition): - raise errors.InvalidValueForIndex() - statement, args = self._format_query(definition, key_values) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def _format_range_query(self, definition, start_value, end_value): - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - wildcard_where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - like_where = [ - novalue_where[i] + ( - " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in - range(len(definition))] - range_where_lower = [ - novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in - range(len(definition))] - range_where_upper = [ - novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in - range(len(definition))] - args = [] - where = [] - if start_value: - if isinstance(start_value, basestring): - start_value = (start_value,) - if len(start_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, start_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(self._strip_glob(value)) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_lower[idx]) - args.append(value) - if end_value: - if isinstance(end_value, basestring): - end_value = (end_value,) - if len(end_value) != len(definition): - raise errors.InvalidValueForIndex() - is_wildcard = False - for idx, (field, value) in enumerate(zip(definition, end_value)): - args.append(field) - if value.endswith('*'): - if value == '*': - where.append(wildcard_where[idx]) - else: - # This is a glob match - if is_wildcard: - # We can't have a partial wildcard following - # another wildcard - raise errors.InvalidGlobbing - where.append(like_where[idx]) - args.append(self._strip_glob(value)) - args.append(value) - is_wildcard = True - else: - if is_wildcard: - raise errors.InvalidGlobbing - where.append(range_where_upper[idx]) - args.append(value) - statement = ( - "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " - "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " - "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " - "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( - ['d%d.value' % i for i in range(len(definition))]))) - return statement, args - - def get_range_from_index(self, index_name, start_value=None, - end_value=None): - """Return all documents with key values in the specified range.""" - definition = self._get_index_definition(index_name) - statement, args = self._format_range_query( - definition, start_value, end_value) - c = self._db_handle.cursor() - try: - c.execute(statement, tuple(args)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, args)) - res = c.fetchall() - results = [] - for row in res: - doc = self._factory(row[0], row[1], row[2]) - doc.has_conflicts = row[3] > 0 - results.append(doc) - return results - - def get_index_keys(self, index_name): - c = self._db_handle.cursor() - definition = self._get_index_definition(index_name) - value_fields = ', '.join([ - 'd%d.value' % i for i in range(len(definition))]) - tables = ["document_fields d%d" % i for i in range(len(definition))] - novalue_where = [ - "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in - range(len(definition))] - where = [ - novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in - range(len(definition))] - statement = ( - "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( - value_fields, ', '.join(tables), ' AND '.join(where), - value_fields)) - try: - c.execute(statement, tuple(definition)) - except dbapi2.OperationalError, e: - raise dbapi2.OperationalError(str(e) + - '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) - return c.fetchall() - - def delete_index(self, index_name): - with self._db_handle: - c = self._db_handle.cursor() - c.execute("DELETE FROM index_definitions WHERE name = ?", - (index_name,)) - c.execute( - "DELETE FROM document_fields WHERE document_fields.field_name " - " NOT IN (SELECT field from index_definitions)") - - -class SQLCipherSyncTarget(CommonSyncTarget): - - def get_sync_info(self, source_replica_uid): - source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( - source_replica_uid) - my_gen, my_trans_id = self._db._get_generation_info() - return ( - self._db._replica_uid, my_gen, my_trans_id, source_gen, - source_trans_id) - - def record_sync_info(self, source_replica_uid, source_replica_generation, - source_replica_transaction_id): - if self._trace_hook: - self._trace_hook('record_sync_info') - self._db._set_replica_gen_and_trans_id( - source_replica_uid, source_replica_generation, - source_replica_transaction_id) - - -class SQLCipherPartialExpandDatabase(SQLCipherDatabase): - """An SQLCipher Backend that expands documents into a document_field table. - - It stores the original document text in document.doc. For fields that are - indexed, the data goes into document_fields. - """ - - _index_storage_value = 'expand referenced' - - def _get_indexed_fields(self): - """Determine what fields are indexed.""" - c = self._db_handle.cursor() - c.execute("SELECT field FROM index_definitions") - return set([x[0] for x in c.fetchall()]) - - def _evaluate_index(self, raw_doc, field): - parser = query_parser.Parser() - getter = parser.parse(field) - return getter.get(raw_doc) - - def _put_and_update_indexes(self, old_doc, doc): - c = self._db_handle.cursor() - if doc and not doc.is_tombstone(): - raw_doc = json.loads(doc.get_json()) - else: - raw_doc = {} - if old_doc is not None: - c.execute("UPDATE document SET doc_rev=?, content=?" - " WHERE doc_id = ?", - (doc.rev, doc.get_json(), doc.doc_id)) - c.execute("DELETE FROM document_fields WHERE doc_id = ?", - (doc.doc_id,)) - else: - c.execute("INSERT INTO document (doc_id, doc_rev, content)" - " VALUES (?, ?, ?)", - (doc.doc_id, doc.rev, doc.get_json())) - indexed_fields = self._get_indexed_fields() - if indexed_fields: - # It is expected that len(indexed_fields) is shorter than - # len(raw_doc) - getters = [(field, self._parse_index_definition(field)) - for field in indexed_fields] - self._update_indexes(doc.doc_id, raw_doc, getters, c) - trans_id = self._allocate_transaction_id() - c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" - " VALUES (?, ?)", (doc.doc_id, trans_id)) - - def create_index(self, index_name, *index_expressions): - with self._db_handle: - c = self._db_handle.cursor() - cur_fields = self._get_indexed_fields() - definition = [(index_name, idx, field) - for idx, field in enumerate(index_expressions)] - try: - c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", - definition) - except dbapi2.IntegrityError as e: - stored_def = self._get_index_definition(index_name) - if stored_def == [x[-1] for x in definition]: - return - raise errors.IndexNameTakenError, e, sys.exc_info()[2] - new_fields = set( - [f for f in index_expressions if f not in cur_fields]) - if new_fields: - self._update_all_indexes(new_fields) - - def _iter_all_docs(self): - c = self._db_handle.cursor() - c.execute("SELECT doc_id, content FROM document") - while True: - next_rows = c.fetchmany() - if not next_rows: - break - for row in next_rows: - yield row - - def _update_all_indexes(self, new_fields): - """Iterate all the documents, and add content to document_fields. - - :param new_fields: The index definitions that need to be added. - """ - getters = [(field, self._parse_index_definition(field)) - for field in new_fields] - c = self._db_handle.cursor() - for doc_id, doc in self._iter_all_docs(): - if doc is None: - continue - raw_doc = json.loads(doc) - self._update_indexes(doc_id, raw_doc, getters, c) -SQLCipherDatabase.register_implementation(SQLCipherPartialExpandDatabase) +SQLCipherDatabase.register_implementation(SQLCipherDatabase) -- cgit v1.2.3 From 7823990656ac65982a1322ea049298350fb2185e Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 14:42:57 -0200 Subject: Refactor test files. --- src/leap/soledad/tests/__init__.py | 284 ------------------ src/leap/soledad/tests/test_encrypted.py | 211 +++++++++++++ src/leap/soledad/tests/test_logs.py | 75 +++++ src/leap/soledad/tests/test_sqlcipher.py | 494 +++++++++++++++++++++++++++++++ 4 files changed, 780 insertions(+), 284 deletions(-) create mode 100644 src/leap/soledad/tests/test_encrypted.py create mode 100644 src/leap/soledad/tests/test_logs.py create mode 100644 src/leap/soledad/tests/test_sqlcipher.py (limited to 'src/leap') diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index b6585755..e69de29b 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -1,284 +0,0 @@ -try: - import simplejson as json -except ImportError: - import json # noqa - -import unittest -import os - -import u1db -from soledad import ( - GPGWrapper, - SimpleLog, - TransactionLog, - SyncLog, - ) -from soledad.backends import leap - - -class EncryptedSyncTestCase(unittest.TestCase): - - PREFIX = "/var/tmp" - GNUPG_HOME = "%s/gnupg" % PREFIX - DB1_FILE = "%s/db1.u1db" % PREFIX - DB2_FILE = "%s/db2.u1db" % PREFIX - - def setUp(self): - self.db1 = u1db.open(self.DB1_FILE, create=True, - document_factory=leap.LeapDocument) - self.db2 = u1db.open(self.DB2_FILE, create=True, - document_factory=leap.LeapDocument) - self.gpg = GPGWrapper(gpghome=self.GNUPG_HOME) - self.gpg.import_keys(PUBLIC_KEY) - self.gpg.import_keys(PRIVATE_KEY) - - def tearDown(self): - os.unlink(self.DB1_FILE) - os.unlink(self.DB2_FILE) - - def test_get_set_encrypted(self): - doc1 = leap.LeapDocument(gpg_wrapper = self.gpg, - default_key = KEY_FINGERPRINT) - doc1.content = { 'key' : 'val' } - doc2 = leap.LeapDocument(doc_id=doc1.doc_id, - encrypted_json=doc1.get_encrypted_json(), - gpg_wrapper=self.gpg, - default_key = KEY_FINGERPRINT) - res1 = doc1.get_json() - res2 = doc2.get_json() - self.assertEqual(res1, res2, 'incorrect document encryption') - - -class LogTestCase(unittest.TestCase): - - - def test_transaction_log(self): - data = [ - (2, "doc_3", "tran_3"), - (3, "doc_2", "tran_2"), - (1, "doc_1", "tran_1") - ] - log = TransactionLog() - log.log = data - self.assertEqual(log.get_generation(), 3, 'error getting generation') - self.assertEqual(log.get_generation_info(), (3, 'tran_2'), - 'error getting generation info') - self.assertEqual(log.get_trans_id_for_gen(1), 'tran_1', - 'error getting trans_id for gen') - self.assertEqual(log.get_trans_id_for_gen(2), 'tran_3', - 'error getting trans_id for gen') - self.assertEqual(log.get_trans_id_for_gen(3), 'tran_2', - 'error getting trans_id for gen') - - def test_sync_log(self): - data = [ - ("replica_3", 3, "tran_3"), - ("replica_2", 2, "tran_2"), - ("replica_1", 1, "tran_1") - ] - log = SyncLog() - log.log = data - # test getting - self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), - (3, 'tran_3'), 'error getting replica gen and trans id') - self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'), - (2, 'tran_2'), 'error getting replica gen and trans id') - self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'), - (1, 'tran_1'), 'error getting replica gen and trans id') - # test setting - log.set_replica_gen_and_trans_id('replica_1', 2, 'tran_12') - self.assertEqual(len(log._log), 3, 'error in log size after setting') - self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'), - (2, 'tran_12'), 'error setting replica gen and trans id') - self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'), - (2, 'tran_2'), 'error setting replica gen and trans id') - self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), - (3, 'tran_3'), 'error setting replica gen and trans id') - - def test_whats_changed(self): - data = [ - (2, "doc_3", "tran_3"), - (3, "doc_2", "tran_2"), - (1, "doc_1", "tran_1") - ] - log = TransactionLog() - log.log = data - self.assertEqual( - log.whats_changed(3), - (3, "tran_2", []), - 'error getting whats changed.') - self.assertEqual( - log.whats_changed(2), - (3, "tran_2", [("doc_2",3,"tran_2")]), - 'error getting whats changed.') - self.assertEqual( - log.whats_changed(1), - (3, "tran_2", [("doc_3",2,"tran_3"),("doc_2",3,"tran_2")]), - 'error getting whats changed.') - - -# Key material for testing -KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF" -PUBLIC_KEY = """ ------BEGIN PGP PUBLIC KEY BLOCK----- -Version: GnuPG v1.4.10 (GNU/Linux) - -mQINBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz -iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO -zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx -irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT -huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs -d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g -wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb -hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv -U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H -T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i -Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB -tBxMZWFwIFRlc3QgS2V5IDxsZWFwQGxlYXAuc2U+iQI3BBMBCAAhBQJQvfnZAhsD -BQsJCAcDBRUKCQgLBRYCAwEAAh4BAheAAAoJEC9FXigk0Y3fT7EQAKH3IuRniOpb -T/DDIgwwjz3oxB/W0DDMyPXowlhSOuM0rgGfntBpBb3boezEXwL86NPQxNGGruF5 -hkmecSiuPSvOmQlqlS95NGQp6hNG0YaKColh+Q5NTspFXCAkFch9oqUje0LdxfSP -QfV9UpeEvGyPmk1I9EJV/YDmZ4+Djge1d7qhVZInz4Rx1NrSyF/Tc2EC0VpjQFsU -Y9Kb2YBBR7ivG6DBc8ty0jJXi7B4WjkFcUEJviQpMF2dCLdonCehYs1PqsN1N7j+ -eFjQd+hqVMJgYuSGKjvuAEfClM6MQw7+FmFwMyLgK/Ew/DttHEDCri77SPSkOGSI -txCzhTg6798f6mJr7WcXmHX1w1Vcib5FfZ8vTDFVhz/XgAgArdhPo9V6/1dgSSiB -KPQ/spsco6u5imdOhckERE0lnAYvVT6KE81TKuhF/b23u7x+Wdew6kK0EQhYA7wy -7LmlaNXc7rMBQJ9Z60CJ4JDtatBWZ0kNrt2VfdDHVdqBTOpl0CraNUjWE5YMDasr -K2dF5IX8D3uuYtpZnxqg0KzyLg0tzL0tvOL1C2iudgZUISZNPKbS0z0v+afuAAnx -2pTC3uezbh2Jt8SWTLhll4i0P4Ps5kZ6HQUO56O+/Z1cWovX+mQekYFmERySDR9n -3k1uAwLilJmRmepGmvYbB8HloV8HqwgguQINBFC9+dkBEAC0I/xn1uborMgDvBtf -H0sEhwnXBC849/32zic6udB6/3Efk9nzbSpL3FSOuXITZsZgCHPkKarnoQ2ztMcS -sh1ke1C5gQGms75UVmM/nS+2YI4vY8OX/GC/on2vUyncqdH+bR6xH5hx4NbWpfTs -iQHmz5C6zzS/kuabGdZyKRaZHt23WQ7JX/4zpjqbC99DjHcP9BSk7tJ8wI4bkMYD -uFVQdT9O6HwyKGYwUU4sAQRAj7XCTGvVbT0dpgJwH4RmrEtJoHAx4Whg8mJ710E0 -GCmzf2jqkNuOw76ivgk27Kge+Hw00jmJjQhHY0yVbiaoJwcRrPKzaSjEVNgrpgP3 -lXPRGQArgESsIOTeVVHQ8fhK2YtTeCY9rIiO+L0OX2xo9HK7hfHZZWL6rqymXdyS -fhzh/f6IPyHFWnvj7Brl7DR8heMikygcJqv+ed2yx7iLyCUJ10g12I48+aEj1aLe -dP7lna32iY8/Z0SHQLNH6PXO9SlPcq2aFUgKqE75A/0FMk7CunzU1OWr2ZtTLNO1 -WT/13LfOhhuEq9jTyTosn0WxBjJKq18lnhzCXlaw6EAtbA7CUwsD3CTPR56aAXFK -3I7KXOVAqggrvMe5Tpdg5drfYpI8hZovL5aAgb+7Y5ta10TcJdUhS5K3kFAWe/td -U0cmWUMDP1UMSQ5Jg6JIQVWhSwARAQABiQIfBBgBCAAJBQJQvfnZAhsMAAoJEC9F -Xigk0Y3fRwsP/i0ElYCyxeLpWJTwo1iCLkMKz2yX1lFVa9nT1BVTPOQwr/IAc5OX -NdtbJ14fUsKL5pWgW8OmrXtwZm1y4euI1RPWWubG01ouzwnGzv26UcuHeqC5orZj -cOnKtL40y8VGMm8LoicVkRJH8blPORCnaLjdOtmA3rx/v2EXrJpSa3AhOy0ZSRXk -ZSrK68AVNwamHRoBSYyo0AtaXnkPX4+tmO8X8BPfj125IljubvwZPIW9VWR9UqCE -VPfDR1XKegVb6VStIywF7kmrknM1C5qUY28rdZYWgKorw01hBGV4jTW0cqde3N51 -XT1jnIAa+NoXUM9uQoGYMiwrL7vNsLlyyiW5ayDyV92H/rIuiqhFgbJsHTlsm7I8 -oGheR784BagAA1NIKD1qEO9T6Kz9lzlDaeWS5AUKeXrb7ZJLI1TTCIZx5/DxjLqM -Tt/RFBpVo9geZQrvLUqLAMwdaUvDXC2c6DaCPXTh65oCZj/hqzlJHH+RoTWWzKI+ -BjXxgUWF9EmZUBrg68DSmI+9wuDFsjZ51BcqvJwxyfxtTaWhdoYqH/UQS+D1FP3/ -diZHHlzwVwPICzM9ooNTgbrcDzyxRkIVqsVwBq7EtzcvgYUyX53yG25Giy6YQaQ2 -ZtQ/VymwFL3XdUWV6B/hU4PVAFvO3qlOtdJ6TpE+nEWgcWjCv5g7RjXX -=MuOY ------END PGP PUBLIC KEY BLOCK----- -""" -PRIVATE_KEY = """ ------BEGIN PGP PRIVATE KEY BLOCK----- -Version: GnuPG v1.4.10 (GNU/Linux) - -lQcYBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz -iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO -zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx -irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT -huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs -d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g -wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb -hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv -U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H -T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i -Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB -AA/+JHtlL39G1wsH9R6UEfUQJGXR9MiIiwZoKcnRB2o8+DS+OLjg0JOh8XehtuCs -E/8oGQKtQqa5bEIstX7IZoYmYFiUQi9LOzIblmp2vxOm+HKkxa4JszWci2/ZmC3t -KtaA4adl9XVnshoQ7pijuCMUKB3naBEOAxd8s9d/JeReGIYkJErdrnVfNk5N71Ds -FmH5Ll3XtEDvgBUQP3nkA6QFjpsaB94FHjL3gDwum/cxzj6pCglcvHOzEhfY0Ddb -J967FozQTaf2JW3O+w3LOqtcKWpq87B7+O61tVidQPSSuzPjCtFF0D2LC9R/Hpky -KTMQ6CaKja4MPhjwywd4QPcHGYSqjMpflvJqi+kYIt8psUK/YswWjnr3r4fbuqVY -VhtiHvnBHQjz135lUqWvEz4hM3Xpnxydx7aRlv5NlevK8+YIO5oFbWbGNTWsPZI5 -jpoFBpSsnR1Q5tnvtNHauvoWV+XN2qAOBTG+/nEbDYH6Ak3aaE9jrpTdYh0CotYF -q7csANsDy3JvkAzeU6WnYpsHHaAjqOGyiZGsLej1UcXPFMosE/aUo4WQhiS8Zx2c -zOVKOi/X5vQ2GdNT9Qolz8AriwzsvFR+bxPzyd8V6ALwDsoXvwEYinYBKK8j0OPv -OOihSR6HVsuP9NUZNU9ewiGzte/+/r6pNXHvR7wTQ8EWLcEIAN6Zyrb0bHZTIlxt -VWur/Ht2mIZrBaO50qmM5RD3T5oXzWXi/pjLrIpBMfeZR9DWfwQwjYzwqi7pxtYx -nJvbMuY505rfnMoYxb4J+cpRXV8MS7Dr1vjjLVUC9KiwSbM3gg6emfd2yuA93ihv -Pe3mffzLIiQa4mRE3wtGcioC43nWuV2K2e1KjxeFg07JhrezA/1Cak505ab/tmvP -4YmjR5c44+yL/YcQ3HdFgs4mV+nVbptRXvRcPpolJsgxPccGNdvHhsoR4gwXMS3F -RRPD2z6x8xeN73Q4KH3bm01swQdwFBZbWVfmUGLxvN7leCdfs9+iFJyqHiCIB6Iv -mQfp8F0IAOwSo8JhWN+V1dwML4EkIrM8wUb4yecNLkyR6TpPH/qXx4PxVMC+vy6x -sCtjeHIwKE+9vqnlhd5zOYh7qYXEJtYwdeDDmDbL8oks1LFfd+FyAuZXY33DLwn0 -cRYsr2OEZmaajqUB3NVmj3H4uJBN9+paFHyFSXrH68K1Fk2o3n+RSf2EiX+eICwI -L6rqoF5sSVUghBWdNegV7qfy4anwTQwrIMGjgU5S6PKW0Dr/3iO5z3qQpGPAj5OW -ATqPWkDICLbObPxD5cJlyyNE2wCA9VVc6/1d6w4EVwSq9h3/WTpATEreXXxTGptd -LNiTA1nmakBYNO2Iyo3djhaqBdWjk+EIAKtVEnJH9FAVwWOvaj1RoZMA5DnDMo7e -SnhrCXl8AL7Z1WInEaybasTJXn1uQ8xY52Ua4b8cbuEKRKzw/70NesFRoMLYoHTO -dyeszvhoDHberpGRTciVmpMu7Hyi33rM31K9epA4ib6QbbCHnxkWOZB+Bhgj1hJ8 -xb4RBYWiWpAYcg0+DAC3w9gfxQhtUlZPIbmbrBmrVkO2GVGUj8kH6k4UV6kUHEGY -HQWQR0HcbKcXW81ZXCCD0l7ROuEWQtTe5Jw7dJ4/QFuqZnPutXVRNOZqpl6eRShw -7X2/a29VXBpmHA95a88rSQsL+qm7Fb3prqRmuMCtrUZgFz7HLSTuUMR867QcTGVh -cCBUZXN0IEtleSA8bGVhcEBsZWFwLnNlPokCNwQTAQgAIQUCUL352QIbAwULCQgH -AwUVCgkICwUWAgMBAAIeAQIXgAAKCRAvRV4oJNGN30+xEACh9yLkZ4jqW0/wwyIM -MI896MQf1tAwzMj16MJYUjrjNK4Bn57QaQW926HsxF8C/OjT0MTRhq7heYZJnnEo -rj0rzpkJapUveTRkKeoTRtGGigqJYfkOTU7KRVwgJBXIfaKlI3tC3cX0j0H1fVKX -hLxsj5pNSPRCVf2A5mePg44HtXe6oVWSJ8+EcdTa0shf03NhAtFaY0BbFGPSm9mA -QUe4rxugwXPLctIyV4uweFo5BXFBCb4kKTBdnQi3aJwnoWLNT6rDdTe4/nhY0Hfo -alTCYGLkhio77gBHwpTOjEMO/hZhcDMi4CvxMPw7bRxAwq4u+0j0pDhkiLcQs4U4 -Ou/fH+pia+1nF5h19cNVXIm+RX2fL0wxVYc/14AIAK3YT6PVev9XYEkogSj0P7Kb -HKOruYpnToXJBERNJZwGL1U+ihPNUyroRf29t7u8flnXsOpCtBEIWAO8Muy5pWjV -3O6zAUCfWetAieCQ7WrQVmdJDa7dlX3Qx1XagUzqZdAq2jVI1hOWDA2rKytnReSF -/A97rmLaWZ8aoNCs8i4NLcy9Lbzi9QtornYGVCEmTTym0tM9L/mn7gAJ8dqUwt7n -s24dibfElky4ZZeItD+D7OZGeh0FDuejvv2dXFqL1/pkHpGBZhEckg0fZ95NbgMC -4pSZkZnqRpr2GwfB5aFfB6sIIJ0HGARQvfnZARAAtCP8Z9bm6KzIA7wbXx9LBIcJ -1wQvOPf99s4nOrnQev9xH5PZ820qS9xUjrlyE2bGYAhz5Cmq56ENs7THErIdZHtQ -uYEBprO+VFZjP50vtmCOL2PDl/xgv6J9r1Mp3KnR/m0esR+YceDW1qX07IkB5s+Q -us80v5LmmxnWcikWmR7dt1kOyV/+M6Y6mwvfQ4x3D/QUpO7SfMCOG5DGA7hVUHU/ -Tuh8MihmMFFOLAEEQI+1wkxr1W09HaYCcB+EZqxLSaBwMeFoYPJie9dBNBgps39o -6pDbjsO+or4JNuyoHvh8NNI5iY0IR2NMlW4mqCcHEazys2koxFTYK6YD95Vz0RkA -K4BErCDk3lVR0PH4StmLU3gmPayIjvi9Dl9saPRyu4Xx2WVi+q6spl3ckn4c4f3+ -iD8hxVp74+wa5ew0fIXjIpMoHCar/nndsse4i8glCddINdiOPPmhI9Wi3nT+5Z2t -9omPP2dEh0CzR+j1zvUpT3KtmhVICqhO+QP9BTJOwrp81NTlq9mbUyzTtVk/9dy3 -zoYbhKvY08k6LJ9FsQYySqtfJZ4cwl5WsOhALWwOwlMLA9wkz0eemgFxStyOylzl -QKoIK7zHuU6XYOXa32KSPIWaLy+WgIG/u2ObWtdE3CXVIUuSt5BQFnv7XVNHJllD -Az9VDEkOSYOiSEFVoUsAEQEAAQAP/1AagnZQZyzHDEgw4QELAspYHCWLXE5aZInX -wTUJhK31IgIXNn9bJ0hFiSpQR2xeMs9oYtRuPOu0P8oOFMn4/z374fkjZy8QVY3e -PlL+3EUeqYtkMwlGNmVw5a/NbNuNfm5Darb7pEfbYd1gPcni4MAYw7R2SG/57GbC -9gucvspHIfOSfBNLBthDzmK8xEKe1yD2eimfc2T7IRYb6hmkYfeds5GsqvGI6mwI -85h4uUHWRc5JOlhVM6yX8hSWx0L60Z3DZLChmc8maWnFXd7C8eQ6P1azJJbW71Ih -7CoK0XW4LE82vlQurSRFgTwfl7wFYszW2bOzCuhHDDtYnwH86Nsu0DC78ZVRnvxn -E8Ke/AJgrdhIOo4UAyR+aZD2+2mKd7/waOUTUrUtTzc7i8N3YXGi/EIaNReBXaq+ -ZNOp24BlFzRp+FCF/pptDW9HjPdiV09x0DgICmeZS4Gq/4vFFIahWctg52NGebT0 -Idxngjj+xDtLaZlLQoOz0n5ByjO/Wi0ANmMv1sMKCHhGvdaSws2/PbMR2r4caj8m -KXpIgdinM/wUzHJ5pZyF2U/qejsRj8Kw8KH/tfX4JCLhiaP/mgeTuWGDHeZQERAT -xPmRFHaLP9/ZhvGNh6okIYtrKjWTLGoXvKLHcrKNisBLSq+P2WeFrlme1vjvJMo/ -jPwLT5o9CADQmcbKZ+QQ1ZM9v99iDZol7SAMZX43JC019sx6GK0u6xouJBcLfeB4 -OXacTgmSYdTa9RM9fbfVpti01tJ84LV2SyL/VJq/enJF4XQPSynT/tFTn1PAor6o -tEAAd8fjKdJ6LnD5wb92SPHfQfXqI84rFEO8rUNIE/1ErT6DYifDzVCbfD2KZdoF -cOSp7TpD77sY1bs74ocBX5ejKtd+aH99D78bJSMM4pSDZsIEwnomkBHTziubPwJb -OwnATy0LmSMAWOw5rKbsh5nfwCiUTM20xp0t5JeXd+wPVWbpWqI2EnkCEN+RJr9i -7dp/ymDQ+Yt5wrsN3NwoyiexPOG91WQVCADdErHsnglVZZq9Z8Wx7KwecGCUurJ2 -H6lKudv5YOxPnAzqZS5HbpZd/nRTMZh2rdXCr5m2YOuewyYjvM757AkmUpM09zJX -MQ1S67/UX2y8/74TcRF97Ncx9HeELs92innBRXoFitnNguvcO6Esx4BTe1OdU6qR -ER3zAmVf22Le9ciXbu24DN4mleOH+OmBx7X2PqJSYW9GAMTsRB081R6EWKH7romQ -waxFrZ4DJzZ9ltyosEJn5F32StyLrFxpcrdLUoEaclZCv2qka7sZvi0EvovDVEBU -e10jOx9AOwf8Gj2ufhquQ6qgVYCzbP+YrodtkFrXRS3IsljIchj1M2ffB/0bfoUs -rtER9pLvYzCjBPg8IfGLw0o754Qbhh/ReplCRTusP/fQMybvCvfxreS3oyEriu/G -GufRomjewZ8EMHDIgUsLcYo2UHZsfF7tcazgxMGmMvazp4r8vpgrvW/8fIN/6Adu -tF+WjWDTvJLFJCe6O+BFJOWrssNrrra1zGtLC1s8s+Wfpe+bGPL5zpHeebGTwH1U -22eqgJArlEKxrfarz7W5+uHZJHSjF/K9ZvunLGD0n9GOPMpji3UO3zeM8IYoWn7E -/EWK1XbjnssNemeeTZ+sDh+qrD7BOi+vCX1IyBxbfqnQfJZvmcPWpruy1UsO+aIC -0GY8Jr3OL69dDQ21jueJAh8EGAEIAAkFAlC9+dkCGwwACgkQL0VeKCTRjd9HCw/+ -LQSVgLLF4ulYlPCjWIIuQwrPbJfWUVVr2dPUFVM85DCv8gBzk5c121snXh9Swovm -laBbw6ate3BmbXLh64jVE9Za5sbTWi7PCcbO/bpRy4d6oLmitmNw6cq0vjTLxUYy -bwuiJxWREkfxuU85EKdouN062YDevH+/YResmlJrcCE7LRlJFeRlKsrrwBU3BqYd -GgFJjKjQC1peeQ9fj62Y7xfwE9+PXbkiWO5u/Bk8hb1VZH1SoIRU98NHVcp6BVvp -VK0jLAXuSauSczULmpRjbyt1lhaAqivDTWEEZXiNNbRyp17c3nVdPWOcgBr42hdQ -z25CgZgyLCsvu82wuXLKJblrIPJX3Yf+si6KqEWBsmwdOWybsjygaF5HvzgFqAAD -U0goPWoQ71PorP2XOUNp5ZLkBQp5etvtkksjVNMIhnHn8PGMuoxO39EUGlWj2B5l -Cu8tSosAzB1pS8NcLZzoNoI9dOHrmgJmP+GrOUkcf5GhNZbMoj4GNfGBRYX0SZlQ -GuDrwNKYj73C4MWyNnnUFyq8nDHJ/G1NpaF2hiof9RBL4PUU/f92JkceXPBXA8gL -Mz2ig1OButwPPLFGQhWqxXAGrsS3Ny+BhTJfnfIbbkaLLphBpDZm1D9XKbAUvdd1 -RZXoH+FTg9UAW87eqU610npOkT6cRaBxaMK/mDtGNdc= -=JTFu ------END PGP PRIVATE KEY BLOCK----- -""" - -if __name__ == '__main__': - unittest.main() diff --git a/src/leap/soledad/tests/test_encrypted.py b/src/leap/soledad/tests/test_encrypted.py new file mode 100644 index 00000000..2333fc41 --- /dev/null +++ b/src/leap/soledad/tests/test_encrypted.py @@ -0,0 +1,211 @@ +try: + import simplejson as json +except ImportError: + import json # noqa + +import unittest2 as unittest +import os + +import u1db +from soledad import GPGWrapper +from soledad.backends.leap import LeapDocument + + +class EncryptedSyncTestCase(unittest.TestCase): + + PREFIX = "/var/tmp" + GNUPG_HOME = "%s/gnupg" % PREFIX + DB1_FILE = "%s/db1.u1db" % PREFIX + DB2_FILE = "%s/db2.u1db" % PREFIX + + def setUp(self): + self.db1 = u1db.open(self.DB1_FILE, create=True, + document_factory=LeapDocument) + self.db2 = u1db.open(self.DB2_FILE, create=True, + document_factory=LeapDocument) + self.gpg = GPGWrapper(gpghome=self.GNUPG_HOME) + self.gpg.import_keys(PUBLIC_KEY) + self.gpg.import_keys(PRIVATE_KEY) + + def tearDown(self): + os.unlink(self.DB1_FILE) + os.unlink(self.DB2_FILE) + + def test_get_set_encrypted(self): + doc1 = LeapDocument(gpg_wrapper = self.gpg, + default_key = KEY_FINGERPRINT) + doc1.content = { 'key' : 'val' } + doc2 = LeapDocument(doc_id=doc1.doc_id, + encrypted_json=doc1.get_encrypted_json(), + gpg_wrapper=self.gpg, + default_key = KEY_FINGERPRINT) + res1 = doc1.get_json() + res2 = doc2.get_json() + self.assertEqual(res1, res2, 'incorrect document encryption') + + +# Key material for testing +KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF" +PUBLIC_KEY = """ +-----BEGIN PGP PUBLIC KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +mQINBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz +iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO +zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx +irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT +huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs +d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g +wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb +hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv +U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H +T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i +Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB +tBxMZWFwIFRlc3QgS2V5IDxsZWFwQGxlYXAuc2U+iQI3BBMBCAAhBQJQvfnZAhsD +BQsJCAcDBRUKCQgLBRYCAwEAAh4BAheAAAoJEC9FXigk0Y3fT7EQAKH3IuRniOpb +T/DDIgwwjz3oxB/W0DDMyPXowlhSOuM0rgGfntBpBb3boezEXwL86NPQxNGGruF5 +hkmecSiuPSvOmQlqlS95NGQp6hNG0YaKColh+Q5NTspFXCAkFch9oqUje0LdxfSP +QfV9UpeEvGyPmk1I9EJV/YDmZ4+Djge1d7qhVZInz4Rx1NrSyF/Tc2EC0VpjQFsU +Y9Kb2YBBR7ivG6DBc8ty0jJXi7B4WjkFcUEJviQpMF2dCLdonCehYs1PqsN1N7j+ +eFjQd+hqVMJgYuSGKjvuAEfClM6MQw7+FmFwMyLgK/Ew/DttHEDCri77SPSkOGSI +txCzhTg6798f6mJr7WcXmHX1w1Vcib5FfZ8vTDFVhz/XgAgArdhPo9V6/1dgSSiB +KPQ/spsco6u5imdOhckERE0lnAYvVT6KE81TKuhF/b23u7x+Wdew6kK0EQhYA7wy +7LmlaNXc7rMBQJ9Z60CJ4JDtatBWZ0kNrt2VfdDHVdqBTOpl0CraNUjWE5YMDasr +K2dF5IX8D3uuYtpZnxqg0KzyLg0tzL0tvOL1C2iudgZUISZNPKbS0z0v+afuAAnx +2pTC3uezbh2Jt8SWTLhll4i0P4Ps5kZ6HQUO56O+/Z1cWovX+mQekYFmERySDR9n +3k1uAwLilJmRmepGmvYbB8HloV8HqwgguQINBFC9+dkBEAC0I/xn1uborMgDvBtf +H0sEhwnXBC849/32zic6udB6/3Efk9nzbSpL3FSOuXITZsZgCHPkKarnoQ2ztMcS +sh1ke1C5gQGms75UVmM/nS+2YI4vY8OX/GC/on2vUyncqdH+bR6xH5hx4NbWpfTs +iQHmz5C6zzS/kuabGdZyKRaZHt23WQ7JX/4zpjqbC99DjHcP9BSk7tJ8wI4bkMYD +uFVQdT9O6HwyKGYwUU4sAQRAj7XCTGvVbT0dpgJwH4RmrEtJoHAx4Whg8mJ710E0 +GCmzf2jqkNuOw76ivgk27Kge+Hw00jmJjQhHY0yVbiaoJwcRrPKzaSjEVNgrpgP3 +lXPRGQArgESsIOTeVVHQ8fhK2YtTeCY9rIiO+L0OX2xo9HK7hfHZZWL6rqymXdyS +fhzh/f6IPyHFWnvj7Brl7DR8heMikygcJqv+ed2yx7iLyCUJ10g12I48+aEj1aLe +dP7lna32iY8/Z0SHQLNH6PXO9SlPcq2aFUgKqE75A/0FMk7CunzU1OWr2ZtTLNO1 +WT/13LfOhhuEq9jTyTosn0WxBjJKq18lnhzCXlaw6EAtbA7CUwsD3CTPR56aAXFK +3I7KXOVAqggrvMe5Tpdg5drfYpI8hZovL5aAgb+7Y5ta10TcJdUhS5K3kFAWe/td +U0cmWUMDP1UMSQ5Jg6JIQVWhSwARAQABiQIfBBgBCAAJBQJQvfnZAhsMAAoJEC9F +Xigk0Y3fRwsP/i0ElYCyxeLpWJTwo1iCLkMKz2yX1lFVa9nT1BVTPOQwr/IAc5OX +NdtbJ14fUsKL5pWgW8OmrXtwZm1y4euI1RPWWubG01ouzwnGzv26UcuHeqC5orZj +cOnKtL40y8VGMm8LoicVkRJH8blPORCnaLjdOtmA3rx/v2EXrJpSa3AhOy0ZSRXk +ZSrK68AVNwamHRoBSYyo0AtaXnkPX4+tmO8X8BPfj125IljubvwZPIW9VWR9UqCE +VPfDR1XKegVb6VStIywF7kmrknM1C5qUY28rdZYWgKorw01hBGV4jTW0cqde3N51 +XT1jnIAa+NoXUM9uQoGYMiwrL7vNsLlyyiW5ayDyV92H/rIuiqhFgbJsHTlsm7I8 +oGheR784BagAA1NIKD1qEO9T6Kz9lzlDaeWS5AUKeXrb7ZJLI1TTCIZx5/DxjLqM +Tt/RFBpVo9geZQrvLUqLAMwdaUvDXC2c6DaCPXTh65oCZj/hqzlJHH+RoTWWzKI+ +BjXxgUWF9EmZUBrg68DSmI+9wuDFsjZ51BcqvJwxyfxtTaWhdoYqH/UQS+D1FP3/ +diZHHlzwVwPICzM9ooNTgbrcDzyxRkIVqsVwBq7EtzcvgYUyX53yG25Giy6YQaQ2 +ZtQ/VymwFL3XdUWV6B/hU4PVAFvO3qlOtdJ6TpE+nEWgcWjCv5g7RjXX +=MuOY +-----END PGP PUBLIC KEY BLOCK----- +""" +PRIVATE_KEY = """ +-----BEGIN PGP PRIVATE KEY BLOCK----- +Version: GnuPG v1.4.10 (GNU/Linux) + +lQcYBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz +iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO +zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx +irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT +huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs +d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g +wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb +hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv +U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H +T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i +Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB +AA/+JHtlL39G1wsH9R6UEfUQJGXR9MiIiwZoKcnRB2o8+DS+OLjg0JOh8XehtuCs +E/8oGQKtQqa5bEIstX7IZoYmYFiUQi9LOzIblmp2vxOm+HKkxa4JszWci2/ZmC3t +KtaA4adl9XVnshoQ7pijuCMUKB3naBEOAxd8s9d/JeReGIYkJErdrnVfNk5N71Ds +FmH5Ll3XtEDvgBUQP3nkA6QFjpsaB94FHjL3gDwum/cxzj6pCglcvHOzEhfY0Ddb +J967FozQTaf2JW3O+w3LOqtcKWpq87B7+O61tVidQPSSuzPjCtFF0D2LC9R/Hpky +KTMQ6CaKja4MPhjwywd4QPcHGYSqjMpflvJqi+kYIt8psUK/YswWjnr3r4fbuqVY +VhtiHvnBHQjz135lUqWvEz4hM3Xpnxydx7aRlv5NlevK8+YIO5oFbWbGNTWsPZI5 +jpoFBpSsnR1Q5tnvtNHauvoWV+XN2qAOBTG+/nEbDYH6Ak3aaE9jrpTdYh0CotYF +q7csANsDy3JvkAzeU6WnYpsHHaAjqOGyiZGsLej1UcXPFMosE/aUo4WQhiS8Zx2c +zOVKOi/X5vQ2GdNT9Qolz8AriwzsvFR+bxPzyd8V6ALwDsoXvwEYinYBKK8j0OPv +OOihSR6HVsuP9NUZNU9ewiGzte/+/r6pNXHvR7wTQ8EWLcEIAN6Zyrb0bHZTIlxt +VWur/Ht2mIZrBaO50qmM5RD3T5oXzWXi/pjLrIpBMfeZR9DWfwQwjYzwqi7pxtYx +nJvbMuY505rfnMoYxb4J+cpRXV8MS7Dr1vjjLVUC9KiwSbM3gg6emfd2yuA93ihv +Pe3mffzLIiQa4mRE3wtGcioC43nWuV2K2e1KjxeFg07JhrezA/1Cak505ab/tmvP +4YmjR5c44+yL/YcQ3HdFgs4mV+nVbptRXvRcPpolJsgxPccGNdvHhsoR4gwXMS3F +RRPD2z6x8xeN73Q4KH3bm01swQdwFBZbWVfmUGLxvN7leCdfs9+iFJyqHiCIB6Iv +mQfp8F0IAOwSo8JhWN+V1dwML4EkIrM8wUb4yecNLkyR6TpPH/qXx4PxVMC+vy6x +sCtjeHIwKE+9vqnlhd5zOYh7qYXEJtYwdeDDmDbL8oks1LFfd+FyAuZXY33DLwn0 +cRYsr2OEZmaajqUB3NVmj3H4uJBN9+paFHyFSXrH68K1Fk2o3n+RSf2EiX+eICwI +L6rqoF5sSVUghBWdNegV7qfy4anwTQwrIMGjgU5S6PKW0Dr/3iO5z3qQpGPAj5OW +ATqPWkDICLbObPxD5cJlyyNE2wCA9VVc6/1d6w4EVwSq9h3/WTpATEreXXxTGptd +LNiTA1nmakBYNO2Iyo3djhaqBdWjk+EIAKtVEnJH9FAVwWOvaj1RoZMA5DnDMo7e +SnhrCXl8AL7Z1WInEaybasTJXn1uQ8xY52Ua4b8cbuEKRKzw/70NesFRoMLYoHTO +dyeszvhoDHberpGRTciVmpMu7Hyi33rM31K9epA4ib6QbbCHnxkWOZB+Bhgj1hJ8 +xb4RBYWiWpAYcg0+DAC3w9gfxQhtUlZPIbmbrBmrVkO2GVGUj8kH6k4UV6kUHEGY +HQWQR0HcbKcXW81ZXCCD0l7ROuEWQtTe5Jw7dJ4/QFuqZnPutXVRNOZqpl6eRShw +7X2/a29VXBpmHA95a88rSQsL+qm7Fb3prqRmuMCtrUZgFz7HLSTuUMR867QcTGVh +cCBUZXN0IEtleSA8bGVhcEBsZWFwLnNlPokCNwQTAQgAIQUCUL352QIbAwULCQgH +AwUVCgkICwUWAgMBAAIeAQIXgAAKCRAvRV4oJNGN30+xEACh9yLkZ4jqW0/wwyIM +MI896MQf1tAwzMj16MJYUjrjNK4Bn57QaQW926HsxF8C/OjT0MTRhq7heYZJnnEo +rj0rzpkJapUveTRkKeoTRtGGigqJYfkOTU7KRVwgJBXIfaKlI3tC3cX0j0H1fVKX +hLxsj5pNSPRCVf2A5mePg44HtXe6oVWSJ8+EcdTa0shf03NhAtFaY0BbFGPSm9mA +QUe4rxugwXPLctIyV4uweFo5BXFBCb4kKTBdnQi3aJwnoWLNT6rDdTe4/nhY0Hfo +alTCYGLkhio77gBHwpTOjEMO/hZhcDMi4CvxMPw7bRxAwq4u+0j0pDhkiLcQs4U4 +Ou/fH+pia+1nF5h19cNVXIm+RX2fL0wxVYc/14AIAK3YT6PVev9XYEkogSj0P7Kb +HKOruYpnToXJBERNJZwGL1U+ihPNUyroRf29t7u8flnXsOpCtBEIWAO8Muy5pWjV +3O6zAUCfWetAieCQ7WrQVmdJDa7dlX3Qx1XagUzqZdAq2jVI1hOWDA2rKytnReSF +/A97rmLaWZ8aoNCs8i4NLcy9Lbzi9QtornYGVCEmTTym0tM9L/mn7gAJ8dqUwt7n +s24dibfElky4ZZeItD+D7OZGeh0FDuejvv2dXFqL1/pkHpGBZhEckg0fZ95NbgMC +4pSZkZnqRpr2GwfB5aFfB6sIIJ0HGARQvfnZARAAtCP8Z9bm6KzIA7wbXx9LBIcJ +1wQvOPf99s4nOrnQev9xH5PZ820qS9xUjrlyE2bGYAhz5Cmq56ENs7THErIdZHtQ +uYEBprO+VFZjP50vtmCOL2PDl/xgv6J9r1Mp3KnR/m0esR+YceDW1qX07IkB5s+Q +us80v5LmmxnWcikWmR7dt1kOyV/+M6Y6mwvfQ4x3D/QUpO7SfMCOG5DGA7hVUHU/ +Tuh8MihmMFFOLAEEQI+1wkxr1W09HaYCcB+EZqxLSaBwMeFoYPJie9dBNBgps39o +6pDbjsO+or4JNuyoHvh8NNI5iY0IR2NMlW4mqCcHEazys2koxFTYK6YD95Vz0RkA +K4BErCDk3lVR0PH4StmLU3gmPayIjvi9Dl9saPRyu4Xx2WVi+q6spl3ckn4c4f3+ +iD8hxVp74+wa5ew0fIXjIpMoHCar/nndsse4i8glCddINdiOPPmhI9Wi3nT+5Z2t +9omPP2dEh0CzR+j1zvUpT3KtmhVICqhO+QP9BTJOwrp81NTlq9mbUyzTtVk/9dy3 +zoYbhKvY08k6LJ9FsQYySqtfJZ4cwl5WsOhALWwOwlMLA9wkz0eemgFxStyOylzl +QKoIK7zHuU6XYOXa32KSPIWaLy+WgIG/u2ObWtdE3CXVIUuSt5BQFnv7XVNHJllD +Az9VDEkOSYOiSEFVoUsAEQEAAQAP/1AagnZQZyzHDEgw4QELAspYHCWLXE5aZInX +wTUJhK31IgIXNn9bJ0hFiSpQR2xeMs9oYtRuPOu0P8oOFMn4/z374fkjZy8QVY3e +PlL+3EUeqYtkMwlGNmVw5a/NbNuNfm5Darb7pEfbYd1gPcni4MAYw7R2SG/57GbC +9gucvspHIfOSfBNLBthDzmK8xEKe1yD2eimfc2T7IRYb6hmkYfeds5GsqvGI6mwI +85h4uUHWRc5JOlhVM6yX8hSWx0L60Z3DZLChmc8maWnFXd7C8eQ6P1azJJbW71Ih +7CoK0XW4LE82vlQurSRFgTwfl7wFYszW2bOzCuhHDDtYnwH86Nsu0DC78ZVRnvxn +E8Ke/AJgrdhIOo4UAyR+aZD2+2mKd7/waOUTUrUtTzc7i8N3YXGi/EIaNReBXaq+ +ZNOp24BlFzRp+FCF/pptDW9HjPdiV09x0DgICmeZS4Gq/4vFFIahWctg52NGebT0 +Idxngjj+xDtLaZlLQoOz0n5ByjO/Wi0ANmMv1sMKCHhGvdaSws2/PbMR2r4caj8m +KXpIgdinM/wUzHJ5pZyF2U/qejsRj8Kw8KH/tfX4JCLhiaP/mgeTuWGDHeZQERAT +xPmRFHaLP9/ZhvGNh6okIYtrKjWTLGoXvKLHcrKNisBLSq+P2WeFrlme1vjvJMo/ +jPwLT5o9CADQmcbKZ+QQ1ZM9v99iDZol7SAMZX43JC019sx6GK0u6xouJBcLfeB4 +OXacTgmSYdTa9RM9fbfVpti01tJ84LV2SyL/VJq/enJF4XQPSynT/tFTn1PAor6o +tEAAd8fjKdJ6LnD5wb92SPHfQfXqI84rFEO8rUNIE/1ErT6DYifDzVCbfD2KZdoF +cOSp7TpD77sY1bs74ocBX5ejKtd+aH99D78bJSMM4pSDZsIEwnomkBHTziubPwJb +OwnATy0LmSMAWOw5rKbsh5nfwCiUTM20xp0t5JeXd+wPVWbpWqI2EnkCEN+RJr9i +7dp/ymDQ+Yt5wrsN3NwoyiexPOG91WQVCADdErHsnglVZZq9Z8Wx7KwecGCUurJ2 +H6lKudv5YOxPnAzqZS5HbpZd/nRTMZh2rdXCr5m2YOuewyYjvM757AkmUpM09zJX +MQ1S67/UX2y8/74TcRF97Ncx9HeELs92innBRXoFitnNguvcO6Esx4BTe1OdU6qR +ER3zAmVf22Le9ciXbu24DN4mleOH+OmBx7X2PqJSYW9GAMTsRB081R6EWKH7romQ +waxFrZ4DJzZ9ltyosEJn5F32StyLrFxpcrdLUoEaclZCv2qka7sZvi0EvovDVEBU +e10jOx9AOwf8Gj2ufhquQ6qgVYCzbP+YrodtkFrXRS3IsljIchj1M2ffB/0bfoUs +rtER9pLvYzCjBPg8IfGLw0o754Qbhh/ReplCRTusP/fQMybvCvfxreS3oyEriu/G +GufRomjewZ8EMHDIgUsLcYo2UHZsfF7tcazgxMGmMvazp4r8vpgrvW/8fIN/6Adu +tF+WjWDTvJLFJCe6O+BFJOWrssNrrra1zGtLC1s8s+Wfpe+bGPL5zpHeebGTwH1U +22eqgJArlEKxrfarz7W5+uHZJHSjF/K9ZvunLGD0n9GOPMpji3UO3zeM8IYoWn7E +/EWK1XbjnssNemeeTZ+sDh+qrD7BOi+vCX1IyBxbfqnQfJZvmcPWpruy1UsO+aIC +0GY8Jr3OL69dDQ21jueJAh8EGAEIAAkFAlC9+dkCGwwACgkQL0VeKCTRjd9HCw/+ +LQSVgLLF4ulYlPCjWIIuQwrPbJfWUVVr2dPUFVM85DCv8gBzk5c121snXh9Swovm +laBbw6ate3BmbXLh64jVE9Za5sbTWi7PCcbO/bpRy4d6oLmitmNw6cq0vjTLxUYy +bwuiJxWREkfxuU85EKdouN062YDevH+/YResmlJrcCE7LRlJFeRlKsrrwBU3BqYd +GgFJjKjQC1peeQ9fj62Y7xfwE9+PXbkiWO5u/Bk8hb1VZH1SoIRU98NHVcp6BVvp +VK0jLAXuSauSczULmpRjbyt1lhaAqivDTWEEZXiNNbRyp17c3nVdPWOcgBr42hdQ +z25CgZgyLCsvu82wuXLKJblrIPJX3Yf+si6KqEWBsmwdOWybsjygaF5HvzgFqAAD +U0goPWoQ71PorP2XOUNp5ZLkBQp5etvtkksjVNMIhnHn8PGMuoxO39EUGlWj2B5l +Cu8tSosAzB1pS8NcLZzoNoI9dOHrmgJmP+GrOUkcf5GhNZbMoj4GNfGBRYX0SZlQ +GuDrwNKYj73C4MWyNnnUFyq8nDHJ/G1NpaF2hiof9RBL4PUU/f92JkceXPBXA8gL +Mz2ig1OButwPPLFGQhWqxXAGrsS3Ny+BhTJfnfIbbkaLLphBpDZm1D9XKbAUvdd1 +RZXoH+FTg9UAW87eqU610npOkT6cRaBxaMK/mDtGNdc= +=JTFu +-----END PGP PRIVATE KEY BLOCK----- +""" + +if __name__ == '__main__': + unittest.main() diff --git a/src/leap/soledad/tests/test_logs.py b/src/leap/soledad/tests/test_logs.py new file mode 100644 index 00000000..a68e0262 --- /dev/null +++ b/src/leap/soledad/tests/test_logs.py @@ -0,0 +1,75 @@ +import unittest2 as unittest +from soledad import TransactionLog, SyncLog + + +class LogTestCase(unittest.TestCase): + + + def test_transaction_log(self): + data = [ + (2, "doc_3", "tran_3"), + (3, "doc_2", "tran_2"), + (1, "doc_1", "tran_1") + ] + log = TransactionLog() + log.log = data + self.assertEqual(log.get_generation(), 3, 'error getting generation') + self.assertEqual(log.get_generation_info(), (3, 'tran_2'), + 'error getting generation info') + self.assertEqual(log.get_trans_id_for_gen(1), 'tran_1', + 'error getting trans_id for gen') + self.assertEqual(log.get_trans_id_for_gen(2), 'tran_3', + 'error getting trans_id for gen') + self.assertEqual(log.get_trans_id_for_gen(3), 'tran_2', + 'error getting trans_id for gen') + + def test_sync_log(self): + data = [ + ("replica_3", 3, "tran_3"), + ("replica_2", 2, "tran_2"), + ("replica_1", 1, "tran_1") + ] + log = SyncLog() + log.log = data + # test getting + self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), + (3, 'tran_3'), 'error getting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'), + (2, 'tran_2'), 'error getting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'), + (1, 'tran_1'), 'error getting replica gen and trans id') + # test setting + log.set_replica_gen_and_trans_id('replica_1', 2, 'tran_12') + self.assertEqual(len(log._log), 3, 'error in log size after setting') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'), + (2, 'tran_12'), 'error setting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'), + (2, 'tran_2'), 'error setting replica gen and trans id') + self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'), + (3, 'tran_3'), 'error setting replica gen and trans id') + + def test_whats_changed(self): + data = [ + (2, "doc_3", "tran_3"), + (3, "doc_2", "tran_2"), + (1, "doc_1", "tran_1") + ] + log = TransactionLog() + log.log = data + self.assertEqual( + log.whats_changed(3), + (3, "tran_2", []), + 'error getting whats changed.') + self.assertEqual( + log.whats_changed(2), + (3, "tran_2", [("doc_2",3,"tran_2")]), + 'error getting whats changed.') + self.assertEqual( + log.whats_changed(1), + (3, "tran_2", [("doc_3",2,"tran_3"),("doc_2",3,"tran_2")]), + 'error getting whats changed.') + + +if __name__ == '__main__': + unittest.main() + diff --git a/src/leap/soledad/tests/test_sqlcipher.py b/src/leap/soledad/tests/test_sqlcipher.py new file mode 100644 index 00000000..46f27f73 --- /dev/null +++ b/src/leap/soledad/tests/test_sqlcipher.py @@ -0,0 +1,494 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see . + +"""Test sqlite backend internals.""" + +import os +import time +import threading + +from sqlite3 import dbapi2 + +from u1db import ( + errors, + tests, + query_parser, + ) +from u1db.backends import sqlite_backend +from u1db.tests.test_backends import TestAlternativeDocument + + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + + +class TestSQLiteDatabase(tests.TestCase): + + def test_atomic_initialize(self): + tmpdir = self.createTempDir() + dbname = os.path.join(tmpdir, 'atomic.db') + + t2 = None # will be a thread + + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + _index_storage_value = "testing" + + def __init__(self, dbname, ntry): + self._try = ntry + self._is_initialized_invocations = 0 + super(SQLiteDatabaseTesting, self).__init__(dbname) + + def _is_initialized(self, c): + res = super(SQLiteDatabaseTesting, self)._is_initialized(c) + if self._try == 1: + self._is_initialized_invocations += 1 + if self._is_initialized_invocations == 2: + t2.start() + # hard to do better and have a generic test + time.sleep(0.05) + return res + + outcome2 = [] + + def second_try(): + try: + db2 = SQLiteDatabaseTesting(dbname, 2) + except Exception, e: + outcome2.append(e) + else: + outcome2.append(db2) + + t2 = threading.Thread(target=second_try) + db1 = SQLiteDatabaseTesting(dbname, 1) + t2.join() + + self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting) + db2 = outcome2[0] + self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor())) + + +class TestSQLitePartialExpandDatabase(tests.TestCase): + + def setUp(self): + super(TestSQLitePartialExpandDatabase, self).setUp() + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db._set_replica_uid('test') + + def test_create_database(self): + raw_db = self.db._get_sqlite_handle() + self.assertNotEqual(None, raw_db) + + def test_default_replica_uid(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.assertIsNot(None, self.db._replica_uid) + self.assertEqual(32, len(self.db._replica_uid)) + int(self.db._replica_uid, 16) + + def test__close_sqlite_handle(self): + raw_db = self.db._get_sqlite_handle() + self.db._close_sqlite_handle() + self.assertRaises(dbapi2.ProgrammingError, + raw_db.cursor) + + def test_create_database_initializes_schema(self): + raw_db = self.db._get_sqlite_handle() + c = raw_db.cursor() + c.execute("SELECT * FROM u1db_config") + config = dict([(r[0], r[1]) for r in c.fetchall()]) + self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', + 'index_storage': 'expand referenced'}, config) + + # These tables must exist, though we don't care what is in them yet + c.execute("SELECT * FROM transaction_log") + c.execute("SELECT * FROM document") + c.execute("SELECT * FROM document_fields") + c.execute("SELECT * FROM sync_log") + c.execute("SELECT * FROM conflicts") + c.execute("SELECT * FROM index_definitions") + + def test__parse_index(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + g = self.db._parse_index_definition('fieldname') + self.assertIsInstance(g, query_parser.ExtractField) + self.assertEqual(['fieldname'], g.field) + + def test__update_indexes(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + g = self.db._parse_index_definition('fieldname') + c = self.db._get_sqlite_handle().cursor() + self.db._update_indexes('doc-id', {'fieldname': 'val'}, + [('fieldname', g)], c) + c.execute('SELECT doc_id, field_name, value FROM document_fields') + self.assertEqual([('doc-id', 'fieldname', 'val')], + c.fetchall()) + + def test__set_replica_uid(self): + # Start from scratch, so that replica_uid isn't set. + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.assertIsNot(None, self.db._real_replica_uid) + self.assertIsNot(None, self.db._replica_uid) + self.db._set_replica_uid('foo') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'") + self.assertEqual(('foo',), c.fetchone()) + self.assertEqual('foo', self.db._real_replica_uid) + self.assertEqual('foo', self.db._replica_uid) + self.db._close_sqlite_handle() + self.assertEqual('foo', self.db._replica_uid) + + def test__get_generation(self): + self.assertEqual(0, self.db._get_generation()) + + def test__get_generation_info(self): + self.assertEqual((0, ''), self.db._get_generation_info()) + + def test_create_index(self): + self.db.create_index('test-idx', "key") + self.assertEqual([('test-idx', ["key"])], self.db.list_indexes()) + + def test_create_index_multiple_fields(self): + self.db.create_index('test-idx', "key", "key2") + self.assertEqual([('test-idx', ["key", "key2"])], + self.db.list_indexes()) + + def test__get_index_definition(self): + self.db.create_index('test-idx', "key", "key2") + # TODO: How would you test that an index is getting used for an SQL + # request? + self.assertEqual(["key", "key2"], + self.db._get_index_definition('test-idx')) + + def test_list_index_mixed(self): + # Make sure that we properly order the output + c = self.db._get_sqlite_handle().cursor() + # We intentionally insert the data in weird ordering, to make sure the + # query still gets it back correctly. + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + [('idx-1', 0, 'key10'), + ('idx-2', 2, 'key22'), + ('idx-1', 1, 'key11'), + ('idx-2', 0, 'key20'), + ('idx-2', 1, 'key21')]) + self.assertEqual([('idx-1', ['key10', 'key11']), + ('idx-2', ['key20', 'key21', 'key22'])], + self.db.list_indexes()) + + def test_no_indexes_no_document_fields(self): + self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + + def test_create_extracts_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + self.db.create_index('test', 'key1', 'key2') + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual(sorted( + [(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "val2"), + (doc2.doc_id, "key1", "valx"), + (doc2.doc_id, "key2", "valy"), + ]), sorted(c.fetchall())) + + def test_put_updates_fields(self): + self.db.create_index('test', 'key1', 'key2') + doc1 = self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + doc1.content = {"key1": "val1", "key2": "valy"} + self.db.put_doc(doc1) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "valy"), + ], c.fetchall()) + + def test_put_updates_nested_fields(self): + self.db.create_index('test', 'key', 'sub.doc') + doc1 = self.db.create_doc_from_json(nested_doc) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key", "value"), + (doc1.doc_id, "sub.doc", "underneath"), + ], c.fetchall()) + + def test__ensure_schema_rollback(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/rollback.db' + + class SQLitePartialExpandDbTesting( + sqlite_backend.SQLitePartialExpandDatabase): + + def _set_replica_uid_in_transaction(self, uid): + super(SQLitePartialExpandDbTesting, + self)._set_replica_uid_in_transaction(uid) + if fail: + raise Exception() + + db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + fail = True + self.assertRaises(Exception, db._ensure_schema) + fail = False + db._initialize(db._db_handle.cursor()) + + def test__open_database(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/test.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase._open_database(path) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test__open_database_with_factory(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/test.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase._open_database( + path, document_factory=TestAlternativeDocument) + self.assertEqual(TestAlternativeDocument, db2._factory) + + def test__open_database_non_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase._open_database, path) + + def test__open_database_during_init(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/initialised.db' + db = sqlite_backend.SQLitePartialExpandDatabase.__new__( + sqlite_backend.SQLitePartialExpandDatabase) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + self.addCleanup(db.close) + observed = [] + + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 + + @classmethod + def _which_index_storage(cls, c): + res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c) + db._ensure_schema() # init db + observed.append(res[0]) + return res + + db2 = SQLiteDatabaseTesting._open_database(path) + self.addCleanup(db2.close) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + self.assertEqual([None, + sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], + observed) + + def test__open_database_invalid(self): + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 + temp_dir = self.createTempDir(prefix='u1db-test-') + path1 = temp_dir + '/invalid1.db' + with open(path1, 'wb') as f: + f.write("") + self.assertRaises(dbapi2.OperationalError, + SQLiteDatabaseTesting._open_database, path1) + with open(path1, 'wb') as f: + f.write("invalid") + self.assertRaises(dbapi2.DatabaseError, + SQLiteDatabaseTesting._open_database, path1) + + def test_open_database_existing(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/existing.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_open_database_with_factory(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/existing.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase.open_database( + path, create=False, document_factory=TestAlternativeDocument) + self.assertEqual(TestAlternativeDocument, db2._factory) + + def test_open_database_create(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/new.sqlite' + sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_open_database_non_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.open_database, path, + create=False) + + def test_delete_database_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/new.sqlite' + db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db.close() + sqlite_backend.SQLiteDatabase.delete_database(path) + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.open_database, path, + create=False) + + def test_delete_database_nonexistent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.delete_database, path) + + def test__get_indexed_fields(self): + self.db.create_index('idx1', 'a', 'b') + self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields()) + self.db.create_index('idx2', 'b', 'c') + self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields()) + + def test_indexed_fields_expanded(self): + self.db.create_index('idx1', 'key1') + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def test_create_index_updates_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.db.create_index('idx1', 'key1') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def assertFormatQueryEquals(self, exp_statement, exp_args, definition, + values): + statement, args = self.db._format_query(definition, values) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_query(self): + self.assertFormatQueryEquals( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON " + "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name " + "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content " + "ORDER BY d0.value;", ["key1", "a"], + ["key1"], ["a"]) + + def test__format_query2(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b", "key3", "c"], + ["key1", "key2", "key3"], ["a", "b", "c"]) + + def test__format_query_wildcard(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? ' + 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content ' + 'ORDER BY d0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"], + ["a", "b*", "*"]) + + def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition, + start_value, end_value): + statement, args = self.db._format_range_query( + definition, start_value, end_value) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_range_query(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q', + 'key3', 'r'], + ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"]) + + def test__format_range_query_no_start(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], None, ["a", "b", "c"]) + + def test__format_range_query_no_end(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], ["a", "b", "c"], None) + + def test__format_range_query_wildcard(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? ' + 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? ' + 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id ' + 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, ' + 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*', + 'key3'], + ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) + -- cgit v1.2.3 From a12b80b23695dd1db8ac5edeb4b79e6ff8e527c2 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 15:03:12 -0200 Subject: Fix SQLCipherDatabase and add tests. --- src/leap/soledad/backends/sqlcipher.py | 5 +- src/leap/soledad/tests/__init__.py | 55 +++++++++++++++++++++ src/leap/soledad/tests/test_sqlcipher.py | 84 +++++++++++++++++--------------- 3 files changed, 102 insertions(+), 42 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index fcdab251..301d4a7f 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -60,7 +60,8 @@ def open(path, create, document_factory=None, password=None): class SQLCipherDatabase(SQLitePartialExpandDatabase): """A U1DB implementation that uses SQLCipher as its persistence layer.""" - _sqlite_registry = {} + _index_storage_value = 'expand referenced encrypted' + @classmethod def set_pragma_key(cls, db_handle, key): @@ -113,7 +114,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): raise if backend_cls is None: # default is SQLCipherPartialExpandDatabase - backend_cls = SQLCipherPartialExpandDatabase + backend_cls = SQLCipherDatabase return backend_cls(sqlite_file, document_factory=document_factory, password=password) diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index e69de29b..7918b265 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -0,0 +1,55 @@ +import unittest2 as unittest +import tempfile +import shutil + +class TestCase(unittest.TestCase): + + def createTempDir(self, prefix='u1db-tmp-'): + """Create a temporary directory to do some work in. + + This directory will be scheduled for cleanup when the test ends. + """ + tempdir = tempfile.mkdtemp(prefix=prefix) + self.addCleanup(shutil.rmtree, tempdir) + return tempdir + + def make_document(self, doc_id, doc_rev, content, has_conflicts=False): + return self.make_document_for_test( + self, doc_id, doc_rev, content, has_conflicts) + + def make_document_for_test(self, test, doc_id, doc_rev, content, + has_conflicts): + return make_document_for_test( + test, doc_id, doc_rev, content, has_conflicts) + + def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id)) + + def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, + has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + + def assertGetDocConflicts(self, db, doc_id, conflicts): + """Assert what conflicts are stored for a given doc_id. + + :param conflicts: A list of (doc_rev, content) pairs. + The first item must match the first item returned from the + database, however the rest can be returned in any order. + """ + if conflicts: + conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) + else cont)) for (rev, cont) in conflicts] + conflicts = conflicts[:1] + sorted(conflicts[1:]) + actual = db.get_doc_conflicts(doc_id) + if actual: + actual = [(doc.rev, (json.loads(doc.get_json()) + if doc.get_json() is not None else None)) for doc in actual] + actual = actual[:1] + sorted(actual[1:]) + self.assertEqual(conflicts, actual) + diff --git a/src/leap/soledad/tests/test_sqlcipher.py b/src/leap/soledad/tests/test_sqlcipher.py index 46f27f73..e35a6d90 100644 --- a/src/leap/soledad/tests/test_sqlcipher.py +++ b/src/leap/soledad/tests/test_sqlcipher.py @@ -19,16 +19,17 @@ import os import time import threading +import unittest2 as unittest from sqlite3 import dbapi2 from u1db import ( errors, - tests, query_parser, ) -from u1db.backends import sqlite_backend -from u1db.tests.test_backends import TestAlternativeDocument +from soledad.backends import sqlcipher +from soledad.backends.leap import LeapDocument +from soledad import tests simple_doc = '{"key": "value"}' @@ -43,7 +44,7 @@ class TestSQLiteDatabase(tests.TestCase): t2 = None # will be a thread - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): _index_storage_value = "testing" def __init__(self, dbname, ntry): @@ -84,7 +85,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def setUp(self): super(TestSQLitePartialExpandDatabase, self).setUp() - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') self.db._set_replica_uid('test') def test_create_database(self): @@ -92,7 +93,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): self.assertNotEqual(None, raw_db) def test_default_replica_uid(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') self.assertIsNot(None, self.db._replica_uid) self.assertEqual(32, len(self.db._replica_uid)) int(self.db._replica_uid, 16) @@ -109,7 +110,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): c.execute("SELECT * FROM u1db_config") config = dict([(r[0], r[1]) for r in c.fetchall()]) self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', - 'index_storage': 'expand referenced'}, config) + 'index_storage': 'expand referenced encrypted'}, config) # These tables must exist, though we don't care what is in them yet c.execute("SELECT * FROM transaction_log") @@ -120,13 +121,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): c.execute("SELECT * FROM index_definitions") def test__parse_index(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') g = self.db._parse_index_definition('fieldname') self.assertIsInstance(g, query_parser.ExtractField) self.assertEqual(['fieldname'], g.field) def test__update_indexes(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') g = self.db._parse_index_definition('fieldname') c = self.db._get_sqlite_handle().cursor() self.db._update_indexes('doc-id', {'fieldname': 'val'}, @@ -137,7 +138,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test__set_replica_uid(self): # Start from scratch, so that replica_uid isn't set. - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') self.assertIsNot(None, self.db._real_replica_uid) self.assertIsNot(None, self.db._replica_uid) self.db._set_replica_uid('foo') @@ -239,7 +240,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): path = temp_dir + '/rollback.db' class SQLitePartialExpandDbTesting( - sqlite_backend.SQLitePartialExpandDatabase): + sqlcipher.SQLCipherDatabase): def _set_replica_uid_in_transaction(self, uid): super(SQLitePartialExpandDbTesting, @@ -257,34 +258,34 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test__open_database(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/test.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase._open_database(path) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase._open_database(path) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test__open_database_with_factory(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/test.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase._open_database( - path, document_factory=TestAlternativeDocument) - self.assertEqual(TestAlternativeDocument, db2._factory) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase._open_database( + path, document_factory=LeapDocument) + self.assertEqual(LeapDocument, db2._factory) def test__open_database_non_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase._open_database, path) + sqlcipher.SQLCipherDatabase._open_database, path) def test__open_database_during_init(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/initialised.db' - db = sqlite_backend.SQLitePartialExpandDatabase.__new__( - sqlite_backend.SQLitePartialExpandDatabase) + db = sqlcipher.SQLCipherDatabase.__new__( + sqlcipher.SQLCipherDatabase) db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed self.addCleanup(db.close) observed = [] - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 @classmethod @@ -296,13 +297,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): db2 = SQLiteDatabaseTesting._open_database(path) self.addCleanup(db2.close) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) self.assertEqual([None, - sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], + sqlcipher.SQLCipherDatabase._index_storage_value], observed) def test__open_database_invalid(self): - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 temp_dir = self.createTempDir(prefix='u1db-test-') path1 = temp_dir + '/invalid1.db' @@ -318,47 +319,47 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test_open_database_existing(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/existing.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test_open_database_with_factory(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/existing.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase.open_database( - path, create=False, document_factory=TestAlternativeDocument) - self.assertEqual(TestAlternativeDocument, db2._factory) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase.open_database( + path, create=False, document_factory=LeapDocument) + self.assertEqual(LeapDocument, db2._factory) def test_open_database_create(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/new.sqlite' - sqlite_backend.SQLiteDatabase.open_database(path, create=True) - db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + sqlcipher.SQLCipherDatabase.open_database(path, create=True) + db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test_open_database_non_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.open_database, path, + sqlcipher.SQLCipherDatabase.open_database, path, create=False) def test_delete_database_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/new.sqlite' - db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db = sqlcipher.SQLCipherDatabase.open_database(path, create=True) db.close() - sqlite_backend.SQLiteDatabase.delete_database(path) + sqlcipher.SQLCipherDatabase.delete_database(path) self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.open_database, path, + sqlcipher.SQLCipherDatabase.open_database, path, create=False) def test_delete_database_nonexistent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.delete_database, path) + sqlcipher.SQLCipherDatabase.delete_database, path) def test__get_indexed_fields(self): self.db.create_index('idx1', 'a', 'b') @@ -492,3 +493,6 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): 'key3'], ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3 From a14d5ae150c52c3419764443409b7d146c43cb09 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 16:34:40 -0200 Subject: Fix gnupg prefix path. --- src/leap/soledad/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index d07567b5..45034561 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -2,6 +2,7 @@ """A U1DB implementation for using Object Stores as its persistence layer.""" +import os import gnupg class GPGWrapper(): @@ -10,7 +11,7 @@ class GPGWrapper(): replaced by a more general class used throughout the project. """ - GNUPG_HOME = "~/.config/leap/gnupg" + GNUPG_HOME = os.environ['HOME'] + "/.config/leap/gnupg" GNUPG_BINARY = "/usr/bin/gpg" # this has to be changed based on OS def __init__(self, gpghome=GNUPG_HOME, gpgbinary=GNUPG_BINARY): -- cgit v1.2.3 From 19ee861b5c5dca236800ffcb944b4299561d841d Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 13 Dec 2012 13:29:17 -0200 Subject: Change name of cyphertext field to something more meaningful. --- src/leap/soledad/backends/leap.py | 6 +- src/leap/soledad/tests/test_couch.py | 280 +++++++++++++++++++++++++++++++++ src/leap/soledad/tests/test_couchdb.py | 280 --------------------------------- 3 files changed, 284 insertions(+), 282 deletions(-) create mode 100644 src/leap/soledad/tests/test_couch.py delete mode 100644 src/leap/soledad/tests/test_couchdb.py (limited to 'src/leap') diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index ce00c8f3..c113f5c2 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -43,13 +43,13 @@ class LeapDocument(Document): self._default_key, always_trust = True) # TODO: always trust? - return json.dumps({'cyphertext' : str(cyphertext)}) + return json.dumps({'_encrypted_json' : str(cyphertext)}) def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ - cyphertext = json.loads(encrypted_json)['cyphertext'] + cyphertext = json.loads(encrypted_json)['_encrypted_json'] plaintext = str(self._gpg.decrypt(cyphertext)) return self.set_json(plaintext) @@ -97,6 +97,7 @@ class LeapSyncTarget(HTTPSyncTarget): raise BrokenSyncStream line, comma = utils.check_and_strip_comma(entry) entry = json.loads(line) + # decrypt after receiving from server. doc = LeapDocument(entry['id'], entry['rev'], encrypted_json=entry['content']) return_doc_cb(doc, entry['gen'], entry['trans_id']) @@ -142,6 +143,7 @@ class LeapSyncTarget(HTTPSyncTarget): ensure=ensure_callback is not None) comma = ',' for doc, gen, trans_id in docs_by_generations: + # encrypt before sending to server. size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_encrypted_json(), gen=gen, trans_id=trans_id) diff --git a/src/leap/soledad/tests/test_couch.py b/src/leap/soledad/tests/test_couch.py new file mode 100644 index 00000000..4468ae04 --- /dev/null +++ b/src/leap/soledad/tests/test_couch.py @@ -0,0 +1,280 @@ +import unittest2 +from soledad.backends.couch import CouchDatabase +from soledad.backends.leap import LeapDocument +from u1db import errors, vectorclock + +try: + import simplejson as json +except ImportError: + import json # noqa + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + +def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): + return LeapDocument(doc_id, rev, content, has_conflicts=has_conflicts) + +class CouchTestCase(unittest2.TestCase): + + def setUp(self): + self.db = CouchDatabase('http://localhost:5984', 'u1db_tests') + + def make_document(self, doc_id, doc_rev, content, has_conflicts=False): + return self.make_document_for_test( + self, doc_id, doc_rev, content, has_conflicts) + + def make_document_for_test(self, test, doc_id, doc_rev, content, + has_conflicts): + return make_document_for_test( + test, doc_id, doc_rev, content, has_conflicts) + + def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id)) + + def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, + has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + + + def test_create_doc_allocating_doc_id(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertNotEqual(None, doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_different_ids_same_db(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertNotEqual(doc1.doc_id, doc2.doc_id) + + def test_create_doc_with_id(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id') + self.assertEqual('my-id', doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_existing_id(self): + doc = self.db.create_doc_from_json(simple_doc) + new_content = '{"something": "else"}' + self.assertRaises( + errors.RevisionConflict, self.db.create_doc_from_json, + new_content, doc.doc_id) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_put_doc_creating_initial(self): + doc = self.make_document('my_doc_id', None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertIsNot(None, new_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False) + + def test_put_doc_space_in_id(self): + doc = self.make_document('my doc id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_update(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + orig_rev = doc.rev + doc.set_json('{"updated": "stuff"}') + new_rev = self.db.put_doc(doc) + self.assertNotEqual(new_rev, orig_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, + '{"updated": "stuff"}', False) + self.assertEqual(doc.rev, new_rev) + + def test_put_non_ascii_key(self): + content = json.dumps({u'key\xe5': u'val'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_non_ascii_value(self): + content = json.dumps({'key': u'\xe5'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_doc_refuses_no_id(self): + doc = self.make_document(None, None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document("", None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_refuses_slashes(self): + doc = self.make_document('a/b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document(r'\b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_url_quoting_is_fine(self): + doc_id = "%2F%2Ffoo%2Fbar" + doc = self.make_document(doc_id, None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False) + + def test_put_doc_refuses_non_existing_old_rev(self): + doc = self.make_document('doc-id', 'test:4', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc) + + def test_put_doc_refuses_non_ascii_doc_id(self): + doc = self.make_document('d\xc3\xa5c-id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_fails_with_bad_old_rev(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + old_rev = doc.rev + bad_doc = self.make_document(doc.doc_id, 'other:1', + '{"something": "else"}') + self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc) + self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False) + + def test_create_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(new_doc.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev)) + + def test_put_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + doc2 = self.make_document('my_doc_id', None, simple_doc) + self.db.put_doc(doc2) + self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(doc2.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (doc2.rev, deleted_doc.rev)) + + def test_get_doc_after_put(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False) + + def test_get_doc_nonexisting(self): + self.assertIs(None, self.db.get_doc('non-existing')) + + def test_get_doc_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertIs(None, self.db.get_doc('my_doc_id')) + + def test_get_doc_include_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_get_docs(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual([doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_include_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual( + [doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id], + include_deleted=True))) + + def test_get_docs_request_ordered(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + self.assertEqual([doc2, doc1], + list(self.db.get_docs([doc2.doc_id, doc1.doc_id]))) + + def test_get_docs_empty_list(self): + self.assertEqual([], list(self.db.get_docs([]))) + + def test_handles_nested_content(self): + doc = self.db.create_doc_from_json(nested_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + def test_handles_doc_with_null(self): + doc = self.db.create_doc_from_json('{"key": null}') + self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False) + + def test_delete_doc(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + orig_rev = doc.rev + self.db.delete_doc(doc) + self.assertNotEqual(orig_rev, doc.rev) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + self.assertIs(None, self.db.get_doc(doc.doc_id)) + + def test_delete_doc_non_existent(self): + doc = self.make_document('non-existing', 'other:1', simple_doc) + self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc) + + def test_delete_doc_already_deleted(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertRaises(errors.DocumentAlreadyDeleted, + self.db.delete_doc, doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_delete_doc_bad_rev(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + + def test_delete_doc_sets_content_to_None(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertIs(None, doc.get_json()) + + def test_delete_doc_rev_supersedes(self): + doc = self.db.create_doc_from_json(simple_doc) + doc.set_json(nested_doc) + self.db.put_doc(doc) + doc.set_json('{"fishy": "content"}') + self.db.put_doc(doc) + old_rev = doc.rev + self.db.delete_doc(doc) + cur_vc = vectorclock.VectorClockRev(old_rev) + deleted_vc = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(deleted_vc.is_newer(cur_vc), + "%s does not supersede %s" % (doc.rev, old_rev)) + + def test_delete_then_put(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + doc.set_json(nested_doc) + self.db.put_doc(doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + + + def tearDown(self): + self.db._server.delete('u1db_tests') + +if __name__ == '__main__': + unittest2.main() diff --git a/src/leap/soledad/tests/test_couchdb.py b/src/leap/soledad/tests/test_couchdb.py deleted file mode 100644 index 4468ae04..00000000 --- a/src/leap/soledad/tests/test_couchdb.py +++ /dev/null @@ -1,280 +0,0 @@ -import unittest2 -from soledad.backends.couch import CouchDatabase -from soledad.backends.leap import LeapDocument -from u1db import errors, vectorclock - -try: - import simplejson as json -except ImportError: - import json # noqa - -simple_doc = '{"key": "value"}' -nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' - -def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): - return LeapDocument(doc_id, rev, content, has_conflicts=has_conflicts) - -class CouchTestCase(unittest2.TestCase): - - def setUp(self): - self.db = CouchDatabase('http://localhost:5984', 'u1db_tests') - - def make_document(self, doc_id, doc_rev, content, has_conflicts=False): - return self.make_document_for_test( - self, doc_id, doc_rev, content, has_conflicts) - - def make_document_for_test(self, test, doc_id, doc_rev, content, - has_conflicts): - return make_document_for_test( - test, doc_id, doc_rev, content, has_conflicts) - - def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): - """Assert that the document in the database looks correct.""" - exp_doc = self.make_document(doc_id, doc_rev, content, - has_conflicts=has_conflicts) - self.assertEqual(exp_doc, db.get_doc(doc_id)) - - def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, - has_conflicts): - """Assert that the document in the database looks correct.""" - exp_doc = self.make_document(doc_id, doc_rev, content, - has_conflicts=has_conflicts) - self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) - - - def test_create_doc_allocating_doc_id(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertNotEqual(None, doc.doc_id) - self.assertNotEqual(None, doc.rev) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - - def test_create_doc_different_ids_same_db(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertNotEqual(doc1.doc_id, doc2.doc_id) - - def test_create_doc_with_id(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id') - self.assertEqual('my-id', doc.doc_id) - self.assertNotEqual(None, doc.rev) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - - def test_create_doc_existing_id(self): - doc = self.db.create_doc_from_json(simple_doc) - new_content = '{"something": "else"}' - self.assertRaises( - errors.RevisionConflict, self.db.create_doc_from_json, - new_content, doc.doc_id) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - - def test_put_doc_creating_initial(self): - doc = self.make_document('my_doc_id', None, simple_doc) - new_rev = self.db.put_doc(doc) - self.assertIsNot(None, new_rev) - self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False) - - def test_put_doc_space_in_id(self): - doc = self.make_document('my doc id', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_doc_update(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - orig_rev = doc.rev - doc.set_json('{"updated": "stuff"}') - new_rev = self.db.put_doc(doc) - self.assertNotEqual(new_rev, orig_rev) - self.assertGetDoc(self.db, 'my_doc_id', new_rev, - '{"updated": "stuff"}', False) - self.assertEqual(doc.rev, new_rev) - - def test_put_non_ascii_key(self): - content = json.dumps({u'key\xe5': u'val'}) - doc = self.db.create_doc_from_json(content, doc_id='my_doc') - self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) - - def test_put_non_ascii_value(self): - content = json.dumps({'key': u'\xe5'}) - doc = self.db.create_doc_from_json(content, doc_id='my_doc') - self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) - - def test_put_doc_refuses_no_id(self): - doc = self.make_document(None, None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - doc = self.make_document("", None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_doc_refuses_slashes(self): - doc = self.make_document('a/b', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - doc = self.make_document(r'\b', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_doc_url_quoting_is_fine(self): - doc_id = "%2F%2Ffoo%2Fbar" - doc = self.make_document(doc_id, None, simple_doc) - new_rev = self.db.put_doc(doc) - self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False) - - def test_put_doc_refuses_non_existing_old_rev(self): - doc = self.make_document('doc-id', 'test:4', simple_doc) - self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc) - - def test_put_doc_refuses_non_ascii_doc_id(self): - doc = self.make_document('d\xc3\xa5c-id', None, simple_doc) - self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) - - def test_put_fails_with_bad_old_rev(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - old_rev = doc.rev - bad_doc = self.make_document(doc.doc_id, 'other:1', - '{"something": "else"}') - self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc) - self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False) - - def test_create_succeeds_after_delete(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) - deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) - new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False) - new_vc = vectorclock.VectorClockRev(new_doc.rev) - self.assertTrue( - new_vc.is_newer(deleted_vc), - "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev)) - - def test_put_succeeds_after_delete(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) - deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) - doc2 = self.make_document('my_doc_id', None, simple_doc) - self.db.put_doc(doc2) - self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False) - new_vc = vectorclock.VectorClockRev(doc2.rev) - self.assertTrue( - new_vc.is_newer(deleted_vc), - "%s does not supersede %s" % (doc2.rev, deleted_doc.rev)) - - def test_get_doc_after_put(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False) - - def test_get_doc_nonexisting(self): - self.assertIs(None, self.db.get_doc('non-existing')) - - def test_get_doc_deleted(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - self.assertIs(None, self.db.get_doc('my_doc_id')) - - def test_get_doc_include_deleted(self): - doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') - self.db.delete_doc(doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - - def test_get_docs(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertEqual([doc1, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) - - def test_get_docs_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.db.delete_doc(doc1) - self.assertEqual([doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) - - def test_get_docs_include_deleted(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.db.delete_doc(doc1) - self.assertEqual( - [doc1, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id], - include_deleted=True))) - - def test_get_docs_request_ordered(self): - doc1 = self.db.create_doc_from_json(simple_doc) - doc2 = self.db.create_doc_from_json(nested_doc) - self.assertEqual([doc1, doc2], - list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) - self.assertEqual([doc2, doc1], - list(self.db.get_docs([doc2.doc_id, doc1.doc_id]))) - - def test_get_docs_empty_list(self): - self.assertEqual([], list(self.db.get_docs([]))) - - def test_handles_nested_content(self): - doc = self.db.create_doc_from_json(nested_doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) - - def test_handles_doc_with_null(self): - doc = self.db.create_doc_from_json('{"key": null}') - self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False) - - def test_delete_doc(self): - doc = self.db.create_doc_from_json(simple_doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) - orig_rev = doc.rev - self.db.delete_doc(doc) - self.assertNotEqual(orig_rev, doc.rev) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - self.assertIs(None, self.db.get_doc(doc.doc_id)) - - def test_delete_doc_non_existent(self): - doc = self.make_document('non-existing', 'other:1', simple_doc) - self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc) - - def test_delete_doc_already_deleted(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertRaises(errors.DocumentAlreadyDeleted, - self.db.delete_doc, doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - - def test_delete_doc_bad_rev(self): - doc1 = self.db.create_doc_from_json(simple_doc) - self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) - doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc) - self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2) - self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) - - def test_delete_doc_sets_content_to_None(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertIs(None, doc.get_json()) - - def test_delete_doc_rev_supersedes(self): - doc = self.db.create_doc_from_json(simple_doc) - doc.set_json(nested_doc) - self.db.put_doc(doc) - doc.set_json('{"fishy": "content"}') - self.db.put_doc(doc) - old_rev = doc.rev - self.db.delete_doc(doc) - cur_vc = vectorclock.VectorClockRev(old_rev) - deleted_vc = vectorclock.VectorClockRev(doc.rev) - self.assertTrue(deleted_vc.is_newer(cur_vc), - "%s does not supersede %s" % (doc.rev, old_rev)) - - def test_delete_then_put(self): - doc = self.db.create_doc_from_json(simple_doc) - self.db.delete_doc(doc) - self.assertGetDocIncludeDeleted( - self.db, doc.doc_id, doc.rev, None, False) - doc.set_json(nested_doc) - self.db.put_doc(doc) - self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) - - - - def tearDown(self): - self.db._server.delete('u1db_tests') - -if __name__ == '__main__': - unittest2.main() -- cgit v1.2.3 From ece9f7c2116fa961cafabcc6a5790206412c95ae Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 13 Dec 2012 13:46:27 -0200 Subject: Enforce password on SQLCipher backend. --- src/leap/soledad/backends/sqlcipher.py | 27 +++++------ src/leap/soledad/tests/test_sqlcipher.py | 79 +++++++++++++++++--------------- 2 files changed, 54 insertions(+), 52 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index 301d4a7f..6fd6e619 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -54,7 +54,7 @@ def open(path, create, document_factory=None, password=None): """ from u1db.backends import sqlite_backend return sqlite_backend.SQLCipherDatabase.open_database( - path, create=create, document_factory=document_factory, password=password) + path, password, create=create, document_factory=document_factory) class SQLCipherDatabase(SQLitePartialExpandDatabase): @@ -67,17 +67,16 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): def set_pragma_key(cls, db_handle, key): db_handle.cursor().execute("PRAGMA key = '%s'" % key) - def __init__(self, sqlite_file, document_factory=None, password=None): + def __init__(self, sqlite_file, password, document_factory=None): """Create a new sqlite file.""" self._db_handle = dbapi2.connect(sqlite_file) - if password: - SQLiteDatabase.set_pragma_key(self._db_handle, password) + SQLCipherDatabase.set_pragma_key(self._db_handle, password) self._real_replica_uid = None self._ensure_schema() self._factory = document_factory or Document @classmethod - def _open_database(cls, sqlite_file, document_factory=None, password=None): + def _open_database(cls, sqlite_file, password, document_factory=None): if not os.path.isfile(sqlite_file): raise errors.DatabaseDoesNotExist() tries = 2 @@ -86,8 +85,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): # where without re-opening the database on Windows, it # doesn't see the transaction that was just committed db_handle = dbapi2.connect(sqlite_file) - if password: - SQLiteDatabase.set_pragma_key(db_handle, password) + SQLCipherDatabase.set_pragma_key(db_handle, password) c = db_handle.cursor() v, err = cls._which_index_storage(c) db_handle.close() @@ -100,23 +98,22 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): tries -= 1 time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) return SQLCipherDatabase._sqlite_registry[v]( - sqlite_file, document_factory=document_factory) + sqlite_file, password, document_factory=document_factory) @classmethod - def open_database(cls, sqlite_file, create, backend_cls=None, - document_factory=None, password=None): + def open_database(cls, sqlite_file, password, create, backend_cls=None, + document_factory=None): try: - return cls._open_database(sqlite_file, - document_factory=document_factory, - password=password) + return cls._open_database(sqlite_file, password, + document_factory=document_factory) except errors.DatabaseDoesNotExist: if not create: raise if backend_cls is None: # default is SQLCipherPartialExpandDatabase backend_cls = SQLCipherDatabase - return backend_cls(sqlite_file, document_factory=document_factory, - password=password) + return backend_cls(sqlite_file, password, + document_factory=document_factory) @staticmethod def register_implementation(klass): diff --git a/src/leap/soledad/tests/test_sqlcipher.py b/src/leap/soledad/tests/test_sqlcipher.py index e35a6d90..f9e9f681 100644 --- a/src/leap/soledad/tests/test_sqlcipher.py +++ b/src/leap/soledad/tests/test_sqlcipher.py @@ -36,7 +36,7 @@ simple_doc = '{"key": "value"}' nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' -class TestSQLiteDatabase(tests.TestCase): +class TestSQLCipherDatabase(tests.TestCase): def test_atomic_initialize(self): tmpdir = self.createTempDir() @@ -44,16 +44,17 @@ class TestSQLiteDatabase(tests.TestCase): t2 = None # will be a thread - class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): + class SQLCipherDatabaseTesting(sqlcipher.SQLCipherDatabase): _index_storage_value = "testing" def __init__(self, dbname, ntry): self._try = ntry self._is_initialized_invocations = 0 - super(SQLiteDatabaseTesting, self).__init__(dbname) + password = '123456' + super(SQLCipherDatabaseTesting, self).__init__(dbname, password) def _is_initialized(self, c): - res = super(SQLiteDatabaseTesting, self)._is_initialized(c) + res = super(SQLCipherDatabaseTesting, self)._is_initialized(c) if self._try == 1: self._is_initialized_invocations += 1 if self._is_initialized_invocations == 2: @@ -66,26 +67,29 @@ class TestSQLiteDatabase(tests.TestCase): def second_try(): try: - db2 = SQLiteDatabaseTesting(dbname, 2) + db2 = SQLCipherDatabaseTesting(dbname, 2) except Exception, e: outcome2.append(e) else: outcome2.append(db2) t2 = threading.Thread(target=second_try) - db1 = SQLiteDatabaseTesting(dbname, 1) + db1 = SQLCipherDatabaseTesting(dbname, 1) t2.join() - self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting) + self.assertIsInstance(outcome2[0], SQLCipherDatabaseTesting) db2 = outcome2[0] self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor())) -class TestSQLitePartialExpandDatabase(tests.TestCase): +_password = '123456' + + +class TestSQLCipherPartialExpandDatabase(tests.TestCase): def setUp(self): - super(TestSQLitePartialExpandDatabase, self).setUp() - self.db = sqlcipher.SQLCipherDatabase(':memory:') + super(TestSQLCipherPartialExpandDatabase, self).setUp() + self.db = sqlcipher.SQLCipherDatabase(':memory:', _password) self.db._set_replica_uid('test') def test_create_database(self): @@ -93,7 +97,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): self.assertNotEqual(None, raw_db) def test_default_replica_uid(self): - self.db = sqlcipher.SQLCipherDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:', _password) self.assertIsNot(None, self.db._replica_uid) self.assertEqual(32, len(self.db._replica_uid)) int(self.db._replica_uid, 16) @@ -121,13 +125,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): c.execute("SELECT * FROM index_definitions") def test__parse_index(self): - self.db = sqlcipher.SQLCipherDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:', _password) g = self.db._parse_index_definition('fieldname') self.assertIsInstance(g, query_parser.ExtractField) self.assertEqual(['fieldname'], g.field) def test__update_indexes(self): - self.db = sqlcipher.SQLCipherDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:', _password) g = self.db._parse_index_definition('fieldname') c = self.db._get_sqlite_handle().cursor() self.db._update_indexes('doc-id', {'fieldname': 'val'}, @@ -138,7 +142,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test__set_replica_uid(self): # Start from scratch, so that replica_uid isn't set. - self.db = sqlcipher.SQLCipherDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:', _password) self.assertIsNot(None, self.db._real_replica_uid) self.assertIsNot(None, self.db._replica_uid) self.db._set_replica_uid('foo') @@ -239,16 +243,16 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/rollback.db' - class SQLitePartialExpandDbTesting( + class SQLCipherPartialExpandDbTesting( sqlcipher.SQLCipherDatabase): def _set_replica_uid_in_transaction(self, uid): - super(SQLitePartialExpandDbTesting, + super(SQLCipherPartialExpandDbTesting, self)._set_replica_uid_in_transaction(uid) if fail: raise Exception() - db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting) + db = SQLCipherPartialExpandDbTesting.__new__(SQLCipherPartialExpandDbTesting) db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed fail = True self.assertRaises(Exception, db._ensure_schema) @@ -258,23 +262,23 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test__open_database(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/test.sqlite' - sqlcipher.SQLCipherDatabase(path) - db2 = sqlcipher.SQLCipherDatabase._open_database(path) + sqlcipher.SQLCipherDatabase(path, _password) + db2 = sqlcipher.SQLCipherDatabase._open_database(path, _password) self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test__open_database_with_factory(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/test.sqlite' - sqlcipher.SQLCipherDatabase(path) + sqlcipher.SQLCipherDatabase(path, _password) db2 = sqlcipher.SQLCipherDatabase._open_database( - path, document_factory=LeapDocument) + path, _password, document_factory=LeapDocument) self.assertEqual(LeapDocument, db2._factory) def test__open_database_non_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlcipher.SQLCipherDatabase._open_database, path) + sqlcipher.SQLCipherDatabase._open_database, path, _password) def test__open_database_during_init(self): temp_dir = self.createTempDir(prefix='u1db-test-') @@ -285,17 +289,17 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): self.addCleanup(db.close) observed = [] - class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): + class SQLCipherDatabaseTesting(sqlcipher.SQLCipherDatabase): WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 @classmethod def _which_index_storage(cls, c): - res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c) + res = super(SQLCipherDatabaseTesting, cls)._which_index_storage(c) db._ensure_schema() # init db observed.append(res[0]) return res - db2 = SQLiteDatabaseTesting._open_database(path) + db2 = SQLCipherDatabaseTesting._open_database(path, _password) self.addCleanup(db2.close) self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) self.assertEqual([None, @@ -303,39 +307,40 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): observed) def test__open_database_invalid(self): - class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): + class SQLCipherDatabaseTesting(sqlcipher.SQLCipherDatabase): WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 temp_dir = self.createTempDir(prefix='u1db-test-') path1 = temp_dir + '/invalid1.db' with open(path1, 'wb') as f: f.write("") self.assertRaises(dbapi2.OperationalError, - SQLiteDatabaseTesting._open_database, path1) + SQLCipherDatabaseTesting._open_database, path1, _password) with open(path1, 'wb') as f: f.write("invalid") self.assertRaises(dbapi2.DatabaseError, - SQLiteDatabaseTesting._open_database, path1) + SQLCipherDatabaseTesting._open_database, path1, _password) def test_open_database_existing(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/existing.sqlite' - sqlcipher.SQLCipherDatabase(path) - db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) + sqlcipher.SQLCipherDatabase(path, _password) + db2 = sqlcipher.SQLCipherDatabase.open_database(path, _password, + create=False) self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test_open_database_with_factory(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/existing.sqlite' - sqlcipher.SQLCipherDatabase(path) + sqlcipher.SQLCipherDatabase(path, _password) db2 = sqlcipher.SQLCipherDatabase.open_database( - path, create=False, document_factory=LeapDocument) + path, _password, create=False, document_factory=LeapDocument) self.assertEqual(LeapDocument, db2._factory) def test_open_database_create(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/new.sqlite' - sqlcipher.SQLCipherDatabase.open_database(path, create=True) - db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) + sqlcipher.SQLCipherDatabase.open_database(path, _password, create=True) + db2 = sqlcipher.SQLCipherDatabase.open_database(path, _password, create=False) self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test_open_database_non_existent(self): @@ -343,17 +348,17 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, sqlcipher.SQLCipherDatabase.open_database, path, - create=False) + _password, create=False) def test_delete_database_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/new.sqlite' - db = sqlcipher.SQLCipherDatabase.open_database(path, create=True) + db = sqlcipher.SQLCipherDatabase.open_database(path, _password, create=True) db.close() sqlcipher.SQLCipherDatabase.delete_database(path) self.assertRaises(errors.DatabaseDoesNotExist, sqlcipher.SQLCipherDatabase.open_database, path, - create=False) + _password, create=False) def test_delete_database_nonexistent(self): temp_dir = self.createTempDir(prefix='u1db-test-') -- cgit v1.2.3 From 7a67c36efd95d86dea04ab0741c68f5307a95c09 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 18 Dec 2012 18:51:01 -0200 Subject: Refactor and symmetric encryption --- src/leap/soledad/__init__.py | 245 ++++++++++--------------------- src/leap/soledad/backends/leap.py | 53 ++++--- src/leap/soledad/backends/objectstore.py | 7 +- src/leap/soledad/tests/test_encrypted.py | 15 +- src/leap/soledad/tests/test_logs.py | 2 +- src/leap/soledad/util.py | 170 +++++++++++++++++++++ 6 files changed, 294 insertions(+), 198 deletions(-) create mode 100644 src/leap/soledad/util.py (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 45034561..835111a5 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -3,170 +3,81 @@ """A U1DB implementation for using Object Stores as its persistence layer.""" import os -import gnupg - -class GPGWrapper(): - """ - This is a temporary class for handling GPG requests, and should be - replaced by a more general class used throughout the project. - """ - - GNUPG_HOME = os.environ['HOME'] + "/.config/leap/gnupg" - GNUPG_BINARY = "/usr/bin/gpg" # this has to be changed based on OS - - def __init__(self, gpghome=GNUPG_HOME, gpgbinary=GNUPG_BINARY): - self.gpg = gnupg.GPG(gnupghome=gpghome, gpgbinary=gpgbinary) - - def find_key(self, email): - """ - Find user's key based on their email. - """ - for key in self.gpg.list_keys(): - for uid in key['uids']: - if re.search(email, uid): - return key - raise LookupError("GnuPG public key for %s not found!" % email) - - def encrypt(self, data, recipient, sign=None, always_trust=False, - passphrase=None, symmetric=False): - return self.gpg.encrypt(data, recipient, sign=sign, - always_trust=always_trust, - passphrase=passphrase, symmetric=symmetric) - - def decrypt(self, data, always_trust=False, passphrase=None): - return self.gpg.decrypt(data, always_trust=always_trust, - passphrase=passphrase) - - def import_keys(self, data): - return self.gpg.import_keys(data) - - -#---------------------------------------------------------------------------- -# u1db Transaction and Sync logs. -#---------------------------------------------------------------------------- - -class SimpleLog(object): - def __init__(self): - self._log = [] - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return self._log - - log = property( - _get_log, _set_log, doc="Log contents.") - - def append(self, msg): - self._log.append(msg) - - def reduce(self, func, initializer=None): - return reduce(func, self.log, initializer) - - def map(self, func): - return map(func, self.log) - - def filter(self, func): - return filter(func, self.log) - - -class TransactionLog(SimpleLog): - """ - An ordered list of (generation, doc_id, transaction_id) tuples. - """ - - def _set_log(self, log): - self._log = log - - def _get_log(self): - return sorted(self._log, reverse=True) - - log = property( - _get_log, _set_log, doc="Log contents.") - - def get_generation(self): - """ - Return the current generation. - """ - gens = self.map(lambda x: x[0]) - if not gens: - return 0 - return max(gens) - - def get_generation_info(self): - """ - Return the current generation and transaction id. - """ - if not self._log: - return(0, '') - info = self.map(lambda x: (x[0], x[2])) - return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) - - def get_trans_id_for_gen(self, gen): - """ - Get the transaction id corresponding to a particular generation. - """ - log = self.reduce(lambda x, y: y if y[0] == gen else x) - if log is None: - return None - return log[2] - - def whats_changed(self, old_generation): - """ - Return a list of documents that have changed since old_generation. - """ - results = self.filter(lambda x: x[0] > old_generation) - seen = set() - changes = [] - newest_trans_id = '' - for generation, doc_id, trans_id in results: - if doc_id not in seen: - changes.append((doc_id, generation, trans_id)) - seen.add(doc_id) - if changes: - cur_gen = changes[0][1] # max generation - newest_trans_id = changes[0][2] - changes.reverse() - else: - results = self.log - if not results: - cur_gen = 0 - newest_trans_id = '' - else: - cur_gen, _, newest_trans_id = results[0] - - return cur_gen, newest_trans_id, changes - - - -class SyncLog(SimpleLog): - """ - A list of (replica_id, generation, transaction_id) tuples. - """ - - def find_by_replica_uid(self, replica_uid): - if not self.log: - return () - return self.reduce(lambda x, y: y if y[0] == replica_uid else x) - - def get_replica_gen_and_trans_id(self, other_replica_uid): - """ - Return the last known generation and transaction id for the other db - replica. - """ - info = self.find_by_replica_uid(other_replica_uid) - if not info: - return (0, '') - return (info[1], info[2]) - - def set_replica_gen_and_trans_id(self, other_replica_uid, - other_generation, other_transaction_id): - """ - Set the last-known generation and transaction id for the other - database replica. - """ - self.log = self.filter(lambda x: x[0] != other_replica_uid) - self.append((other_replica_uid, other_generation, - other_transaction_id)) - +import string +import random +import cStringIO +from soledad.util import GPGWrapper + +class Soledad(object): + + PREFIX = os.environ['HOME'] + '/.config/leap/soledad' + SECRET_PATH = PREFIX + '/secret.gpg' + GNUPG_HOME = PREFIX + '/gnupg' + SECRET_LENGTH = 50 + + def __init__(self, user_email, gpghome=None): + self._user_email = user_email + if not os.path.isdir(self.PREFIX): + os.makedirs(self.PREFIX) + if not gpghome: + gpghome = self.GNUPG_HOME + self._gpg = GPGWrapper(gpghome=gpghome) + # load OpenPGP keypair + if not self._has_openpgp_keypair(): + self._gen_openpgp_keypair() + self._load_openpgp_keypair() + # load secret + if not self._has_secret(): + self._gen_secret() + self._load_secret() + + def _has_secret(self): + if os.path.isfile(self.SECRET_PATH): + return True + return False + + def _load_secret(self): + try: + with open(self.SECRET_PATH) as f: + self._secret = self._gpg.decrypt(f.read()) + except IOError as e: + raise IOError('Failed to open secret file %s.' % self.SECRET_PATH) + + def _gen_secret(self): + self._secret = ''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(self.SECRET_LENGTH)) + cyphertext = self._gpg.encrypt(self._secret, self._fingerprint, self._fingerprint) + f = open(self.SECRET_PATH, 'w') + f.write(str(cyphertext)) + f.close() + + + def _has_openpgp_keypair(self): + if self._gpg.find_key(self._user_email): + return True + return False + + def _gen_openpgp_keypair(self): + params = self._gpg.gen_key_input( + key_type='RSA', + key_length=4096, + name_real=self._user_email, + name_email=self._user_email, + name_comment='Generated by LEAP Soledad.') + self._gpg.gen_key(params) + + def _load_openpgp_keypair(self): + self._fingerprint = self._gpg.find_key(self._user_email)['fingerprint'] + + def encrypt(self, data, sign=None, passphrase=None, symmetric=False): + return str(self._gpg.encrypt(data, self._fingerprint, sign=sign, + passphrase=passphrase, symmetric=symmetric)) + + def encrypt_symmetric(self, data, sign=None): + return self.encrypt(data, sign=sign, passphrase=self._secret, + symmetric=True) + + def decrypt(self, data, passphrase=None, symmetric=False): + return str(self._gpg.decrypt(data, passphrase=passphrase)) + + def decrypt_symmetric(self, data): + return self.decrypt(data, passphrase=self._secret) diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index ce00c8f3..4a496d3e 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -7,12 +7,15 @@ from u1db import Document from u1db.remote.http_target import HTTPSyncTarget from u1db.remote.http_database import HTTPDatabase import base64 -from soledad import GPGWrapper +from soledad.util import GPGWrapper class NoDefaultKey(Exception): pass +class NoSoledadInstance(Exception): + pass + class LeapDocument(Document): """ @@ -22,41 +25,40 @@ class LeapDocument(Document): """ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, - encrypted_json=None, default_key=None, gpg_wrapper=None): + encrypted_json=None, soledad=None): super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts) - # we might want to get already initialized wrappers for testing. - if gpg_wrapper is None: - self._gpg = GPGWrapper() - else: - self._gpg = gpg_wrapper + self._soledad = soledad if encrypted_json: self.set_encrypted_json(encrypted_json) - self._default_key = default_key def get_encrypted_json(self): """ Returns document's json serialization encrypted with user's public key. """ - if self._default_key is None: - raise NoDefaultKey() - cyphertext = self._gpg.encrypt(self.get_json(), - self._default_key, - always_trust = True) - # TODO: always trust? - return json.dumps({'cyphertext' : str(cyphertext)}) + if not self._soledad: + raise NoSoledadInstance() + cyphertext = self._soledad.encrypt_symmetric(self.get_json()) + return json.dumps({'_encrypted_json' : cyphertext}) def set_encrypted_json(self, encrypted_json): """ Set document's content based on encrypted version of json string. """ - cyphertext = json.loads(encrypted_json)['cyphertext'] - plaintext = str(self._gpg.decrypt(cyphertext)) + if not self._soledad: + raise NoSoledadInstance() + cyphertext = json.loads(encrypted_json)['_encrypted_json'] + plaintext = self._soledad.decrypt_symmetric(cyphertext) return self.set_json(plaintext) class LeapDatabase(HTTPDatabase): """Implement the HTTP remote database API to a Leap server.""" + def __init__(self, url, document_factory=None, creds=None, soledad=None): + super(LeapDatabase, self).__init__(url, creds=creds) + self._soledad = soledad + self._factory = LeapDocument + @staticmethod def open_database(url, create): db = LeapDatabase(url) @@ -74,9 +76,21 @@ class LeapDatabase(HTTPDatabase): st._creds = self._creds return st + def create_doc_from_json(self, content, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + res, headers = self._request_json('PUT', ['doc', doc_id], {}, + content, 'application/json') + new_doc = self._factory(doc_id, res['rev'], content, soledad=self._soledad) + return new_doc + class LeapSyncTarget(HTTPSyncTarget): + def __init__(self, url, creds=None, soledad=None): + super(LeapSyncTarget, self).__init__(url, creds) + self._soledad = soledad + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): """ Does the same as parent's method but ensures incoming content will be @@ -97,8 +111,10 @@ class LeapSyncTarget(HTTPSyncTarget): raise BrokenSyncStream line, comma = utils.check_and_strip_comma(entry) entry = json.loads(line) + # decrypt after receiving from server. doc = LeapDocument(entry['id'], entry['rev'], - encrypted_json=entry['content']) + encrypted_json=entry['content'], + soledad=self._soledad) return_doc_cb(doc, entry['gen'], entry['trans_id']) if parts[-1] != ']': try: @@ -142,6 +158,7 @@ class LeapSyncTarget(HTTPSyncTarget): ensure=ensure_callback is not None) comma = ',' for doc, gen, trans_id in docs_by_generations: + # encrypt before sending to server. size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_encrypted_json(), gen=gen, trans_id=trans_id) diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index 298bdda3..a8e139f7 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -1,8 +1,7 @@ import uuid from u1db.backends import CommonBackend -from u1db import errors -from soledad import SyncLog, TransactionLog -from soledad.backends.leap import LeapDocument +from u1db import errors, Document +from soledad.util import SyncLog, TransactionLog class ObjectStore(CommonBackend): @@ -11,7 +10,7 @@ class ObjectStore(CommonBackend): # This initialization method should be called after the connection # with the database is established, so it can ensure that u1db data is # configured and up-to-date. - self.set_document_factory(LeapDocument) + self.set_document_factory(Document) self._sync_log = SyncLog() self._transaction_log = TransactionLog() self._ensure_u1db_data() diff --git a/src/leap/soledad/tests/test_encrypted.py b/src/leap/soledad/tests/test_encrypted.py index 2333fc41..eafd258e 100644 --- a/src/leap/soledad/tests/test_encrypted.py +++ b/src/leap/soledad/tests/test_encrypted.py @@ -7,7 +7,7 @@ import unittest2 as unittest import os import u1db -from soledad import GPGWrapper +from soledad import Soledad from soledad.backends.leap import LeapDocument @@ -17,28 +17,27 @@ class EncryptedSyncTestCase(unittest.TestCase): GNUPG_HOME = "%s/gnupg" % PREFIX DB1_FILE = "%s/db1.u1db" % PREFIX DB2_FILE = "%s/db2.u1db" % PREFIX + EMAIL = 'leap@leap.se' def setUp(self): self.db1 = u1db.open(self.DB1_FILE, create=True, document_factory=LeapDocument) self.db2 = u1db.open(self.DB2_FILE, create=True, document_factory=LeapDocument) - self.gpg = GPGWrapper(gpghome=self.GNUPG_HOME) - self.gpg.import_keys(PUBLIC_KEY) - self.gpg.import_keys(PRIVATE_KEY) + self.soledad = Soledad(self.EMAIL, gpghome=self.GNUPG_HOME) + self.soledad._gpg.import_keys(PUBLIC_KEY) + self.soledad._gpg.import_keys(PRIVATE_KEY) def tearDown(self): os.unlink(self.DB1_FILE) os.unlink(self.DB2_FILE) def test_get_set_encrypted(self): - doc1 = LeapDocument(gpg_wrapper = self.gpg, - default_key = KEY_FINGERPRINT) + doc1 = LeapDocument(soledad=self.soledad) doc1.content = { 'key' : 'val' } doc2 = LeapDocument(doc_id=doc1.doc_id, encrypted_json=doc1.get_encrypted_json(), - gpg_wrapper=self.gpg, - default_key = KEY_FINGERPRINT) + soledad=self.soledad) res1 = doc1.get_json() res2 = doc2.get_json() self.assertEqual(res1, res2, 'incorrect document encryption') diff --git a/src/leap/soledad/tests/test_logs.py b/src/leap/soledad/tests/test_logs.py index a68e0262..d61700f2 100644 --- a/src/leap/soledad/tests/test_logs.py +++ b/src/leap/soledad/tests/test_logs.py @@ -1,5 +1,5 @@ import unittest2 as unittest -from soledad import TransactionLog, SyncLog +from soledad.util import TransactionLog, SyncLog class LogTestCase(unittest.TestCase): diff --git a/src/leap/soledad/util.py b/src/leap/soledad/util.py new file mode 100644 index 00000000..1485fce1 --- /dev/null +++ b/src/leap/soledad/util.py @@ -0,0 +1,170 @@ +import os +import gnupg +import re + +class GPGWrapper(): + """ + This is a temporary class for handling GPG requests, and should be + replaced by a more general class used throughout the project. + """ + + GNUPG_HOME = os.environ['HOME'] + "/.config/leap/gnupg" + GNUPG_BINARY = "/usr/bin/gpg" # this has to be changed based on OS + + def __init__(self, gpghome=GNUPG_HOME, gpgbinary=GNUPG_BINARY): + self.gpg = gnupg.GPG(gnupghome=gpghome, gpgbinary=gpgbinary) + + def find_key(self, email): + """ + Find user's key based on their email. + """ + for key in self.gpg.list_keys(): + for uid in key['uids']: + if re.search(email, uid): + return key + raise LookupError("GnuPG public key for %s not found!" % email) + + def encrypt(self, data, recipient, sign=None, always_trust=True, + passphrase=None, symmetric=False): + return self.gpg.encrypt(data, recipient, sign=sign, + always_trust=always_trust, + passphrase=passphrase, symmetric=symmetric) + + def decrypt(self, data, always_trust=True, passphrase=None): + result = self.gpg.decrypt(data, always_trust=always_trust, + passphrase=passphrase) + return result + + def import_keys(self, data): + return self.gpg.import_keys(data) + + +#---------------------------------------------------------------------------- +# u1db Transaction and Sync logs. +#---------------------------------------------------------------------------- + +class SimpleLog(object): + def __init__(self): + self._log = [] + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return self._log + + log = property( + _get_log, _set_log, doc="Log contents.") + + def append(self, msg): + self._log.append(msg) + + def reduce(self, func, initializer=None): + return reduce(func, self.log, initializer) + + def map(self, func): + return map(func, self.log) + + def filter(self, func): + return filter(func, self.log) + + +class TransactionLog(SimpleLog): + """ + An ordered list of (generation, doc_id, transaction_id) tuples. + """ + + def _set_log(self, log): + self._log = log + + def _get_log(self): + return sorted(self._log, reverse=True) + + log = property( + _get_log, _set_log, doc="Log contents.") + + def get_generation(self): + """ + Return the current generation. + """ + gens = self.map(lambda x: x[0]) + if not gens: + return 0 + return max(gens) + + def get_generation_info(self): + """ + Return the current generation and transaction id. + """ + if not self._log: + return(0, '') + info = self.map(lambda x: (x[0], x[2])) + return reduce(lambda x, y: x if (x[0] > y[0]) else y, info) + + def get_trans_id_for_gen(self, gen): + """ + Get the transaction id corresponding to a particular generation. + """ + log = self.reduce(lambda x, y: y if y[0] == gen else x) + if log is None: + return None + return log[2] + + def whats_changed(self, old_generation): + """ + Return a list of documents that have changed since old_generation. + """ + results = self.filter(lambda x: x[0] > old_generation) + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + results = self.log + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, _, newest_trans_id = results[0] + + return cur_gen, newest_trans_id, changes + + + +class SyncLog(SimpleLog): + """ + A list of (replica_id, generation, transaction_id) tuples. + """ + + def find_by_replica_uid(self, replica_uid): + if not self.log: + return () + return self.reduce(lambda x, y: y if y[0] == replica_uid else x) + + def get_replica_gen_and_trans_id(self, other_replica_uid): + """ + Return the last known generation and transaction id for the other db + replica. + """ + info = self.find_by_replica_uid(other_replica_uid) + if not info: + return (0, '') + return (info[1], info[2]) + + def set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """ + Set the last-known generation and transaction id for the other + database replica. + """ + self.log = self.filter(lambda x: x[0] != other_replica_uid) + self.append((other_replica_uid, other_generation, + other_transaction_id)) + -- cgit v1.2.3 From 4cd81148ec25cd6f1a9498345c7405a4d37a4012 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 18 Dec 2012 18:57:01 -0200 Subject: Correct typ0 --- src/leap/soledad/__init__.py | 4 ++-- src/leap/soledad/backends/leap.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 835111a5..4325d773 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -45,9 +45,9 @@ class Soledad(object): def _gen_secret(self): self._secret = ''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(self.SECRET_LENGTH)) - cyphertext = self._gpg.encrypt(self._secret, self._fingerprint, self._fingerprint) + ciphertext = self._gpg.encrypt(self._secret, self._fingerprint, self._fingerprint) f = open(self.SECRET_PATH, 'w') - f.write(str(cyphertext)) + f.write(str(ciphertext)) f.close() diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index 4a496d3e..c019ed3f 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -37,8 +37,8 @@ class LeapDocument(Document): """ if not self._soledad: raise NoSoledadInstance() - cyphertext = self._soledad.encrypt_symmetric(self.get_json()) - return json.dumps({'_encrypted_json' : cyphertext}) + ciphertext = self._soledad.encrypt_symmetric(self.get_json()) + return json.dumps({'_encrypted_json' : ciphertext}) def set_encrypted_json(self, encrypted_json): """ @@ -46,8 +46,8 @@ class LeapDocument(Document): """ if not self._soledad: raise NoSoledadInstance() - cyphertext = json.loads(encrypted_json)['_encrypted_json'] - plaintext = self._soledad.decrypt_symmetric(cyphertext) + ciphertext = json.loads(encrypted_json)['_encrypted_json'] + plaintext = self._soledad.decrypt_symmetric(ciphertext) return self.set_json(plaintext) -- cgit v1.2.3 From 7161784fc65698e2603cf53e797dbd13711689e0 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 20 Dec 2012 11:35:19 -0200 Subject: Use doc_id with HMAC for symmetric encryption --- src/leap/soledad/__init__.py | 14 ++++++++------ src/leap/soledad/backends/leap.py | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 4325d773..9f5d6e22 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -6,6 +6,7 @@ import os import string import random import cStringIO +import hmac from soledad.util import GPGWrapper class Soledad(object): @@ -39,7 +40,7 @@ class Soledad(object): def _load_secret(self): try: with open(self.SECRET_PATH) as f: - self._secret = self._gpg.decrypt(f.read()) + self._secret = str(self._gpg.decrypt(f.read())) except IOError as e: raise IOError('Failed to open secret file %s.' % self.SECRET_PATH) @@ -72,12 +73,13 @@ class Soledad(object): return str(self._gpg.encrypt(data, self._fingerprint, sign=sign, passphrase=passphrase, symmetric=symmetric)) - def encrypt_symmetric(self, data, sign=None): - return self.encrypt(data, sign=sign, passphrase=self._secret, - symmetric=True) + def encrypt_symmetric(self, doc_id, data, sign=None): + h = hmac.new(self._secret, doc_id).hexdigest() + return self.encrypt(data, sign=sign, passphrase=h, symmetric=True) def decrypt(self, data, passphrase=None, symmetric=False): return str(self._gpg.decrypt(data, passphrase=passphrase)) - def decrypt_symmetric(self, data): - return self.decrypt(data, passphrase=self._secret) + def decrypt_symmetric(self, doc_id, data): + h = hmac.new(self._secret, doc_id).hexdigest() + return self.decrypt(data, passphrase=h) diff --git a/src/leap/soledad/backends/leap.py b/src/leap/soledad/backends/leap.py index c019ed3f..9fbd49fe 100644 --- a/src/leap/soledad/backends/leap.py +++ b/src/leap/soledad/backends/leap.py @@ -37,7 +37,7 @@ class LeapDocument(Document): """ if not self._soledad: raise NoSoledadInstance() - ciphertext = self._soledad.encrypt_symmetric(self.get_json()) + ciphertext = self._soledad.encrypt_symmetric(self.doc_id, self.get_json()) return json.dumps({'_encrypted_json' : ciphertext}) def set_encrypted_json(self, encrypted_json): @@ -47,7 +47,7 @@ class LeapDocument(Document): if not self._soledad: raise NoSoledadInstance() ciphertext = json.loads(encrypted_json)['_encrypted_json'] - plaintext = self._soledad.decrypt_symmetric(ciphertext) + plaintext = self._soledad.decrypt_symmetric(self.doc_id, ciphertext) return self.set_json(plaintext) -- cgit v1.2.3 From 940bdd0e06b22fc07faeb3e9a6c9d2963cf69fbb Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 20 Dec 2012 12:41:49 -0200 Subject: Add info about hmac module in README --- src/leap/soledad/README | 2 ++ 1 file changed, 2 insertions(+) (limited to 'src/leap') diff --git a/src/leap/soledad/README b/src/leap/soledad/README index 97976b01..b59d4184 100644 --- a/src/leap/soledad/README +++ b/src/leap/soledad/README @@ -12,8 +12,10 @@ Soledad depends on the following python libraries: * python-swiftclient 1.2.0 [2] * python-gnupg 0.3.1 [3] * CouchDB 0.8 [4] + * hmac 20101005 [5] [1] http://pypi.python.org/pypi/u1db/0.1.4 [2] http://pypi.python.org/pypi/python-swiftclient/1.2.0 [3] http://pypi.python.org/pypi/python-gnupg/0.3.1 [4] http://pypi.python.org/pypi/CouchDB/0.8 +[5] http://pypi.python.org/pypi/hmac/20101005 -- cgit v1.2.3 From 8ec2353d688a6064e5c2cd69745e246c12707b95 Mon Sep 17 00:00:00 2001 From: drebs Date: Thu, 20 Dec 2012 12:42:34 -0200 Subject: Fix OpenPGP key generation. --- src/leap/soledad/__init__.py | 6 ++++-- src/leap/soledad/util.py | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'src/leap') diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py index 9f5d6e22..6a3707ea 100644 --- a/src/leap/soledad/__init__.py +++ b/src/leap/soledad/__init__.py @@ -53,9 +53,11 @@ class Soledad(object): def _has_openpgp_keypair(self): - if self._gpg.find_key(self._user_email): + try: + self._gpg.find_key(self._user_email) return True - return False + except LookupError: + return False def _gen_openpgp_keypair(self): params = self._gpg.gen_key_input( diff --git a/src/leap/soledad/util.py b/src/leap/soledad/util.py index 1485fce1..41fd4548 100644 --- a/src/leap/soledad/util.py +++ b/src/leap/soledad/util.py @@ -38,6 +38,12 @@ class GPGWrapper(): def import_keys(self, data): return self.gpg.import_keys(data) + def gen_key_input(self, **kwargs): + return self.gpg.gen_key_input(**kwargs) + + def gen_key(self, input): + return self.gpg.gen_key(input) + #---------------------------------------------------------------------------- # u1db Transaction and Sync logs. -- cgit v1.2.3 From 277f17aa7b7bbcc48583149a3d72d8621f83c0ff Mon Sep 17 00:00:00 2001 From: drebs Date: Mon, 24 Dec 2012 10:13:12 -0200 Subject: Document ObjectStore --- src/leap/soledad/backends/objectstore.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'src/leap') diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py index a8e139f7..61445a1f 100644 --- a/src/leap/soledad/backends/objectstore.py +++ b/src/leap/soledad/backends/objectstore.py @@ -5,6 +5,9 @@ from soledad.util import SyncLog, TransactionLog class ObjectStore(CommonBackend): + """ + A backend for storing u1db data in an object store. + """ def __init__(self): # This initialization method should be called after the connection @@ -153,9 +156,13 @@ class ObjectStore(CommonBackend): raise errors.InvalidGeneration return trans_id + #------------------------------------------------------------------------- + # methods specific for object stores + #------------------------------------------------------------------------- + def _ensure_u1db_data(self): """ - Guarantee that u1db data exists in store. + Guarantee that u1db data (logs and replica info) exists in store. """ if not self._is_initialized(): self._initialize() -- cgit v1.2.3