summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--u1db/__init__.py697
-rw-r--r--u1db/backends/__init__.py211
-rw-r--r--u1db/backends/dbschema.sql42
-rw-r--r--u1db/backends/inmemory.py469
-rw-r--r--u1db/backends/sqlite_backend.py926
-rw-r--r--u1db/commandline/__init__.py15
-rw-r--r--u1db/commandline/client.py497
-rw-r--r--u1db/commandline/command.py80
-rw-r--r--u1db/commandline/serve.py34
-rw-r--r--u1db/errors.py189
-rw-r--r--u1db/query_parser.py370
-rw-r--r--u1db/remote/__init__.py15
-rw-r--r--u1db/remote/basic_auth_middleware.py68
-rw-r--r--u1db/remote/http_app.py629
-rw-r--r--u1db/remote/http_client.py218
-rw-r--r--u1db/remote/http_database.py143
-rw-r--r--u1db/remote/http_errors.py46
-rw-r--r--u1db/remote/http_target.py135
-rw-r--r--u1db/remote/oauth_middleware.py89
-rw-r--r--u1db/remote/server_state.py67
-rw-r--r--u1db/remote/ssl_match_hostname.py64
-rw-r--r--u1db/remote/utils.py23
-rw-r--r--u1db/sync.py304
-rw-r--r--u1db/tests/__init__.py463
-rw-r--r--u1db/tests/c_backend_wrapper.pyx1541
-rw-r--r--u1db/tests/commandline/__init__.py47
-rw-r--r--u1db/tests/commandline/test_client.py916
-rw-r--r--u1db/tests/commandline/test_command.py105
-rw-r--r--u1db/tests/commandline/test_serve.py101
-rw-r--r--u1db/tests/test_auth_middleware.py309
-rw-r--r--u1db/tests/test_backends.py1895
-rw-r--r--u1db/tests/test_c_backend.py634
-rw-r--r--u1db/tests/test_common_backend.py33
-rw-r--r--u1db/tests/test_document.py148
-rw-r--r--u1db/tests/test_errors.py61
-rw-r--r--u1db/tests/test_http_app.py1133
-rw-r--r--u1db/tests/test_http_client.py361
-rw-r--r--u1db/tests/test_http_database.py256
-rw-r--r--u1db/tests/test_https.py117
-rw-r--r--u1db/tests/test_inmemory.py128
-rw-r--r--u1db/tests/test_open.py69
-rw-r--r--u1db/tests/test_query_parser.py443
-rw-r--r--u1db/tests/test_remote_sync_target.py314
-rw-r--r--u1db/tests/test_remote_utils.py36
-rw-r--r--u1db/tests/test_server_state.py93
-rw-r--r--u1db/tests/test_sqlite_backend.py493
-rw-r--r--u1db/tests/test_sync.py1285
-rw-r--r--u1db/tests/test_test_infrastructure.py41
-rw-r--r--u1db/tests/test_vectorclock.py121
-rw-r--r--u1db/tests/testing-certs/Makefile35
-rw-r--r--u1db/tests/testing-certs/cacert.pem58
-rw-r--r--u1db/tests/testing-certs/testing.cert61
-rw-r--r--u1db/tests/testing-certs/testing.key16
-rw-r--r--u1db/vectorclock.py89
54 files changed, 16733 insertions, 0 deletions
diff --git a/u1db/__init__.py b/u1db/__init__.py
new file mode 100644
index 00000000..ed41bb03
--- /dev/null
+++ b/u1db/__init__.py
@@ -0,0 +1,697 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""U1DB"""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db.errors import InvalidJSON, InvalidContent
+
+__version_info__ = (0, 1, 4)
+__version__ = '.'.join(map(str, __version_info__))
+
+
+def open(path, create, document_factory=None):
+ """Open a database at the given location.
+
+ Will raise u1db.errors.DatabaseDoesNotExist if create=False and the
+ database does not already exist.
+
+ :param path: The filesystem path for the database to open.
+ :param create: True/False, should the database be created if it doesn't
+ already exist?
+ :param document_factory: A function that will be called with the same
+ parameters as Document.__init__.
+ :return: An instance of Database.
+ """
+ from u1db.backends import sqlite_backend
+ return sqlite_backend.SQLiteDatabase.open_database(
+ path, create=create, document_factory=document_factory)
+
+
+# constraints on database names (relevant for remote access, as regex)
+DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*"
+
+# constraints on doc ids (as regex)
+# (no slashes, and no characters outside the ascii range)
+DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+"
+
+
+class Database(object):
+ """A JSON Document data store.
+
+ This data store can be synchronized with other u1db.Database instances.
+ """
+
+ def set_document_factory(self, factory):
+ """Set the document factory that will be used to create objects to be
+ returned as documents by the database.
+
+ :param factory: A function that returns an object which at minimum must
+ satisfy the same interface as does the class DocumentBase.
+ Subclassing that class is the easiest way to create such
+ a function.
+ """
+ raise NotImplementedError(self.set_document_factory)
+
+ def set_document_size_limit(self, limit):
+ """Set the maximum allowed document size for this database.
+
+ :param limit: Maximum allowed document size in bytes.
+ """
+ raise NotImplementedError(self.set_document_size_limit)
+
+ def whats_changed(self, old_generation=0):
+ """Return a list of documents that have changed since old_generation.
+ This allows APPS to only store a db generation before going
+ 'offline', and then when coming back online they can use this
+ data to update whatever extra data they are storing.
+
+ :param old_generation: The generation of the database in the old
+ state.
+ :return: (generation, trans_id, [(doc_id, generation, trans_id),...])
+ The current generation of the database, its associated transaction
+ id, and a list of of changed documents since old_generation,
+ represented by tuples with for each document its doc_id and the
+ generation and transaction id corresponding to the last intervening
+ change and sorted by generation (old changes first)
+ """
+ raise NotImplementedError(self.whats_changed)
+
+ def get_doc(self, doc_id, include_deleted=False):
+ """Get the JSON string for the given document.
+
+ :param doc_id: The unique document identifier
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise asking for a deleted
+ document will return None.
+ :return: a Document object.
+ """
+ raise NotImplementedError(self.get_doc)
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ """Get the JSON content for many documents.
+
+ :param doc_ids: A list of document identifiers.
+ :param check_for_conflicts: If set to False, then the conflict check
+ will be skipped, and 'None' will be returned instead of True/False.
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise deleted documents will not
+ be included in the results.
+ :return: iterable giving the Document object for each document id
+ in matching doc_ids order.
+ """
+ raise NotImplementedError(self.get_docs)
+
+ def get_all_docs(self, include_deleted=False):
+ """Get the JSON content for all documents in the database.
+
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise deleted documents will not
+ be included in the results.
+ :return: (generation, [Document])
+ The current generation of the database, followed by a list of all
+ the documents in the database.
+ """
+ raise NotImplementedError(self.get_all_docs)
+
+ def create_doc(self, content, doc_id=None):
+ """Create a new document.
+
+ You can optionally specify the document identifier, but the document
+ must not already exist. See 'put_doc' if you want to override an
+ existing document.
+ If the database specifies a maximum document size and the document
+ exceeds it, create will fail and raise a DocumentTooBig exception.
+
+ :param content: A Python dictionary.
+ :param doc_id: An optional identifier specifying the document id.
+ :return: Document
+ """
+ raise NotImplementedError(self.create_doc)
+
+ def create_doc_from_json(self, json, doc_id=None):
+ """Create a new document.
+
+ You can optionally specify the document identifier, but the document
+ must not already exist. See 'put_doc' if you want to override an
+ existing document.
+ If the database specifies a maximum document size and the document
+ exceeds it, create will fail and raise a DocumentTooBig exception.
+
+ :param json: The JSON document string
+ :param doc_id: An optional identifier specifying the document id.
+ :return: Document
+ """
+ raise NotImplementedError(self.create_doc_from_json)
+
+ def put_doc(self, doc):
+ """Update a document.
+ If the document currently has conflicts, put will fail.
+ If the database specifies a maximum document size and the document
+ exceeds it, put will fail and raise a DocumentTooBig exception.
+
+ :param doc: A Document with new content.
+ :return: new_doc_rev - The new revision identifier for the document.
+ The Document object will also be updated.
+ """
+ raise NotImplementedError(self.put_doc)
+
+ def delete_doc(self, doc):
+ """Mark a document as deleted.
+ Will abort if the current revision doesn't match doc.rev.
+ This will also set doc.content to None.
+ """
+ raise NotImplementedError(self.delete_doc)
+
+ def create_index(self, index_name, *index_expressions):
+ """Create an named index, which can then be queried for future lookups.
+ Creating an index which already exists is not an error, and is cheap.
+ Creating an index which does not match the index_expressions of the
+ existing index is an error.
+ Creating an index will block until the expressions have been evaluated
+ and the index generated.
+
+ :param index_name: A unique name which can be used as a key prefix
+ :param index_expressions: index expressions defining the index
+ information.
+
+ Examples:
+
+ "fieldname", or "fieldname.subfieldname" to index alphabetically
+ sorted on the contents of a field.
+
+ "number(fieldname, width)", "lower(fieldname)"
+ """
+ raise NotImplementedError(self.create_index)
+
+ def delete_index(self, index_name):
+ """Remove a named index.
+
+ :param index_name: The name of the index we are removing
+ """
+ raise NotImplementedError(self.delete_index)
+
+ def list_indexes(self):
+ """List the definitions of all known indexes.
+
+ :return: A list of [('index-name', ['field', 'field2'])] definitions.
+ """
+ raise NotImplementedError(self.list_indexes)
+
+ def get_from_index(self, index_name, *key_values):
+ """Return documents that match the keys supplied.
+
+ You must supply exactly the same number of values as have been defined
+ in the index. It is possible to do a prefix match by using '*' to
+ indicate a wildcard match. You can only supply '*' to trailing entries,
+ (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.)
+ It is also possible to append a '*' to the last supplied value (eg
+ 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*')
+
+ :param index_name: The index to query
+ :param key_values: values to match. eg, if you have
+ an index with 3 fields then you would have:
+ get_from_index(index_name, val1, val2, val3)
+ :return: List of [Document]
+ """
+ raise NotImplementedError(self.get_from_index)
+
+ def get_range_from_index(self, index_name, start_value, end_value):
+ """Return documents that fall within the specified range.
+
+ Both ends of the range are inclusive. For both start_value and
+ end_value, one must supply exactly the same number of values as have
+ been defined in the index, or pass None. In case of a single column
+ index, a string is accepted as an alternative for a tuple with a single
+ value. It is possible to do a prefix match by using '*' to indicate
+ a wildcard match. You can only supply '*' to trailing entries, (eg
+ 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also
+ possible to append a '*' to the last supplied value (eg 'val*', '*',
+ '*' or 'val', 'val*', '*', but not 'val*', 'val', '*')
+
+ :param index_name: The index to query
+ :param start_values: tuples of values that define the lower bound of
+ the range. eg, if you have an index with 3 fields then you would
+ have: (val1, val2, val3)
+ :param end_values: tuples of values that define the upper bound of the
+ range. eg, if you have an index with 3 fields then you would have:
+ (val1, val2, val3)
+ :return: List of [Document]
+ """
+ raise NotImplementedError(self.get_range_from_index)
+
+ def get_index_keys(self, index_name):
+ """Return all keys under which documents are indexed in this index.
+
+ :param index_name: The index to query
+ :return: [] A list of tuples of indexed keys.
+ """
+ raise NotImplementedError(self.get_index_keys)
+
+ def get_doc_conflicts(self, doc_id):
+ """Get the list of conflicts for the given document.
+
+ The order of the conflicts is such that the first entry is the value
+ that would be returned by "get_doc".
+
+ :return: [doc] A list of the Document entries that are conflicted.
+ """
+ raise NotImplementedError(self.get_doc_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ """Mark a document as no longer conflicted.
+
+ We take the list of revisions that the client knows about that it is
+ superseding. This may be a different list from the actual current
+ conflicts, in which case only those are removed as conflicted. This
+ may fail if the conflict list is significantly different from the
+ supplied information. (sync could have happened in the background from
+ the time you GET_DOC_CONFLICTS until the point where you RESOLVE)
+
+ :param doc: A Document with the new content to be inserted.
+ :param conflicted_doc_revs: A list of revisions that the new content
+ supersedes.
+ """
+ raise NotImplementedError(self.resolve_doc)
+
+ def get_sync_target(self):
+ """Return a SyncTarget object, for another u1db to synchronize with.
+
+ :return: An instance of SyncTarget.
+ """
+ raise NotImplementedError(self.get_sync_target)
+
+ def close(self):
+ """Release any resources associated with this database."""
+ raise NotImplementedError(self.close)
+
+ def sync(self, url, creds=None, autocreate=True):
+ """Synchronize documents with remote replica exposed at url.
+
+ :param url: the url of the target replica to sync with.
+ :param creds: optional dictionary giving credentials
+ to authorize the operation with the server. For using OAuth
+ the form of creds is:
+ {'oauth': {
+ 'consumer_key': ...,
+ 'consumer_secret': ...,
+ 'token_key': ...,
+ 'token_secret': ...
+ }}
+ :param autocreate: ask the target to create the db if non-existent.
+ :return: local_gen_before_sync The local generation before the
+ synchronisation was performed. This is useful to pass into
+ whatschanged, if an application wants to know which documents were
+ affected by a synchronisation.
+ """
+ from u1db.sync import Synchronizer
+ from u1db.remote.http_target import HTTPSyncTarget
+ return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync(
+ autocreate=autocreate)
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ """Return the last known generation and transaction id for the other db
+ replica.
+
+ When you do a synchronization with another replica, the Database keeps
+ track of what generation the other database replica was at, and what
+ the associated transaction id was. This is used to determine what data
+ needs to be sent, and if two databases are claiming to be the same
+ replica.
+
+ :param other_replica_uid: The identifier for the other replica.
+ :return: (gen, trans_id) The generation and transaction id we
+ encountered during synchronization. If we've never synchronized
+ with the replica, this is (0, '').
+ """
+ raise NotImplementedError(self._get_replica_gen_and_trans_id)
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ """Set the last-known generation and transaction id for the other
+ database replica.
+
+ We have just performed some synchronization, and we want to track what
+ generation the other replica was at. See also
+ _get_replica_gen_and_trans_id.
+ :param other_replica_uid: The U1DB identifier for the other replica.
+ :param other_generation: The generation number for the other replica.
+ :param other_transaction_id: The transaction id associated with the
+ generation.
+ """
+ raise NotImplementedError(self._set_replica_gen_and_trans_id)
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen,
+ replica_trans_id=''):
+ """Insert/update document into the database with a given revision.
+
+ This api is used during synchronization operations.
+
+ If a document would conflict and save_conflict is set to True, the
+ content will be selected as the 'current' content for doc.doc_id,
+ even though doc.rev doesn't supersede the currently stored revision.
+ The currently stored document will be added to the list of conflict
+ alternatives for the given doc_id.
+
+ This forces the new content to be 'current' so that we get convergence
+ after synchronizing, even if people don't resolve conflicts. Users can
+ then notice that their content is out of date, update it, and
+ synchronize again. (The alternative is that users could synchronize and
+ think the data has propagated, but their local copy looks fine, and the
+ remote copy is never updated again.)
+
+ :param doc: A Document object
+ :param save_conflict: If this document is a conflict, do you want to
+ save it as a conflict, or just ignore it.
+ :param replica_uid: A unique replica identifier.
+ :param replica_gen: The generation of the replica corresponding to the
+ this document. The replica arguments are optional, but are used
+ during synchronization.
+ :param replica_trans_id: The transaction_id associated with the
+ generation.
+ :return: (state, at_gen) - If we don't have doc_id already,
+ or if doc_rev supersedes the existing document revision,
+ then the content will be inserted, and state is 'inserted'.
+ If doc_rev is less than or equal to the existing revision,
+ then the put is ignored and state is respecitvely 'superseded'
+ or 'converged'.
+ If doc_rev is not strictly superseded or supersedes, then
+ state is 'conflicted'. The document will not be inserted if
+ save_conflict is False.
+ For 'inserted' or 'converged', at_gen is the insertion/current
+ generation.
+ """
+ raise NotImplementedError(self._put_doc_if_newer)
+
+
+class DocumentBase(object):
+ """Container for handling a single document.
+
+ :ivar doc_id: Unique identifier for this document.
+ :ivar rev: The revision identifier of the document.
+ :ivar json_string: The JSON string for this document.
+ :ivar has_conflicts: Boolean indicating if this document has conflicts
+ """
+
+ def __init__(self, doc_id, rev, json_string, has_conflicts=False):
+ self.doc_id = doc_id
+ self.rev = rev
+ if json_string is not None:
+ try:
+ value = json.loads(json_string)
+ except json.JSONDecodeError:
+ raise InvalidJSON
+ if not isinstance(value, dict):
+ raise InvalidJSON
+ self._json = json_string
+ self.has_conflicts = has_conflicts
+
+ def same_content_as(self, other):
+ """Compare the content of two documents."""
+ if self._json:
+ c1 = json.loads(self._json)
+ else:
+ c1 = None
+ if other._json:
+ c2 = json.loads(other._json)
+ else:
+ c2 = None
+ return c1 == c2
+
+ def __repr__(self):
+ if self.has_conflicts:
+ extra = ', conflicted'
+ else:
+ extra = ''
+ return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id,
+ self.rev, extra, self.get_json())
+
+ def __hash__(self):
+ raise NotImplementedError(self.__hash__)
+
+ def __eq__(self, other):
+ if not isinstance(other, Document):
+ return NotImplemented
+ return (
+ self.doc_id == other.doc_id and self.rev == other.rev and
+ self.same_content_as(other) and self.has_conflicts ==
+ other.has_conflicts)
+
+ def __lt__(self, other):
+ """This is meant for testing, not part of the official api.
+
+ It is implemented so that sorted([Document, Document]) can be used.
+ It doesn't imply that users would want their documents to be sorted in
+ this order.
+ """
+ # Since this is just for testing, we don't worry about comparing
+ # against things that aren't a Document.
+ return ((self.doc_id, self.rev, self.get_json())
+ < (other.doc_id, other.rev, other.get_json()))
+
+ def get_json(self):
+ """Get the json serialization of this document."""
+ if self._json is not None:
+ return self._json
+ return None
+
+ def get_size(self):
+ """Calculate the total size of the document."""
+ size = 0
+ json = self.get_json()
+ if json:
+ size += len(json)
+ if self.rev:
+ size += len(self.rev)
+ if self.doc_id:
+ size += len(self.doc_id)
+ return size
+
+ def set_json(self, json_string):
+ """Set the json serialization of this document."""
+ if json_string is not None:
+ try:
+ value = json.loads(json_string)
+ except json.JSONDecodeError:
+ raise InvalidJSON
+ if not isinstance(value, dict):
+ raise InvalidJSON
+ self._json = json_string
+
+ def make_tombstone(self):
+ """Make this document into a tombstone."""
+ self._json = None
+
+ def is_tombstone(self):
+ """Return True if the document is a tombstone, False otherwise."""
+ if self._json is not None:
+ return False
+ return True
+
+
+class Document(DocumentBase):
+ """Container for handling a single document.
+
+ :ivar doc_id: Unique identifier for this document.
+ :ivar rev: The revision identifier of the document.
+ :ivar json: The JSON string for this document.
+ :ivar has_conflicts: Boolean indicating if this document has conflicts
+ """
+
+ # The following part of the API is optional: no implementation is forced to
+ # have it but if the language supports dictionaries/hashtables, it makes
+ # Documents a lot more user friendly.
+
+ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False):
+ # TODO: We convert the json in the superclass to check its validity so
+ # we might as well set _content here directly since the price is
+ # already being paid.
+ super(Document, self).__init__(doc_id, rev, json, has_conflicts)
+ self._content = None
+
+ def same_content_as(self, other):
+ """Compare the content of two documents."""
+ if self._json:
+ c1 = json.loads(self._json)
+ else:
+ c1 = self._content
+ if other._json:
+ c2 = json.loads(other._json)
+ else:
+ c2 = other._content
+ return c1 == c2
+
+ def get_json(self):
+ """Get the json serialization of this document."""
+ json_string = super(Document, self).get_json()
+ if json_string is not None:
+ return json_string
+ if self._content is not None:
+ return json.dumps(self._content)
+ return None
+
+ def set_json(self, json):
+ """Set the json serialization of this document."""
+ self._content = None
+ super(Document, self).set_json(json)
+
+ def make_tombstone(self):
+ """Make this document into a tombstone."""
+ self._content = None
+ super(Document, self).make_tombstone()
+
+ def is_tombstone(self):
+ """Return True if the document is a tombstone, False otherwise."""
+ if self._content is not None:
+ return False
+ return super(Document, self).is_tombstone()
+
+ def _get_content(self):
+ """Get the dictionary representing this document."""
+ if self._json is not None:
+ self._content = json.loads(self._json)
+ self._json = None
+ if self._content is not None:
+ return self._content
+ return None
+
+ def _set_content(self, content):
+ """Set the dictionary representing this document."""
+ try:
+ tmp = json.dumps(content)
+ except TypeError:
+ raise InvalidContent(
+ "Can not be converted to JSON: %r" % (content,))
+ if not tmp.startswith('{'):
+ raise InvalidContent(
+ "Can not be converted to a JSON object: %r." % (content,))
+ # We might as well store the JSON at this point since we did the work
+ # of encoding it, and it doesn't lose any information.
+ self._json = tmp
+ self._content = None
+
+ content = property(
+ _get_content, _set_content, doc="Content of the Document.")
+
+ # End of optional part.
+
+
+class SyncTarget(object):
+ """Functionality for using a Database as a synchronization target."""
+
+ def get_sync_info(self, source_replica_uid):
+ """Return information about known state.
+
+ Return the replica_uid and the current database generation of this
+ database, and the last-seen database generation for source_replica_uid
+
+ :param source_replica_uid: Another replica which we might have
+ synchronized with in the past.
+ :return: (target_replica_uid, target_replica_generation,
+ target_trans_id, source_replica_last_known_generation,
+ source_replica_last_known_transaction_id)
+ """
+ raise NotImplementedError(self.get_sync_info)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ """Record tip information for another replica.
+
+ After sync_exchange has been processed, the caller will have
+ received new content from this replica. This call allows the
+ source replica instigating the sync to inform us what their
+ generation became after applying the documents we returned.
+
+ This is used to allow future sync operations to not need to repeat data
+ that we just talked about. It also means that if this is called at the
+ wrong time, there can be database records that will never be
+ synchronized.
+
+ :param source_replica_uid: The identifier for the source replica.
+ :param source_replica_generation:
+ The database generation for the source replica.
+ :param source_replica_transaction_id: The transaction id associated
+ with the source replica generation.
+ """
+ raise NotImplementedError(self.record_sync_info)
+
+ def sync_exchange(self, docs_by_generation, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ """Incorporate the documents sent from the source replica.
+
+ This is not meant to be called by client code directly, but is used as
+ part of sync().
+
+ This adds docs to the local store, and determines documents that need
+ to be returned to the source replica.
+
+ Documents must be supplied in docs_by_generation paired with
+ the generation of their latest change in order from the oldest
+ change to the newest, that means from the oldest generation to
+ the newest.
+
+ Documents are also returned paired with the generation of
+ their latest change in order from the oldest change to the
+ newest.
+
+ :param docs_by_generation: A list of [(Document, generation,
+ transaction_id)] tuples indicating documents which should be
+ updated on this replica paired with the generation and transaction
+ id of their latest change.
+ :param source_replica_uid: The source replica's identifier
+ :param last_known_generation: The last generation that the source
+ replica knows about this target replica
+ :param last_known_trans_id: The last transaction id that the source
+ replica knows about this target replica
+ :param: return_doc_cb(doc, gen): is a callback
+ used to return documents to the source replica, it will
+ be invoked in turn with Documents that have changed since
+ last_known_generation together with the generation of
+ their last change.
+ :param: ensure_callback(replica_uid): if set the target may create
+ the target db if not yet existent, the callback can then
+ be used to inform of the created db replica uid.
+ :return: new_generation - After applying docs_by_generation, this is
+ the current generation for this replica
+ """
+ raise NotImplementedError(self.sync_exchange)
+
+ def _set_trace_hook(self, cb):
+ """Set a callback that will be invoked to trace database actions.
+
+ The callback will be passed a string indicating the current state, and
+ the sync target object. Implementations do not have to implement this
+ api, it is used by the test suite.
+
+ :param cb: A callable that takes cb(state)
+ """
+ raise NotImplementedError(self._set_trace_hook)
+
+ def _set_trace_hook_shallow(self, cb):
+ """Set a callback that will be invoked to trace database actions.
+
+ Similar to _set_trace_hook, for implementations that don't offer
+ state changes from the inner working of sync_exchange().
+
+ :param cb: A callable that takes cb(state)
+ """
+ self._set_trace_hook(cb)
diff --git a/u1db/backends/__init__.py b/u1db/backends/__init__.py
new file mode 100644
index 00000000..c8e5adc6
--- /dev/null
+++ b/u1db/backends/__init__.py
@@ -0,0 +1,211 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Abstract classes and common implementations for the backends."""
+
+import re
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import uuid
+
+import u1db
+from u1db import (
+ errors,
+)
+import u1db.sync
+from u1db.vectorclock import VectorClockRev
+
+
+check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE)
+
+
+class CommonSyncTarget(u1db.sync.LocalSyncTarget):
+ pass
+
+
+class CommonBackend(u1db.Database):
+
+ document_size_limit = 0
+
+ def _allocate_doc_id(self):
+ """Generate a unique identifier for this document."""
+ return 'D-' + uuid.uuid4().hex # 'D-' stands for document
+
+ def _allocate_transaction_id(self):
+ return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction
+
+ def _allocate_doc_rev(self, old_doc_rev):
+ vcr = VectorClockRev(old_doc_rev)
+ vcr.increment(self._replica_uid)
+ return vcr.as_str()
+
+ def _check_doc_id(self, doc_id):
+ if not check_doc_id_re.match(doc_id):
+ raise errors.InvalidDocId()
+
+ def _check_doc_size(self, doc):
+ if not self.document_size_limit:
+ return
+ if doc.get_size() > self.document_size_limit:
+ raise errors.DocumentTooBig
+
+ def _get_generation(self):
+ """Return the current generation.
+
+ """
+ raise NotImplementedError(self._get_generation)
+
+ def _get_generation_info(self):
+ """Return the current generation and transaction id.
+
+ """
+ raise NotImplementedError(self._get_generation_info)
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Extract the document from storage.
+
+ This can return None if the document doesn't exist.
+ """
+ raise NotImplementedError(self._get_doc)
+
+ def _has_conflicts(self, doc_id):
+ """Return True if the doc has conflicts, False otherwise."""
+ raise NotImplementedError(self._has_conflicts)
+
+ def create_doc(self, content, doc_id=None):
+ json_string = json.dumps(content)
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ doc = self._factory(doc_id, None, json_string)
+ self.put_doc(doc)
+ return doc
+
+ def create_doc_from_json(self, json, doc_id=None):
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ doc = self._factory(doc_id, None, json)
+ self.put_doc(doc)
+ return doc
+
+ def _get_transaction_log(self):
+ """This is only for the test suite, it is not part of the api."""
+ raise NotImplementedError(self._get_transaction_log)
+
+ def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content):
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ for doc_id in doc_ids:
+ doc = self._get_doc(
+ doc_id, check_for_conflicts=check_for_conflicts)
+ if doc.is_tombstone() and not include_deleted:
+ continue
+ yield doc
+
+ def _get_trans_id_for_gen(self, generation):
+ """Get the transaction id corresponding to a particular generation.
+
+ Raises an InvalidGeneration when the generation does not exist.
+
+ """
+ raise NotImplementedError(self._get_trans_id_for_gen)
+
+ def validate_gen_and_trans_id(self, generation, trans_id):
+ """Validate the generation and transaction id.
+
+ Raises an InvalidGeneration when the generation does not exist, and an
+ InvalidTransactionId when it does but with a different transaction id.
+
+ """
+ if generation == 0:
+ return
+ known_trans_id = self._get_trans_id_for_gen(generation)
+ if known_trans_id != trans_id:
+ raise errors.InvalidTransactionId
+
+ def _validate_source(self, other_replica_uid, other_generation,
+ other_transaction_id):
+ """Validate the new generation and transaction id.
+
+ other_generation must be greater than what we have stored for this
+ replica, *or* it must be the same and the transaction_id must be the
+ same as well.
+ """
+ (old_generation,
+ old_transaction_id) = self._get_replica_gen_and_trans_id(
+ other_replica_uid)
+ if other_generation < old_generation:
+ raise errors.InvalidGeneration
+ if other_generation > old_generation:
+ return
+ if other_transaction_id == old_transaction_id:
+ return
+ raise errors.InvalidTransactionId
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen,
+ replica_trans_id=''):
+ cur_doc = self._get_doc(doc.doc_id)
+ doc_vcr = VectorClockRev(doc.rev)
+ if cur_doc is None:
+ cur_vcr = VectorClockRev(None)
+ else:
+ cur_vcr = VectorClockRev(cur_doc.rev)
+ self._validate_source(replica_uid, replica_gen, replica_trans_id)
+ if doc_vcr.is_newer(cur_vcr):
+ rev = doc.rev
+ self._prune_conflicts(doc, doc_vcr)
+ if doc.rev != rev:
+ # conflicts have been autoresolved
+ state = 'superseded'
+ else:
+ state = 'inserted'
+ self._put_and_update_indexes(cur_doc, doc)
+ elif doc.rev == cur_doc.rev:
+ # magical convergence
+ state = 'converged'
+ elif cur_vcr.is_newer(doc_vcr):
+ # Don't add this to seen_ids, because we have something newer,
+ # so we should send it back, and we should not generate a
+ # conflict
+ state = 'superseded'
+ elif cur_doc.same_content_as(doc):
+ # the documents have been edited to the same thing at both ends
+ doc_vcr.maximize(cur_vcr)
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._put_and_update_indexes(cur_doc, doc)
+ state = 'superseded'
+ else:
+ state = 'conflicted'
+ if save_conflict:
+ self._force_doc_sync_conflict(doc)
+ if replica_uid is not None and replica_gen is not None:
+ self._do_set_replica_gen_and_trans_id(
+ replica_uid, replica_gen, replica_trans_id)
+ return state, self._get_generation()
+
+ def _ensure_maximal_rev(self, cur_rev, extra_revs):
+ vcr = VectorClockRev(cur_rev)
+ for rev in extra_revs:
+ vcr.maximize(VectorClockRev(rev))
+ vcr.increment(self._replica_uid)
+ return vcr.as_str()
+
+ def set_document_size_limit(self, limit):
+ self.document_size_limit = limit
diff --git a/u1db/backends/dbschema.sql b/u1db/backends/dbschema.sql
new file mode 100644
index 00000000..ae027fc5
--- /dev/null
+++ b/u1db/backends/dbschema.sql
@@ -0,0 +1,42 @@
+-- Database schema
+CREATE TABLE transaction_log (
+ generation INTEGER PRIMARY KEY AUTOINCREMENT,
+ doc_id TEXT NOT NULL,
+ transaction_id TEXT NOT NULL
+);
+CREATE TABLE document (
+ doc_id TEXT PRIMARY KEY,
+ doc_rev TEXT NOT NULL,
+ content TEXT
+);
+CREATE TABLE document_fields (
+ doc_id TEXT NOT NULL,
+ field_name TEXT NOT NULL,
+ value TEXT
+);
+CREATE INDEX document_fields_field_value_doc_idx
+ ON document_fields(field_name, value, doc_id);
+
+CREATE TABLE sync_log (
+ replica_uid TEXT PRIMARY KEY,
+ known_generation INTEGER,
+ known_transaction_id TEXT
+);
+CREATE TABLE conflicts (
+ doc_id TEXT,
+ doc_rev TEXT,
+ content TEXT,
+ CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
+);
+CREATE TABLE index_definitions (
+ name TEXT,
+ offset INT,
+ field TEXT,
+ CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
+);
+create index index_definitions_field on index_definitions(field);
+CREATE TABLE u1db_config (
+ name TEXT PRIMARY KEY,
+ value TEXT
+);
+INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/u1db/backends/inmemory.py b/u1db/backends/inmemory.py
new file mode 100644
index 00000000..a271bb37
--- /dev/null
+++ b/u1db/backends/inmemory.py
@@ -0,0 +1,469 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The in-memory Database class for U1DB."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+from u1db.backends import CommonBackend, CommonSyncTarget
+
+
+def get_prefix(value):
+ key_prefix = '\x01'.join(value)
+ return key_prefix.rstrip('*')
+
+
+class InMemoryDatabase(CommonBackend):
+ """A database that only stores the data internally."""
+
+ def __init__(self, replica_uid, document_factory=None):
+ self._transaction_log = []
+ self._docs = {}
+ # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
+ self._conflicts = {}
+ self._other_generations = {}
+ self._indexes = {}
+ self._replica_uid = replica_uid
+ self._factory = document_factory or Document
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ self._replica_uid = replica_uid
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def close(self):
+ # This is a no-op, We don't want to free the data because one client
+ # may be closing it, while another wants to inspect the results.
+ pass
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ return self._other_generations.get(other_replica_uid, (0, ''))
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ # TODO: to handle race conditions, we may want to check if the current
+ # value is greater than this new value.
+ self._other_generations[other_replica_uid] = (other_generation,
+ other_transaction_id)
+
+ def get_sync_target(self):
+ return InMemorySyncTarget(self)
+
+ def _get_transaction_log(self):
+ # snapshot!
+ return self._transaction_log[:]
+
+ def _get_generation(self):
+ return len(self._transaction_log)
+
+ def _get_generation_info(self):
+ if not self._transaction_log:
+ return 0, ''
+ return len(self._transaction_log), self._transaction_log[-1][1]
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ if generation > len(self._transaction_log):
+ raise errors.InvalidGeneration
+ return self._transaction_log[generation - 1][1]
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ for index in self._indexes.itervalues():
+ if old_doc is not None and not old_doc.is_tombstone():
+ index.remove_json(old_doc.doc_id, old_doc.get_json())
+ if not doc.is_tombstone():
+ index.add_json(doc.doc_id, doc.get_json())
+ trans_id = self._allocate_transaction_id()
+ self._docs[doc.doc_id] = (doc.rev, doc.get_json())
+ self._transaction_log.append((doc.doc_id, trans_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ try:
+ doc_rev, content = self._docs[doc_id]
+ except KeyError:
+ return None
+ doc = self._factory(doc_id, doc_rev, content)
+ if check_for_conflicts:
+ doc.has_conflicts = (doc.doc_id in self._conflicts)
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ return doc_id in self._conflicts
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Return all documents in the database."""
+ generation = self._get_generation()
+ results = []
+ for doc_id, (doc_rev, content) in self._docs.items():
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = self._has_conflicts(doc_id)
+ results.append(doc)
+ return (generation, results)
+
+ def get_doc_conflicts(self, doc_id):
+ if doc_id not in self._conflicts:
+ return []
+ result = [self._get_doc(doc_id)]
+ result[0].has_conflicts = True
+ result.extend([self._factory(doc_id, rev, content)
+ for rev, content in self._conflicts[doc_id]])
+ return result
+
+ def _replace_conflicts(self, doc, conflicts):
+ if not conflicts:
+ del self._conflicts[doc.doc_id]
+ else:
+ self._conflicts[doc.doc_id] = conflicts
+ doc.has_conflicts = bool(conflicts)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ c_vcr = vectorclock.VectorClockRev(c_rev)
+ if doc_vcr.is_newer(c_vcr):
+ continue
+ if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ cur_doc = self._get_doc(doc.doc_id)
+ if cur_doc is None:
+ cur_rev = None
+ else:
+ cur_rev = cur_doc.rev
+ new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ if c_rev in superseded_revs:
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ doc.rev = new_rev
+ if cur_rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ remaining_conflicts.append((new_rev, doc.get_json()))
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def delete_doc(self, doc):
+ if doc.doc_id not in self._docs:
+ raise errors.DocumentDoesNotExist
+ if self._docs[doc.doc_id][1] in ('null', None):
+ raise errors.DocumentAlreadyDeleted
+ doc.make_tombstone()
+ self.put_doc(doc)
+
+ def create_index(self, index_name, *index_expressions):
+ if index_name in self._indexes:
+ if self._indexes[index_name]._definition == list(
+ index_expressions):
+ return
+ raise errors.IndexNameTakenError
+ index = InMemoryIndex(index_name, list(index_expressions))
+ for doc_id, (doc_rev, doc) in self._docs.iteritems():
+ if doc is not None:
+ index.add_json(doc_id, doc)
+ self._indexes[index_name] = index
+
+ def delete_index(self, index_name):
+ del self._indexes[index_name]
+
+ def list_indexes(self):
+ definitions = []
+ for idx in self._indexes.itervalues():
+ definitions.append((idx._name, idx._definition))
+ return definitions
+
+ def get_from_index(self, index_name, *key_values):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ doc_ids = index.lookup(key_values)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ doc_ids = index.lookup_range(start_value, end_value)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_index_keys(self, index_name):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ keys = index.keys()
+ # XXX inefficiency warning
+ return list(set([tuple(key.split('\x01')) for key in keys]))
+
+ def whats_changed(self, old_generation=0):
+ changes = []
+ relevant_tail = self._transaction_log[old_generation:]
+ # We don't use len(self._transaction_log) because _transaction_log may
+ # get mutated by a concurrent operation.
+ cur_generation = old_generation + len(relevant_tail)
+ last_trans_id = ''
+ if relevant_tail:
+ last_trans_id = relevant_tail[-1][1]
+ elif self._transaction_log:
+ last_trans_id = self._transaction_log[-1][1]
+ seen = set()
+ generation = cur_generation
+ for doc_id, trans_id in reversed(relevant_tail):
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ generation -= 1
+ changes.reverse()
+ return (cur_generation, last_trans_id, changes)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._conflicts.setdefault(doc.doc_id, []).append(
+ (my_doc.rev, my_doc.get_json()))
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+
+class InMemoryIndex(object):
+ """Interface for managing an Index."""
+
+ def __init__(self, index_name, index_definition):
+ self._name = index_name
+ self._definition = index_definition
+ self._values = {}
+ parser = query_parser.Parser()
+ self._getters = parser.parse_all(self._definition)
+
+ def evaluate_json(self, doc):
+ """Determine the 'key' after applying this index to the doc."""
+ raw = json.loads(doc)
+ return self.evaluate(raw)
+
+ def evaluate(self, obj):
+ """Evaluate a dict object, applying this definition."""
+ all_rows = [[]]
+ for getter in self._getters:
+ new_rows = []
+ keys = getter.get(obj)
+ if not keys:
+ return []
+ for key in keys:
+ new_rows.extend([row + [key] for row in all_rows])
+ all_rows = new_rows
+ all_rows = ['\x01'.join(row) for row in all_rows]
+ return all_rows
+
+ def add_json(self, doc_id, doc):
+ """Add this json doc to the index."""
+ keys = self.evaluate_json(doc)
+ if not keys:
+ return
+ for key in keys:
+ self._values.setdefault(key, []).append(doc_id)
+
+ def remove_json(self, doc_id, doc):
+ """Remove this json doc from the index."""
+ keys = self.evaluate_json(doc)
+ if keys:
+ for key in keys:
+ doc_ids = self._values[key]
+ doc_ids.remove(doc_id)
+ if not doc_ids:
+ del self._values[key]
+
+ def _find_non_wildcards(self, values):
+ """Check if this should be a wildcard match.
+
+ Further, this will raise an exception if the syntax is improperly
+ defined.
+
+ :return: The offset of the last value we need to match against.
+ """
+ if len(values) != len(self._definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ last = 0
+ for idx, val in enumerate(values):
+ if val.endswith('*'):
+ if val != '*':
+ # We have an 'x*' style wildcard
+ if is_wildcard:
+ # We were already in wildcard mode, so this is invalid
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ # We were in wildcard mode, we can't follow that with
+ # non-wildcard
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ if not is_wildcard:
+ return -1
+ return last
+
+ def lookup(self, values):
+ """Find docs that match the values."""
+ last = self._find_non_wildcards(values)
+ if last == -1:
+ return self._lookup_exact(values)
+ else:
+ return self._lookup_prefix(values[:last])
+
+ def lookup_range(self, start_values, end_values):
+ """Find docs within the range."""
+ # TODO: Wildly inefficient, which is unlikely to be a problem for the
+ # inmemory implementation.
+ if start_values:
+ self._find_non_wildcards(start_values)
+ start_values = get_prefix(start_values)
+ if end_values:
+ if self._find_non_wildcards(end_values) == -1:
+ exact = True
+ else:
+ exact = False
+ end_values = get_prefix(end_values)
+ found = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if start_values and start_values > key:
+ continue
+ if end_values and end_values < key:
+ if exact:
+ break
+ else:
+ if not key.startswith(end_values):
+ break
+ found.extend(doc_ids)
+ return found
+
+ def keys(self):
+ """Find the indexed keys."""
+ return self._values.keys()
+
+ def _lookup_prefix(self, value):
+ """Find docs that match the prefix string in values."""
+ # TODO: We need a different data structure to make prefix style fast,
+ # some sort of sorted list would work, but a plain dict doesn't.
+ key_prefix = get_prefix(value)
+ all_doc_ids = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if key.startswith(key_prefix):
+ all_doc_ids.extend(doc_ids)
+ return all_doc_ids
+
+ def _lookup_exact(self, value):
+ """Find docs that match exactly."""
+ key = '\x01'.join(value)
+ if key in self._values:
+ return self._values[key]
+ return ()
+
+
+class InMemorySyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_transaction_id)
diff --git a/u1db/backends/sqlite_backend.py b/u1db/backends/sqlite_backend.py
new file mode 100644
index 00000000..773213b5
--- /dev/null
+++ b/u1db/backends/sqlite_backend.py
@@ -0,0 +1,926 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A U1DB implementation that uses SQLite as its persistence layer."""
+
+import errno
+import os
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from sqlite3 import dbapi2
+import sys
+import time
+import uuid
+
+import pkg_resources
+
+from u1db.backends import CommonBackend, CommonSyncTarget
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+
+
+class SQLiteDatabase(CommonBackend):
+ """A U1DB implementation that uses SQLite as its persistence layer."""
+
+ _sqlite_registry = {}
+
+ def __init__(self, sqlite_file, document_factory=None):
+ """Create a new sqlite file."""
+ self._db_handle = dbapi2.connect(sqlite_file)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def get_sync_target(self):
+ return SQLiteSyncTarget(self)
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'index_storage'")
+ except dbapi2.OperationalError, e:
+ # The table does not exist yet
+ return None, e
+ else:
+ return c.fetchone()[0], None
+
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
+
+ @classmethod
+ def _open_database(cls, sqlite_file, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLiteDatabase._sqlite_registry[v](
+ sqlite_file, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(
+ sqlite_file, document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLitePartialExpandDatabase
+ backend_cls = SQLitePartialExpandDatabase
+ return backend_cls(sqlite_file, document_factory=document_factory)
+
+ @staticmethod
+ def delete_database(sqlite_file):
+ try:
+ os.unlink(sqlite_file)
+ except OSError as ex:
+ if ex.errno == errno.ENOENT:
+ raise errors.DatabaseDoesNotExist()
+ raise
+
+ @staticmethod
+ def register_implementation(klass):
+ """Register that we implement an SQLiteDatabase.
+
+ The attribute _index_storage_value will be used as the lookup key.
+ """
+ SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
+
+ def _get_sqlite_handle(self):
+ """Get access to the underlying sqlite database.
+
+ This should only be used by the test suite, etc, for examining the
+ state of the underlying database.
+ """
+ return self._db_handle
+
+ def _close_sqlite_handle(self):
+ """Release access to the underlying sqlite database."""
+ self._db_handle.close()
+
+ def close(self):
+ self._close_sqlite_handle()
+
+ def _is_initialized(self, c):
+ """Check if this database has been initialized."""
+ c.execute("PRAGMA case_sensitive_like=ON")
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'sql_schema'")
+ except dbapi2.OperationalError:
+ # The table does not exist yet
+ val = None
+ else:
+ val = c.fetchone()
+ if val is not None:
+ return True
+ return False
+
+ def _initialize(self, c):
+ """Create the schema in the database."""
+ #read the script with sql commands
+ # TODO: Change how we set up the dependency. Most likely use something
+ # like lp:dirspec to grab the file from a common resource
+ # directory. Doesn't specifically need to be handled until we get
+ # to the point of packaging this.
+ schema_content = pkg_resources.resource_string(
+ __name__, 'dbschema.sql')
+ # Note: We'd like to use c.executescript() here, but it seems that
+ # executescript always commits, even if you set
+ # isolation_level = None, so if we want to properly handle
+ # exclusive locking and rollbacks between processes, we need
+ # to execute it line-by-line
+ for line in schema_content.split(';'):
+ if not line:
+ continue
+ c.execute(line)
+ #add extra fields
+ self._extra_schema_init(c)
+ # A unique identifier should be set for this replica. Implementations
+ # don't have to strictly use uuid here, but we do want the uid to be
+ # unique amongst all databases that will sync with each other.
+ # We might extend this to using something with hostname for easier
+ # debugging.
+ self._set_replica_uid_in_transaction(uuid.uuid4().hex)
+ c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
+ (self._index_storage_value,))
+
+ def _ensure_schema(self):
+ """Ensure that the database schema has been created."""
+ old_isolation_level = self._db_handle.isolation_level
+ c = self._db_handle.cursor()
+ if self._is_initialized(c):
+ return
+ try:
+ # autocommit/own mgmt of transactions
+ self._db_handle.isolation_level = None
+ with self._db_handle:
+ # only one execution path should initialize the db
+ c.execute("begin exclusive")
+ if self._is_initialized(c):
+ return
+ self._initialize(c)
+ finally:
+ self._db_handle.isolation_level = old_isolation_level
+
+ def _extra_schema_init(self, c):
+ """Add any extra fields, etc to the basic table definitions."""
+
+ def _parse_index_definition(self, index_field):
+ """Parse a field definition for an index, returning a Getter."""
+ # Note: We may want to keep a Parser object around, and cache the
+ # Getter objects for a greater length of time. Specifically, if
+ # you create a bunch of indexes, and then insert 50k docs, you'll
+ # re-parse the indexes between puts. The time to insert the docs
+ # is still likely to dominate put_doc time, though.
+ parser = query_parser.Parser()
+ getter = parser.parse(index_field)
+ return getter
+
+ def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
+ """Update document_fields for a single document.
+
+ :param doc_id: Identifier for this document
+ :param raw_doc: The python dict representation of the document.
+ :param getters: A list of [(field_name, Getter)]. Getter.get will be
+ called to evaluate the index definition for this document, and the
+ results will be inserted into the db.
+ :param db_cursor: An sqlite Cursor.
+ :return: None
+ """
+ values = []
+ for field_name, getter in getters:
+ for idx_value in getter.get(raw_doc):
+ values.append((doc_id, field_name, idx_value))
+ if values:
+ db_cursor.executemany(
+ "INSERT INTO document_fields VALUES (?, ?, ?)", values)
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ with self._db_handle:
+ self._set_replica_uid_in_transaction(replica_uid)
+
+ def _set_replica_uid_in_transaction(self, replica_uid):
+ """Set the replica_uid. A transaction should already be held."""
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO u1db_config"
+ " VALUES ('replica_uid', ?)",
+ (replica_uid,))
+ self._real_replica_uid = replica_uid
+
+ def _get_replica_uid(self):
+ if self._real_replica_uid is not None:
+ return self._real_replica_uid
+ c = self._db_handle.cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
+ val = c.fetchone()
+ if val is None:
+ return None
+ self._real_replica_uid = val[0]
+ return self._real_replica_uid
+
+ _replica_uid = property(_get_replica_uid)
+
+ def _get_generation(self):
+ c = self._db_handle.cursor()
+ c.execute('SELECT max(generation) FROM transaction_log')
+ val = c.fetchone()[0]
+ if val is None:
+ return 0
+ return val
+
+ def _get_generation_info(self):
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT max(generation), transaction_id FROM transaction_log ')
+ val = c.fetchone()
+ if val[0] is None:
+ return(0, '')
+ return val
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
+ (generation,))
+ val = c.fetchone()
+ if val is None:
+ raise errors.InvalidGeneration
+ return val[0]
+
+ def _get_transaction_log(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, transaction_id FROM transaction_log"
+ " ORDER BY generation")
+ return c.fetchall()
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ c = self._db_handle.cursor()
+ if check_for_conflicts:
+ c.execute(
+ "SELECT document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
+ "conflicts ON conflicts.doc_id = document.doc_id WHERE "
+ "document.doc_id = ? GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;", (doc_id,))
+ else:
+ c.execute(
+ "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return None
+ doc_rev, content, conflicts = val
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return False
+ else:
+ return True
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ c = self._db_handle.cursor()
+ c.execute(
+ "SELECT document.doc_id, document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
+ "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;")
+ rows = c.fetchall()
+ for doc_id, doc_rev, content, conflicts in rows:
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ results.append(doc)
+ return (generation, results)
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
+ """Convert a dict representation into named fields.
+
+ So something like: {'key1': 'val1', 'key2': 'val2'}
+ gets converted into: [(doc_id, 'key1', 'val1', 0)
+ (doc_id, 'key2', 'val2', 0)]
+ :param doc_id: Just added to every record.
+ :param base_field: if set, these are nested keys, so each field should
+ be appropriately prefixed.
+ :param raw_doc: The python dictionary.
+ """
+ # TODO: Handle lists
+ values = []
+ for field_name, value in raw_doc.iteritems():
+ if value is None and not save_none:
+ continue
+ if base_field:
+ full_name = base_field + '.' + field_name
+ else:
+ full_name = field_name
+ if value is None or isinstance(value, (int, float, basestring)):
+ values.append((doc_id, full_name, value, len(values)))
+ else:
+ subvalues = self._expand_to_fields(doc_id, full_name, value,
+ save_none)
+ for _, subfield_name, val, _ in subvalues:
+ values.append((doc_id, subfield_name, val, len(values)))
+ return values
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """Actually insert a document into the database.
+
+ This both updates the existing documents content, and any indexes that
+ refer to this document.
+ """
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def whats_changed(self, old_generation=0):
+ c = self._db_handle.cursor()
+ c.execute("SELECT generation, doc_id, transaction_id"
+ " FROM transaction_log"
+ " WHERE generation > ? ORDER BY generation DESC",
+ (old_generation,))
+ results = c.fetchall()
+ cur_gen = old_generation
+ seen = set()
+ changes = []
+ newest_trans_id = ''
+ for generation, doc_id, trans_id in results:
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ if changes:
+ cur_gen = changes[0][1] # max generation
+ newest_trans_id = changes[0][2]
+ changes.reverse()
+ else:
+ c.execute("SELECT generation, transaction_id"
+ " FROM transaction_log ORDER BY generation DESC LIMIT 1")
+ results = c.fetchone()
+ if not results:
+ cur_gen = 0
+ newest_trans_id = ''
+ else:
+ cur_gen, newest_trans_id = results
+
+ return cur_gen, newest_trans_id, changes
+
+ def delete_doc(self, doc):
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _get_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
+ (doc_id,))
+ return [self._factory(doc_id, doc_rev, content)
+ for doc_rev, content in c.fetchall()]
+
+ def get_doc_conflicts(self, doc_id):
+ with self._db_handle:
+ conflict_docs = self._get_conflicts(doc_id)
+ if not conflict_docs:
+ return []
+ this_doc = self._get_doc(doc_id)
+ this_doc.has_conflicts = True
+ return [this_doc] + conflict_docs
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ c = self._db_handle.cursor()
+ c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
+ " WHERE replica_uid = ?",
+ (other_replica_uid,))
+ val = c.fetchone()
+ if val is None:
+ other_gen = 0
+ trans_id = ''
+ else:
+ other_gen = val[0]
+ trans_id = val[1]
+ return other_gen, trans_id
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ with self._db_handle:
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
+ (other_replica_uid, other_generation,
+ other_transaction_id))
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ with self._db_handle:
+ return super(SQLiteDatabase, self)._put_doc_if_newer(doc,
+ save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+
+ def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
+ c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
+ (doc_id, my_doc_rev, my_content))
+
+ def _delete_conflicts(self, c, doc, conflict_revs):
+ deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
+ c.executemany("DELETE FROM conflicts"
+ " WHERE doc_id=? AND doc_rev=?", deleting)
+ doc.has_conflicts = self._has_conflicts(doc.doc_id)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ c_revs_to_prune = []
+ for c_doc in self._get_conflicts(doc.doc_id):
+ c_vcr = vectorclock.VectorClockRev(c_doc.rev)
+ if doc_vcr.is_newer(c_vcr):
+ c_revs_to_prune.append(c_doc.rev)
+ elif doc.same_content_as(c_doc):
+ c_revs_to_prune.append(c_doc.rev)
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ c = self._db_handle.cursor()
+ self._delete_conflicts(c, doc, c_revs_to_prune)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ c = self._db_handle.cursor()
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ with self._db_handle:
+ cur_doc = self._get_doc(doc.doc_id)
+ # TODO: https://bugs.launchpad.net/u1db/+bug/928274
+ # I think we have a logic bug in resolve_doc
+ # Specifically, cur_doc.rev is always in the final vector
+ # clock of revisions that we supersede, even if it wasn't in
+ # conflicted_doc_revs. We still add it as a conflict, but the
+ # fact that _put_doc_if_newer propagates resolutions means I
+ # think that conflict could accidentally be resolved. We need
+ # to add a test for this case first. (create a rev, create a
+ # conflict, create another conflict, resolve the first rev
+ # and first conflict, then make sure that the resolved
+ # rev doesn't supersede the second conflict rev.) It *might*
+ # not matter, because the superseding rev is in as a
+ # conflict, but it does seem incorrect
+ new_rev = self._ensure_maximal_rev(cur_doc.rev,
+ conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ c = self._db_handle.cursor()
+ doc.rev = new_rev
+ if cur_doc.rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
+ # TODO: Is there some way that we could construct a rev that would
+ # end up in superseded_revs, such that we add a conflict, and
+ # then immediately delete it?
+ self._delete_conflicts(c, doc, superseded_revs)
+
+ def list_indexes(self):
+ """Return the list of indexes and their definitions."""
+ c = self._db_handle.cursor()
+ # TODO: How do we test the ordering?
+ c.execute("SELECT name, field FROM index_definitions"
+ " ORDER BY name, offset")
+ definitions = []
+ cur_name = None
+ for name, field in c.fetchall():
+ if cur_name != name:
+ definitions.append((name, []))
+ cur_name = name
+ definitions[-1][-1].append(field)
+ return definitions
+
+ def _get_index_definition(self, index_name):
+ """Return the stored definition for a given index_name."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions"
+ " WHERE name = ? ORDER BY offset", (index_name,))
+ fields = [x[0] for x in c.fetchall()]
+ if not fields:
+ raise errors.IndexDoesNotExist
+ return fields
+
+ @staticmethod
+ def _strip_glob(value):
+ """Remove the trailing * from a value."""
+ assert value[-1] == '*'
+ return value[:-1]
+
+ def _format_query(self, definition, key_values):
+ # First, build the definition. We join the document_fields table
+ # against itself, as many times as the 'width' of our definition.
+ # We then do a query for each key_value, one-at-a-time.
+ # Note: All of these strings are static, we could cache them, etc.
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ wildcard_where = [novalue_where[i]
+ + (" AND d%d.value NOT NULL" % (i,))
+ for i in range(len(definition))]
+ exact_where = [novalue_where[i]
+ + (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ like_where = [novalue_where[i]
+ + (" AND d%d.value GLOB ?" % (i,))
+ for i in range(len(definition))]
+ is_wildcard = False
+ # Merge the lists together, so that:
+ # [field1, field2, field3], [val1, val2, val3]
+ # Becomes:
+ # (field1, val1, field2, val2, field3, val3)
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(exact_where[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_from_index(self, index_name, *key_values):
+ definition = self._get_index_definition(index_name)
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ statement, args = self._format_query(definition, key_values)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def _format_range_query(self, definition, start_value, end_value):
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ wildcard_where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ like_where = [
+ novalue_where[i] + (
+ " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
+ range(len(definition))]
+ range_where_lower = [
+ novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
+ range(len(definition))]
+ range_where_upper = [
+ novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
+ range(len(definition))]
+ args = []
+ where = []
+ if start_value:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if len(start_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, start_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(self._strip_glob(value))
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(value)
+ if end_value:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ if len(end_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, end_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(self._strip_glob(value))
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_upper[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ definition = self._get_index_definition(index_name)
+ statement, args = self._format_range_query(
+ definition, start_value, end_value)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def get_index_keys(self, index_name):
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+ value_fields = ', '.join([
+ 'd%d.value' % i for i in range(len(definition))])
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ statement = (
+ "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
+ value_fields, ', '.join(tables), ' AND '.join(where),
+ value_fields))
+ try:
+ c.execute(statement, tuple(definition))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
+ return c.fetchall()
+
+ def delete_index(self, index_name):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ c.execute("DELETE FROM index_definitions WHERE name = ?",
+ (index_name,))
+ c.execute(
+ "DELETE FROM document_fields WHERE document_fields.field_name "
+ " NOT IN (SELECT field from index_definitions)")
+
+
+class SQLiteSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
+
+
+class SQLitePartialExpandDatabase(SQLiteDatabase):
+ """An SQLite Backend that expands documents into a document_field table.
+
+ It stores the original document text in document.doc. For fields that are
+ indexed, the data goes into document_fields.
+ """
+
+ _index_storage_value = 'expand referenced'
+
+ def _get_indexed_fields(self):
+ """Determine what fields are indexed."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions")
+ return set([x[0] for x in c.fetchall()])
+
+ def _evaluate_index(self, raw_doc, field):
+ parser = query_parser.Parser()
+ getter = parser.parse(field)
+ return getter.get(raw_doc)
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ c = self._db_handle.cursor()
+ if doc and not doc.is_tombstone():
+ raw_doc = json.loads(doc.get_json())
+ else:
+ raw_doc = {}
+ if old_doc is not None:
+ c.execute("UPDATE document SET doc_rev=?, content=?"
+ " WHERE doc_id = ?",
+ (doc.rev, doc.get_json(), doc.doc_id))
+ c.execute("DELETE FROM document_fields WHERE doc_id = ?",
+ (doc.doc_id,))
+ else:
+ c.execute("INSERT INTO document (doc_id, doc_rev, content)"
+ " VALUES (?, ?, ?)",
+ (doc.doc_id, doc.rev, doc.get_json()))
+ indexed_fields = self._get_indexed_fields()
+ if indexed_fields:
+ # It is expected that len(indexed_fields) is shorter than
+ # len(raw_doc)
+ getters = [(field, self._parse_index_definition(field))
+ for field in indexed_fields]
+ self._update_indexes(doc.doc_id, raw_doc, getters, c)
+ trans_id = self._allocate_transaction_id()
+ c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
+ " VALUES (?, ?)", (doc.doc_id, trans_id))
+
+ def create_index(self, index_name, *index_expressions):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ cur_fields = self._get_indexed_fields()
+ definition = [(index_name, idx, field)
+ for idx, field in enumerate(index_expressions)]
+ try:
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ definition)
+ except dbapi2.IntegrityError as e:
+ stored_def = self._get_index_definition(index_name)
+ if stored_def == [x[-1] for x in definition]:
+ return
+ raise errors.IndexNameTakenError, e, sys.exc_info()[2]
+ new_fields = set(
+ [f for f in index_expressions if f not in cur_fields])
+ if new_fields:
+ self._update_all_indexes(new_fields)
+
+ def _iter_all_docs(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, content FROM document")
+ while True:
+ next_rows = c.fetchmany()
+ if not next_rows:
+ break
+ for row in next_rows:
+ yield row
+
+ def _update_all_indexes(self, new_fields):
+ """Iterate all the documents, and add content to document_fields.
+
+ :param new_fields: The index definitions that need to be added.
+ """
+ getters = [(field, self._parse_index_definition(field))
+ for field in new_fields]
+ c = self._db_handle.cursor()
+ for doc_id, doc in self._iter_all_docs():
+ if doc is None:
+ continue
+ raw_doc = json.loads(doc)
+ self._update_indexes(doc_id, raw_doc, getters, c)
+
+SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)
diff --git a/u1db/commandline/__init__.py b/u1db/commandline/__init__.py
new file mode 100644
index 00000000..3f32e381
--- /dev/null
+++ b/u1db/commandline/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
diff --git a/u1db/commandline/client.py b/u1db/commandline/client.py
new file mode 100644
index 00000000..15bf8561
--- /dev/null
+++ b/u1db/commandline/client.py
@@ -0,0 +1,497 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Commandline bindings for the u1db-client program."""
+
+import argparse
+import os
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import sys
+
+from u1db import (
+ Document,
+ open as u1db_open,
+ sync,
+ errors,
+ )
+from u1db.commandline import command
+from u1db.remote import (
+ http_database,
+ http_target,
+ )
+
+
+client_commands = command.CommandGroup()
+
+
+def set_oauth_credentials(client):
+ keys = os.environ.get('OAUTH_CREDENTIALS', None)
+ if keys is not None:
+ consumer_key, consumer_secret, \
+ token_key, token_secret = keys.split(":")
+ client.set_oauth_credentials(consumer_key, consumer_secret,
+ token_key, token_secret)
+
+
+class OneDbCmd(command.Command):
+ """Base class for commands operating on one local or remote database."""
+
+ def _open(self, database, create):
+ if database.startswith(('http://', 'https://')):
+ db = http_database.HTTPDatabase(database)
+ set_oauth_credentials(db)
+ db.open(create)
+ return db
+ else:
+ return u1db_open(database, create)
+
+
+class CmdCreate(OneDbCmd):
+ """Create a new document from scratch"""
+
+ name = 'create'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url')
+ parser.add_argument('infile', nargs='?', default=None,
+ help='The file to read content from.')
+ parser.add_argument('--id', dest='doc_id', default=None,
+ help='Set the document identifier')
+
+ def run(self, database, infile, doc_id):
+ if infile is None:
+ infile = self.stdin
+ db = self._open(database, create=False)
+ doc = db.create_doc_from_json(infile.read(), doc_id=doc_id)
+ self.stderr.write('id: %s\nrev: %s\n' % (doc.doc_id, doc.rev))
+
+client_commands.register(CmdCreate)
+
+
+class CmdDelete(OneDbCmd):
+ """Delete a document from the database"""
+
+ name = 'delete'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url')
+ parser.add_argument('doc_id', help='The document id to retrieve')
+ parser.add_argument('doc_rev',
+ help='The revision of the document (which is being superseded.)')
+
+ def run(self, database, doc_id, doc_rev):
+ db = self._open(database, create=False)
+ doc = Document(doc_id, doc_rev, None)
+ db.delete_doc(doc)
+ self.stderr.write('rev: %s\n' % (doc.rev,))
+
+client_commands.register(CmdDelete)
+
+
+class CmdGet(OneDbCmd):
+ """Extract a document from the database"""
+
+ name = 'get'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to query',
+ metavar='database-path-or-url')
+ parser.add_argument('doc_id', help='The document id to retrieve.')
+ parser.add_argument('outfile', nargs='?', default=None,
+ help='The file to write the document to',
+ type=argparse.FileType('wb'))
+
+ def run(self, database, doc_id, outfile):
+ if outfile is None:
+ outfile = self.stdout
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ doc = db.get_doc(doc_id)
+ if doc is None:
+ self.stderr.write('Document not found (id: %s)\n' % (doc_id,))
+ return 1 # failed
+ if doc.is_tombstone():
+ outfile.write('[document deleted]\n')
+ else:
+ outfile.write(doc.get_json() + '\n')
+ self.stderr.write('rev: %s\n' % (doc.rev,))
+ if doc.has_conflicts:
+ self.stderr.write("Document has conflicts.\n")
+
+client_commands.register(CmdGet)
+
+
+class CmdGetDocConflicts(OneDbCmd):
+ """Get the conflicts from a document"""
+
+ name = 'get-doc-conflicts'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local database to query',
+ metavar='database-path')
+ parser.add_argument('doc_id', help='The document id to retrieve.')
+
+ def run(self, database, doc_id):
+ try:
+ db = self._open(database, False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ conflicts = db.get_doc_conflicts(doc_id)
+ if not conflicts:
+ if db.get_doc(doc_id) is None:
+ self.stderr.write("Document does not exist.\n")
+ return 1
+ self.stdout.write("[")
+ for i, doc in enumerate(conflicts):
+ if i:
+ self.stdout.write(",")
+ self.stdout.write(
+ json.dumps(dict(rev=doc.rev, content=doc.content), indent=4))
+ self.stdout.write("]\n")
+
+client_commands.register(CmdGetDocConflicts)
+
+
+class CmdInitDB(OneDbCmd):
+ """Create a new database"""
+
+ name = 'init-db'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to create',
+ metavar='database-path-or-url')
+ parser.add_argument('--replica-uid', default=None,
+ help='The unique identifier for this database (not for remote)')
+
+ def run(self, database, replica_uid):
+ db = self._open(database, create=True)
+ if replica_uid is not None:
+ db._set_replica_uid(replica_uid)
+
+client_commands.register(CmdInitDB)
+
+
+class CmdPut(OneDbCmd):
+ """Add a document to the database"""
+
+ name = 'put'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url'),
+ parser.add_argument('doc_id', help='The document id to retrieve')
+ parser.add_argument('doc_rev',
+ help='The revision of the document (which is being superseded.)')
+ parser.add_argument('infile', nargs='?', default=None,
+ help='The filename of the document that will be used for content',
+ type=argparse.FileType('rb'))
+
+ def run(self, database, doc_id, doc_rev, infile):
+ if infile is None:
+ infile = self.stdin
+ try:
+ db = self._open(database, create=False)
+ doc = Document(doc_id, doc_rev, infile.read())
+ doc_rev = db.put_doc(doc)
+ self.stderr.write('rev: %s\n' % (doc_rev,))
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ except errors.RevisionConflict:
+ if db.get_doc(doc_id) is None:
+ self.stderr.write("Document does not exist.\n")
+ else:
+ self.stderr.write("Given revision is not current.\n")
+ except errors.ConflictedDoc:
+ self.stderr.write(
+ "Document has conflicts.\n"
+ "Inspect with get-doc-conflicts, then resolve.\n")
+ else:
+ return
+ return 1
+
+client_commands.register(CmdPut)
+
+
+class CmdResolve(OneDbCmd):
+ """Resolve a conflicted document"""
+
+ name = 'resolve-doc'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url'),
+ parser.add_argument('doc_id', help='The conflicted document id')
+ parser.add_argument('doc_revs', metavar="doc-rev", nargs="+",
+ help='The revisions that the new content supersedes')
+ parser.add_argument('--infile', nargs='?', default=None,
+ help='The filename of the document that will be used for content',
+ type=argparse.FileType('rb'))
+
+ def run(self, database, doc_id, doc_revs, infile):
+ if infile is None:
+ infile = self.stdin
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ doc = db.get_doc(doc_id)
+ if doc is None:
+ self.stderr.write("Document does not exist.\n")
+ return 1
+ doc.set_json(infile.read())
+ db.resolve_doc(doc, doc_revs)
+ self.stderr.write("rev: %s\n" % db.get_doc(doc_id).rev)
+ if doc.has_conflicts:
+ self.stderr.write("Document still has conflicts.\n")
+
+client_commands.register(CmdResolve)
+
+
+class CmdSync(command.Command):
+ """Synchronize two databases"""
+
+ name = 'sync'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('source', help='database to sync from')
+ parser.add_argument('target', help='database to sync to')
+
+ def _open_target(self, target):
+ if target.startswith(('http://', 'https://')):
+ st = http_target.HTTPSyncTarget.connect(target)
+ set_oauth_credentials(st)
+ else:
+ db = u1db_open(target, create=True)
+ st = db.get_sync_target()
+ return st
+
+ def run(self, source, target):
+ """Start a Sync request."""
+ source_db = u1db_open(source, create=False)
+ st = self._open_target(target)
+ syncer = sync.Synchronizer(source_db, st)
+ syncer.sync()
+ source_db.close()
+
+client_commands.register(CmdSync)
+
+
+class CmdCreateIndex(OneDbCmd):
+ """Create an index"""
+
+ name = "create-index"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to update',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+ parser.add_argument('expression', help='an index expression',
+ nargs='+')
+
+ def run(self, database, index, expression):
+ try:
+ db = self._open(database, create=False)
+ db.create_index(index, *expression)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ except errors.IndexNameTakenError:
+ self.stderr.write("There is already a different index named %r.\n"
+ % (index,))
+ return 1
+ except errors.IndexDefinitionParseError:
+ self.stderr.write("Bad index expression.\n")
+ return 1
+
+client_commands.register(CmdCreateIndex)
+
+
+class CmdListIndexes(OneDbCmd):
+ """List existing indexes"""
+
+ name = "list-indexes"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to query',
+ metavar='database-path')
+
+ def run(self, database):
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ for (index, expression) in db.list_indexes():
+ self.stdout.write("%s: %s\n" % (index, ", ".join(expression)))
+
+client_commands.register(CmdListIndexes)
+
+
+class CmdDeleteIndex(OneDbCmd):
+ """Delete an index"""
+
+ name = "delete-index"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to update',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+
+ def run(self, database, index):
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ db.delete_index(index)
+
+client_commands.register(CmdDeleteIndex)
+
+
+class CmdGetIndexKeys(OneDbCmd):
+ """Get the index's keys"""
+
+ name = "get-index-keys"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to query',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+
+ def run(self, database, index):
+ try:
+ db = self._open(database, create=False)
+ for key in db.get_index_keys(index):
+ self.stdout.write("%s\n" % (", ".join(
+ [i.encode('utf-8') for i in key],)))
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ except errors.IndexDoesNotExist:
+ self.stderr.write("Index does not exist.\n")
+ else:
+ return
+ return 1
+
+client_commands.register(CmdGetIndexKeys)
+
+
+class CmdGetFromIndex(OneDbCmd):
+ """Find documents by searching an index"""
+
+ name = "get-from-index"
+ argv = None
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to query',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+ parser.add_argument('values', metavar="value",
+ help='the value to look up (one per index column)',
+ nargs="+")
+
+ def run(self, database, index, values):
+ try:
+ db = self._open(database, create=False)
+ docs = db.get_from_index(index, *values)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ except errors.IndexDoesNotExist:
+ self.stderr.write("Index does not exist.\n")
+ except errors.InvalidValueForIndex:
+ index_def = db._get_index_definition(index)
+ len_diff = len(index_def) - len(values)
+ if len_diff == 0:
+ # can't happen (HAH)
+ raise
+ argv = self.argv if self.argv is not None else sys.argv
+ self.stderr.write(
+ "Invalid query: "
+ "index %r requires %d query expression%s%s.\n"
+ "For example, the following would be valid:\n"
+ " %s %s %r %r %s\n"
+ % (index,
+ len(index_def),
+ "s" if len(index_def) > 1 else "",
+ ", not %d" % len(values) if len(values) else "",
+ argv[0], argv[1], database, index,
+ " ".join(map(repr,
+ values[:len(index_def)]
+ + ["*" for i in range(len_diff)])),
+ ))
+ except errors.InvalidGlobbing:
+ argv = self.argv if self.argv is not None else sys.argv
+ fixed = []
+ for (i, v) in enumerate(values):
+ fixed.append(v)
+ if v.endswith('*'):
+ break
+ # values has at least one element, so i is defined
+ fixed.extend('*' * (len(values) - i - 1))
+ self.stderr.write(
+ "Invalid query: a star can only be followed by stars.\n"
+ "For example, the following would be valid:\n"
+ " %s %s %r %r %s\n"
+ % (argv[0], argv[1], database, index,
+ " ".join(map(repr, fixed))))
+
+ else:
+ self.stdout.write("[")
+ for i, doc in enumerate(docs):
+ if i:
+ self.stdout.write(",")
+ self.stdout.write(
+ json.dumps(
+ dict(id=doc.doc_id, rev=doc.rev, content=doc.content),
+ indent=4))
+ self.stdout.write("]\n")
+ return
+ return 1
+
+client_commands.register(CmdGetFromIndex)
+
+
+def main(args):
+ return client_commands.run_argv(args, sys.stdin, sys.stdout, sys.stderr)
diff --git a/u1db/commandline/command.py b/u1db/commandline/command.py
new file mode 100644
index 00000000..eace0560
--- /dev/null
+++ b/u1db/commandline/command.py
@@ -0,0 +1,80 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Command infrastructure for u1db"""
+
+import argparse
+import inspect
+
+
+class CommandGroup(object):
+ """A collection of commands."""
+
+ def __init__(self, description=None):
+ self.commands = {}
+ self.description = description
+
+ def register(self, cmd):
+ """Register a new command to be incorporated with this group."""
+ self.commands[cmd.name] = cmd
+
+ def make_argparser(self):
+ """Create an argparse.ArgumentParser"""
+ parser = argparse.ArgumentParser(description=self.description)
+ subs = parser.add_subparsers(title='commands')
+ for name, cmd in sorted(self.commands.iteritems()):
+ sub = subs.add_parser(name, help=cmd.__doc__)
+ sub.set_defaults(subcommand=cmd)
+ cmd._populate_subparser(sub)
+ return parser
+
+ def run_argv(self, argv, stdin, stdout, stderr):
+ """Run a command, from a sys.argv[1:] style input."""
+ parser = self.make_argparser()
+ args = parser.parse_args(argv)
+ cmd = args.subcommand(stdin, stdout, stderr)
+ params, _, _, _ = inspect.getargspec(cmd.run)
+ vals = []
+ for param in params[1:]:
+ vals.append(getattr(args, param))
+ return cmd.run(*vals)
+
+
+class Command(object):
+ """Definition of a Command that can be run.
+
+ :cvar name: The name of the command, so that you can run
+ 'u1db-client <name>'.
+ """
+
+ name = None
+
+ def __init__(self, stdin, stdout, stderr):
+ self.stdin = stdin
+ self.stdout = stdout
+ self.stderr = stderr
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ """Child classes should override this to provide their arguments."""
+ raise NotImplementedError(cls._populate_subparser)
+
+ def run(self, *args):
+ """This is where the magic happens.
+
+ Subclasses should implement this, requesting their specific arguments.
+ """
+ raise NotImplementedError(self.run)
diff --git a/u1db/commandline/serve.py b/u1db/commandline/serve.py
new file mode 100644
index 00000000..0bb0e641
--- /dev/null
+++ b/u1db/commandline/serve.py
@@ -0,0 +1,34 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Build server for u1db-serve."""
+
+from paste import httpserver
+
+from u1db.remote import (
+ http_app,
+ server_state,
+ )
+
+
+def make_server(host, port, working_dir):
+ """Make a server on host and port exposing dbs living in working_dir."""
+ state = server_state.ServerState()
+ state.set_workingdir(working_dir)
+ application = http_app.HTTPApp(state)
+ server = httpserver.WSGIServer(application, (host, port),
+ httpserver.WSGIHandler)
+ return server
diff --git a/u1db/errors.py b/u1db/errors.py
new file mode 100644
index 00000000..967c7c38
--- /dev/null
+++ b/u1db/errors.py
@@ -0,0 +1,189 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A list of errors that u1db can raise."""
+
+
+class U1DBError(Exception):
+ """Generic base class for U1DB errors."""
+
+ # description/tag for identifying the error during transmission (http,...)
+ wire_description = "error"
+
+ def __init__(self, message=None):
+ self.message = message
+
+
+class RevisionConflict(U1DBError):
+ """The document revisions supplied does not match the current version."""
+
+ wire_description = "revision conflict"
+
+
+class InvalidJSON(U1DBError):
+ """Content was not valid json."""
+
+
+class InvalidContent(U1DBError):
+ """Content was not a python dictionary."""
+
+
+class InvalidDocId(U1DBError):
+ """A document was requested with an invalid document identifier."""
+
+ wire_description = "invalid document id"
+
+
+class MissingDocIds(U1DBError):
+ """Needs document ids."""
+
+ wire_description = "missing document ids"
+
+
+class DocumentTooBig(U1DBError):
+ """Document exceeds the maximum document size for this database."""
+
+ wire_description = "document too big"
+
+
+class UserQuotaExceeded(U1DBError):
+ """Document exceeds the maximum document size for this database."""
+
+ wire_description = "user quota exceeded"
+
+
+class SubscriptionNeeded(U1DBError):
+ """User needs a subscription to be able to use this replica.."""
+
+ wire_description = "user needs subscription"
+
+
+class InvalidTransactionId(U1DBError):
+ """Invalid transaction for generation."""
+
+ wire_description = "invalid transaction id"
+
+
+class InvalidGeneration(U1DBError):
+ """Generation was previously synced with a different transaction id."""
+
+ wire_description = "invalid generation"
+
+
+class ConflictedDoc(U1DBError):
+ """The document is conflicted, you must call resolve before put()"""
+
+
+class InvalidValueForIndex(U1DBError):
+ """The values supplied does not match the index definition."""
+
+
+class InvalidGlobbing(U1DBError):
+ """Raised if wildcard matches are not strictly at the tail of the request.
+ """
+
+
+class DocumentDoesNotExist(U1DBError):
+ """The document does not exist."""
+
+ wire_description = "document does not exist"
+
+
+class DocumentAlreadyDeleted(U1DBError):
+ """The document was already deleted."""
+
+ wire_description = "document already deleted"
+
+
+class DatabaseDoesNotExist(U1DBError):
+ """The database does not exist."""
+
+ wire_description = "database does not exist"
+
+
+class IndexNameTakenError(U1DBError):
+ """The given index name is already taken."""
+
+
+class IndexDefinitionParseError(U1DBError):
+ """The index definition cannot be parsed."""
+
+
+class IndexDoesNotExist(U1DBError):
+ """No index of that name exists."""
+
+
+class Unauthorized(U1DBError):
+ """Request wasn't authorized properly."""
+
+ wire_description = "unauthorized"
+
+
+class HTTPError(U1DBError):
+ """Unspecific HTTP errror."""
+
+ wire_description = None
+
+ def __init__(self, status, message=None, headers={}):
+ self.status = status
+ self.message = message
+ self.headers = headers
+
+ def __str__(self):
+ if not self.message:
+ return "HTTPError(%d)" % self.status
+ else:
+ return "HTTPError(%d, %r)" % (self.status, self.message)
+
+
+class Unavailable(HTTPError):
+ """Server not available not serve request."""
+
+ wire_description = "unavailable"
+
+ def __init__(self, message=None, headers={}):
+ super(Unavailable, self).__init__(503, message, headers)
+
+ def __str__(self):
+ if not self.message:
+ return "Unavailable()"
+ else:
+ return "Unavailable(%r)" % self.message
+
+
+class BrokenSyncStream(U1DBError):
+ """Unterminated or otherwise broken sync exchange stream."""
+
+ wire_description = None
+
+
+class UnknownAuthMethod(U1DBError):
+ """Unknown auhorization method."""
+
+ wire_description = None
+
+
+# mapping wire (transimission) descriptions/tags for errors to the exceptions
+wire_description_to_exc = dict(
+ (x.wire_description, x) for x in globals().values()
+ if getattr(x, 'wire_description', None) not in (None, "error")
+)
+wire_description_to_exc["error"] = U1DBError
+
+
+#
+# wire error descriptions not corresponding to an exception
+DOCUMENT_DELETED = "document deleted"
diff --git a/u1db/query_parser.py b/u1db/query_parser.py
new file mode 100644
index 00000000..f564821f
--- /dev/null
+++ b/u1db/query_parser.py
@@ -0,0 +1,370 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Code for parsing Index definitions."""
+
+import re
+from u1db import (
+ errors,
+ )
+
+
+class Getter(object):
+ """Get values from a document based on a specification."""
+
+ def get(self, raw_doc):
+ """Get a value from the document.
+
+ :param raw_doc: a python dictionary to get the value from.
+ :return: A list of values that match the description.
+ """
+ raise NotImplementedError(self.get)
+
+
+class StaticGetter(Getter):
+ """A getter that returns a defined value (independent of the doc)."""
+
+ def __init__(self, value):
+ """Create a StaticGetter.
+
+ :param value: the value to return when get is called.
+ """
+ if value is None:
+ self.value = []
+ elif isinstance(value, list):
+ self.value = value
+ else:
+ self.value = [value]
+
+ def get(self, raw_doc):
+ return self.value
+
+
+def extract_field(raw_doc, subfields, index=0):
+ if not isinstance(raw_doc, dict):
+ return []
+ val = raw_doc.get(subfields[index])
+ if val is None:
+ return []
+ if index < len(subfields) - 1:
+ if isinstance(val, list):
+ results = []
+ for item in val:
+ results.extend(extract_field(item, subfields, index + 1))
+ return results
+ if isinstance(val, dict):
+ return extract_field(val, subfields, index + 1)
+ return []
+ if isinstance(val, dict):
+ return []
+ if isinstance(val, list):
+ # Strip anything in the list that isn't a simple type
+ return [v for v in val if not isinstance(v, (dict, list))]
+ return [val]
+
+
+class ExtractField(Getter):
+ """Extract a field from the document."""
+
+ def __init__(self, field):
+ """Create an ExtractField object.
+
+ When a document is passed to get() this will return a value
+ from the document based on the field specifier passed to
+ the constructor.
+
+ None will be returned if the field is nonexistant, or refers to an
+ object, rather than a simple type or list of simple types.
+
+ :param field: a specifier for the field to return.
+ This is either a field name, or a dotted field name.
+ """
+ self.field = field.split('.')
+
+ def get(self, raw_doc):
+ return extract_field(raw_doc, self.field)
+
+
+class Transformation(Getter):
+ """A transformation on a value from another Getter."""
+
+ name = None
+ arity = 1
+ args = ['expression']
+
+ def __init__(self, inner):
+ """Create a transformation.
+
+ :param inner: the argument(s) to the transformation.
+ """
+ self.inner = inner
+
+ def get(self, raw_doc):
+ inner_values = self.inner.get(raw_doc)
+ assert isinstance(inner_values, list),\
+ 'get() should always return a list'
+ return self.transform(inner_values)
+
+ def transform(self, values):
+ """Transform the values.
+
+ This should be implemented by subclasses to transform the
+ value when get() is called.
+
+ :param values: the values from the other Getter
+ :return: the transformed values.
+ """
+ raise NotImplementedError(self.transform)
+
+
+class Lower(Transformation):
+ """Lowercase a string.
+
+ This transformation will return None for non-string inputs. However,
+ it will lowercase any strings in a list, dropping any elements
+ that are not strings.
+ """
+
+ name = "lower"
+
+ def _can_transform(self, val):
+ return isinstance(val, basestring)
+
+ def transform(self, values):
+ if not values:
+ return []
+ return [val.lower() for val in values if self._can_transform(val)]
+
+
+class Number(Transformation):
+ """Convert an integer to a zero padded string.
+
+ This transformation will return None for non-integer inputs. However, it
+ will transform any integers in a list, dropping any elements that are not
+ integers.
+ """
+
+ name = 'number'
+ arity = 2
+ args = ['expression', int]
+
+ def __init__(self, inner, number):
+ super(Number, self).__init__(inner)
+ self.padding = "%%0%sd" % number
+
+ def _can_transform(self, val):
+ return isinstance(val, int) and not isinstance(val, bool)
+
+ def transform(self, values):
+ """Transform any integers in values into zero padded strings."""
+ if not values:
+ return []
+ return [self.padding % (v,) for v in values if self._can_transform(v)]
+
+
+class Bool(Transformation):
+ """Convert bool to string."""
+
+ name = "bool"
+ args = ['expression']
+
+ def _can_transform(self, val):
+ return isinstance(val, bool)
+
+ def transform(self, values):
+ """Transform any booleans in values into strings."""
+ if not values:
+ return []
+ return [('1' if v else '0') for v in values if self._can_transform(v)]
+
+
+class SplitWords(Transformation):
+ """Split a string on whitespace.
+
+ This Getter will return [] for non-string inputs. It will however
+ split any strings in an input list, discarding any elements that
+ are not strings.
+ """
+
+ name = "split_words"
+
+ def _can_transform(self, val):
+ return isinstance(val, basestring)
+
+ def transform(self, values):
+ if not values:
+ return []
+ result = set()
+ for value in values:
+ if self._can_transform(value):
+ for word in value.split():
+ result.add(word)
+ return list(result)
+
+
+class Combine(Transformation):
+ """Combine multiple expressions into a single index."""
+
+ name = "combine"
+ # variable number of args
+ arity = -1
+
+ def __init__(self, *inner):
+ super(Combine, self).__init__(inner)
+
+ def get(self, raw_doc):
+ inner_values = []
+ for inner in self.inner:
+ inner_values.extend(inner.get(raw_doc))
+ return self.transform(inner_values)
+
+ def transform(self, values):
+ return values
+
+
+class IsNull(Transformation):
+ """Indicate whether the input is None.
+
+ This Getter returns a bool indicating whether the input is nil.
+ """
+
+ name = "is_null"
+
+ def transform(self, values):
+ return [len(values) == 0]
+
+
+def check_fieldname(fieldname):
+ if fieldname.endswith('.'):
+ raise errors.IndexDefinitionParseError(
+ "Fieldname cannot end in '.':%s^" % (fieldname,))
+
+
+class Parser(object):
+ """Parse an index expression into a sequence of transformations."""
+
+ _transformations = {}
+ _delimiters = re.compile("\(|\)|,")
+
+ def __init__(self):
+ self._tokens = []
+
+ def _set_expression(self, expression):
+ self._open_parens = 0
+ self._tokens = []
+ expression = expression.strip()
+ while expression:
+ delimiter = self._delimiters.search(expression)
+ if delimiter:
+ idx = delimiter.start()
+ if idx == 0:
+ result, expression = (expression[:1], expression[1:])
+ self._tokens.append(result)
+ else:
+ result, expression = (expression[:idx], expression[idx:])
+ result = result.strip()
+ if result:
+ self._tokens.append(result)
+ else:
+ expression = expression.strip()
+ if expression:
+ self._tokens.append(expression)
+ expression = None
+
+ def _get_token(self):
+ if self._tokens:
+ return self._tokens.pop(0)
+
+ def _peek_token(self):
+ if self._tokens:
+ return self._tokens[0]
+
+ @staticmethod
+ def _to_getter(term):
+ if isinstance(term, Getter):
+ return term
+ check_fieldname(term)
+ return ExtractField(term)
+
+ def _parse_op(self, op_name):
+ self._get_token() # '('
+ op = self._transformations.get(op_name, None)
+ if op is None:
+ raise errors.IndexDefinitionParseError(
+ "Unknown operation: %s" % op_name)
+ args = []
+ while True:
+ args.append(self._parse_term())
+ sep = self._get_token()
+ if sep == ')':
+ break
+ if sep != ',':
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' in parentheses." % (sep,))
+ parsed = []
+ for i, arg in enumerate(args):
+ arg_type = op.args[i % len(op.args)]
+ if arg_type == 'expression':
+ inner = self._to_getter(arg)
+ else:
+ try:
+ inner = arg_type(arg)
+ except ValueError, e:
+ raise errors.IndexDefinitionParseError(
+ "Invalid value %r for argument type %r "
+ "(%r)." % (arg, arg_type, e))
+ parsed.append(inner)
+ return op(*parsed)
+
+ def _parse_term(self):
+ term = self._get_token()
+ if term is None:
+ raise errors.IndexDefinitionParseError(
+ "Unexpected end of index definition.")
+ if term in (',', ')', '('):
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' at start of expression." % (term,))
+ next_token = self._peek_token()
+ if next_token == '(':
+ return self._parse_op(term)
+ return term
+
+ def parse(self, expression):
+ self._set_expression(expression)
+ term = self._to_getter(self._parse_term())
+ if self._peek_token():
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' after end of expression."
+ % (self._peek_token(),))
+ return term
+
+ def parse_all(self, fields):
+ return [self.parse(field) for field in fields]
+
+ @classmethod
+ def register_transormation(cls, transform):
+ assert transform.name not in cls._transformations, (
+ "Transform %s already registered for %s"
+ % (transform.name, cls._transformations[transform.name]))
+ cls._transformations[transform.name] = transform
+
+
+Parser.register_transormation(SplitWords)
+Parser.register_transormation(Lower)
+Parser.register_transormation(Number)
+Parser.register_transormation(Bool)
+Parser.register_transormation(IsNull)
+Parser.register_transormation(Combine)
diff --git a/u1db/remote/__init__.py b/u1db/remote/__init__.py
new file mode 100644
index 00000000..3f32e381
--- /dev/null
+++ b/u1db/remote/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
diff --git a/u1db/remote/basic_auth_middleware.py b/u1db/remote/basic_auth_middleware.py
new file mode 100644
index 00000000..a2cbff62
--- /dev/null
+++ b/u1db/remote/basic_auth_middleware.py
@@ -0,0 +1,68 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+"""U1DB Basic Auth authorisation WSGI middleware."""
+import httplib
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from wsgiref.util import shift_path_info
+
+
+class Unauthorized(Exception):
+ """User authorization failed."""
+
+
+class BasicAuthMiddleware(object):
+ """U1DB Basic Auth Authorisation WSGI middleware."""
+
+ def __init__(self, app, prefix):
+ self.app = app
+ self.prefix = prefix
+
+ def _error(self, start_response, status, description, message=None):
+ start_response("%d %s" % (status, httplib.responses[status]),
+ [('content-type', 'application/json')])
+ err = {"error": description}
+ if message:
+ err['message'] = message
+ return [json.dumps(err)]
+
+ def __call__(self, environ, start_response):
+ if self.prefix and not environ['PATH_INFO'].startswith(self.prefix):
+ return self._error(start_response, 400, "bad request")
+ auth = environ.get('HTTP_AUTHORIZATION')
+ if not auth:
+ return self._error(start_response, 401, "unauthorized",
+ "Missing Basic Authentication.")
+ scheme, encoded = auth.split(None, 1)
+ if scheme.lower() != 'basic':
+ return self._error(
+ start_response, 401, "unauthorized",
+ "Missing Basic Authentication")
+ user, password = encoded.decode('base64').split(':', 1)
+ try:
+ self.verify_user(environ, user, password)
+ except Unauthorized:
+ return self._error(
+ start_response, 401, "unauthorized",
+ "Incorrect password or login.")
+ del environ['HTTP_AUTHORIZATION']
+ shift_path_info(environ)
+ return self.app(environ, start_response)
+
+ def verify_user(self, environ, username, password):
+ raise NotImplementedError(self.verify_user)
diff --git a/u1db/remote/http_app.py b/u1db/remote/http_app.py
new file mode 100644
index 00000000..3d7d4248
--- /dev/null
+++ b/u1db/remote/http_app.py
@@ -0,0 +1,629 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""HTTP Application exposing U1DB."""
+
+import functools
+import httplib
+import inspect
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import sys
+import urlparse
+
+import routes.mapper
+
+from u1db import (
+ __version__ as _u1db_version,
+ DBNAME_CONSTRAINTS,
+ Document,
+ errors,
+ sync,
+ )
+from u1db.remote import (
+ http_errors,
+ utils,
+ )
+
+
+def parse_bool(expression):
+ """Parse boolean querystring parameter."""
+ if expression == 'true':
+ return True
+ return False
+
+
+def parse_list(expression):
+ if expression is None:
+ return []
+ return [t.strip() for t in expression.split(',')]
+
+
+def none_or_str(expression):
+ if expression is None:
+ return None
+ return str(expression)
+
+
+class BadRequest(Exception):
+ """Bad request."""
+
+
+class _FencedReader(object):
+ """Read and get lines from a file but not past a given length."""
+
+ MAXCHUNK = 8192
+
+ def __init__(self, rfile, total, max_entry_size):
+ self.rfile = rfile
+ self.remaining = total
+ self.max_entry_size = max_entry_size
+ self._kept = None
+
+ def read_chunk(self, atmost):
+ if self._kept is not None:
+ # ignore atmost, kept data should be a subchunk anyway
+ kept, self._kept = self._kept, None
+ return kept
+ if self.remaining == 0:
+ return ''
+ data = self.rfile.read(min(self.remaining, atmost))
+ self.remaining -= len(data)
+ return data
+
+ def getline(self):
+ line_parts = []
+ size = 0
+ while True:
+ chunk = self.read_chunk(self.MAXCHUNK)
+ if chunk == '':
+ break
+ nl = chunk.find("\n")
+ if nl != -1:
+ size += nl + 1
+ if size > self.max_entry_size:
+ raise BadRequest
+ line_parts.append(chunk[:nl + 1])
+ rest = chunk[nl + 1:]
+ self._kept = rest or None
+ break
+ else:
+ size += len(chunk)
+ if size > self.max_entry_size:
+ raise BadRequest
+ line_parts.append(chunk)
+ return ''.join(line_parts)
+
+
+def http_method(**control):
+ """Decoration for handling of query arguments and content for a HTTP
+ method.
+
+ args and content here are the query arguments and body of the incoming
+ HTTP requests.
+
+ Match query arguments to python method arguments:
+ w = http_method()(f)
+ w(self, args, content) => args["content"]=content;
+ f(self, **args)
+
+ JSON deserialize content to arguments:
+ w = http_method(content_as_args=True,...)(f)
+ w(self, args, content) => args.update(json.loads(content));
+ f(self, **args)
+
+ Support conversions (e.g int):
+ w = http_method(Arg=Conv,...)(f)
+ w(self, args, content) => args["Arg"]=Conv(args["Arg"]);
+ f(self, **args)
+
+ Enforce no use of query arguments:
+ w = http_method(no_query=True,...)(f)
+ w(self, args, content) raises BadRequest if args is not empty
+
+ Argument mismatches, deserialisation failures produce BadRequest.
+ """
+ content_as_args = control.pop('content_as_args', False)
+ no_query = control.pop('no_query', False)
+ conversions = control.items()
+
+ def wrap(f):
+ argspec = inspect.getargspec(f)
+ assert argspec.args[0] == "self"
+ nargs = len(argspec.args)
+ ndefaults = len(argspec.defaults or ())
+ required_args = set(argspec.args[1:nargs - ndefaults])
+ all_args = set(argspec.args)
+
+ @functools.wraps(f)
+ def wrapper(self, args, content):
+ if no_query and args:
+ raise BadRequest()
+ if content is not None:
+ if content_as_args:
+ try:
+ args.update(json.loads(content))
+ except ValueError:
+ raise BadRequest()
+ else:
+ args["content"] = content
+ if not (required_args <= set(args) <= all_args):
+ raise BadRequest("Missing required arguments.")
+ for name, conv in conversions:
+ if name not in args:
+ continue
+ try:
+ args[name] = conv(args[name])
+ except ValueError:
+ raise BadRequest()
+ return f(self, **args)
+
+ return wrapper
+
+ return wrap
+
+
+class URLToResource(object):
+ """Mappings from URLs to resources."""
+
+ def __init__(self):
+ self._map = routes.mapper.Mapper(controller_scan=None)
+
+ def register(self, resource_cls):
+ # register
+ self._map.connect(None, resource_cls.url_pattern,
+ resource_cls=resource_cls,
+ requirements={"dbname": DBNAME_CONSTRAINTS})
+ self._map.create_regs()
+ return resource_cls
+
+ def match(self, path):
+ params = self._map.match(path)
+ if params is None:
+ return None, None
+ resource_cls = params.pop('resource_cls')
+ return resource_cls, params
+
+url_to_resource = URLToResource()
+
+
+@url_to_resource.register
+class GlobalResource(object):
+ """Global (root) resource."""
+
+ url_pattern = "/"
+
+ def __init__(self, state, responder):
+ self.responder = responder
+
+ @http_method()
+ def get(self):
+ self.responder.send_response_json(version=_u1db_version)
+
+
+@url_to_resource.register
+class DatabaseResource(object):
+ """Database resource."""
+
+ url_pattern = "/{dbname}"
+
+ def __init__(self, dbname, state, responder):
+ self.dbname = dbname
+ self.state = state
+ self.responder = responder
+
+ @http_method()
+ def get(self):
+ self.state.check_database(self.dbname)
+ self.responder.send_response_json(200)
+
+ @http_method(content_as_args=True)
+ def put(self):
+ self.state.ensure_database(self.dbname)
+ self.responder.send_response_json(200, ok=True)
+
+ @http_method()
+ def delete(self):
+ self.state.delete_database(self.dbname)
+ self.responder.send_response_json(200, ok=True)
+
+
+@url_to_resource.register
+class DocsResource(object):
+ """Documents resource."""
+
+ url_pattern = "/{dbname}/docs"
+
+ def __init__(self, dbname, state, responder):
+ self.responder = responder
+ self.db = state.open_database(dbname)
+
+ @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool,
+ include_deleted=parse_bool)
+ def get(self, doc_ids=None, check_for_conflicts=True,
+ include_deleted=False):
+ if doc_ids is None:
+ raise errors.MissingDocIds
+ docs = self.db.get_docs(doc_ids, include_deleted=include_deleted)
+ self.responder.content_type = 'application/json'
+ self.responder.start_response(200)
+ self.responder.start_stream(),
+ for doc in docs:
+ entry = dict(
+ doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(),
+ has_conflicts=doc.has_conflicts)
+ self.responder.stream_entry(entry)
+ self.responder.end_stream()
+ self.responder.finish_response()
+
+
+@url_to_resource.register
+class DocResource(object):
+ """Document resource."""
+
+ url_pattern = "/{dbname}/doc/{id:.*}"
+
+ def __init__(self, dbname, id, state, responder):
+ self.id = id
+ self.responder = responder
+ self.db = state.open_database(dbname)
+
+ @http_method(old_rev=str)
+ def put(self, content, old_rev=None):
+ doc = Document(self.id, old_rev, content)
+ doc_rev = self.db.put_doc(doc)
+ if old_rev is None:
+ status = 201 # created
+ else:
+ status = 200
+ self.responder.send_response_json(status, rev=doc_rev)
+
+ @http_method(old_rev=str)
+ def delete(self, old_rev=None):
+ doc = Document(self.id, old_rev, None)
+ self.db.delete_doc(doc)
+ self.responder.send_response_json(200, rev=doc.rev)
+
+ @http_method(include_deleted=parse_bool)
+ def get(self, include_deleted=False):
+ doc = self.db.get_doc(self.id, include_deleted=include_deleted)
+ if doc is None:
+ wire_descr = errors.DocumentDoesNotExist.wire_description
+ self.responder.send_response_json(
+ http_errors.wire_description_to_status[wire_descr],
+ error=wire_descr,
+ headers={
+ 'x-u1db-rev': '',
+ 'x-u1db-has-conflicts': 'false'
+ })
+ return
+ headers = {
+ 'x-u1db-rev': doc.rev,
+ 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts)
+ }
+ if doc.is_tombstone():
+ self.responder.send_response_json(
+ http_errors.wire_description_to_status[
+ errors.DOCUMENT_DELETED],
+ error=errors.DOCUMENT_DELETED,
+ headers=headers)
+ else:
+ self.responder.send_response_content(
+ doc.get_json(), headers=headers)
+
+
+@url_to_resource.register
+class SyncResource(object):
+ """Sync endpoint resource."""
+
+ # maximum allowed request body size
+ max_request_size = 15 * 1024 * 1024 # 15Mb
+ # maximum allowed entry/line size in request body
+ max_entry_size = 10 * 1024 * 1024 # 10Mb
+
+ url_pattern = "/{dbname}/sync-from/{source_replica_uid}"
+
+ # pluggable
+ sync_exchange_class = sync.SyncExchange
+
+ def __init__(self, dbname, source_replica_uid, state, responder):
+ self.source_replica_uid = source_replica_uid
+ self.responder = responder
+ self.state = state
+ self.dbname = dbname
+ self.replica_uid = None
+
+ def get_target(self):
+ return self.state.open_database(self.dbname).get_sync_target()
+
+ @http_method()
+ def get(self):
+ result = self.get_target().get_sync_info(self.source_replica_uid)
+ self.responder.send_response_json(
+ target_replica_uid=result[0], target_replica_generation=result[1],
+ target_replica_transaction_id=result[2],
+ source_replica_uid=self.source_replica_uid,
+ source_replica_generation=result[3],
+ source_transaction_id=result[4])
+
+ @http_method(generation=int,
+ content_as_args=True, no_query=True)
+ def put(self, generation, transaction_id):
+ self.get_target().record_sync_info(self.source_replica_uid,
+ generation,
+ transaction_id)
+ self.responder.send_response_json(ok=True)
+
+ # Implements the same logic as LocalSyncTarget.sync_exchange
+
+ @http_method(last_known_generation=int, last_known_trans_id=none_or_str,
+ content_as_args=True)
+ def post_args(self, last_known_generation, last_known_trans_id=None,
+ ensure=False):
+ if ensure:
+ db, self.replica_uid = self.state.ensure_database(self.dbname)
+ else:
+ db = self.state.open_database(self.dbname)
+ db.validate_gen_and_trans_id(
+ last_known_generation, last_known_trans_id)
+ self.sync_exch = self.sync_exchange_class(
+ db, self.source_replica_uid, last_known_generation)
+
+ @http_method(content_as_args=True)
+ def post_stream_entry(self, id, rev, content, gen, trans_id):
+ doc = Document(id, rev, content)
+ self.sync_exch.insert_doc_from_source(doc, gen, trans_id)
+
+ def post_end(self):
+
+ def send_doc(doc, gen, trans_id):
+ entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(),
+ gen=gen, trans_id=trans_id)
+ self.responder.stream_entry(entry)
+
+ new_gen = self.sync_exch.find_changes_to_return()
+ self.responder.content_type = 'application/x-u1db-sync-stream'
+ self.responder.start_response(200)
+ self.responder.start_stream(),
+ header = {"new_generation": new_gen,
+ "new_transaction_id": self.sync_exch.new_trans_id}
+ if self.replica_uid is not None:
+ header['replica_uid'] = self.replica_uid
+ self.responder.stream_entry(header)
+ self.sync_exch.return_docs(send_doc)
+ self.responder.end_stream()
+ self.responder.finish_response()
+
+
+class HTTPResponder(object):
+ """Encode responses from the server back to the client."""
+
+ # a multi document response will put args and documents
+ # each on one line of the response body
+
+ def __init__(self, start_response):
+ self._started = False
+ self._stream_state = -1
+ self._no_initial_obj = True
+ self.sent_response = False
+ self._start_response = start_response
+ self._write = None
+ self.content_type = 'application/json'
+ self.content = []
+
+ def start_response(self, status, obj_dic=None, headers={}):
+ """start sending response with optional first json object."""
+ if self._started:
+ return
+ self._started = True
+ status_text = httplib.responses[status]
+ self._write = self._start_response('%d %s' % (status, status_text),
+ [('content-type', self.content_type),
+ ('cache-control', 'no-cache')] +
+ headers.items())
+ # xxx version in headers
+ if obj_dic is not None:
+ self._no_initial_obj = False
+ self._write(json.dumps(obj_dic) + "\r\n")
+
+ def finish_response(self):
+ """finish sending response."""
+ self.sent_response = True
+
+ def send_response_json(self, status=200, headers={}, **kwargs):
+ """send and finish response with json object body from keyword args."""
+ content = json.dumps(kwargs) + "\r\n"
+ self.send_response_content(content, headers=headers, status=status)
+
+ def send_response_content(self, content, status=200, headers={}):
+ """send and finish response with content"""
+ headers['content-length'] = str(len(content))
+ self.start_response(status, headers=headers)
+ if self._stream_state == 1:
+ self.content = [',\r\n', content]
+ else:
+ self.content = [content]
+ self.finish_response()
+
+ def start_stream(self):
+ "start stream (array) as part of the response."
+ assert self._started and self._no_initial_obj
+ self._stream_state = 0
+ self._write("[")
+
+ def stream_entry(self, entry):
+ "send stream entry as part of the response."
+ assert self._stream_state != -1
+ if self._stream_state == 0:
+ self._stream_state = 1
+ self._write('\r\n')
+ else:
+ self._write(',\r\n')
+ self._write(json.dumps(entry))
+
+ def end_stream(self):
+ "end stream (array)."
+ assert self._stream_state != -1
+ self._write("\r\n]\r\n")
+
+
+class HTTPInvocationByMethodWithBody(object):
+ """Invoke methods on a resource."""
+
+ def __init__(self, resource, environ, parameters):
+ self.resource = resource
+ self.environ = environ
+ self.max_request_size = getattr(
+ resource, 'max_request_size', parameters.max_request_size)
+ self.max_entry_size = getattr(
+ resource, 'max_entry_size', parameters.max_entry_size)
+
+ def _lookup(self, method):
+ try:
+ return getattr(self.resource, method)
+ except AttributeError:
+ raise BadRequest()
+
+ def __call__(self):
+ args = urlparse.parse_qsl(self.environ['QUERY_STRING'],
+ strict_parsing=False)
+ try:
+ args = dict(
+ (k.decode('utf-8'), v.decode('utf-8')) for k, v in args)
+ except ValueError:
+ raise BadRequest()
+ method = self.environ['REQUEST_METHOD'].lower()
+ if method in ('get', 'delete'):
+ meth = self._lookup(method)
+ return meth(args, None)
+ else:
+ # we expect content-length > 0, reconsider if we move
+ # to support chunked enconding
+ try:
+ content_length = int(self.environ['CONTENT_LENGTH'])
+ except (ValueError, KeyError):
+ raise BadRequest
+ if content_length <= 0:
+ raise BadRequest
+ if content_length > self.max_request_size:
+ raise BadRequest
+ reader = _FencedReader(self.environ['wsgi.input'], content_length,
+ self.max_entry_size)
+ content_type = self.environ.get('CONTENT_TYPE')
+ if content_type == 'application/json':
+ meth = self._lookup(method)
+ body = reader.read_chunk(sys.maxint)
+ return meth(args, body)
+ elif content_type == 'application/x-u1db-sync-stream':
+ meth_args = self._lookup('%s_args' % method)
+ meth_entry = self._lookup('%s_stream_entry' % method)
+ meth_end = self._lookup('%s_end' % method)
+ body_getline = reader.getline
+ if body_getline().strip() != '[':
+ raise BadRequest()
+ line = body_getline()
+ line, comma = utils.check_and_strip_comma(line.strip())
+ meth_args(args, line)
+ while True:
+ line = body_getline()
+ entry = line.strip()
+ if entry == ']':
+ break
+ if not entry or not comma: # empty or no prec comma
+ raise BadRequest
+ entry, comma = utils.check_and_strip_comma(entry)
+ meth_entry({}, entry)
+ if comma or body_getline(): # extra comma or data
+ raise BadRequest
+ return meth_end()
+ else:
+ raise BadRequest()
+
+
+class HTTPApp(object):
+
+ # maximum allowed request body size
+ max_request_size = 15 * 1024 * 1024 # 15Mb
+ # maximum allowed entry/line size in request body
+ max_entry_size = 10 * 1024 * 1024 # 10Mb
+
+ def __init__(self, state):
+ self.state = state
+
+ def _lookup_resource(self, environ, responder):
+ resource_cls, params = url_to_resource.match(environ['PATH_INFO'])
+ if resource_cls is None:
+ raise BadRequest # 404 instead?
+ resource = resource_cls(
+ state=self.state, responder=responder, **params)
+ return resource
+
+ def __call__(self, environ, start_response):
+ responder = HTTPResponder(start_response)
+ self.request_begin(environ)
+ try:
+ resource = self._lookup_resource(environ, responder)
+ HTTPInvocationByMethodWithBody(resource, environ, self)()
+ except errors.U1DBError, e:
+ self.request_u1db_error(environ, e)
+ status = http_errors.wire_description_to_status.get(
+ e.wire_description, 500)
+ responder.send_response_json(status, error=e.wire_description)
+ except BadRequest:
+ self.request_bad_request(environ)
+ responder.send_response_json(400, error="bad request")
+ except KeyboardInterrupt:
+ raise
+ except:
+ self.request_failed(environ)
+ raise
+ else:
+ self.request_done(environ)
+ return responder.content
+
+ # hooks for tracing requests
+
+ def request_begin(self, environ):
+ """Hook called at the beginning of processing a request."""
+ pass
+
+ def request_done(self, environ):
+ """Hook called when done processing a request."""
+ pass
+
+ def request_u1db_error(self, environ, exc):
+ """Hook called when processing a request resulted in a U1DBError.
+
+ U1DBError passed as exc.
+ """
+ pass
+
+ def request_bad_request(self, environ):
+ """Hook called when processing a bad request.
+
+ No actual processing was done.
+ """
+ pass
+
+ def request_failed(self, environ):
+ """Hook called when processing a request failed unexpectedly.
+
+ Invoked from an except block, so there's interpreter exception
+ information available.
+ """
+ pass
diff --git a/u1db/remote/http_client.py b/u1db/remote/http_client.py
new file mode 100644
index 00000000..decddda3
--- /dev/null
+++ b/u1db/remote/http_client.py
@@ -0,0 +1,218 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Base class to make requests to a remote HTTP server."""
+
+import httplib
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import socket
+import ssl
+import sys
+import urlparse
+import urllib
+
+from time import sleep
+from u1db import (
+ errors,
+ )
+from u1db.remote import (
+ http_errors,
+ )
+
+from u1db.remote.ssl_match_hostname import ( # noqa
+ CertificateError,
+ match_hostname,
+ )
+
+# Ubuntu/debian
+# XXX other...
+CA_CERTS = "/etc/ssl/certs/ca-certificates.crt"
+
+
+def _encode_query_parameter(value):
+ """Encode query parameter."""
+ if isinstance(value, bool):
+ if value:
+ value = 'true'
+ else:
+ value = 'false'
+ return unicode(value).encode('utf-8')
+
+
+class _VerifiedHTTPSConnection(httplib.HTTPSConnection):
+ """HTTPSConnection verifying server side certificates."""
+ # derived from httplib.py
+
+ def connect(self):
+ "Connect to a host on a given (SSL) port."
+
+ sock = socket.create_connection((self.host, self.port),
+ self.timeout, self.source_address)
+ if self._tunnel_host:
+ self.sock = sock
+ self._tunnel()
+ if sys.platform.startswith('linux'):
+ cert_opts = {
+ 'cert_reqs': ssl.CERT_REQUIRED,
+ 'ca_certs': CA_CERTS
+ }
+ else:
+ # XXX no cert verification implemented elsewhere for now
+ cert_opts = {}
+ self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
+ ssl_version=ssl.PROTOCOL_SSLv3,
+ **cert_opts
+ )
+ if cert_opts:
+ match_hostname(self.sock.getpeercert(), self.host)
+
+
+class HTTPClientBase(object):
+ """Base class to make requests to a remote HTTP server."""
+
+ # by default use HMAC-SHA1 OAuth signature method to not disclose
+ # tokens
+ # NB: given that the content bodies are not covered by the
+ # signatures though, to achieve security (against man-in-the-middle
+ # attacks for example) one would need HTTPS
+ oauth_signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1()
+
+ # Will use these delays to retry on 503 befor finally giving up. The final
+ # 0 is there to not wait after the final try fails.
+ _delays = (1, 1, 2, 4, 0)
+
+ def __init__(self, url, creds=None):
+ self._url = urlparse.urlsplit(url)
+ self._conn = None
+ self._creds = {}
+ if creds is not None:
+ if len(creds) != 1:
+ raise errors.UnknownAuthMethod()
+ auth_meth, credentials = creds.items()[0]
+ try:
+ set_creds = getattr(self, 'set_%s_credentials' % auth_meth)
+ except AttributeError:
+ raise errors.UnknownAuthMethod(auth_meth)
+ set_creds(**credentials)
+
+ def set_oauth_credentials(self, consumer_key, consumer_secret,
+ token_key, token_secret):
+ self._creds = {'oauth': (
+ oauth.OAuthConsumer(consumer_key, consumer_secret),
+ oauth.OAuthToken(token_key, token_secret))}
+
+ def _ensure_connection(self):
+ if self._conn is not None:
+ return
+ if self._url.scheme == 'https':
+ connClass = _VerifiedHTTPSConnection
+ else:
+ connClass = httplib.HTTPConnection
+ self._conn = connClass(self._url.hostname, self._url.port)
+
+ def close(self):
+ if self._conn:
+ self._conn.close()
+ self._conn = None
+
+ # xxx retry mechanism?
+
+ def _error(self, respdic):
+ descr = respdic.get("error")
+ exc_cls = errors.wire_description_to_exc.get(descr)
+ if exc_cls is not None:
+ message = respdic.get("message")
+ raise exc_cls(message)
+
+ def _response(self):
+ resp = self._conn.getresponse()
+ body = resp.read()
+ headers = dict(resp.getheaders())
+ if resp.status in (200, 201):
+ return body, headers
+ elif resp.status in http_errors.ERROR_STATUSES:
+ try:
+ respdic = json.loads(body)
+ except ValueError:
+ pass
+ else:
+ self._error(respdic)
+ # special case
+ if resp.status == 503:
+ raise errors.Unavailable(body, headers)
+ raise errors.HTTPError(resp.status, body, headers)
+
+ def _sign_request(self, method, url_query, params):
+ if 'oauth' in self._creds:
+ consumer, token = self._creds['oauth']
+ full_url = "%s://%s%s" % (self._url.scheme, self._url.netloc,
+ url_query)
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ consumer, token,
+ http_method=method,
+ parameters=params,
+ http_url=full_url
+ )
+ oauth_req.sign_request(
+ self.oauth_signature_method, consumer, token)
+ # Authorization: OAuth ...
+ return oauth_req.to_header().items()
+ else:
+ return []
+
+ def _request(self, method, url_parts, params=None, body=None,
+ content_type=None):
+ self._ensure_connection()
+ unquoted_url = url_query = self._url.path
+ if url_parts:
+ if not url_query.endswith('/'):
+ url_query += '/'
+ unquoted_url = url_query
+ url_query += '/'.join(urllib.quote(part, safe='')
+ for part in url_parts)
+ # oauth performs its own quoting
+ unquoted_url += '/'.join(url_parts)
+ encoded_params = {}
+ if params:
+ for key, value in params.items():
+ key = unicode(key).encode('utf-8')
+ encoded_params[key] = _encode_query_parameter(value)
+ url_query += ('?' + urllib.urlencode(encoded_params))
+ if body is not None and not isinstance(body, basestring):
+ body = json.dumps(body)
+ content_type = 'application/json'
+ headers = {}
+ if content_type:
+ headers['content-type'] = content_type
+ headers.update(
+ self._sign_request(method, unquoted_url, encoded_params))
+ for delay in self._delays:
+ try:
+ self._conn.request(method, url_query, body, headers)
+ return self._response()
+ except errors.Unavailable, e:
+ sleep(delay)
+ raise e
+
+ def _request_json(self, method, url_parts, params=None, body=None,
+ content_type=None):
+ res, headers = self._request(method, url_parts, params, body,
+ content_type)
+ return json.loads(res), headers
diff --git a/u1db/remote/http_database.py b/u1db/remote/http_database.py
new file mode 100644
index 00000000..6901baad
--- /dev/null
+++ b/u1db/remote/http_database.py
@@ -0,0 +1,143 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""HTTPDatabase to access a remote db over the HTTP API."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import uuid
+
+from u1db import (
+ Database,
+ Document,
+ errors,
+ )
+from u1db.remote import (
+ http_client,
+ http_errors,
+ http_target,
+ )
+
+
+DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[
+ errors.DOCUMENT_DELETED]
+
+
+class HTTPDatabase(http_client.HTTPClientBase, Database):
+ """Implement the Database API to a remote HTTP server."""
+
+ def __init__(self, url, document_factory=None, creds=None):
+ super(HTTPDatabase, self).__init__(url, creds=creds)
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ @staticmethod
+ def open_database(url, create):
+ db = HTTPDatabase(url)
+ db.open(create)
+ return db
+
+ @staticmethod
+ def delete_database(url):
+ db = HTTPDatabase(url)
+ db._delete()
+ db.close()
+
+ def open(self, create):
+ if create:
+ self._ensure()
+ else:
+ self._check()
+
+ def _check(self):
+ return self._request_json('GET', [])[0]
+
+ def _ensure(self):
+ self._request_json('PUT', [], {}, {})
+
+ def _delete(self):
+ self._request_json('DELETE', [], {}, {})
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ params = {}
+ if doc.rev is not None:
+ params['old_rev'] = doc.rev
+ res, headers = self._request_json('PUT', ['doc', doc.doc_id], params,
+ doc.get_json(), 'application/json')
+ doc.rev = res['rev']
+ return res['rev']
+
+ def get_doc(self, doc_id, include_deleted=False):
+ try:
+ res, headers = self._request(
+ 'GET', ['doc', doc_id], {"include_deleted": include_deleted})
+ except errors.DocumentDoesNotExist:
+ return None
+ except errors.HTTPError, e:
+ if (e.status == DOCUMENT_DELETED_STATUS and
+ 'x-u1db-rev' in e.headers):
+ res = None
+ headers = e.headers
+ else:
+ raise
+ doc_rev = headers['x-u1db-rev']
+ has_conflicts = json.loads(headers['x-u1db-has-conflicts'])
+ doc = self._factory(doc_id, doc_rev, res)
+ doc.has_conflicts = has_conflicts
+ return doc
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ if not doc_ids:
+ return
+ doc_ids = ','.join(doc_ids)
+ res, headers = self._request(
+ 'GET', ['docs'], {
+ "doc_ids": doc_ids, "include_deleted": include_deleted,
+ "check_for_conflicts": check_for_conflicts})
+ for doc_dict in json.loads(res):
+ doc = self._factory(
+ doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content'])
+ doc.has_conflicts = doc_dict['has_conflicts']
+ yield doc
+
+ def create_doc_from_json(self, content, doc_id=None):
+ if doc_id is None:
+ doc_id = 'D-%s' % (uuid.uuid4().hex,)
+ res, headers = self._request_json('PUT', ['doc', doc_id], {},
+ content, 'application/json')
+ new_doc = self._factory(doc_id, res['rev'], content)
+ return new_doc
+
+ def delete_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ params = {'old_rev': doc.rev}
+ res, headers = self._request_json('DELETE',
+ ['doc', doc.doc_id], params)
+ doc.make_tombstone()
+ doc.rev = res['rev']
+
+ def get_sync_target(self):
+ st = http_target.HTTPSyncTarget(self._url.geturl())
+ st._creds = self._creds
+ return st
diff --git a/u1db/remote/http_errors.py b/u1db/remote/http_errors.py
new file mode 100644
index 00000000..2039c5b2
--- /dev/null
+++ b/u1db/remote/http_errors.py
@@ -0,0 +1,46 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Information about the encoding of errors over HTTP."""
+
+from u1db import (
+ errors,
+ )
+
+
+# error wire descriptions mapping to HTTP status codes
+wire_description_to_status = dict([
+ (errors.InvalidDocId.wire_description, 400),
+ (errors.MissingDocIds.wire_description, 400),
+ (errors.Unauthorized.wire_description, 401),
+ (errors.DocumentTooBig.wire_description, 403),
+ (errors.UserQuotaExceeded.wire_description, 403),
+ (errors.SubscriptionNeeded.wire_description, 403),
+ (errors.DatabaseDoesNotExist.wire_description, 404),
+ (errors.DocumentDoesNotExist.wire_description, 404),
+ (errors.DocumentAlreadyDeleted.wire_description, 404),
+ (errors.RevisionConflict.wire_description, 409),
+ (errors.InvalidGeneration.wire_description, 409),
+ (errors.InvalidTransactionId.wire_description, 409),
+ (errors.Unavailable.wire_description, 503),
+# without matching exception
+ (errors.DOCUMENT_DELETED, 404)
+])
+
+
+ERROR_STATUSES = set(wire_description_to_status.values())
+# 400 included explicitly for tests
+ERROR_STATUSES.add(400)
diff --git a/u1db/remote/http_target.py b/u1db/remote/http_target.py
new file mode 100644
index 00000000..1028963e
--- /dev/null
+++ b/u1db/remote/http_target.py
@@ -0,0 +1,135 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""SyncTarget API implementation to a remote HTTP server."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ Document,
+ SyncTarget,
+ )
+from u1db.errors import (
+ BrokenSyncStream,
+ )
+from u1db.remote import (
+ http_client,
+ utils,
+ )
+
+
+class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget):
+ """Implement the SyncTarget api to a remote HTTP server."""
+
+ @staticmethod
+ def connect(url):
+ return HTTPSyncTarget(url)
+
+ def get_sync_info(self, source_replica_uid):
+ self._ensure_connection()
+ res, _ = self._request_json('GET', ['sync-from', source_replica_uid])
+ return (res['target_replica_uid'], res['target_replica_generation'],
+ res['target_replica_transaction_id'],
+ res['source_replica_generation'], res['source_transaction_id'])
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ self._ensure_connection()
+ if self._trace_hook: # for tests
+ self._trace_hook('record_sync_info')
+ self._request_json('PUT', ['sync-from', source_replica_uid], {},
+ {'generation': source_replica_generation,
+ 'transaction_id': source_transaction_id})
+
+ def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None):
+ parts = data.splitlines() # one at a time
+ if not parts or parts[0] != '[':
+ raise BrokenSyncStream
+ data = parts[1:-1]
+ comma = False
+ if data:
+ line, comma = utils.check_and_strip_comma(data[0])
+ res = json.loads(line)
+ if ensure_callback and 'replica_uid' in res:
+ ensure_callback(res['replica_uid'])
+ for entry in data[1:]:
+ if not comma: # missing in between comma
+ raise BrokenSyncStream
+ line, comma = utils.check_and_strip_comma(entry)
+ entry = json.loads(line)
+ doc = Document(entry['id'], entry['rev'], entry['content'])
+ return_doc_cb(doc, entry['gen'], entry['trans_id'])
+ if parts[-1] != ']':
+ try:
+ partdic = json.loads(parts[-1])
+ except ValueError:
+ pass
+ else:
+ if isinstance(partdic, dict):
+ self._error(partdic)
+ raise BrokenSyncStream
+ if not data or comma: # no entries or bad extra comma
+ raise BrokenSyncStream
+ return res
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ self._ensure_connection()
+ if self._trace_hook: # for tests
+ self._trace_hook('sync_exchange')
+ url = '%s/sync-from/%s' % (self._url.path, source_replica_uid)
+ self._conn.putrequest('POST', url)
+ self._conn.putheader('content-type', 'application/x-u1db-sync-stream')
+ for header_name, header_value in self._sign_request('POST', url, {}):
+ self._conn.putheader(header_name, header_value)
+ entries = ['[']
+ size = 1
+
+ def prepare(**dic):
+ entry = comma + '\r\n' + json.dumps(dic)
+ entries.append(entry)
+ return len(entry)
+
+ comma = ''
+ size += prepare(
+ last_known_generation=last_known_generation,
+ last_known_trans_id=last_known_trans_id,
+ ensure=ensure_callback is not None)
+ comma = ','
+ for doc, gen, trans_id in docs_by_generations:
+ size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(),
+ gen=gen, trans_id=trans_id)
+ entries.append('\r\n]')
+ size += len(entries[-1])
+ self._conn.putheader('content-length', str(size))
+ self._conn.endheaders()
+ for entry in entries:
+ self._conn.send(entry)
+ entries = None
+ data, _ = self._response()
+ res = self._parse_sync_stream(data, return_doc_cb, ensure_callback)
+ data = None
+ return res['new_generation'], res['new_transaction_id']
+
+ # for tests
+ _trace_hook = None
+
+ def _set_trace_hook_shallow(self, cb):
+ self._trace_hook = cb
diff --git a/u1db/remote/oauth_middleware.py b/u1db/remote/oauth_middleware.py
new file mode 100644
index 00000000..5772580a
--- /dev/null
+++ b/u1db/remote/oauth_middleware.py
@@ -0,0 +1,89 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+"""U1DB OAuth authorisation WSGI middleware."""
+import httplib
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from urllib import quote
+from wsgiref.util import shift_path_info
+
+
+sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
+sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
+
+
+class OAuthMiddleware(object):
+ """U1DB OAuth Authorisation WSGI middleware."""
+
+ # max seconds the request timestamp is allowed to be shifted
+ # from arrival time
+ timestamp_threshold = 300
+
+ def __init__(self, app, base_url, prefix='/~/'):
+ self.app = app
+ self.base_url = base_url
+ self.prefix = prefix
+
+ def get_oauth_data_store(self):
+ """Provide a oauth.OAuthDataStore."""
+ raise NotImplementedError(self.get_oauth_data_store)
+
+ def _error(self, start_response, status, description, message=None):
+ start_response("%d %s" % (status, httplib.responses[status]),
+ [('content-type', 'application/json')])
+ err = {"error": description}
+ if message:
+ err['message'] = message
+ return [json.dumps(err)]
+
+ def __call__(self, environ, start_response):
+ if self.prefix and not environ['PATH_INFO'].startswith(self.prefix):
+ return self._error(start_response, 400, "bad request")
+ headers = {}
+ if 'HTTP_AUTHORIZATION' in environ:
+ headers['Authorization'] = environ['HTTP_AUTHORIZATION']
+ oauth_req = oauth.OAuthRequest.from_request(
+ http_method=environ['REQUEST_METHOD'],
+ http_url=self.base_url + environ['PATH_INFO'],
+ headers=headers,
+ query_string=environ['QUERY_STRING']
+ )
+ if oauth_req is None:
+ return self._error(start_response, 401, "unauthorized",
+ "Missing OAuth.")
+ try:
+ self.verify(environ, oauth_req)
+ except oauth.OAuthError, e:
+ return self._error(start_response, 401, "unauthorized",
+ e.message)
+ shift_path_info(environ)
+ return self.app(environ, start_response)
+
+ def verify(self, environ, oauth_req):
+ """Verify OAuth request, put user_id in the environ."""
+ oauth_server = oauth.OAuthServer(self.get_oauth_data_store())
+ oauth_server.timestamp_threshold = self.timestamp_threshold
+ oauth_server.add_signature_method(sign_meth_HMAC_SHA1)
+ oauth_server.add_signature_method(sign_meth_PLAINTEXT)
+ consumer, token, parameters = oauth_server.verify_request(oauth_req)
+ # filter out oauth bits
+ environ['QUERY_STRING'] = '&'.join("%s=%s" % (quote(k, safe=''),
+ quote(v, safe=''))
+ for k, v in parameters.iteritems())
+ return consumer, token
diff --git a/u1db/remote/server_state.py b/u1db/remote/server_state.py
new file mode 100644
index 00000000..96581359
--- /dev/null
+++ b/u1db/remote/server_state.py
@@ -0,0 +1,67 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""State for servers exposing a set of U1DB databases."""
+import os
+import errno
+
+class ServerState(object):
+ """Passed to a Request when it is instantiated.
+
+ This is used to track server-side state, such as working-directory, open
+ databases, etc.
+ """
+
+ def __init__(self):
+ self._workingdir = None
+
+ def set_workingdir(self, path):
+ self._workingdir = path
+
+ def _relpath(self, relpath):
+ # Note: We don't want to allow absolute paths here, because we
+ # don't want to expose the filesystem. We should also check that
+ # relpath doesn't have '..' in it, etc.
+ return self._workingdir + '/' + relpath
+
+ def open_database(self, path):
+ """Open a database at the given location."""
+ from u1db.backends import sqlite_backend
+ full_path = self._relpath(path)
+ return sqlite_backend.SQLiteDatabase.open_database(full_path,
+ create=False)
+
+ def check_database(self, path):
+ """Check if the database at the given location exists.
+
+ Simply returns if it does or raises DatabaseDoesNotExist.
+ """
+ db = self.open_database(path)
+ db.close()
+
+ def ensure_database(self, path):
+ """Ensure database at the given location."""
+ from u1db.backends import sqlite_backend
+ full_path = self._relpath(path)
+ db = sqlite_backend.SQLiteDatabase.open_database(full_path,
+ create=True)
+ return db, db._replica_uid
+
+ def delete_database(self, path):
+ """Delete database at the given location."""
+ from u1db.backends import sqlite_backend
+ full_path = self._relpath(path)
+ sqlite_backend.SQLiteDatabase.delete_database(full_path)
diff --git a/u1db/remote/ssl_match_hostname.py b/u1db/remote/ssl_match_hostname.py
new file mode 100644
index 00000000..fbabc177
--- /dev/null
+++ b/u1db/remote/ssl_match_hostname.py
@@ -0,0 +1,64 @@
+"""The match_hostname() function from Python 3.2, essential when using SSL."""
+# XXX put it here until it's packaged
+
+import re
+
+__version__ = '3.2a3'
+
+
+class CertificateError(ValueError):
+ pass
+
+
+def _dnsname_to_pat(dn):
+ pats = []
+ for frag in dn.split(r'.'):
+ if frag == '*':
+ # When '*' is a fragment by itself, it matches a non-empty dotless
+ # fragment.
+ pats.append('[^.]+')
+ else:
+ # Otherwise, '*' matches any dotless fragment.
+ frag = re.escape(frag)
+ pats.append(frag.replace(r'\*', '[^.]*'))
+ return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
+
+
+def match_hostname(cert, hostname):
+ """Verify that *cert* (in decoded format as returned by
+ SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules
+ are mostly followed, but IP addresses are not accepted for *hostname*.
+
+ CertificateError is raised on failure. On success, the function
+ returns nothing.
+ """
+ if not cert:
+ raise ValueError("empty or no certificate")
+ dnsnames = []
+ san = cert.get('subjectAltName', ())
+ for key, value in san:
+ if key == 'DNS':
+ if _dnsname_to_pat(value).match(hostname):
+ return
+ dnsnames.append(value)
+ if not san:
+ # The subject is only checked when subjectAltName is empty
+ for sub in cert.get('subject', ()):
+ for key, value in sub:
+ # XXX according to RFC 2818, the most specific Common Name
+ # must be used.
+ if key == 'commonName':
+ if _dnsname_to_pat(value).match(hostname):
+ return
+ dnsnames.append(value)
+ if len(dnsnames) > 1:
+ raise CertificateError("hostname %r "
+ "doesn't match either of %s"
+ % (hostname, ', '.join(map(repr, dnsnames))))
+ elif len(dnsnames) == 1:
+ raise CertificateError("hostname %r "
+ "doesn't match %r"
+ % (hostname, dnsnames[0]))
+ else:
+ raise CertificateError("no appropriate commonName or "
+ "subjectAltName fields were found")
diff --git a/u1db/remote/utils.py b/u1db/remote/utils.py
new file mode 100644
index 00000000..14cedea9
--- /dev/null
+++ b/u1db/remote/utils.py
@@ -0,0 +1,23 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Utilities for details of the procotol."""
+
+
+def check_and_strip_comma(line):
+ if line and line[-1] == ',':
+ return line[:-1], True
+ return line, False
diff --git a/u1db/sync.py b/u1db/sync.py
new file mode 100644
index 00000000..3375d097
--- /dev/null
+++ b/u1db/sync.py
@@ -0,0 +1,304 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The synchronization utilities for U1DB."""
+from itertools import izip
+
+import u1db
+from u1db import errors
+
+
+class Synchronizer(object):
+ """Collect the state around synchronizing 2 U1DB replicas.
+
+ Synchronization is bi-directional, in that new items in the source are sent
+ to the target, and new items in the target are returned to the source.
+ However, it still recognizes that one side is initiating the request. Also,
+ at the moment, conflicts are only created in the source.
+ """
+
+ def __init__(self, source, sync_target):
+ """Create a new Synchronization object.
+
+ :param source: A Database
+ :param sync_target: A SyncTarget
+ """
+ self.source = source
+ self.sync_target = sync_target
+ self.target_replica_uid = None
+ self.num_inserted = 0
+
+ def _insert_doc_from_target(self, doc, replica_gen, trans_id):
+ """Try to insert synced document from target.
+
+ Implements TAKE OTHER semantics: any document from the target
+ that is in conflict will be taken as the new official value,
+ while the current conflicting value will be stored alongside
+ as a conflict. In the process indexes will be updated etc.
+
+ :return: None
+ """
+ # Increases self.num_inserted depending whether the document
+ # was effectively inserted.
+ state, _ = self.source._put_doc_if_newer(doc, save_conflict=True,
+ replica_uid=self.target_replica_uid, replica_gen=replica_gen,
+ replica_trans_id=trans_id)
+ if state == 'inserted':
+ self.num_inserted += 1
+ elif state == 'converged':
+ # magical convergence
+ pass
+ elif state == 'superseded':
+ # we have something newer, will be taken care of at the next sync
+ pass
+ else:
+ assert state == 'conflicted'
+ # The doc was saved as a conflict, so the database was updated
+ self.num_inserted += 1
+
+ def _record_sync_info_with_the_target(self, start_generation):
+ """Record our new after sync generation with the target if gapless.
+
+ Any documents received from the target will cause the local
+ database to increment its generation. We do not want to send
+ them back to the target in a future sync. However, there could
+ also be concurrent updates from another process doing eg
+ 'put_doc' while the sync was running. And we do want to
+ synchronize those documents. We can tell if there was a
+ concurrent update by comparing our new generation number
+ versus the generation we started, and how many documents we
+ inserted from the target. If it matches exactly, then we can
+ record with the target that they are fully up to date with our
+ new generation.
+ """
+ cur_gen, trans_id = self.source._get_generation_info()
+ if (cur_gen == start_generation + self.num_inserted
+ and self.num_inserted > 0):
+ self.sync_target.record_sync_info(
+ self.source._replica_uid, cur_gen, trans_id)
+
+ def sync(self, callback=None, autocreate=False):
+ """Synchronize documents between source and target."""
+ sync_target = self.sync_target
+ # get target identifier, its current generation,
+ # and its last-seen database generation for this source
+ try:
+ (self.target_replica_uid, target_gen, target_trans_id,
+ target_my_gen, target_my_trans_id) = sync_target.get_sync_info(
+ self.source._replica_uid)
+ except errors.DatabaseDoesNotExist:
+ if not autocreate:
+ raise
+ # will try to ask sync_exchange() to create the db
+ self.target_replica_uid = None
+ target_gen, target_trans_id = 0, ''
+ target_my_gen, target_my_trans_id = 0, ''
+ def ensure_callback(replica_uid):
+ self.target_replica_uid = replica_uid
+ else:
+ ensure_callback = None
+ # validate the generation and transaction id the target knows about us
+ self.source.validate_gen_and_trans_id(
+ target_my_gen, target_my_trans_id)
+ # what's changed since that generation and this current gen
+ my_gen, _, changes = self.source.whats_changed(target_my_gen)
+
+ # this source last-seen database generation for the target
+ if self.target_replica_uid is None:
+ target_last_known_gen, target_last_known_trans_id = 0, ''
+ else:
+ target_last_known_gen, target_last_known_trans_id = \
+ self.source._get_replica_gen_and_trans_id(self.target_replica_uid)
+ if not changes and target_last_known_gen == target_gen:
+ if target_trans_id != target_last_known_trans_id:
+ raise errors.InvalidTransactionId
+ return my_gen
+ changed_doc_ids = [doc_id for doc_id, _, _ in changes]
+ # prepare to send all the changed docs
+ docs_to_send = self.source.get_docs(changed_doc_ids,
+ check_for_conflicts=False, include_deleted=True)
+ # TODO: there must be a way to not iterate twice
+ docs_by_generation = zip(
+ docs_to_send, (gen for _, gen, _ in changes),
+ (trans for _, _, trans in changes))
+
+ # exchange documents and try to insert the returned ones with
+ # the target, return target synced-up-to gen
+ new_gen, new_trans_id = sync_target.sync_exchange(
+ docs_by_generation, self.source._replica_uid,
+ target_last_known_gen, target_last_known_trans_id,
+ self._insert_doc_from_target, ensure_callback=ensure_callback)
+ # record target synced-up-to generation including applying what we sent
+ self.source._set_replica_gen_and_trans_id(
+ self.target_replica_uid, new_gen, new_trans_id)
+
+ # if gapless record current reached generation with target
+ self._record_sync_info_with_the_target(my_gen)
+
+ return my_gen
+
+
+class SyncExchange(object):
+ """Steps and state for carrying through a sync exchange on a target."""
+
+ def __init__(self, db, source_replica_uid, last_known_generation):
+ self._db = db
+ self.source_replica_uid = source_replica_uid
+ self.source_last_known_generation = last_known_generation
+ self.seen_ids = {} # incoming ids not superseded
+ self.changes_to_return = None
+ self.new_gen = None
+ self.new_trans_id = None
+ # for tests
+ self._incoming_trace = []
+ self._trace_hook = None
+ self._db._last_exchange_log = {
+ 'receive': {'docs': self._incoming_trace},
+ 'return': None
+ }
+
+ def _set_trace_hook(self, cb):
+ self._trace_hook = cb
+
+ def _trace(self, state):
+ if not self._trace_hook:
+ return
+ self._trace_hook(state)
+
+ def insert_doc_from_source(self, doc, source_gen, trans_id):
+ """Try to insert synced document from source.
+
+ Conflicting documents are not inserted but will be sent over
+ to the sync source.
+
+ It keeps track of progress by storing the document source
+ generation as well.
+
+ The 1st step of a sync exchange is to call this repeatedly to
+ try insert all incoming documents from the source.
+
+ :param doc: A Document object.
+ :param source_gen: The source generation of doc.
+ :return: None
+ """
+ state, at_gen = self._db._put_doc_if_newer(doc, save_conflict=False,
+ replica_uid=self.source_replica_uid, replica_gen=source_gen,
+ replica_trans_id=trans_id)
+ if state == 'inserted':
+ self.seen_ids[doc.doc_id] = at_gen
+ elif state == 'converged':
+ # magical convergence
+ self.seen_ids[doc.doc_id] = at_gen
+ elif state == 'superseded':
+ # we have something newer that we will return
+ pass
+ else:
+ # conflict that we will returne
+ assert state == 'conflicted'
+ # for tests
+ self._incoming_trace.append((doc.doc_id, doc.rev))
+ self._db._last_exchange_log['receive'].update({
+ 'source_uid': self.source_replica_uid,
+ 'source_gen': source_gen
+ })
+
+ def find_changes_to_return(self):
+ """Find changes to return.
+
+ Find changes since last_known_generation in db generation
+ order using whats_changed. It excludes documents ids that have
+ already been considered (superseded by the sender, etc).
+
+ :return: new_generation - the generation of this database
+ which the caller can consider themselves to be synchronized after
+ processing the returned documents.
+ """
+ self._db._last_exchange_log['receive'].update({ # for tests
+ 'last_known_gen': self.source_last_known_generation
+ })
+ self._trace('before whats_changed')
+ gen, trans_id, changes = self._db.whats_changed(
+ self.source_last_known_generation)
+ self._trace('after whats_changed')
+ self.new_gen = gen
+ self.new_trans_id = trans_id
+ seen_ids = self.seen_ids
+ # changed docs that weren't superseded by or converged with
+ self.changes_to_return = [
+ (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes
+ # there was a subsequent update
+ if doc_id not in seen_ids or seen_ids.get(doc_id) < gen]
+ return self.new_gen
+
+ def return_docs(self, return_doc_cb):
+ """Return the changed documents and their last change generation
+ repeatedly invoking the callback return_doc_cb.
+
+ The final step of a sync exchange.
+
+ :param: return_doc_cb(doc, gen, trans_id): is a callback
+ used to return the documents with their last change generation
+ to the target replica.
+ :return: None
+ """
+ changes_to_return = self.changes_to_return
+ # return docs, including conflicts
+ changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return]
+ self._trace('before get_docs')
+ docs = self._db.get_docs(
+ changed_doc_ids, check_for_conflicts=False, include_deleted=True)
+
+ docs_by_gen = izip(
+ docs, (gen for _, gen, _ in changes_to_return),
+ (trans_id for _, _, trans_id in changes_to_return))
+ _outgoing_trace = [] # for tests
+ for doc, gen, trans_id in docs_by_gen:
+ return_doc_cb(doc, gen, trans_id)
+ _outgoing_trace.append((doc.doc_id, doc.rev))
+ # for tests
+ self._db._last_exchange_log['return'] = {
+ 'docs': _outgoing_trace,
+ 'last_gen': self.new_gen
+ }
+
+
+class LocalSyncTarget(u1db.SyncTarget):
+ """Common sync target implementation logic for all local sync targets."""
+
+ def __init__(self, db):
+ self._db = db
+ self._trace_hook = None
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ self._db.validate_gen_and_trans_id(
+ last_known_generation, last_known_trans_id)
+ sync_exch = SyncExchange(
+ self._db, source_replica_uid, last_known_generation)
+ if self._trace_hook:
+ sync_exch._set_trace_hook(self._trace_hook)
+ # 1st step: try to insert incoming docs and record progress
+ for doc, doc_gen, trans_id in docs_by_generations:
+ sync_exch.insert_doc_from_source(doc, doc_gen, trans_id)
+ # 2nd step: find changed documents (including conflicts) to return
+ new_gen = sync_exch.find_changes_to_return()
+ # final step: return docs and record source replica sync point
+ sync_exch.return_docs(return_doc_cb)
+ return new_gen, sync_exch.new_trans_id
+
+ def _set_trace_hook(self, cb):
+ self._trace_hook = cb
diff --git a/u1db/tests/__init__.py b/u1db/tests/__init__.py
new file mode 100644
index 00000000..b8e16b15
--- /dev/null
+++ b/u1db/tests/__init__.py
@@ -0,0 +1,463 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test infrastructure for U1DB"""
+
+import copy
+import shutil
+import socket
+import tempfile
+import threading
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from wsgiref import simple_server
+
+from oauth import oauth
+from sqlite3 import dbapi2
+from StringIO import StringIO
+
+import testscenarios
+import testtools
+
+from u1db import (
+ errors,
+ Document,
+ )
+from u1db.backends import (
+ inmemory,
+ sqlite_backend,
+ )
+from u1db.remote import (
+ server_state,
+ )
+
+try:
+ from u1db.tests import c_backend_wrapper
+ c_backend_error = None
+except ImportError, e:
+ c_backend_wrapper = None # noqa
+ c_backend_error = e
+
+# Setting this means that failing assertions will not include this module in
+# their traceback. However testtools doesn't seem to set it, and we don't want
+# this level to be omitted, but the lower levels to be shown.
+# __unittest = 1
+
+
+class TestCase(testtools.TestCase):
+
+ def createTempDir(self, prefix='u1db-tmp-'):
+ """Create a temporary directory to do some work in.
+
+ This directory will be scheduled for cleanup when the test ends.
+ """
+ tempdir = tempfile.mkdtemp(prefix=prefix)
+ self.addCleanup(shutil.rmtree, tempdir)
+ return tempdir
+
+ def make_document(self, doc_id, doc_rev, content, has_conflicts=False):
+ return self.make_document_for_test(
+ self, doc_id, doc_rev, content, has_conflicts)
+
+ def make_document_for_test(self, test, doc_id, doc_rev, content,
+ has_conflicts):
+ return make_document_for_test(
+ test, doc_id, doc_rev, content, has_conflicts)
+
+ def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id))
+
+ def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content,
+ has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True))
+
+ def assertGetDocConflicts(self, db, doc_id, conflicts):
+ """Assert what conflicts are stored for a given doc_id.
+
+ :param conflicts: A list of (doc_rev, content) pairs.
+ The first item must match the first item returned from the
+ database, however the rest can be returned in any order.
+ """
+ if conflicts:
+ conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring)
+ else cont)) for (rev, cont) in conflicts]
+ conflicts = conflicts[:1] + sorted(conflicts[1:])
+ actual = db.get_doc_conflicts(doc_id)
+ if actual:
+ actual = [(doc.rev, (json.loads(doc.get_json())
+ if doc.get_json() is not None else None)) for doc in actual]
+ actual = actual[:1] + sorted(actual[1:])
+ self.assertEqual(conflicts, actual)
+
+
+def multiply_scenarios(a_scenarios, b_scenarios):
+ """Create the cross-product of scenarios."""
+
+ all_scenarios = []
+ for a_name, a_attrs in a_scenarios:
+ for b_name, b_attrs in b_scenarios:
+ name = '%s,%s' % (a_name, b_name)
+ attrs = dict(a_attrs)
+ attrs.update(b_attrs)
+ all_scenarios.append((name, attrs))
+ return all_scenarios
+
+
+simple_doc = '{"key": "value"}'
+nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
+
+
+def make_memory_database_for_test(test, replica_uid):
+ return inmemory.InMemoryDatabase(replica_uid)
+
+
+def copy_memory_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = inmemory.InMemoryDatabase(db._replica_uid)
+ new_db._transaction_log = db._transaction_log[:]
+ new_db._docs = copy.deepcopy(db._docs)
+ new_db._conflicts = copy.deepcopy(db._conflicts)
+ new_db._indexes = copy.deepcopy(db._indexes)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_sqlite_partial_expanded_for_test(test, replica_uid):
+ db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ db._set_replica_uid(replica_uid)
+ return db
+
+
+def copy_sqlite_partial_expanded_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ tmpfile = StringIO()
+ for line in db._db_handle.iterdump():
+ if not 'sqlite_sequence' in line: # work around bug in iterdump
+ tmpfile.write('%s\n' % line)
+ tmpfile.seek(0)
+ new_db._db_handle = dbapi2.connect(':memory:')
+ new_db._db_handle.cursor().executescript(tmpfile.read())
+ new_db._db_handle.commit()
+ new_db._set_replica_uid(db._replica_uid)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ return Document(doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+def make_c_database_for_test(test, replica_uid):
+ if c_backend_wrapper is None:
+ test.skipTest('c_backend_wrapper is not available')
+ db = c_backend_wrapper.CDatabase(':memory:')
+ db._set_replica_uid(replica_uid)
+ return db
+
+
+def copy_c_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ if c_backend_wrapper is None:
+ test.skipTest('c_backend_wrapper is not available')
+ new_db = db._copy(db)
+ return new_db
+
+
+def make_c_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ if c_backend_wrapper is None:
+ test.skipTest('c_backend_wrapper is not available')
+ return c_backend_wrapper.make_document(
+ doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+LOCAL_DATABASES_SCENARIOS = [
+ ('mem', {'make_database_for_test': make_memory_database_for_test,
+ 'copy_database_for_test': copy_memory_database_for_test,
+ 'make_document_for_test': make_document_for_test}),
+ ('sql', {'make_database_for_test':
+ make_sqlite_partial_expanded_for_test,
+ 'copy_database_for_test':
+ copy_sqlite_partial_expanded_for_test,
+ 'make_document_for_test': make_document_for_test}),
+ ]
+
+
+C_DATABASE_SCENARIOS = [
+ ('c', {'make_database_for_test': make_c_database_for_test,
+ 'copy_database_for_test': copy_c_database_for_test,
+ 'make_document_for_test': make_c_document_for_test})]
+
+
+class DatabaseBaseTests(TestCase):
+
+ accept_fixed_trans_id = False # set to True assertTransactionLog
+ # is happy with all trans ids = ''
+
+ scenarios = LOCAL_DATABASES_SCENARIOS
+
+ def create_database(self, replica_uid):
+ return self.make_database_for_test(self, replica_uid)
+
+ def copy_database(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ return self.copy_database_for_test(self, db)
+
+ def setUp(self):
+ super(DatabaseBaseTests, self).setUp()
+ self.db = self.create_database('test')
+
+ def tearDown(self):
+ # TODO: Add close_database parameterization
+ # self.close_database(self.db)
+ super(DatabaseBaseTests, self).tearDown()
+
+ def assertTransactionLog(self, doc_ids, db):
+ """Assert that the given docs are in the transaction log."""
+ log = db._get_transaction_log()
+ just_ids = []
+ seen_transactions = set()
+ for doc_id, transaction_id in log:
+ just_ids.append(doc_id)
+ self.assertIsNot(None, transaction_id,
+ "Transaction id should not be None")
+ if transaction_id == '' and self.accept_fixed_trans_id:
+ continue
+ self.assertNotEqual('', transaction_id,
+ "Transaction id should be a unique string")
+ self.assertTrue(transaction_id.startswith('T-'))
+ self.assertNotIn(transaction_id, seen_transactions)
+ seen_transactions.add(transaction_id)
+ self.assertEqual(doc_ids, just_ids)
+
+ def getLastTransId(self, db):
+ """Return the transaction id for the last database update."""
+ return self.db._get_transaction_log()[-1][-1]
+
+
+class ServerStateForTests(server_state.ServerState):
+ """Used in the test suite, so we don't have to touch disk, etc."""
+
+ def __init__(self):
+ super(ServerStateForTests, self).__init__()
+ self._dbs = {}
+
+ def open_database(self, path):
+ try:
+ return self._dbs[path]
+ except KeyError:
+ raise errors.DatabaseDoesNotExist
+
+ def check_database(self, path):
+ # cares only about the possible exception
+ self.open_database(path)
+
+ def ensure_database(self, path):
+ try:
+ db = self.open_database(path)
+ except errors.DatabaseDoesNotExist:
+ db = self._create_database(path)
+ return db, db._replica_uid
+
+ def _copy_database(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ new_db = copy_memory_database_for_test(None, db)
+ path = db._replica_uid
+ while path in self._dbs:
+ path += 'copy'
+ self._dbs[path] = new_db
+ return new_db
+
+ def _create_database(self, path):
+ db = inmemory.InMemoryDatabase(path)
+ self._dbs[path] = db
+ return db
+
+ def delete_database(self, path):
+ del self._dbs[path]
+
+
+class ResponderForTests(object):
+ """Responder for tests."""
+ _started = False
+ sent_response = False
+ status = None
+
+ def start_response(self, status='success', **kwargs):
+ self._started = True
+ self.status = status
+ self.kwargs = kwargs
+
+ def send_response(self, status='success', **kwargs):
+ self.start_response(status, **kwargs)
+ self.finish_response()
+
+ def finish_response(self):
+ self.sent_response = True
+
+
+class TestCaseWithServer(TestCase):
+
+ @staticmethod
+ def server_def():
+ # hook point
+ # should return (ServerClass, "shutdown method name", "url_scheme")
+ class _RequestHandler(simple_server.WSGIRequestHandler):
+ def log_request(*args):
+ pass # suppress
+
+ def make_server(host_port, application):
+ assert application, "forgot to override make_app(_with_state)?"
+ srv = simple_server.WSGIServer(host_port, _RequestHandler)
+ # patch the value in if it's None
+ if getattr(application, 'base_url', 1) is None:
+ application.base_url = "http://%s:%s" % srv.server_address
+ srv.set_app(application)
+ return srv
+
+ return make_server, "shutdown", "http"
+
+ @staticmethod
+ def make_app_with_state(state):
+ # hook point
+ return None
+
+ def make_app(self):
+ # potential hook point
+ self.request_state = ServerStateForTests()
+ return self.make_app_with_state(self.request_state)
+
+ def setUp(self):
+ super(TestCaseWithServer, self).setUp()
+ self.server = self.server_thread = None
+
+ @property
+ def url_scheme(self):
+ return self.server_def()[-1]
+
+ def startServer(self):
+ server_def = self.server_def()
+ server_class, shutdown_meth, _ = server_def
+ application = self.make_app()
+ self.server = server_class(('127.0.0.1', 0), application)
+ self.server_thread = threading.Thread(target=self.server.serve_forever,
+ kwargs=dict(poll_interval=0.01))
+ self.server_thread.start()
+ self.addCleanup(self.server_thread.join)
+ self.addCleanup(getattr(self.server, shutdown_meth))
+
+ def getURL(self, path=None):
+ host, port = self.server.server_address
+ if path is None:
+ path = ''
+ return '%s://%s:%s/%s' % (self.url_scheme, host, port, path)
+
+
+def socket_pair():
+ """Return a pair of TCP sockets connected to each other.
+
+ Unlike socket.socketpair, this should work on Windows.
+ """
+ sock_pair = getattr(socket, 'socket_pair', None)
+ if sock_pair:
+ return sock_pair(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(1)
+ client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client_sock.connect(listen_sock.getsockname())
+ server_sock, addr = listen_sock.accept()
+ listen_sock.close()
+ return server_sock, client_sock
+
+
+# OAuth related testing
+
+consumer1 = oauth.OAuthConsumer('K1', 'S1')
+token1 = oauth.OAuthToken('kkkk1', 'XYZ')
+consumer2 = oauth.OAuthConsumer('K2', 'S2')
+token2 = oauth.OAuthToken('kkkk2', 'ZYX')
+token3 = oauth.OAuthToken('kkkk3', 'ZYX')
+
+
+class TestingOAuthDataStore(oauth.OAuthDataStore):
+ """In memory predefined OAuthDataStore for testing."""
+
+ consumers = {
+ consumer1.key: consumer1,
+ consumer2.key: consumer2,
+ }
+
+ tokens = {
+ token1.key: token1,
+ token2.key: token2
+ }
+
+ def lookup_consumer(self, key):
+ return self.consumers.get(key)
+
+ def lookup_token(self, token_type, token_token):
+ return self.tokens.get(token_token)
+
+ def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
+ return None
+
+testingOAuthStore = TestingOAuthDataStore()
+
+sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
+sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
+
+
+def load_with_scenarios(loader, standard_tests, pattern):
+ """Load the tests in a given module.
+
+ This just applies testscenarios.generate_scenarios to all the tests that
+ are present. We do it at load time rather than at run time, because it
+ plays nicer with various tools.
+ """
+ suite = loader.suiteClass()
+ suite.addTests(testscenarios.generate_scenarios(standard_tests))
+ return suite
diff --git a/u1db/tests/c_backend_wrapper.pyx b/u1db/tests/c_backend_wrapper.pyx
new file mode 100644
index 00000000..8a4b600d
--- /dev/null
+++ b/u1db/tests/c_backend_wrapper.pyx
@@ -0,0 +1,1541 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+#
+"""A Cython wrapper around the C implementation of U1DB Database backend."""
+
+cdef extern from "Python.h":
+ object PyString_FromStringAndSize(char *s, Py_ssize_t n)
+ int PyString_AsStringAndSize(object o, char **buf, Py_ssize_t *length
+ ) except -1
+ char *PyString_AsString(object) except NULL
+ char *PyString_AS_STRING(object)
+ char *strdup(char *)
+ void *calloc(size_t, size_t)
+ void free(void *)
+ ctypedef struct FILE:
+ pass
+ fprintf(FILE *, char *, ...)
+ FILE *stderr
+ size_t strlen(char *)
+
+cdef extern from "stdarg.h":
+ ctypedef struct va_list:
+ pass
+ void va_start(va_list, void*)
+ void va_start_int "va_start" (va_list, int)
+ void va_end(va_list)
+
+cdef extern from "u1db/u1db.h":
+ ctypedef struct u1database:
+ pass
+ ctypedef struct u1db_document:
+ char *doc_id
+ size_t doc_id_len
+ char *doc_rev
+ size_t doc_rev_len
+ char *json
+ size_t json_len
+ int has_conflicts
+ # Note: u1query is actually defined in u1db_internal.h, and in u1db.h it is
+ # just an opaque pointer. However, older versions of Cython don't let
+ # you have a forward declaration and a full declaration, so we just
+ # expose the whole thing here.
+ ctypedef struct u1query:
+ char *index_name
+ int num_fields
+ char **fields
+ cdef struct u1db_oauth_creds:
+ int auth_kind
+ char *consumer_key
+ char *consumer_secret
+ char *token_key
+ char *token_secret
+ ctypedef union u1db_creds
+ ctypedef u1db_creds* const_u1db_creds_ptr "const u1db_creds *"
+
+ ctypedef char* const_char_ptr "const char*"
+ ctypedef int (*u1db_doc_callback)(void *context, u1db_document *doc)
+ ctypedef int (*u1db_key_callback)(void *context, int num_fields,
+ const_char_ptr *key)
+ ctypedef int (*u1db_doc_gen_callback)(void *context,
+ u1db_document *doc, int gen, const_char_ptr trans_id)
+ ctypedef int (*u1db_trans_info_callback)(void *context,
+ const_char_ptr doc_id, int gen, const_char_ptr trans_id)
+
+ u1database * u1db_open(char *fname)
+ void u1db_free(u1database **)
+ int u1db_set_replica_uid(u1database *, char *replica_uid)
+ int u1db_set_document_size_limit(u1database *, int limit)
+ int u1db_get_replica_uid(u1database *, const_char_ptr *replica_uid)
+ int u1db_create_doc_from_json(u1database *db, char *json, char *doc_id,
+ u1db_document **doc)
+ int u1db_delete_doc(u1database *db, u1db_document *doc)
+ int u1db_get_doc(u1database *db, char *doc_id, int include_deleted,
+ u1db_document **doc)
+ int u1db_get_docs(u1database *db, int n_doc_ids, const_char_ptr *doc_ids,
+ int check_for_conflicts, int include_deleted,
+ void *context, u1db_doc_callback cb)
+ int u1db_get_all_docs(u1database *db, int include_deleted, int *generation,
+ void *context, u1db_doc_callback cb)
+ int u1db_put_doc(u1database *db, u1db_document *doc)
+ int u1db__validate_source(u1database *db, const_char_ptr replica_uid,
+ int replica_gen, const_char_ptr replica_trans_id)
+ int u1db__put_doc_if_newer(u1database *db, u1db_document *doc,
+ int save_conflict, char *replica_uid,
+ int replica_gen, char *replica_trans_id,
+ int *state, int *at_gen)
+ int u1db_resolve_doc(u1database *db, u1db_document *doc,
+ int n_revs, const_char_ptr *revs)
+ int u1db_delete_doc(u1database *db, u1db_document *doc)
+ int u1db_whats_changed(u1database *db, int *gen, char **trans_id,
+ void *context, u1db_trans_info_callback cb)
+ int u1db__get_transaction_log(u1database *db, void *context,
+ u1db_trans_info_callback cb)
+ int u1db_get_doc_conflicts(u1database *db, char *doc_id, void *context,
+ u1db_doc_callback cb)
+ int u1db_sync(u1database *db, const_char_ptr url,
+ const_u1db_creds_ptr creds, int *local_gen) nogil
+ int u1db_create_index_list(u1database *db, char *index_name,
+ int n_expressions, const_char_ptr *expressions)
+ int u1db_create_index(u1database *db, char *index_name, int n_expressions,
+ ...)
+ int u1db_get_from_index_list(u1database *db, u1query *query, void *context,
+ u1db_doc_callback cb, int n_values,
+ const_char_ptr *values)
+ int u1db_get_from_index(u1database *db, u1query *query, void *context,
+ u1db_doc_callback cb, int n_values, char *val0,
+ ...)
+ int u1db_get_range_from_index(u1database *db, u1query *query,
+ void *context, u1db_doc_callback cb,
+ int n_values, const_char_ptr *start_values,
+ const_char_ptr *end_values)
+ int u1db_delete_index(u1database *db, char *index_name)
+ int u1db_list_indexes(u1database *db, void *context,
+ int (*cb)(void *context, const_char_ptr index_name,
+ int n_expressions, const_char_ptr *expressions))
+ int u1db_get_index_keys(u1database *db, char *index_name, void *context,
+ u1db_key_callback cb)
+ int u1db_simple_lookup1(u1database *db, char *index_name, char *val1,
+ void *context, u1db_doc_callback cb)
+ int u1db_query_init(u1database *db, char *index_name, u1query **query)
+ void u1db_free_query(u1query **query)
+
+ int U1DB_OK
+ int U1DB_INVALID_PARAMETER
+ int U1DB_REVISION_CONFLICT
+ int U1DB_INVALID_DOC_ID
+ int U1DB_DOCUMENT_ALREADY_DELETED
+ int U1DB_DOCUMENT_DOES_NOT_EXIST
+ int U1DB_NOT_IMPLEMENTED
+ int U1DB_INVALID_JSON
+ int U1DB_DOCUMENT_TOO_BIG
+ int U1DB_USER_QUOTA_EXCEEDED
+ int U1DB_INVALID_VALUE_FOR_INDEX
+ int U1DB_INVALID_FIELD_SPECIFIER
+ int U1DB_INVALID_GLOBBING
+ int U1DB_BROKEN_SYNC_STREAM
+ int U1DB_DUPLICATE_INDEX_NAME
+ int U1DB_INDEX_DOES_NOT_EXIST
+ int U1DB_INVALID_GENERATION
+ int U1DB_INVALID_TRANSACTION_ID
+ int U1DB_INVALID_TRANSFORMATION_FUNCTION
+ int U1DB_UNKNOWN_OPERATION
+ int U1DB_INTERNAL_ERROR
+ int U1DB_TARGET_UNAVAILABLE
+
+ int U1DB_INSERTED
+ int U1DB_SUPERSEDED
+ int U1DB_CONVERGED
+ int U1DB_CONFLICTED
+
+ int U1DB_OAUTH_AUTH
+
+ void u1db_free_doc(u1db_document **doc)
+ int u1db_doc_set_json(u1db_document *doc, char *json)
+ int u1db_doc_get_size(u1db_document *doc)
+
+
+cdef extern from "u1db/u1db_internal.h":
+ ctypedef struct u1db_row:
+ u1db_row *next
+ int num_columns
+ int *column_sizes
+ unsigned char **columns
+
+ ctypedef struct u1db_table:
+ int status
+ u1db_row *first_row
+
+ ctypedef struct u1db_record:
+ u1db_record *next
+ char *doc_id
+ char *doc_rev
+ char *doc
+
+ ctypedef struct u1db_sync_exchange:
+ int target_gen
+ int num_doc_ids
+ char **doc_ids_to_return
+ int *gen_for_doc_ids
+ const_char_ptr *trans_ids_for_doc_ids
+
+ ctypedef int (*u1db__trace_callback)(void *context, const_char_ptr state)
+ ctypedef struct u1db_sync_target:
+ int (*get_sync_info)(u1db_sync_target *st, char *source_replica_uid,
+ const_char_ptr *st_replica_uid, int *st_gen,
+ char **st_trans_id, int *source_gen,
+ char **source_trans_id) nogil
+ int (*record_sync_info)(u1db_sync_target *st,
+ char *source_replica_uid, int source_gen, char *trans_id) nogil
+ int (*sync_exchange)(u1db_sync_target *st,
+ char *source_replica_uid, int n_docs,
+ u1db_document **docs, int *generations,
+ const_char_ptr *trans_ids,
+ int *target_gen, char **target_trans_id,
+ void *context, u1db_doc_gen_callback cb,
+ void *ensure_callback) nogil
+ int (*sync_exchange_doc_ids)(u1db_sync_target *st,
+ u1database *source_db, int n_doc_ids,
+ const_char_ptr *doc_ids, int *generations,
+ const_char_ptr *trans_ids,
+ int *target_gen, char **target_trans_id,
+ void *context,
+ u1db_doc_gen_callback cb,
+ void *ensure_callback) nogil
+ int (*get_sync_exchange)(u1db_sync_target *st,
+ char *source_replica_uid,
+ int last_known_source_gen,
+ u1db_sync_exchange **exchange) nogil
+ void (*finalize_sync_exchange)(u1db_sync_target *st,
+ u1db_sync_exchange **exchange) nogil
+ int (*_set_trace_hook)(u1db_sync_target *st,
+ void *context, u1db__trace_callback cb) nogil
+
+
+ void u1db__set_zero_delays()
+ int u1db__get_generation(u1database *, int *db_rev)
+ int u1db__get_document_size_limit(u1database *, int *limit)
+ int u1db__get_generation_info(u1database *, int *db_rev, char **trans_id)
+ int u1db__get_trans_id_for_gen(u1database *, int db_rev, char **trans_id)
+ int u1db_validate_gen_and_trans_id(u1database *, int db_rev,
+ const_char_ptr trans_id)
+ char *u1db__allocate_doc_id(u1database *)
+ int u1db__sql_close(u1database *)
+ u1database *u1db__copy(u1database *)
+ int u1db__sql_is_open(u1database *)
+ u1db_table *u1db__sql_run(u1database *, char *sql, size_t n)
+ void u1db__free_table(u1db_table **table)
+ u1db_record *u1db__create_record(char *doc_id, char *doc_rev, char *doc)
+ void u1db__free_records(u1db_record **)
+
+ int u1db__allocate_document(char *doc_id, char *revision, char *content,
+ int has_conflicts, u1db_document **result)
+ int u1db__generate_hex_uuid(char *)
+
+ int u1db__get_replica_gen_and_trans_id(u1database *db, char *replica_uid,
+ int *generation, char **trans_id)
+ int u1db__set_replica_gen_and_trans_id(u1database *db, char *replica_uid,
+ int generation, char *trans_id)
+ int u1db__sync_get_machine_info(u1database *db, char *other_replica_uid,
+ int *other_db_rev, char **my_replica_uid,
+ int *my_db_rev)
+ int u1db__sync_record_machine_info(u1database *db, char *replica_uid,
+ int db_rev)
+ int u1db__sync_exchange_seen_ids(u1db_sync_exchange *se, int *n_ids,
+ const_char_ptr **doc_ids)
+ int u1db__format_query(int n_fields, const_char_ptr *values, char **buf,
+ int *wildcard)
+ int u1db__get_sync_target(u1database *db, u1db_sync_target **sync_target)
+ int u1db__free_sync_target(u1db_sync_target **sync_target)
+ int u1db__sync_db_to_target(u1database *db, u1db_sync_target *target,
+ int *local_gen_before_sync) nogil
+
+ int u1db__sync_exchange_insert_doc_from_source(u1db_sync_exchange *se,
+ u1db_document *doc, int source_gen, const_char_ptr trans_id)
+ int u1db__sync_exchange_find_doc_ids_to_return(u1db_sync_exchange *se)
+ int u1db__sync_exchange_return_docs(u1db_sync_exchange *se, void *context,
+ int (*cb)(void *context,
+ u1db_document *doc, int gen,
+ const_char_ptr trans_id))
+ int u1db__create_http_sync_target(char *url, u1db_sync_target **target)
+ int u1db__create_oauth_http_sync_target(char *url,
+ char *consumer_key, char *consumer_secret,
+ char *token_key, char *token_secret,
+ u1db_sync_target **target)
+
+cdef extern from "u1db/u1db_http_internal.h":
+ int u1db__format_sync_url(u1db_sync_target *st,
+ const_char_ptr source_replica_uid, char **sync_url)
+ int u1db__get_oauth_authorization(u1db_sync_target *st,
+ char *http_method, char *url,
+ char **oauth_authorization)
+
+
+cdef extern from "u1db/u1db_vectorclock.h":
+ ctypedef struct u1db_vectorclock_item:
+ char *replica_uid
+ int generation
+
+ ctypedef struct u1db_vectorclock:
+ int num_items
+ u1db_vectorclock_item *items
+
+ u1db_vectorclock *u1db__vectorclock_from_str(char *s)
+ void u1db__free_vectorclock(u1db_vectorclock **clock)
+ int u1db__vectorclock_increment(u1db_vectorclock *clock, char *replica_uid)
+ int u1db__vectorclock_maximize(u1db_vectorclock *clock,
+ u1db_vectorclock *other)
+ int u1db__vectorclock_as_str(u1db_vectorclock *clock, char **result)
+ int u1db__vectorclock_is_newer(u1db_vectorclock *maybe_newer,
+ u1db_vectorclock *older)
+
+from u1db import errors
+from sqlite3 import dbapi2
+
+
+cdef int _append_trans_info_to_list(void *context, const_char_ptr doc_id,
+ int generation,
+ const_char_ptr trans_id) with gil:
+ a_list = <object>(context)
+ doc = doc_id
+ a_list.append((doc, generation, trans_id))
+ return 0
+
+
+cdef int _append_doc_to_list(void *context, u1db_document *doc) with gil:
+ a_list = <object>context
+ pydoc = CDocument()
+ pydoc._doc = doc
+ a_list.append(pydoc)
+ return 0
+
+cdef int _append_key_to_list(void *context, int num_fields,
+ const_char_ptr *key) with gil:
+ a_list = <object>(context)
+ field_list = []
+ for i from 0 <= i < num_fields:
+ field = key[i]
+ field_list.append(field.decode('utf-8'))
+ a_list.append(tuple(field_list))
+ return 0
+
+cdef _list_to_array(lst, const_char_ptr **res, int *count):
+ cdef const_char_ptr *tmp
+ count[0] = len(lst)
+ tmp = <const_char_ptr*>calloc(sizeof(char*), count[0])
+ for idx, x in enumerate(lst):
+ tmp[idx] = x
+ res[0] = tmp
+
+cdef _list_to_str_array(lst, const_char_ptr **res, int *count):
+ cdef const_char_ptr *tmp
+ count[0] = len(lst)
+ tmp = <const_char_ptr*>calloc(sizeof(char*), count[0])
+ new_objs = []
+ for idx, x in enumerate(lst):
+ if isinstance(x, unicode):
+ x = x.encode('utf-8')
+ new_objs.append(x)
+ tmp[idx] = x
+ res[0] = tmp
+ return new_objs
+
+
+cdef int _append_index_definition_to_list(void *context,
+ const_char_ptr index_name, int n_expressions,
+ const_char_ptr *expressions) with gil:
+ cdef int i
+
+ a_list = <object>(context)
+ exp_list = []
+ for i from 0 <= i < n_expressions:
+ s = expressions[i]
+ exp_list.append(s.decode('utf-8'))
+ a_list.append((index_name, exp_list))
+ return 0
+
+
+cdef int return_doc_cb_wrapper(void *context, u1db_document *doc,
+ int gen, const_char_ptr trans_id) with gil:
+ cdef CDocument pydoc
+ user_cb = <object>context
+ pydoc = CDocument()
+ pydoc._doc = doc
+ try:
+ user_cb(pydoc, gen, trans_id)
+ except Exception, e:
+ # We suppress the exception here, because intermediating through the C
+ # layer gets a bit crazy
+ return U1DB_INVALID_PARAMETER
+ return U1DB_OK
+
+
+cdef int _trace_hook(void *context, const_char_ptr state) with gil:
+ if context == NULL:
+ return U1DB_INVALID_PARAMETER
+ ctx = <object>context
+ try:
+ ctx(state)
+ except:
+ # Note: It would be nice if we could map the Python exception into
+ # something in C
+ return U1DB_INTERNAL_ERROR
+ return U1DB_OK
+
+
+cdef char *_ensure_str(object obj, object extra_objs) except NULL:
+ """Ensure that we have the UTF-8 representation of a parameter.
+
+ :param obj: A Unicode or String object.
+ :param extra_objs: This should be a Python list. If we have to convert obj
+ from being a Unicode object, this will hold the PyString object so that
+ we know the char* lifetime will be correct.
+ :return: A C pointer to the UTF-8 representation.
+ """
+ if isinstance(obj, unicode):
+ obj = obj.encode('utf-8')
+ extra_objs.append(obj)
+ return PyString_AsString(obj)
+
+
+def _format_query(fields):
+ """Wrapper around u1db__format_query for testing."""
+ cdef int status
+ cdef char *buf
+ cdef int wildcard[10]
+ cdef const_char_ptr *values
+ cdef int n_values
+
+ # keep a reference to new_objs so that the pointers in expressions
+ # remain valid.
+ new_objs = _list_to_str_array(fields, &values, &n_values)
+ try:
+ status = u1db__format_query(n_values, values, &buf, wildcard)
+ finally:
+ free(<void*>values)
+ handle_status("format_query", status)
+ if buf == NULL:
+ res = None
+ else:
+ res = buf
+ free(buf)
+ w = []
+ for i in range(len(fields)):
+ w.append(wildcard[i])
+ return res, w
+
+
+def make_document(doc_id, rev, content, has_conflicts=False):
+ cdef u1db_document *doc
+ cdef char *c_content = NULL, *c_rev = NULL, *c_doc_id = NULL
+ cdef int conflict
+
+ if has_conflicts:
+ conflict = 1
+ else:
+ conflict = 0
+ if doc_id is None:
+ c_doc_id = NULL
+ else:
+ c_doc_id = doc_id
+ if content is None:
+ c_content = NULL
+ else:
+ c_content = content
+ if rev is None:
+ c_rev = NULL
+ else:
+ c_rev = rev
+ handle_status(
+ "make_document",
+ u1db__allocate_document(c_doc_id, c_rev, c_content, conflict, &doc))
+ pydoc = CDocument()
+ pydoc._doc = doc
+ return pydoc
+
+
+def generate_hex_uuid():
+ uuid = PyString_FromStringAndSize(NULL, 32)
+ handle_status(
+ "Failed to generate uuid",
+ u1db__generate_hex_uuid(PyString_AS_STRING(uuid)))
+ return uuid
+
+
+cdef class CDocument(object):
+ """A thin wrapper around the C Document struct."""
+
+ cdef u1db_document *_doc
+
+ def __init__(self):
+ self._doc = NULL
+
+ def __dealloc__(self):
+ u1db_free_doc(&self._doc)
+
+ property doc_id:
+ def __get__(self):
+ if self._doc.doc_id == NULL:
+ return None
+ return PyString_FromStringAndSize(
+ self._doc.doc_id, self._doc.doc_id_len)
+
+ property rev:
+ def __get__(self):
+ if self._doc.doc_rev == NULL:
+ return None
+ return PyString_FromStringAndSize(
+ self._doc.doc_rev, self._doc.doc_rev_len)
+
+ def get_json(self):
+ if self._doc.json == NULL:
+ return None
+ return PyString_FromStringAndSize(
+ self._doc.json, self._doc.json_len)
+
+ def set_json(self, val):
+ u1db_doc_set_json(self._doc, val)
+
+ def get_size(self):
+ return u1db_doc_get_size(self._doc)
+
+ property has_conflicts:
+ def __get__(self):
+ if self._doc.has_conflicts:
+ return True
+ return False
+
+ def __repr__(self):
+ if self._doc.has_conflicts:
+ extra = ', conflicted'
+ else:
+ extra = ''
+ return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id,
+ self.rev, extra, self.get_json())
+
+ def __hash__(self):
+ raise NotImplementedError(self.__hash__)
+
+ def __richcmp__(self, other, int t):
+ try:
+ if t == 0: # Py_LT <
+ return ((self.doc_id, self.rev, self.get_json())
+ < (other.doc_id, other.rev, other.get_json()))
+ elif t == 2: # Py_EQ ==
+ return (self.doc_id == other.doc_id
+ and self.rev == other.rev
+ and self.get_json() == other.get_json()
+ and self.has_conflicts == other.has_conflicts)
+ except AttributeError:
+ # Fall through to NotImplemented
+ pass
+
+ return NotImplemented
+
+
+cdef object safe_str(const_char_ptr s):
+ if s == NULL:
+ return None
+ return s
+
+
+cdef class CQuery:
+
+ cdef u1query *_query
+
+ def __init__(self):
+ self._query = NULL
+
+ def __dealloc__(self):
+ u1db_free_query(&self._query)
+
+ def _check(self):
+ if self._query == NULL:
+ raise RuntimeError("No valid _query.")
+
+ property index_name:
+ def __get__(self):
+ self._check()
+ return safe_str(self._query.index_name)
+
+ property num_fields:
+ def __get__(self):
+ self._check()
+ return self._query.num_fields
+
+ property fields:
+ def __get__(self):
+ cdef int i
+ self._check()
+ fields = []
+ for i from 0 <= i < self._query.num_fields:
+ fields.append(safe_str(self._query.fields[i]))
+ return fields
+
+
+cdef handle_status(context, int status):
+ if status == U1DB_OK:
+ return
+ if status == U1DB_REVISION_CONFLICT:
+ raise errors.RevisionConflict()
+ if status == U1DB_INVALID_DOC_ID:
+ raise errors.InvalidDocId()
+ if status == U1DB_DOCUMENT_ALREADY_DELETED:
+ raise errors.DocumentAlreadyDeleted()
+ if status == U1DB_DOCUMENT_DOES_NOT_EXIST:
+ raise errors.DocumentDoesNotExist()
+ if status == U1DB_INVALID_PARAMETER:
+ raise RuntimeError('Bad parameters supplied')
+ if status == U1DB_NOT_IMPLEMENTED:
+ raise NotImplementedError("Functionality not implemented yet: %s"
+ % (context,))
+ if status == U1DB_INVALID_VALUE_FOR_INDEX:
+ raise errors.InvalidValueForIndex()
+ if status == U1DB_INVALID_GLOBBING:
+ raise errors.InvalidGlobbing()
+ if status == U1DB_INTERNAL_ERROR:
+ raise errors.U1DBError("internal error")
+ if status == U1DB_BROKEN_SYNC_STREAM:
+ raise errors.BrokenSyncStream()
+ if status == U1DB_CONFLICTED:
+ raise errors.ConflictedDoc()
+ if status == U1DB_DUPLICATE_INDEX_NAME:
+ raise errors.IndexNameTakenError()
+ if status == U1DB_INDEX_DOES_NOT_EXIST:
+ raise errors.IndexDoesNotExist
+ if status == U1DB_INVALID_GENERATION:
+ raise errors.InvalidGeneration
+ if status == U1DB_INVALID_TRANSACTION_ID:
+ raise errors.InvalidTransactionId
+ if status == U1DB_TARGET_UNAVAILABLE:
+ raise errors.Unavailable
+ if status == U1DB_INVALID_JSON:
+ raise errors.InvalidJSON
+ if status == U1DB_DOCUMENT_TOO_BIG:
+ raise errors.DocumentTooBig
+ if status == U1DB_USER_QUOTA_EXCEEDED:
+ raise errors.UserQuotaExceeded
+ if status == U1DB_INVALID_TRANSFORMATION_FUNCTION:
+ raise errors.IndexDefinitionParseError
+ if status == U1DB_UNKNOWN_OPERATION:
+ raise errors.IndexDefinitionParseError
+ if status == U1DB_INVALID_FIELD_SPECIFIER:
+ raise errors.IndexDefinitionParseError()
+ raise RuntimeError('%s (status: %s)' % (context, status))
+
+
+cdef class CDatabase
+cdef class CSyncTarget
+
+cdef class CSyncExchange(object):
+
+ cdef u1db_sync_exchange *_exchange
+ cdef CSyncTarget _target
+
+ def __init__(self, CSyncTarget target, source_replica_uid, source_gen):
+ self._target = target
+ assert self._target._st.get_sync_exchange != NULL, \
+ "get_sync_exchange is NULL?"
+ handle_status("get_sync_exchange",
+ self._target._st.get_sync_exchange(self._target._st,
+ source_replica_uid, source_gen, &self._exchange))
+
+ def __dealloc__(self):
+ if self._target is not None and self._target._st != NULL:
+ self._target._st.finalize_sync_exchange(self._target._st,
+ &self._exchange)
+
+ def _check(self):
+ if self._exchange == NULL:
+ raise RuntimeError("self._exchange is NULL")
+
+ property target_gen:
+ def __get__(self):
+ self._check()
+ return self._exchange.target_gen
+
+ def insert_doc_from_source(self, CDocument doc, source_gen,
+ source_trans_id):
+ self._check()
+ handle_status("insert_doc_from_source",
+ u1db__sync_exchange_insert_doc_from_source(self._exchange,
+ doc._doc, source_gen, source_trans_id))
+
+ def find_doc_ids_to_return(self):
+ self._check()
+ handle_status("find_doc_ids_to_return",
+ u1db__sync_exchange_find_doc_ids_to_return(self._exchange))
+
+ def return_docs(self, return_doc_cb):
+ self._check()
+ handle_status("return_docs",
+ u1db__sync_exchange_return_docs(self._exchange,
+ <void *>return_doc_cb, &return_doc_cb_wrapper))
+
+ def get_seen_ids(self):
+ cdef const_char_ptr *seen_ids
+ cdef int i, n_ids
+ self._check()
+ handle_status("sync_exchange_seen_ids",
+ u1db__sync_exchange_seen_ids(self._exchange, &n_ids, &seen_ids))
+ res = []
+ for i from 0 <= i < n_ids:
+ res.append(seen_ids[i])
+ if (seen_ids != NULL):
+ free(<void*>seen_ids)
+ return res
+
+ def get_doc_ids_to_return(self):
+ self._check()
+ res = []
+ if (self._exchange.num_doc_ids > 0
+ and self._exchange.doc_ids_to_return != NULL):
+ for i from 0 <= i < self._exchange.num_doc_ids:
+ res.append(
+ (self._exchange.doc_ids_to_return[i],
+ self._exchange.gen_for_doc_ids[i],
+ self._exchange.trans_ids_for_doc_ids[i]))
+ return res
+
+
+cdef class CSyncTarget(object):
+
+ cdef u1db_sync_target *_st
+ cdef CDatabase _db
+
+ def __init__(self):
+ self._db = None
+ self._st = NULL
+ u1db__set_zero_delays()
+
+ def __dealloc__(self):
+ u1db__free_sync_target(&self._st)
+
+ def _check(self):
+ if self._st == NULL:
+ raise RuntimeError("self._st is NULL")
+
+ def get_sync_info(self, source_replica_uid):
+ cdef const_char_ptr st_replica_uid = NULL
+ cdef int st_gen = 0, source_gen = 0, status
+ cdef char *trans_id = NULL
+ cdef char *st_trans_id = NULL
+ cdef char *c_source_replica_uid = NULL
+
+ self._check()
+ assert self._st.get_sync_info != NULL, "get_sync_info is NULL?"
+ c_source_replica_uid = source_replica_uid
+ with nogil:
+ status = self._st.get_sync_info(self._st, c_source_replica_uid,
+ &st_replica_uid, &st_gen, &st_trans_id, &source_gen, &trans_id)
+ handle_status("get_sync_info", status)
+ res_trans_id = None
+ res_st_trans_id = None
+ if trans_id != NULL:
+ res_trans_id = trans_id
+ free(trans_id)
+ if st_trans_id != NULL:
+ res_st_trans_id = st_trans_id
+ free(st_trans_id)
+ return (
+ safe_str(st_replica_uid), st_gen, res_st_trans_id, source_gen,
+ res_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_gen, source_trans_id):
+ cdef int status
+ cdef int c_source_gen
+ cdef char *c_source_replica_uid = NULL
+ cdef char *c_source_trans_id = NULL
+
+ self._check()
+ assert self._st.record_sync_info != NULL, "record_sync_info is NULL?"
+ c_source_replica_uid = source_replica_uid
+ c_source_gen = source_gen
+ c_source_trans_id = source_trans_id
+ with nogil:
+ status = self._st.record_sync_info(
+ self._st, c_source_replica_uid, c_source_gen,
+ c_source_trans_id)
+ handle_status("record_sync_info", status)
+
+ def _get_sync_exchange(self, source_replica_uid, source_gen):
+ self._check()
+ return CSyncExchange(self, source_replica_uid, source_gen)
+
+ def sync_exchange_doc_ids(self, source_db, doc_id_generations,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb):
+ cdef const_char_ptr *doc_ids
+ cdef int *generations
+ cdef int num_doc_ids
+ cdef int target_gen
+ cdef char *target_trans_id = NULL
+ cdef int status
+ cdef CDatabase sdb
+
+ self._check()
+ assert self._st.sync_exchange_doc_ids != NULL, "sync_exchange_doc_ids is NULL?"
+ sdb = source_db
+ num_doc_ids = len(doc_id_generations)
+ doc_ids = <const_char_ptr *>calloc(num_doc_ids, sizeof(char *))
+ if doc_ids == NULL:
+ raise MemoryError
+ generations = <int *>calloc(num_doc_ids, sizeof(int))
+ if generations == NULL:
+ free(<void *>doc_ids)
+ raise MemoryError
+ trans_ids = <const_char_ptr*>calloc(num_doc_ids, sizeof(char *))
+ if trans_ids == NULL:
+ raise MemoryError
+ res_trans_id = ''
+ try:
+ for i, (doc_id, gen, trans_id) in enumerate(doc_id_generations):
+ doc_ids[i] = PyString_AsString(doc_id)
+ generations[i] = gen
+ trans_ids[i] = trans_id
+ target_gen = last_known_generation
+ if last_known_trans_id is not None:
+ target_trans_id = last_known_trans_id
+ with nogil:
+ status = self._st.sync_exchange_doc_ids(self._st, sdb._db,
+ num_doc_ids, doc_ids, generations, trans_ids,
+ &target_gen, &target_trans_id,
+ <void*>return_doc_cb, return_doc_cb_wrapper, NULL)
+ handle_status("sync_exchange_doc_ids", status)
+ if target_trans_id != NULL:
+ res_trans_id = target_trans_id
+ finally:
+ if target_trans_id != NULL:
+ free(target_trans_id)
+ if doc_ids != NULL:
+ free(<void *>doc_ids)
+ if generations != NULL:
+ free(generations)
+ if trans_ids != NULL:
+ free(trans_ids)
+ return target_gen, res_trans_id
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ cdef CDocument cur_doc
+ cdef u1db_document **docs = NULL
+ cdef int *generations = NULL
+ cdef const_char_ptr *trans_ids = NULL
+ cdef char *target_trans_id = NULL
+ cdef char *c_source_replica_uid = NULL
+ cdef int i, count, status, target_gen
+ assert ensure_callback is None # interface difference
+
+ self._check()
+ assert self._st.sync_exchange != NULL, "sync_exchange is NULL?"
+ count = len(docs_by_generations)
+ res_trans_id = ''
+ try:
+ docs = <u1db_document **>calloc(count, sizeof(u1db_document*))
+ if docs == NULL:
+ raise MemoryError
+ generations = <int*>calloc(count, sizeof(int))
+ if generations == NULL:
+ raise MemoryError
+ trans_ids = <const_char_ptr*>calloc(count, sizeof(char*))
+ if trans_ids == NULL:
+ raise MemoryError
+ for i from 0 <= i < count:
+ cur_doc = docs_by_generations[i][0]
+ generations[i] = docs_by_generations[i][1]
+ trans_ids[i] = docs_by_generations[i][2]
+ docs[i] = cur_doc._doc
+ target_gen = last_known_generation
+ if last_known_trans_id is not None:
+ target_trans_id = last_known_trans_id
+ c_source_replica_uid = source_replica_uid
+ with nogil:
+ status = self._st.sync_exchange(
+ self._st, c_source_replica_uid, count, docs, generations,
+ trans_ids, &target_gen, &target_trans_id,
+ <void *>return_doc_cb, return_doc_cb_wrapper, NULL)
+ handle_status("sync_exchange", status)
+ finally:
+ if docs != NULL:
+ free(docs)
+ if generations != NULL:
+ free(generations)
+ if trans_ids != NULL:
+ free(trans_ids)
+ if target_trans_id != NULL:
+ res_trans_id = target_trans_id
+ free(target_trans_id)
+ return target_gen, res_trans_id
+
+ def _set_trace_hook(self, cb):
+ self._check()
+ assert self._st._set_trace_hook != NULL, "_set_trace_hook is NULL?"
+ handle_status("_set_trace_hook",
+ self._st._set_trace_hook(self._st, <void*>cb, _trace_hook))
+
+ _set_trace_hook_shallow = _set_trace_hook
+
+
+cdef class CDatabase(object):
+ """A thin wrapper/shim to interact with the C implementation.
+
+ Functionality should not be written here. It is only provided as a way to
+ expose the C API to the python test suite.
+ """
+
+ cdef public object _filename
+ cdef u1database *_db
+ cdef public object _supports_indexes
+
+ def __init__(self, filename):
+ self._supports_indexes = False
+ self._filename = filename
+ self._db = u1db_open(self._filename)
+
+ def __dealloc__(self):
+ u1db_free(&self._db)
+
+ def close(self):
+ return u1db__sql_close(self._db)
+
+ def _copy(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = CDatabase(':memory:')
+ u1db_free(&new_db._db)
+ new_db._db = u1db__copy(self._db)
+ return new_db
+
+ def _sql_is_open(self):
+ if self._db == NULL:
+ return True
+ return u1db__sql_is_open(self._db)
+
+ property _replica_uid:
+ def __get__(self):
+ cdef const_char_ptr val
+ cdef int status
+ status = u1db_get_replica_uid(self._db, &val)
+ if status != 0:
+ if val != NULL:
+ err = str(val)
+ else:
+ err = "<unknown>"
+ raise RuntimeError("Failed to get_replica_uid: %d %s"
+ % (status, err))
+ if val == NULL:
+ return None
+ return str(val)
+
+ def _set_replica_uid(self, replica_uid):
+ cdef int status
+ status = u1db_set_replica_uid(self._db, replica_uid)
+ if status != 0:
+ raise RuntimeError('replica_uid could not be set to %s, error: %d'
+ % (replica_uid, status))
+
+ property document_size_limit:
+ def __get__(self):
+ cdef int limit
+ handle_status("document_size_limit",
+ u1db__get_document_size_limit(self._db, &limit))
+ return limit
+
+ def set_document_size_limit(self, limit):
+ cdef int status
+ status = u1db_set_document_size_limit(self._db, limit)
+ if status != 0:
+ raise RuntimeError(
+ "document_size_limit could not be set to %d, error: %d",
+ (limit, status))
+
+ def _allocate_doc_id(self):
+ cdef char *val
+ val = u1db__allocate_doc_id(self._db)
+ if val == NULL:
+ raise RuntimeError("Failed to allocate document id")
+ s = str(val)
+ free(val)
+ return s
+
+ def _run_sql(self, sql):
+ cdef u1db_table *tbl
+ cdef u1db_row *cur_row
+ cdef size_t n
+ cdef int i
+
+ if self._db == NULL:
+ raise RuntimeError("called _run_sql with a NULL pointer.")
+ tbl = u1db__sql_run(self._db, sql, len(sql))
+ if tbl == NULL:
+ raise MemoryError("Failed to allocate table memory.")
+ try:
+ if tbl.status != 0:
+ raise RuntimeError("Status was not 0: %d" % (tbl.status,))
+ # Now convert the table into python
+ res = []
+ cur_row = tbl.first_row
+ while cur_row != NULL:
+ row = []
+ for i from 0 <= i < cur_row.num_columns:
+ row.append(PyString_FromStringAndSize(
+ <char*>(cur_row.columns[i]), cur_row.column_sizes[i]))
+ res.append(tuple(row))
+ cur_row = cur_row.next
+ return res
+ finally:
+ u1db__free_table(&tbl)
+
+ def create_doc_from_json(self, json, doc_id=None):
+ cdef u1db_document *doc = NULL
+ cdef char *c_doc_id
+
+ if doc_id is None:
+ c_doc_id = NULL
+ else:
+ c_doc_id = doc_id
+ handle_status('Failed to create_doc',
+ u1db_create_doc_from_json(self._db, json, c_doc_id, &doc))
+ pydoc = CDocument()
+ pydoc._doc = doc
+ return pydoc
+
+ def put_doc(self, CDocument doc):
+ handle_status("Failed to put_doc",
+ u1db_put_doc(self._db, doc._doc))
+ return doc.rev
+
+ def _validate_source(self, replica_uid, replica_gen, replica_trans_id):
+ cdef const_char_ptr c_uid, c_trans_id
+ cdef int c_gen = 0
+
+ c_uid = replica_uid
+ c_trans_id = replica_trans_id
+ c_gen = replica_gen
+ handle_status(
+ "invalid generation or transaction id",
+ u1db__validate_source(self._db, c_uid, c_gen, c_trans_id))
+
+ def _put_doc_if_newer(self, CDocument doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ cdef char *c_uid, *c_trans_id
+ cdef int gen, state = 0, at_gen = -1
+
+ if replica_uid is None:
+ c_uid = NULL
+ else:
+ c_uid = replica_uid
+ if replica_trans_id is None:
+ c_trans_id = NULL
+ else:
+ c_trans_id = replica_trans_id
+ if replica_gen is None:
+ gen = 0
+ else:
+ gen = replica_gen
+ handle_status("Failed to _put_doc_if_newer",
+ u1db__put_doc_if_newer(self._db, doc._doc, save_conflict,
+ c_uid, gen, c_trans_id, &state, &at_gen))
+ if state == U1DB_INSERTED:
+ return 'inserted', at_gen
+ elif state == U1DB_SUPERSEDED:
+ return 'superseded', at_gen
+ elif state == U1DB_CONVERGED:
+ return 'converged', at_gen
+ elif state == U1DB_CONFLICTED:
+ return 'conflicted', at_gen
+ else:
+ raise RuntimeError("Unknown _put_doc_if_newer state: %d" % (state,))
+
+ def get_doc(self, doc_id, include_deleted=False):
+ cdef u1db_document *doc = NULL
+ deleted = 1 if include_deleted else 0
+ handle_status("get_doc failed",
+ u1db_get_doc(self._db, doc_id, deleted, &doc))
+ if doc == NULL:
+ return None
+ pydoc = CDocument()
+ pydoc._doc = doc
+ return pydoc
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ cdef int n_doc_ids, conflicts
+ cdef const_char_ptr *c_doc_ids
+
+ _list_to_array(doc_ids, &c_doc_ids, &n_doc_ids)
+ deleted = 1 if include_deleted else 0
+ conflicts = 1 if check_for_conflicts else 0
+ a_list = []
+ handle_status("get_docs",
+ u1db_get_docs(self._db, n_doc_ids, c_doc_ids,
+ conflicts, deleted, <void*>a_list, _append_doc_to_list))
+ free(<void*>c_doc_ids)
+ return a_list
+
+ def get_all_docs(self, include_deleted=False):
+ cdef int c_generation
+
+ a_list = []
+ deleted = 1 if include_deleted else 0
+ generation = 0
+ c_generation = generation
+ handle_status(
+ "get_all_docs", u1db_get_all_docs(
+ self._db, deleted, &c_generation, <void*>a_list,
+ _append_doc_to_list))
+ return (c_generation, a_list)
+
+ def resolve_doc(self, CDocument doc, conflicted_doc_revs):
+ cdef const_char_ptr *revs
+ cdef int n_revs
+
+ _list_to_array(conflicted_doc_revs, &revs, &n_revs)
+ handle_status("resolve_doc",
+ u1db_resolve_doc(self._db, doc._doc, n_revs, revs))
+ free(<void*>revs)
+
+ def get_doc_conflicts(self, doc_id):
+ conflict_docs = []
+ handle_status("get_doc_conflicts",
+ u1db_get_doc_conflicts(self._db, doc_id, <void*>conflict_docs,
+ _append_doc_to_list))
+ return conflict_docs
+
+ def delete_doc(self, CDocument doc):
+ handle_status(
+ "Failed to delete %s" % (doc,),
+ u1db_delete_doc(self._db, doc._doc))
+
+ def whats_changed(self, generation=0):
+ cdef int c_generation
+ cdef int status
+ cdef char *trans_id = NULL
+
+ a_list = []
+ c_generation = generation
+ res_trans_id = ''
+ status = u1db_whats_changed(self._db, &c_generation, &trans_id,
+ <void*>a_list, _append_trans_info_to_list)
+ try:
+ handle_status("whats_changed", status)
+ finally:
+ if trans_id != NULL:
+ res_trans_id = trans_id
+ free(trans_id)
+ return c_generation, res_trans_id, a_list
+
+ def _get_transaction_log(self):
+ a_list = []
+ handle_status("_get_transaction_log",
+ u1db__get_transaction_log(self._db, <void*>a_list,
+ _append_trans_info_to_list))
+ return [(doc_id, trans_id) for doc_id, gen, trans_id in a_list]
+
+ def _get_generation(self):
+ cdef int generation
+ handle_status("get_generation",
+ u1db__get_generation(self._db, &generation))
+ return generation
+
+ def _get_generation_info(self):
+ cdef int generation
+ cdef char *trans_id
+ handle_status("get_generation_info",
+ u1db__get_generation_info(self._db, &generation, &trans_id))
+ raw_trans_id = None
+ if trans_id != NULL:
+ raw_trans_id = trans_id
+ free(trans_id)
+ return generation, raw_trans_id
+
+ def validate_gen_and_trans_id(self, generation, trans_id):
+ handle_status(
+ "validate_gen_and_trans_id",
+ u1db_validate_gen_and_trans_id(self._db, generation, trans_id))
+
+ def _get_trans_id_for_gen(self, generation):
+ cdef char *trans_id = NULL
+
+ handle_status(
+ "_get_trans_id_for_gen",
+ u1db__get_trans_id_for_gen(self._db, generation, &trans_id))
+ raw_trans_id = None
+ if trans_id != NULL:
+ raw_trans_id = trans_id
+ free(trans_id)
+ return raw_trans_id
+
+ def _get_replica_gen_and_trans_id(self, replica_uid):
+ cdef int generation, status
+ cdef char *trans_id = NULL
+
+ status = u1db__get_replica_gen_and_trans_id(
+ self._db, replica_uid, &generation, &trans_id)
+ handle_status("_get_replica_gen_and_trans_id", status)
+ raw_trans_id = None
+ if trans_id != NULL:
+ raw_trans_id = trans_id
+ free(trans_id)
+ return generation, raw_trans_id
+
+ def _set_replica_gen_and_trans_id(self, replica_uid, generation, trans_id):
+ handle_status("_set_replica_gen_and_trans_id",
+ u1db__set_replica_gen_and_trans_id(
+ self._db, replica_uid, generation, trans_id))
+
+ def create_index_list(self, index_name, index_expressions):
+ cdef const_char_ptr *expressions
+ cdef int n_expressions
+
+ # keep a reference to new_objs so that the pointers in expressions
+ # remain valid.
+ new_objs = _list_to_str_array(
+ index_expressions, &expressions, &n_expressions)
+ try:
+ status = u1db_create_index_list(
+ self._db, index_name, n_expressions, expressions)
+ finally:
+ free(<void*>expressions)
+ handle_status("create_index", status)
+
+ def create_index(self, index_name, *index_expressions):
+ extra = []
+ if len(index_expressions) == 0:
+ status = u1db_create_index(self._db, index_name, 0, NULL)
+ elif len(index_expressions) == 1:
+ status = u1db_create_index(
+ self._db, index_name, 1,
+ _ensure_str(index_expressions[0], extra))
+ elif len(index_expressions) == 2:
+ status = u1db_create_index(
+ self._db, index_name, 2,
+ _ensure_str(index_expressions[0], extra),
+ _ensure_str(index_expressions[1], extra))
+ elif len(index_expressions) == 3:
+ status = u1db_create_index(
+ self._db, index_name, 3,
+ _ensure_str(index_expressions[0], extra),
+ _ensure_str(index_expressions[1], extra),
+ _ensure_str(index_expressions[2], extra))
+ elif len(index_expressions) == 4:
+ status = u1db_create_index(
+ self._db, index_name, 4,
+ _ensure_str(index_expressions[0], extra),
+ _ensure_str(index_expressions[1], extra),
+ _ensure_str(index_expressions[2], extra),
+ _ensure_str(index_expressions[3], extra))
+ else:
+ status = U1DB_NOT_IMPLEMENTED
+ handle_status("create_index", status)
+
+ def sync(self, url, creds=None):
+ cdef const_char_ptr c_url
+ cdef int local_gen = 0
+ cdef u1db_oauth_creds _oauth_creds
+ cdef u1db_creds *_creds = NULL
+ c_url = url
+ if creds is not None:
+ _oauth_creds.auth_kind = U1DB_OAUTH_AUTH
+ _oauth_creds.consumer_key = creds['oauth']['consumer_key']
+ _oauth_creds.consumer_secret = creds['oauth']['consumer_secret']
+ _oauth_creds.token_key = creds['oauth']['token_key']
+ _oauth_creds.token_secret = creds['oauth']['token_secret']
+ _creds = <u1db_creds *>&_oauth_creds
+ with nogil:
+ status = u1db_sync(self._db, c_url, _creds, &local_gen)
+ handle_status("sync", status)
+ return local_gen
+
+ def list_indexes(self):
+ a_list = []
+ handle_status("list_indexes",
+ u1db_list_indexes(self._db, <void *>a_list,
+ _append_index_definition_to_list))
+ return a_list
+
+ def delete_index(self, index_name):
+ handle_status("delete_index",
+ u1db_delete_index(self._db, index_name))
+
+ def get_from_index_list(self, index_name, key_values):
+ cdef const_char_ptr *values
+ cdef int n_values
+ cdef CQuery query
+
+ query = self._query_init(index_name)
+ res = []
+ # keep a reference to new_objs so that the pointers in expressions
+ # remain valid.
+ new_objs = _list_to_str_array(key_values, &values, &n_values)
+ try:
+ handle_status(
+ "get_from_index", u1db_get_from_index_list(
+ self._db, query._query, <void*>res, _append_doc_to_list,
+ n_values, values))
+ finally:
+ free(<void*>values)
+ return res
+
+ def get_from_index(self, index_name, *key_values):
+ cdef CQuery query
+ cdef int status
+
+ extra = []
+ query = self._query_init(index_name)
+ res = []
+ status = U1DB_OK
+ if len(key_values) == 0:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 0, NULL)
+ elif len(key_values) == 1:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 1,
+ _ensure_str(key_values[0], extra))
+ elif len(key_values) == 2:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 2,
+ _ensure_str(key_values[0], extra),
+ _ensure_str(key_values[1], extra))
+ elif len(key_values) == 3:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 3,
+ _ensure_str(key_values[0], extra),
+ _ensure_str(key_values[1], extra),
+ _ensure_str(key_values[2], extra))
+ elif len(key_values) == 4:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 4,
+ _ensure_str(key_values[0], extra),
+ _ensure_str(key_values[1], extra),
+ _ensure_str(key_values[2], extra),
+ _ensure_str(key_values[3], extra))
+ else:
+ status = U1DB_NOT_IMPLEMENTED
+ handle_status("get_from_index", status)
+ return res
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ cdef CQuery query
+ cdef const_char_ptr *start_values
+ cdef int n_values
+ cdef const_char_ptr *end_values
+
+ if start_value is not None:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ new_objs_1 = _list_to_str_array(
+ start_value, &start_values, &n_values)
+ else:
+ n_values = 0
+ start_values = NULL
+ if end_value is not None:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ new_objs_2 = _list_to_str_array(
+ end_value, &end_values, &n_values)
+ else:
+ end_values = NULL
+ query = self._query_init(index_name)
+ res = []
+ try:
+ handle_status("get_range_from_index",
+ u1db_get_range_from_index(
+ self._db, query._query, <void*>res, _append_doc_to_list,
+ n_values, start_values, end_values))
+ finally:
+ if start_values != NULL:
+ free(<void*>start_values)
+ if end_values != NULL:
+ free(<void*>end_values)
+ return res
+
+ def get_index_keys(self, index_name):
+ cdef int status
+ keys = []
+ status = U1DB_OK
+ status = u1db_get_index_keys(
+ self._db, index_name, <void*>keys, _append_key_to_list)
+ handle_status("get_index_keys", status)
+ return keys
+
+ def _query_init(self, index_name):
+ cdef CQuery query
+ query = CQuery()
+ handle_status("query_init",
+ u1db_query_init(self._db, index_name, &query._query))
+ return query
+
+ def get_sync_target(self):
+ cdef CSyncTarget target
+ target = CSyncTarget()
+ target._db = self
+ handle_status("get_sync_target",
+ u1db__get_sync_target(target._db._db, &target._st))
+ return target
+
+
+cdef class VectorClockRev:
+
+ cdef u1db_vectorclock *_clock
+
+ def __init__(self, s):
+ if s is None:
+ self._clock = u1db__vectorclock_from_str(NULL)
+ else:
+ self._clock = u1db__vectorclock_from_str(s)
+
+ def __dealloc__(self):
+ u1db__free_vectorclock(&self._clock)
+
+ def __repr__(self):
+ cdef int status
+ cdef char *res
+ if self._clock == NULL:
+ return '%s(None)' % (self.__class__.__name__,)
+ status = u1db__vectorclock_as_str(self._clock, &res)
+ if status != U1DB_OK:
+ return '%s(<failure: %d>)' % (status,)
+ if res == NULL:
+ val = '%s(NULL)' % (self.__class__.__name__,)
+ else:
+ val = '%s(%s)' % (self.__class__.__name__, res)
+ free(res)
+ return val
+
+ def as_dict(self):
+ cdef u1db_vectorclock *cur
+ cdef int i
+ cdef int gen
+ if self._clock == NULL:
+ return None
+ res = {}
+ for i from 0 <= i < self._clock.num_items:
+ gen = self._clock.items[i].generation
+ res[self._clock.items[i].replica_uid] = gen
+ return res
+
+ def as_str(self):
+ cdef int status
+ cdef char *res
+
+ status = u1db__vectorclock_as_str(self._clock, &res)
+ if status != U1DB_OK:
+ raise RuntimeError("Failed to VectorClockRev.as_str(): %d" % (status,))
+ if res == NULL:
+ s = None
+ else:
+ s = res
+ free(res)
+ return s
+
+ def increment(self, replica_uid):
+ cdef int status
+
+ status = u1db__vectorclock_increment(self._clock, replica_uid)
+ if status != U1DB_OK:
+ raise RuntimeError("Failed to increment: %d" % (status,))
+
+ def maximize(self, vcr):
+ cdef int status
+ cdef VectorClockRev other
+
+ other = vcr
+ status = u1db__vectorclock_maximize(self._clock, other._clock)
+ if status != U1DB_OK:
+ raise RuntimeError("Failed to maximize: %d" % (status,))
+
+ def is_newer(self, vcr):
+ cdef int is_newer
+ cdef VectorClockRev other
+
+ other = vcr
+ is_newer = u1db__vectorclock_is_newer(self._clock, other._clock)
+ if is_newer == 0:
+ return False
+ elif is_newer == 1:
+ return True
+ else:
+ raise RuntimeError("Failed to is_newer: %d" % (is_newer,))
+
+
+def sync_db_to_target(db, target):
+ """Sync the data between a CDatabase and a CSyncTarget"""
+ cdef CDatabase cdb
+ cdef CSyncTarget ctarget
+ cdef int local_gen = 0, status
+
+ cdb = db
+ ctarget = target
+ with nogil:
+ status = u1db__sync_db_to_target(cdb._db, ctarget._st, &local_gen)
+ handle_status("sync_db_to_target", status)
+ return local_gen
+
+
+def create_http_sync_target(url):
+ cdef CSyncTarget target
+
+ target = CSyncTarget()
+ handle_status("create_http_sync_target",
+ u1db__create_http_sync_target(url, &target._st))
+ return target
+
+
+def create_oauth_http_sync_target(url, consumer_key, consumer_secret,
+ token_key, token_secret):
+ cdef CSyncTarget target
+
+ target = CSyncTarget()
+ handle_status("create_http_sync_target",
+ u1db__create_oauth_http_sync_target(url, consumer_key, consumer_secret,
+ token_key, token_secret,
+ &target._st))
+ return target
+
+
+def _format_sync_url(target, source_replica_uid):
+ cdef CSyncTarget st
+ cdef char *sync_url = NULL
+ cdef object res
+ st = target
+ handle_status("format_sync_url",
+ u1db__format_sync_url(st._st, source_replica_uid, &sync_url))
+ if sync_url == NULL:
+ res = None
+ else:
+ res = sync_url
+ free(sync_url)
+ return res
+
+
+def _get_oauth_authorization(target, method, url):
+ cdef CSyncTarget st
+ cdef char *auth = NULL
+
+ st = target
+ handle_status("get_oauth_authorization",
+ u1db__get_oauth_authorization(st._st, method, url, &auth))
+ res = None
+ if auth != NULL:
+ res = auth
+ free(auth)
+ return res
diff --git a/u1db/tests/commandline/__init__.py b/u1db/tests/commandline/__init__.py
new file mode 100644
index 00000000..007cecd3
--- /dev/null
+++ b/u1db/tests/commandline/__init__.py
@@ -0,0 +1,47 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import errno
+import time
+
+
+def safe_close(process, timeout=0.1):
+ """Shutdown the process in the nicest fashion you can manage.
+
+ :param process: A subprocess.Popen object.
+ :param timeout: We'll try to send 'SIGTERM' but if the process is alive
+ longer that 'timeout', we'll send SIGKILL.
+ """
+ if process.poll() is not None:
+ return
+ try:
+ process.terminate()
+ except OSError, e:
+ if e.errno in (errno.ESRCH,):
+ # Process has exited
+ return
+ tend = time.time() + timeout
+ while time.time() < tend:
+ if process.poll() is not None:
+ return
+ time.sleep(0.01)
+ try:
+ process.kill()
+ except OSError, e:
+ if e.errno in (errno.ESRCH,):
+ # Process has exited
+ return
+ process.wait()
diff --git a/u1db/tests/commandline/test_client.py b/u1db/tests/commandline/test_client.py
new file mode 100644
index 00000000..78ca21eb
--- /dev/null
+++ b/u1db/tests/commandline/test_client.py
@@ -0,0 +1,916 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import cStringIO
+import os
+import sys
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import subprocess
+
+from u1db import (
+ errors,
+ open as u1db_open,
+ tests,
+ vectorclock,
+ )
+from u1db.commandline import (
+ client,
+ serve,
+ )
+from u1db.tests.commandline import safe_close
+from u1db.tests import test_remote_sync_target
+
+
+class TestArgs(tests.TestCase):
+ """These tests are meant to test just the argument parsing.
+
+ Each Command should have at least one test, possibly more if it allows
+ optional arguments, etc.
+ """
+
+ def setUp(self):
+ super(TestArgs, self).setUp()
+ self.parser = client.client_commands.make_argparser()
+
+ def parse_args(self, args):
+ # ArgumentParser.parse_args doesn't play very nicely with a test suite,
+ # so we trap SystemExit in case something is wrong with the args we're
+ # parsing.
+ try:
+ return self.parser.parse_args(args)
+ except SystemExit:
+ raise AssertionError('got SystemExit')
+
+ def test_create(self):
+ args = self.parse_args(['create', 'test.db'])
+ self.assertEqual(client.CmdCreate, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual(None, args.doc_id)
+ self.assertEqual(None, args.infile)
+
+ def test_create_custom_doc_id(self):
+ args = self.parse_args(['create', '--id', 'xyz', 'test.db'])
+ self.assertEqual(client.CmdCreate, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('xyz', args.doc_id)
+ self.assertEqual(None, args.infile)
+
+ def test_delete(self):
+ args = self.parse_args(['delete', 'test.db', 'doc-id', 'doc-rev'])
+ self.assertEqual(client.CmdDelete, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual('doc-rev', args.doc_rev)
+
+ def test_get(self):
+ args = self.parse_args(['get', 'test.db', 'doc-id'])
+ self.assertEqual(client.CmdGet, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual(None, args.outfile)
+
+ def test_get_dash(self):
+ args = self.parse_args(['get', 'test.db', 'doc-id', '-'])
+ self.assertEqual(client.CmdGet, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual(sys.stdout, args.outfile)
+
+ def test_init_db(self):
+ args = self.parse_args(
+ ['init-db', 'test.db', '--replica-uid=replica-uid'])
+ self.assertEqual(client.CmdInitDB, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('replica-uid', args.replica_uid)
+
+ def test_init_db_no_replica(self):
+ args = self.parse_args(['init-db', 'test.db'])
+ self.assertEqual(client.CmdInitDB, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertIs(None, args.replica_uid)
+
+ def test_put(self):
+ args = self.parse_args(['put', 'test.db', 'doc-id', 'old-doc-rev'])
+ self.assertEqual(client.CmdPut, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual('old-doc-rev', args.doc_rev)
+ self.assertEqual(None, args.infile)
+
+ def test_sync(self):
+ args = self.parse_args(['sync', 'source', 'target'])
+ self.assertEqual(client.CmdSync, args.subcommand)
+ self.assertEqual('source', args.source)
+ self.assertEqual('target', args.target)
+
+ def test_create_index(self):
+ args = self.parse_args(['create-index', 'db', 'index', 'expression'])
+ self.assertEqual(client.CmdCreateIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+ self.assertEqual(['expression'], args.expression)
+
+ def test_create_index_multi_expression(self):
+ args = self.parse_args(['create-index', 'db', 'index', 'e1', 'e2'])
+ self.assertEqual(client.CmdCreateIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+ self.assertEqual(['e1', 'e2'], args.expression)
+
+ def test_list_indexes(self):
+ args = self.parse_args(['list-indexes', 'db'])
+ self.assertEqual(client.CmdListIndexes, args.subcommand)
+ self.assertEqual('db', args.database)
+
+ def test_delete_index(self):
+ args = self.parse_args(['delete-index', 'db', 'index'])
+ self.assertEqual(client.CmdDeleteIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+
+ def test_get_index_keys(self):
+ args = self.parse_args(['get-index-keys', 'db', 'index'])
+ self.assertEqual(client.CmdGetIndexKeys, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+
+ def test_get_from_index(self):
+ args = self.parse_args(['get-from-index', 'db', 'index', 'foo'])
+ self.assertEqual(client.CmdGetFromIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+ self.assertEqual(['foo'], args.values)
+
+ def test_get_doc_conflicts(self):
+ args = self.parse_args(['get-doc-conflicts', 'db', 'doc-id'])
+ self.assertEqual(client.CmdGetDocConflicts, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+
+ def test_resolve(self):
+ args = self.parse_args(
+ ['resolve-doc', 'db', 'doc-id', 'rev:1', 'other:1'])
+ self.assertEqual(client.CmdResolve, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual(['rev:1', 'other:1'], args.doc_revs)
+ self.assertEqual(None, args.infile)
+
+
+class TestCaseWithDB(tests.TestCase):
+ """These next tests are meant to have one class per Command.
+
+ It is meant to test the inner workings of each command. The detailed
+ testing should happen in these classes. Stuff like how it handles errors,
+ etc. should be done here.
+ """
+
+ def setUp(self):
+ super(TestCaseWithDB, self).setUp()
+ self.working_dir = self.createTempDir()
+ self.db_path = self.working_dir + '/test.db'
+ self.db = u1db_open(self.db_path, create=True)
+ self.db._set_replica_uid('test')
+ self.addCleanup(self.db.close)
+
+ def make_command(self, cls, stdin_content=''):
+ inf = cStringIO.StringIO(stdin_content)
+ out = cStringIO.StringIO()
+ err = cStringIO.StringIO()
+ return cls(inf, out, err)
+
+
+class TestCmdCreate(TestCaseWithDB):
+
+ def test_create(self):
+ cmd = self.make_command(client.CmdCreate)
+ inf = cStringIO.StringIO(tests.simple_doc)
+ cmd.run(self.db_path, inf, 'test-id')
+ doc = self.db.get_doc('test-id')
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('id: test-id\nrev: %s\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+
+class TestCmdDelete(TestCaseWithDB):
+
+ def test_delete(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ cmd = self.make_command(client.CmdDelete)
+ cmd.run(self.db_path, doc.doc_id, doc.rev)
+ doc2 = self.db.get_doc(doc.doc_id, include_deleted=True)
+ self.assertEqual(doc.doc_id, doc2.doc_id)
+ self.assertNotEqual(doc.rev, doc2.rev)
+ self.assertIs(None, doc2.get_json())
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (doc2.rev,), cmd.stderr.getvalue())
+
+ def test_delete_fails_if_nonexistent(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ db2_path = self.db_path + '.typo'
+ cmd = self.make_command(client.CmdDelete)
+ # TODO: We should really not be showing a traceback here. But we need
+ # to teach the commandline infrastructure how to handle
+ # exceptions.
+ # However, we *do* want to test that the db doesn't get created
+ # by accident.
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ cmd.run, db2_path, doc.doc_id, doc.rev)
+ self.assertFalse(os.path.exists(db2_path))
+
+ def test_delete_no_such_doc(self):
+ cmd = self.make_command(client.CmdDelete)
+ # TODO: We should really not be showing a traceback here. But we need
+ # to teach the commandline infrastructure how to handle
+ # exceptions.
+ self.assertRaises(errors.DocumentDoesNotExist,
+ cmd.run, self.db_path, 'no-doc-id', 'no-rev')
+
+ def test_delete_bad_rev(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ cmd = self.make_command(client.CmdDelete)
+ self.assertRaises(errors.RevisionConflict,
+ cmd.run, self.db_path, doc.doc_id, 'not-the-actual-doc-rev:1')
+ # TODO: Test that we get a pretty output.
+
+
+class TestCmdGet(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdGet, self).setUp()
+ self.doc = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-test-doc')
+
+ def test_get_simple(self):
+ cmd = self.make_command(client.CmdGet)
+ cmd.run(self.db_path, 'my-test-doc', None)
+ self.assertEqual(tests.simple_doc + "\n", cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (self.doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_get_conflict(self):
+ doc = self.make_document('my-test-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ cmd = self.make_command(client.CmdGet)
+ cmd.run(self.db_path, 'my-test-doc', None)
+ self.assertEqual('{}\n', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\nDocument has conflicts.\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_get_fail(self):
+ cmd = self.make_command(client.CmdGet)
+ result = cmd.run(self.db_path, 'doc-not-there', None)
+ self.assertEqual(1, result)
+ self.assertEqual("", cmd.stdout.getvalue())
+ self.assertTrue("not found" in cmd.stderr.getvalue())
+
+ def test_get_no_database(self):
+ cmd = self.make_command(client.CmdGet)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", None)
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+
+class TestCmdGetDocConflicts(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdGetDocConflicts, self).setUp()
+ self.doc1 = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-doc')
+ self.doc2 = self.make_document('my-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ self.doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+
+ def test_get_doc_conflicts_none(self):
+ self.db.create_doc_from_json(tests.simple_doc, doc_id='a-doc')
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ cmd.run(self.db_path, 'a-doc')
+ self.assertEqual([], json.loads(cmd.stdout.getvalue()))
+ self.assertEqual('', cmd.stderr.getvalue())
+
+ def test_get_doc_conflicts_simple(self):
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ cmd.run(self.db_path, 'my-doc')
+ self.assertEqual(
+ [dict(rev=self.doc2.rev, content=self.doc2.content),
+ dict(rev=self.doc1.rev, content=self.doc1.content)],
+ json.loads(cmd.stdout.getvalue()))
+ self.assertEqual('', cmd.stderr.getvalue())
+
+ def test_get_doc_conflicts_no_db(self):
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_get_doc_conflicts_no_doc(self):
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ retval = cmd.run(self.db_path, "some-doc")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n')
+
+
+class TestCmdInit(TestCaseWithDB):
+
+ def test_init_new(self):
+ path = self.working_dir + '/test2.db'
+ self.assertFalse(os.path.exists(path))
+ cmd = self.make_command(client.CmdInitDB)
+ cmd.run(path, 'test-uid')
+ self.assertTrue(os.path.exists(path))
+ db = u1db_open(path, create=False)
+ self.assertEqual('test-uid', db._replica_uid)
+
+ def test_init_no_uid(self):
+ path = self.working_dir + '/test2.db'
+ cmd = self.make_command(client.CmdInitDB)
+ cmd.run(path, None)
+ self.assertTrue(os.path.exists(path))
+ db = u1db_open(path, create=False)
+ self.assertIsNot(None, db._replica_uid)
+
+
+class TestCmdPut(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdPut, self).setUp()
+ self.doc = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-test-doc')
+
+ def test_put_simple(self):
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ cmd.run(self.db_path, 'my-test-doc', self.doc.rev, inf)
+ doc = self.db.get_doc('my-test-doc')
+ self.assertNotEqual(self.doc.rev, doc.rev)
+ self.assertGetDoc(self.db, 'my-test-doc', doc.rev,
+ tests.nested_doc, False)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_put_no_db(self):
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST",
+ 'my-test-doc', self.doc.rev, inf)
+ self.assertEqual(retval, 1)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Database does not exist.\n', cmd.stderr.getvalue())
+
+ def test_put_no_doc(self):
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path, 'no-such-doc', 'wut:1', inf)
+ self.assertEqual(1, retval)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Document does not exist.\n', cmd.stderr.getvalue())
+
+ def test_put_doc_old_rev(self):
+ rev = self.doc.rev
+ doc = self.make_document('my-test-doc', rev, '{}', False)
+ self.db.put_doc(doc)
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path, 'my-test-doc', rev, inf)
+ self.assertEqual(1, retval)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Given revision is not current.\n',
+ cmd.stderr.getvalue())
+
+ def test_put_doc_w_conflicts(self):
+ doc = self.make_document('my-test-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path, 'my-test-doc', 'other:1', inf)
+ self.assertEqual(1, retval)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Document has conflicts.\n'
+ 'Inspect with get-doc-conflicts, then resolve.\n',
+ cmd.stderr.getvalue())
+
+
+class TestCmdResolve(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdResolve, self).setUp()
+ self.doc1 = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-doc')
+ self.doc2 = self.make_document('my-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ self.doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+
+ def test_resolve_simple(self):
+ self.assertTrue(self.db.get_doc('my-doc').has_conflicts)
+ cmd = self.make_command(client.CmdResolve)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf)
+ doc = self.db.get_doc('my-doc')
+ vec = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(
+ vec.is_newer(vectorclock.VectorClockRev(self.doc1.rev)))
+ self.assertTrue(
+ vec.is_newer(vectorclock.VectorClockRev(self.doc2.rev)))
+ self.assertGetDoc(self.db, 'my-doc', doc.rev, tests.nested_doc, False)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_resolve_double(self):
+ moar = '{"x": 42}'
+ doc3 = self.make_document('my-doc', 'third:1', moar, False)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ cmd = self.make_command(client.CmdResolve)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf)
+ doc = self.db.get_doc('my-doc')
+ self.assertGetDoc(self.db, 'my-doc', doc.rev, moar, True)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual(
+ 'rev: %s\nDocument still has conflicts.\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_resolve_no_db(self):
+ cmd = self.make_command(client.CmdResolve)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", [], None)
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_resolve_no_doc(self):
+ cmd = self.make_command(client.CmdResolve)
+ retval = cmd.run(self.db_path, "foo", [], None)
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n')
+
+
+class TestCmdSync(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdSync, self).setUp()
+ self.db2_path = self.working_dir + '/test2.db'
+ self.db2 = u1db_open(self.db2_path, create=True)
+ self.addCleanup(self.db2.close)
+ self.db2._set_replica_uid('test2')
+ self.doc = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='test-id')
+ self.doc2 = self.db2.create_doc_from_json(
+ tests.nested_doc, doc_id='my-test-id')
+
+ def test_sync(self):
+ cmd = self.make_command(client.CmdSync)
+ cmd.run(self.db_path, self.db2_path)
+ self.assertGetDoc(self.db2, 'test-id', self.doc.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, 'my-test-id', self.doc2.rev,
+ tests.nested_doc, False)
+
+
+class TestCmdSyncRemote(tests.TestCaseWithServer, TestCaseWithDB):
+
+ make_app_with_state = \
+ staticmethod(test_remote_sync_target.make_http_app)
+
+ def setUp(self):
+ super(TestCmdSyncRemote, self).setUp()
+ self.startServer()
+ self.db2 = self.request_state._create_database('test2.db')
+
+ def test_sync_remote(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db2.create_doc_from_json(tests.nested_doc)
+ db2_url = self.getURL('test2.db')
+ self.assertTrue(db2_url.startswith('http://'))
+ self.assertTrue(db2_url.endswith('/test2.db'))
+ cmd = self.make_command(client.CmdSync)
+ cmd.run(self.db_path, db2_url)
+ self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+
+
+class TestCmdCreateIndex(TestCaseWithDB):
+
+ def test_create_index(self):
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["bar", "baz"])
+ self.assertEqual(self.db.list_indexes(), [('foo', ['bar', "baz"])])
+ self.assertEqual(retval, None) # conveniently mapped to 0
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_create_index_no_db(self):
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", ["bar"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_create_dupe_index(self):
+ self.db.create_index("foo", "bar")
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["bar"])
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_create_dupe_index_different_expression(self):
+ self.db.create_index("foo", "bar")
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["baz"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(),
+ "There is already a different index named 'foo'.\n")
+
+ def test_create_index_bad_expression(self):
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["WAT()"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(),
+ 'Bad index expression.\n')
+
+
+class TestCmdListIndexes(TestCaseWithDB):
+
+ def test_list_no_indexes(self):
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path)
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_list_indexes(self):
+ self.db.create_index("foo", "bar", "baz")
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path)
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), 'foo: bar, baz\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_list_several_indexes(self):
+ self.db.create_index("foo", "bar", "baz")
+ self.db.create_index("bar", "baz", "foo")
+ self.db.create_index("baz", "foo", "bar")
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path)
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(),
+ 'bar: baz, foo\n'
+ 'baz: foo, bar\n'
+ 'foo: bar, baz\n'
+ )
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_list_indexes_no_db(self):
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+
+class TestCmdDeleteIndex(TestCaseWithDB):
+
+ def test_delete_index(self):
+ self.db.create_index("foo", "bar", "baz")
+ cmd = self.make_command(client.CmdDeleteIndex)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+ self.assertEqual([], self.db.list_indexes())
+
+ def test_delete_index_no_db(self):
+ cmd = self.make_command(client.CmdDeleteIndex)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_delete_index_no_index(self):
+ cmd = self.make_command(client.CmdDeleteIndex)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+
+class TestCmdGetIndexKeys(TestCaseWithDB):
+
+ def test_get_index_keys(self):
+ self.db.create_index("foo", "bar")
+ self.db.create_doc_from_json('{"bar": 42}')
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '42\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_index_keys_nonascii(self):
+ self.db.create_index("foo", "bar")
+ self.db.create_doc_from_json('{"bar": "\u00a4"}')
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '\xc2\xa4\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_index_keys_empty(self):
+ self.db.create_index("foo", "bar")
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_index_keys_no_db(self):
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_get_index_keys_no_index(self):
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n')
+
+
+class TestCmdGetFromIndex(TestCaseWithDB):
+
+ def test_get_from_index(self):
+ self.db.create_index("index", "key")
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db.create_doc_from_json(tests.nested_doc)
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path, "index", ["value"])
+ self.assertEqual(retval, None)
+ self.assertEqual(sorted(json.loads(cmd.stdout.getvalue())),
+ sorted([dict(id=doc1.doc_id,
+ rev=doc1.rev,
+ content=doc1.content),
+ dict(id=doc2.doc_id,
+ rev=doc2.rev,
+ content=doc2.content),
+ ]))
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_from_index_empty(self):
+ self.db.create_index("index", "key")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path, "index", ["value"])
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '[]\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_from_index_no_db(self):
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", [])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_get_from_index_no_index(self):
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path, "foo", [])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n')
+
+ def test_get_from_index_two_expr_instead_of_one(self):
+ self.db.create_index("index", "key1")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1", "value2"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query: index 'index' requires"
+ " 1 query expression, not 2.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+ def test_get_from_index_three_expr_instead_of_two(self):
+ self.db.create_index("index", "key1", "key2")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1", "value2", "value3"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query: index 'index' requires"
+ " 2 query expressions, not 3.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1' 'value2'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+ def test_get_from_index_one_expr_instead_of_two(self):
+ self.db.create_index("index", "key1", "key2")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query: index 'index' requires"
+ " 2 query expressions, not 1.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1' '*'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+ def test_get_from_index_cant_bad_glob(self):
+ self.db.create_index("index", "key1", "key2")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1*", "value2"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query:"
+ " a star can only be followed by stars.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1*' '*'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+
+class RunMainHelper(object):
+
+ def run_main(self, args, stdin=None):
+ if stdin is not None:
+ self.patch(sys, 'stdin', cStringIO.StringIO(stdin))
+ stdout = cStringIO.StringIO()
+ stderr = cStringIO.StringIO()
+ self.patch(sys, 'stdout', stdout)
+ self.patch(sys, 'stderr', stderr)
+ try:
+ ret = client.main(args)
+ except SystemExit, e:
+ self.fail("Intercepted SystemExit: %s" % (e,))
+ if ret is None:
+ ret = 0
+ return ret, stdout.getvalue(), stderr.getvalue()
+
+
+class TestCommandLine(TestCaseWithDB, RunMainHelper):
+ """These are meant to test that the infrastructure is fully connected.
+
+ Each command is likely to only have one test here. Something that ensures
+ 'main()' knows about and can run the command correctly. Most logic-level
+ testing of the Command should go into its own test class above.
+ """
+
+ def _get_u1db_client_path(self):
+ from u1db import __path__ as u1db_path
+ u1db_parent_dir = os.path.dirname(u1db_path[0])
+ return os.path.join(u1db_parent_dir, 'u1db-client')
+
+ def runU1DBClient(self, args):
+ command = [sys.executable, self._get_u1db_client_path()]
+ command.extend(args)
+ p = subprocess.Popen(command, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ self.addCleanup(safe_close, p)
+ return p
+
+ def test_create_subprocess(self):
+ p = self.runU1DBClient(['create', '--id', 'test-id', self.db_path])
+ stdout, stderr = p.communicate(tests.simple_doc)
+ self.assertEqual(0, p.returncode)
+ self.assertEqual('', stdout)
+ doc = self.db.get_doc('test-id')
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+ expected = 'id: test-id\nrev: %s\n' % (doc.rev,)
+ stripped = stderr.replace('\r\n', '\n')
+ if expected != stripped:
+ # When run under python-dbg, it prints out the refs after the
+ # actual content, so match it if we need to.
+ expected_re = expected + '\[\d+ refs\]\n'
+ self.assertRegexpMatches(stripped, expected_re)
+
+ def test_get(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ ret, stdout, stderr = self.run_main(['get', self.db_path, 'test-id'])
+ self.assertEqual(0, ret)
+ self.assertEqual(tests.simple_doc + "\n", stdout)
+ self.assertEqual('rev: %s\n' % (doc.rev,), stderr)
+ ret, stdout, stderr = self.run_main(['get', self.db_path, 'not-there'])
+ self.assertEqual(1, ret)
+
+ def test_delete(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ ret, stdout, stderr = self.run_main(
+ ['delete', self.db_path, 'test-id', doc.rev])
+ doc = self.db.get_doc('test-id', include_deleted=True)
+ self.assertEqual(0, ret)
+ self.assertEqual('', stdout)
+ self.assertEqual('rev: %s\n' % (doc.rev,), stderr)
+
+ def test_init_db(self):
+ path = self.working_dir + '/test2.db'
+ ret, stdout, stderr = self.run_main(['init-db', path])
+ u1db_open(path, create=False)
+
+ def test_put(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ ret, stdout, stderr = self.run_main(
+ ['put', self.db_path, 'test-id', doc.rev],
+ stdin=tests.nested_doc)
+ doc = self.db.get_doc('test-id')
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual(tests.nested_doc, doc.get_json())
+ self.assertEqual(0, ret)
+ self.assertEqual('', stdout)
+ self.assertEqual('rev: %s\n' % (doc.rev,), stderr)
+
+ def test_sync(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ self.db2_path = self.working_dir + '/test2.db'
+ self.db2 = u1db_open(self.db2_path, create=True)
+ self.addCleanup(self.db2.close)
+ ret, stdout, stderr = self.run_main(
+ ['sync', self.db_path, self.db2_path])
+ self.assertEqual(0, ret)
+ self.assertEqual('', stdout)
+ self.assertEqual('', stderr)
+ self.assertGetDoc(
+ self.db2, 'test-id', doc.rev, tests.simple_doc, False)
+
+
+class TestHTTPIntegration(tests.TestCaseWithServer, RunMainHelper):
+ """Meant to test the cases where commands operate over http."""
+
+ def server_def(self):
+ def make_server(host_port, _application):
+ return serve.make_server(host_port[0], host_port[1],
+ self.working_dir)
+ return make_server, "shutdown", "http"
+
+ def setUp(self):
+ super(TestHTTPIntegration, self).setUp()
+ self.working_dir = self.createTempDir(prefix='u1db-http-server-')
+ self.startServer()
+
+ def getPath(self, dbname):
+ return os.path.join(self.working_dir, dbname)
+
+ def test_init_db(self):
+ url = self.getURL('new.db')
+ ret, stdout, stderr = self.run_main(['init-db', url])
+ u1db_open(self.getPath('new.db'), create=False)
+
+ def test_create_get_put_delete(self):
+ db = u1db_open(self.getPath('test.db'), create=True)
+ url = self.getURL('test.db')
+ doc_id = '%abcd'
+ ret, stdout, stderr = self.run_main(['create', url, '--id', doc_id],
+ stdin=tests.simple_doc)
+ self.assertEqual(0, ret)
+ ret, stdout, stderr = self.run_main(['get', url, doc_id])
+ self.assertEqual(0, ret)
+ self.assertTrue(stderr.startswith('rev: '))
+ doc_rev = stderr[len('rev: '):].rstrip()
+ ret, stdout, stderr = self.run_main(['put', url, doc_id, doc_rev],
+ stdin=tests.nested_doc)
+ self.assertEqual(0, ret)
+ self.assertTrue(stderr.startswith('rev: '))
+ doc_rev1 = stderr[len('rev: '):].rstrip()
+ self.assertGetDoc(db, doc_id, doc_rev1, tests.nested_doc, False)
+ ret, stdout, stderr = self.run_main(['delete', url, doc_id, doc_rev1])
+ self.assertEqual(0, ret)
+ self.assertTrue(stderr.startswith('rev: '))
+ doc_rev2 = stderr[len('rev: '):].rstrip()
+ self.assertGetDocIncludeDeleted(db, doc_id, doc_rev2, None, False)
diff --git a/u1db/tests/commandline/test_command.py b/u1db/tests/commandline/test_command.py
new file mode 100644
index 00000000..43580f23
--- /dev/null
+++ b/u1db/tests/commandline/test_command.py
@@ -0,0 +1,105 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import cStringIO
+import argparse
+
+from u1db import (
+ tests,
+ )
+from u1db.commandline import (
+ command,
+ )
+
+
+class MyTestCommand(command.Command):
+ """Help String"""
+
+ name = 'mycmd'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('foo')
+ parser.add_argument('--bar', dest='nbar', type=int)
+
+ def run(self, foo, nbar):
+ self.stdout.write('foo: %s nbar: %d' % (foo, nbar))
+ return 0
+
+
+def make_stdin_out_err():
+ return cStringIO.StringIO(), cStringIO.StringIO(), cStringIO.StringIO()
+
+
+class TestCommandGroup(tests.TestCase):
+
+ def trap_system_exit(self, func, *args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except SystemExit, e:
+ self.fail('Got SystemExit trying to run: %s' % (func,))
+
+ def parse_args(self, parser, args):
+ return self.trap_system_exit(parser.parse_args, args)
+
+ def test_register(self):
+ group = command.CommandGroup()
+ self.assertEqual({}, group.commands)
+ group.register(MyTestCommand)
+ self.assertEqual({'mycmd': MyTestCommand},
+ group.commands)
+
+ def test_make_argparser(self):
+ group = command.CommandGroup(description='test-foo')
+ parser = group.make_argparser()
+ self.assertIsInstance(parser, argparse.ArgumentParser)
+
+ def test_make_argparser_with_command(self):
+ group = command.CommandGroup(description='test-foo')
+ group.register(MyTestCommand)
+ parser = group.make_argparser()
+ args = self.parse_args(parser, ['mycmd', 'foozizle', '--bar=10'])
+ self.assertEqual('foozizle', args.foo)
+ self.assertEqual(10, args.nbar)
+ self.assertEqual(MyTestCommand, args.subcommand)
+
+ def test_run_argv(self):
+ group = command.CommandGroup()
+ group.register(MyTestCommand)
+ stdin, stdout, stderr = make_stdin_out_err()
+ ret = self.trap_system_exit(group.run_argv,
+ ['mycmd', 'foozizle', '--bar=10'],
+ stdin, stdout, stderr)
+ self.assertEqual(0, ret)
+
+
+class TestCommand(tests.TestCase):
+
+ def make_command(self):
+ stdin, stdout, stderr = make_stdin_out_err()
+ return command.Command(stdin, stdout, stderr)
+
+ def test__init__(self):
+ cmd = self.make_command()
+ self.assertIsNot(None, cmd.stdin)
+ self.assertIsNot(None, cmd.stdout)
+ self.assertIsNot(None, cmd.stderr)
+
+ def test_run_args(self):
+ stdin, stdout, stderr = make_stdin_out_err()
+ cmd = MyTestCommand(stdin, stdout, stderr)
+ res = cmd.run(foo='foozizle', nbar=10)
+ self.assertEqual('foo: foozizle nbar: 10', stdout.getvalue())
diff --git a/u1db/tests/commandline/test_serve.py b/u1db/tests/commandline/test_serve.py
new file mode 100644
index 00000000..6397eabe
--- /dev/null
+++ b/u1db/tests/commandline/test_serve.py
@@ -0,0 +1,101 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import os
+import socket
+import subprocess
+import sys
+
+from u1db import (
+ __version__ as _u1db_version,
+ open as u1db_open,
+ tests,
+ )
+from u1db.remote import http_client
+from u1db.tests.commandline import safe_close
+
+
+class TestU1DBServe(tests.TestCase):
+
+ def _get_u1db_serve_path(self):
+ from u1db import __path__ as u1db_path
+ u1db_parent_dir = os.path.dirname(u1db_path[0])
+ return os.path.join(u1db_parent_dir, 'u1db-serve')
+
+ def startU1DBServe(self, args):
+ command = [sys.executable, self._get_u1db_serve_path()]
+ command.extend(args)
+ p = subprocess.Popen(command, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ self.addCleanup(safe_close, p)
+ return p
+
+ def test_help(self):
+ p = self.startU1DBServe(['--help'])
+ stdout, stderr = p.communicate()
+ if stderr != '':
+ # stderr should normally be empty, but if we are running under
+ # python-dbg, it contains the following string
+ self.assertRegexpMatches(stderr, r'\[\d+ refs\]')
+ self.assertEqual(0, p.returncode)
+ self.assertIn('Run the U1DB server', stdout)
+
+ def test_bind_to_port(self):
+ p = self.startU1DBServe([])
+ starts = 'listening on:'
+ x = p.stdout.readline()
+ self.assertTrue(x.startswith(starts))
+ port = int(x[len(starts):].split(":")[1])
+ url = "http://127.0.0.1:%s/" % port
+ c = http_client.HTTPClientBase(url)
+ self.addCleanup(c.close)
+ res, _ = c._request_json('GET', [])
+ self.assertEqual({'version': _u1db_version}, res)
+
+ def test_supply_port(self):
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ host, port = s.getsockname()
+ s.close()
+ p = self.startU1DBServe(['--port', str(port)])
+ x = p.stdout.readline().strip()
+ self.assertEqual('listening on: 127.0.0.1:%s' % (port,), x)
+ url = "http://127.0.0.1:%s/" % port
+ c = http_client.HTTPClientBase(url)
+ self.addCleanup(c.close)
+ res, _ = c._request_json('GET', [])
+ self.assertEqual({'version': _u1db_version}, res)
+
+ def test_bind_to_host(self):
+ p = self.startU1DBServe(["--host", "localhost"])
+ starts = 'listening on: 127.0.0.1:'
+ x = p.stdout.readline()
+ self.assertTrue(x.startswith(starts))
+
+ def test_supply_working_dir(self):
+ tmp_dir = self.createTempDir('u1db-serve-test')
+ db = u1db_open(os.path.join(tmp_dir, 'landmark.db'), create=True)
+ db.close()
+ p = self.startU1DBServe(['--working-dir', tmp_dir])
+ starts = 'listening on:'
+ x = p.stdout.readline()
+ self.assertTrue(x.startswith(starts))
+ port = int(x[len(starts):].split(":")[1])
+ url = "http://127.0.0.1:%s/landmark.db" % port
+ c = http_client.HTTPClientBase(url)
+ self.addCleanup(c.close)
+ res, _ = c._request_json('GET', [])
+ self.assertEqual({}, res)
diff --git a/u1db/tests/test_auth_middleware.py b/u1db/tests/test_auth_middleware.py
new file mode 100644
index 00000000..e765f8a7
--- /dev/null
+++ b/u1db/tests/test_auth_middleware.py
@@ -0,0 +1,309 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test OAuth wsgi middleware"""
+import paste.fixture
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import time
+
+from u1db import tests
+
+from u1db.remote.oauth_middleware import OAuthMiddleware
+from u1db.remote.basic_auth_middleware import BasicAuthMiddleware, Unauthorized
+
+
+BASE_URL = 'https://example.net'
+
+
+class TestBasicAuthMiddleware(tests.TestCase):
+
+ def setUp(self):
+ super(TestBasicAuthMiddleware, self).setUp()
+ self.got = []
+
+ def witness_app(environ, start_response):
+ start_response("200 OK", [("content-type", "text/plain")])
+ self.got.append((
+ environ['user_id'], environ['PATH_INFO'],
+ environ['QUERY_STRING']))
+ return ["ok"]
+
+ class MyAuthMiddleware(BasicAuthMiddleware):
+
+ def verify_user(self, environ, user, password):
+ if user != "correct_user":
+ raise Unauthorized
+ if password != "correct_password":
+ raise Unauthorized
+ environ['user_id'] = user
+
+ self.auth_midw = MyAuthMiddleware(witness_app, prefix="/pfx/")
+ self.app = paste.fixture.TestApp(self.auth_midw)
+
+ def test_expect_prefix(self):
+ url = BASE_URL + '/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"error": "bad request"}', resp.body)
+
+ def test_missing_auth(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "unauthorized",
+ "message": "Missing Basic Authentication."},
+ json.loads(resp.body))
+
+ def test_correct_auth(self):
+ user = "correct_user"
+ password = "correct_password"
+ params = {'old_rev': 'old-rev'}
+ url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ auth = '%s:%s' % (user, password)
+ headers = {
+ 'Authorization': 'Basic %s' % (auth.encode('base64'),)}
+ resp = self.app.delete(url, headers=headers)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ [('correct_user', '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_incorrect_auth(self):
+ user = "correct_user"
+ password = "incorrect_password"
+ params = {'old_rev': 'old-rev'}
+ url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ auth = '%s:%s' % (user, password)
+ headers = {
+ 'Authorization': 'Basic %s' % (auth.encode('base64'),)}
+ resp = self.app.delete(url, headers=headers, expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "unauthorized",
+ "message": "Incorrect password or login."},
+ json.loads(resp.body))
+
+
+class TestOAuthMiddlewareDefaultPrefix(tests.TestCase):
+ def setUp(self):
+
+ super(TestOAuthMiddlewareDefaultPrefix, self).setUp()
+ self.got = []
+
+ def witness_app(environ, start_response):
+ start_response("200 OK", [("content-type", "text/plain")])
+ self.got.append((environ['token_key'], environ['PATH_INFO'],
+ environ['QUERY_STRING']))
+ return ["ok"]
+
+ class MyOAuthMiddleware(OAuthMiddleware):
+ get_oauth_data_store = lambda self: tests.testingOAuthStore
+
+ def verify(self, environ, oauth_req):
+ consumer, token = super(MyOAuthMiddleware, self).verify(
+ environ, oauth_req)
+ environ['token_key'] = token.key
+
+ self.oauth_midw = MyOAuthMiddleware(witness_app, BASE_URL)
+ self.app = paste.fixture.TestApp(self.oauth_midw)
+
+ def test_expect_tilde(self):
+ url = BASE_URL + '/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"error": "bad request"}', resp.body)
+
+ def test_oauth_in_header(self):
+ url = BASE_URL + '/~/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer2,
+ tests.token2,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ url = oauth_req.get_normalized_http_url() + '?' + (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer2, tests.token2)
+ resp = self.app.delete(url, headers=oauth_req.to_header())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token2.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_in_query_string(self):
+ url = BASE_URL + '/~/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer1, tests.token1)
+ resp = self.app.delete(oauth_req.to_url())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token1.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+
+class TestOAuthMiddleware(tests.TestCase):
+
+ def setUp(self):
+ super(TestOAuthMiddleware, self).setUp()
+ self.got = []
+
+ def witness_app(environ, start_response):
+ start_response("200 OK", [("content-type", "text/plain")])
+ self.got.append((environ['token_key'], environ['PATH_INFO'],
+ environ['QUERY_STRING']))
+ return ["ok"]
+
+ class MyOAuthMiddleware(OAuthMiddleware):
+ get_oauth_data_store = lambda self: tests.testingOAuthStore
+
+ def verify(self, environ, oauth_req):
+ consumer, token = super(MyOAuthMiddleware, self).verify(
+ environ, oauth_req)
+ environ['token_key'] = token.key
+
+ self.oauth_midw = MyOAuthMiddleware(
+ witness_app, BASE_URL, prefix='/pfx/')
+ self.app = paste.fixture.TestApp(self.oauth_midw)
+
+ def test_expect_prefix(self):
+ url = BASE_URL + '/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"error": "bad request"}', resp.body)
+
+ def test_missing_oauth(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "unauthorized", "message": "Missing OAuth."},
+ json.loads(resp.body))
+
+ def test_oauth_in_query_string(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer1, tests.token1)
+ resp = self.app.delete(oauth_req.to_url())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token1.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_invalid(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token3,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer1, tests.token3)
+ resp = self.app.delete(oauth_req.to_url(),
+ expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ err = json.loads(resp.body)
+ self.assertEqual({"error": "unauthorized",
+ "message": err['message']},
+ err)
+
+ def test_oauth_in_header(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer2,
+ tests.token2,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ url = oauth_req.get_normalized_http_url() + '?' + (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer2, tests.token2)
+ resp = self.app.delete(url, headers=oauth_req.to_header())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token2.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_plain_text(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_PLAINTEXT,
+ tests.consumer1, tests.token1)
+ resp = self.app.delete(oauth_req.to_url())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token1.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_timestamp_threshold(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.set_parameter('oauth_timestamp', int(time.time()) - 5)
+ oauth_req.sign_request(tests.sign_meth_PLAINTEXT,
+ tests.consumer1, tests.token1)
+ # tweak threshold
+ self.oauth_midw.timestamp_threshold = 1
+ resp = self.app.delete(oauth_req.to_url(), expect_errors=True)
+ self.assertEqual(401, resp.status)
+ err = json.loads(resp.body)
+ self.assertIn('Expired timestamp', err['message'])
+ self.assertIn('threshold 1', err['message'])
diff --git a/u1db/tests/test_backends.py b/u1db/tests/test_backends.py
new file mode 100644
index 00000000..7a3c9e5c
--- /dev/null
+++ b/u1db/tests/test_backends.py
@@ -0,0 +1,1895 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The backend class for U1DB. This deals with hiding storage details."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from u1db import (
+ DocumentBase,
+ errors,
+ tests,
+ vectorclock,
+ )
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+)
+
+from u1db.remote import (
+ http_database,
+ )
+
+try:
+ from u1db.tests import c_backend_wrapper
+except ImportError:
+ c_backend_wrapper = None # noqa
+
+
+def make_http_database_for_test(test, replica_uid, path='test'):
+ test.startServer()
+ test.request_state._create_database(replica_uid)
+ return http_database.HTTPDatabase(test.getURL(path))
+
+
+def copy_http_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ return test.request_state._copy_database(db)
+
+
+def make_oauth_http_database_for_test(test, replica_uid):
+ http_db = make_http_database_for_test(test, replica_uid, '~/test')
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+def copy_oauth_http_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ http_db = test.request_state._copy_database(db)
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+class TestAlternativeDocument(DocumentBase):
+ """A (not very) alternative implementation of Document."""
+
+
+class AllDatabaseTests(tests.DatabaseBaseTests, tests.TestCaseWithServer):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + [
+ ('http', {'make_database_for_test': make_http_database_for_test,
+ 'copy_database_for_test': copy_http_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_http_app}),
+ ('oauth_http', {'make_database_for_test':
+ make_oauth_http_database_for_test,
+ 'copy_database_for_test':
+ copy_oauth_http_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_oauth_http_app})
+ ] + tests.C_DATABASE_SCENARIOS
+
+ def test_close(self):
+ self.db.close()
+
+ def test_create_doc_allocating_doc_id(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertNotEqual(None, doc.doc_id)
+ self.assertNotEqual(None, doc.rev)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_create_doc_different_ids_same_db(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertNotEqual(doc1.doc_id, doc2.doc_id)
+
+ def test_create_doc_with_id(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id')
+ self.assertEqual('my-id', doc.doc_id)
+ self.assertNotEqual(None, doc.rev)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_create_doc_existing_id(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ new_content = '{"something": "else"}'
+ self.assertRaises(
+ errors.RevisionConflict, self.db.create_doc_from_json,
+ new_content, doc.doc_id)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_put_doc_creating_initial(self):
+ doc = self.make_document('my_doc_id', None, simple_doc)
+ new_rev = self.db.put_doc(doc)
+ self.assertIsNot(None, new_rev)
+ self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False)
+
+ def test_put_doc_space_in_id(self):
+ doc = self.make_document('my doc id', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_update(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ orig_rev = doc.rev
+ doc.set_json('{"updated": "stuff"}')
+ new_rev = self.db.put_doc(doc)
+ self.assertNotEqual(new_rev, orig_rev)
+ self.assertGetDoc(self.db, 'my_doc_id', new_rev,
+ '{"updated": "stuff"}', False)
+ self.assertEqual(doc.rev, new_rev)
+
+ def test_put_non_ascii_key(self):
+ content = json.dumps({u'key\xe5': u'val'})
+ doc = self.db.create_doc_from_json(content, doc_id='my_doc')
+ self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False)
+
+ def test_put_non_ascii_value(self):
+ content = json.dumps({'key': u'\xe5'})
+ doc = self.db.create_doc_from_json(content, doc_id='my_doc')
+ self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False)
+
+ def test_put_doc_refuses_no_id(self):
+ doc = self.make_document(None, None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+ doc = self.make_document("", None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_refuses_slashes(self):
+ doc = self.make_document('a/b', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+ doc = self.make_document(r'\b', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_url_quoting_is_fine(self):
+ doc_id = "%2F%2Ffoo%2Fbar"
+ doc = self.make_document(doc_id, None, simple_doc)
+ new_rev = self.db.put_doc(doc)
+ self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False)
+
+ def test_put_doc_refuses_non_existing_old_rev(self):
+ doc = self.make_document('doc-id', 'test:4', simple_doc)
+ self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc)
+
+ def test_put_doc_refuses_non_ascii_doc_id(self):
+ doc = self.make_document('d\xc3\xa5c-id', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_fails_with_bad_old_rev(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ old_rev = doc.rev
+ bad_doc = self.make_document(doc.doc_id, 'other:1',
+ '{"something": "else"}')
+ self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc)
+ self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False)
+
+ def test_create_succeeds_after_delete(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True)
+ deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev)
+ new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False)
+ new_vc = vectorclock.VectorClockRev(new_doc.rev)
+ self.assertTrue(
+ new_vc.is_newer(deleted_vc),
+ "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev))
+
+ def test_put_succeeds_after_delete(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True)
+ deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev)
+ doc2 = self.make_document('my_doc_id', None, simple_doc)
+ self.db.put_doc(doc2)
+ self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False)
+ new_vc = vectorclock.VectorClockRev(doc2.rev)
+ self.assertTrue(
+ new_vc.is_newer(deleted_vc),
+ "%s does not supersede %s" % (doc2.rev, deleted_doc.rev))
+
+ def test_get_doc_after_put(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False)
+
+ def test_get_doc_nonexisting(self):
+ self.assertIs(None, self.db.get_doc('non-existing'))
+
+ def test_get_doc_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ self.assertIs(None, self.db.get_doc('my_doc_id'))
+
+ def test_get_doc_include_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+
+ def test_get_docs(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual([doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+
+ def test_get_docs_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc1)
+ self.assertEqual([doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+
+ def test_get_docs_include_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc1)
+ self.assertEqual(
+ [doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id],
+ include_deleted=True)))
+
+ def test_get_docs_request_ordered(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual([doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+ self.assertEqual([doc2, doc1],
+ list(self.db.get_docs([doc2.doc_id, doc1.doc_id])))
+
+ def test_get_docs_empty_list(self):
+ self.assertEqual([], list(self.db.get_docs([])))
+
+ def test_handles_nested_content(self):
+ doc = self.db.create_doc_from_json(nested_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False)
+
+ def test_handles_doc_with_null(self):
+ doc = self.db.create_doc_from_json('{"key": null}')
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False)
+
+ def test_delete_doc(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ orig_rev = doc.rev
+ self.db.delete_doc(doc)
+ self.assertNotEqual(orig_rev, doc.rev)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+ self.assertIs(None, self.db.get_doc(doc.doc_id))
+
+ def test_delete_doc_non_existent(self):
+ doc = self.make_document('non-existing', 'other:1', simple_doc)
+ self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc)
+
+ def test_delete_doc_already_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertRaises(errors.DocumentAlreadyDeleted,
+ self.db.delete_doc, doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+
+ def test_delete_doc_bad_rev(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+ doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc)
+ self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+
+ def test_delete_doc_sets_content_to_None(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertIs(None, doc.get_json())
+
+ def test_delete_doc_rev_supersedes(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc.set_json(nested_doc)
+ self.db.put_doc(doc)
+ doc.set_json('{"fishy": "content"}')
+ self.db.put_doc(doc)
+ old_rev = doc.rev
+ self.db.delete_doc(doc)
+ cur_vc = vectorclock.VectorClockRev(old_rev)
+ deleted_vc = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(deleted_vc.is_newer(cur_vc),
+ "%s does not supersede %s" % (doc.rev, old_rev))
+
+ def test_delete_then_put(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+ doc.set_json(nested_doc)
+ self.db.put_doc(doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False)
+
+
+class DocumentSizeTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_put_doc_refuses_oversized_documents(self):
+ self.db.set_document_size_limit(1)
+ doc = self.make_document('doc-id', None, simple_doc)
+ self.assertRaises(errors.DocumentTooBig, self.db.put_doc, doc)
+
+ def test_create_doc_refuses_oversized_documents(self):
+ self.db.set_document_size_limit(1)
+ self.assertRaises(
+ errors.DocumentTooBig, self.db.create_doc_from_json, simple_doc,
+ doc_id='my_doc_id')
+
+ def test_set_document_size_limit_zero(self):
+ self.db.set_document_size_limit(0)
+ self.assertEqual(0, self.db.document_size_limit)
+
+ def test_set_document_size_limit(self):
+ self.db.set_document_size_limit(1000000)
+ self.assertEqual(1000000, self.db.document_size_limit)
+
+
+class LocalDatabaseTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_create_doc_different_ids_diff_db(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ db2 = self.create_database('other-uid')
+ doc2 = db2.create_doc_from_json(simple_doc)
+ self.assertNotEqual(doc1.doc_id, doc2.doc_id)
+
+ def test_put_doc_refuses_slashes_picky(self):
+ doc = self.make_document('/a', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_get_all_docs_empty(self):
+ self.assertEqual([], list(self.db.get_all_docs()[1]))
+
+ def test_get_all_docs(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(
+ sorted([doc1, doc2]), sorted(list(self.db.get_all_docs()[1])))
+
+ def test_get_all_docs_exclude_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc2)
+ self.assertEqual([doc1], list(self.db.get_all_docs()[1]))
+
+ def test_get_all_docs_include_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc2)
+ self.assertEqual(
+ sorted([doc1, doc2]),
+ sorted(list(self.db.get_all_docs(include_deleted=True)[1])))
+
+ def test_get_all_docs_generation(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(2, self.db.get_all_docs()[0])
+
+ def test_simple_put_doc_if_newer(self):
+ doc = self.make_document('my-doc-id', 'test:1', simple_doc)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('inserted', 1), state_at_gen)
+ self.assertGetDoc(self.db, 'my-doc-id', 'test:1', simple_doc, False)
+
+ def test_simple_put_doc_if_newer_deleted(self):
+ self.db.create_doc_from_json('{}', doc_id='my-doc-id')
+ doc = self.make_document('my-doc-id', 'test:2', None)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('inserted', 2), state_at_gen)
+ self.assertGetDocIncludeDeleted(
+ self.db, 'my-doc-id', 'test:2', None, False)
+
+ def test_put_doc_if_newer_already_superseded(self):
+ orig_doc = '{"new": "doc"}'
+ doc1 = self.db.create_doc_from_json(orig_doc)
+ doc1_rev1 = doc1.rev
+ doc1.set_json(simple_doc)
+ self.db.put_doc(doc1)
+ doc1_rev2 = doc1.rev
+ # Nothing is inserted, because the document is already superseded
+ doc = self.make_document(doc1.doc_id, doc1_rev1, orig_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('superseded', state)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1_rev2, simple_doc, False)
+
+ def test_put_doc_if_newer_autoresolve(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ rev = doc1.rev
+ doc = self.make_document(doc1.doc_id, "whatever:1", doc1.get_json())
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('superseded', state)
+ doc2 = self.db.get_doc(doc1.doc_id)
+ v2 = vectorclock.VectorClockRev(doc2.rev)
+ self.assertTrue(v2.is_newer(vectorclock.VectorClockRev("whatever:1")))
+ self.assertTrue(v2.is_newer(vectorclock.VectorClockRev(rev)))
+ # strictly newer locally
+ self.assertTrue(rev not in doc2.rev)
+
+ def test_put_doc_if_newer_already_converged(self):
+ orig_doc = '{"new": "doc"}'
+ doc1 = self.db.create_doc_from_json(orig_doc)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc1, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('converged', 1), state_at_gen)
+
+ def test_put_doc_if_newer_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ # Nothing is inserted, the document id is returned as would-conflict
+ alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ state, _ = self.db._put_doc_if_newer(
+ alt_doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('conflicted', state)
+ # The database wasn't altered
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+
+ def test_put_doc_if_newer_newer_generation(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.make_document('doc_id', 'other:2', simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='other', replica_gen=2,
+ replica_trans_id='T-irrelevant')
+ self.assertEqual('inserted', state)
+
+ def test_put_doc_if_newer_same_generation_same_txid(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.make_document(doc.doc_id, 'other:1', simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='other', replica_gen=1,
+ replica_trans_id='T-sid')
+ self.assertEqual('converged', state)
+
+ def test_put_doc_if_newer_wrong_transaction_id(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.make_document('doc_id', 'other:1', simple_doc)
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db._put_doc_if_newer, doc, save_conflict=False,
+ replica_uid='other', replica_gen=1, replica_trans_id='T-sad')
+
+ def test_put_doc_if_newer_old_generation_older_doc(self):
+ orig_doc = '{"new": "doc"}'
+ doc = self.db.create_doc_from_json(orig_doc)
+ doc_rev1 = doc.rev
+ doc.set_json(simple_doc)
+ self.db.put_doc(doc)
+ self.db._set_replica_gen_and_trans_id('other', 3, 'T-sid')
+ older_doc = self.make_document(doc.doc_id, doc_rev1, simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ older_doc, save_conflict=False, replica_uid='other', replica_gen=8,
+ replica_trans_id='T-irrelevant')
+ self.assertEqual('superseded', state)
+
+ def test_put_doc_if_newer_old_generation_newer_doc(self):
+ self.db._set_replica_gen_and_trans_id('other', 5, 'T-sid')
+ doc = self.make_document('doc_id', 'other:1', simple_doc)
+ self.assertRaises(
+ errors.InvalidGeneration,
+ self.db._put_doc_if_newer, doc, save_conflict=False,
+ replica_uid='other', replica_gen=1, replica_trans_id='T-sad')
+
+ def test_put_doc_if_newer_replica_uid(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1',
+ nested_doc)
+ self.assertEqual('inserted',
+ self.db._put_doc_if_newer(doc2, save_conflict=False,
+ replica_uid='other', replica_gen=2,
+ replica_trans_id='T-id2')[0])
+ self.assertEqual((2, 'T-id2'), self.db._get_replica_gen_and_trans_id(
+ 'other'))
+ # Compare to the old rev, should be superseded
+ doc2 = self.make_document(doc1.doc_id, doc1.rev, nested_doc)
+ self.assertEqual('superseded',
+ self.db._put_doc_if_newer(doc2, save_conflict=False,
+ replica_uid='other', replica_gen=3,
+ replica_trans_id='T-id3')[0])
+ self.assertEqual(
+ (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other'))
+ # A conflict that isn't saved still records the sync gen, because we
+ # don't need to see it again
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|fourth:1',
+ '{}')
+ self.assertEqual('conflicted',
+ self.db._put_doc_if_newer(doc2, save_conflict=False,
+ replica_uid='other', replica_gen=4,
+ replica_trans_id='T-id4')[0])
+ self.assertEqual(
+ (4, 'T-id4'), self.db._get_replica_gen_and_trans_id('other'))
+
+ def test__get_replica_gen_and_trans_id(self):
+ self.assertEqual(
+ (0, ''), self.db._get_replica_gen_and_trans_id('other-db'))
+ self.db._set_replica_gen_and_trans_id('other-db', 2, 'T-transaction')
+ self.assertEqual(
+ (2, 'T-transaction'),
+ self.db._get_replica_gen_and_trans_id('other-db'))
+
+ def test_put_updates_transaction_log(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ doc.set_json('{"something": "else"}')
+ self.db.put_doc(doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed())
+
+ def test_delete_updates_transaction_log(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ db_gen, _, _ = self.db.whats_changed()
+ self.db.delete_doc(doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed(db_gen))
+
+ def test_whats_changed_initial_database(self):
+ self.assertEqual((0, '', []), self.db.whats_changed())
+
+ def test_whats_changed_returns_one_id_for_multiple_changes(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc.set_json('{"new": "contents"}')
+ self.db.put_doc(doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed())
+ self.assertEqual((2, last_trans_id, []), self.db.whats_changed(2))
+
+ def test_whats_changed_returns_last_edits_ascending(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc.set_json('{"new": "contents"}')
+ self.db.delete_doc(doc1)
+ delete_trans_id = self.getLastTransId(self.db)
+ self.db.put_doc(doc)
+ put_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((4, put_trans_id,
+ [(doc1.doc_id, 3, delete_trans_id),
+ (doc.doc_id, 4, put_trans_id)]),
+ self.db.whats_changed())
+
+ def test_whats_changed_doesnt_include_old_gen(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((3, last_trans_id, [(doc2.doc_id, 3, last_trans_id)]),
+ self.db.whats_changed(2))
+
+
+class LocalDatabaseValidateGenNTransIdTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_validate_gen_and_trans_id(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, trans_id = self.db._get_generation_info()
+ self.db.validate_gen_and_trans_id(gen, trans_id)
+
+ def test_validate_gen_and_trans_id_invalid_txid(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, _ = self.db._get_generation_info()
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db.validate_gen_and_trans_id, gen, 'wrong')
+
+ def test_validate_gen_and_trans_id_invalid_gen(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, trans_id = self.db._get_generation_info()
+ self.assertRaises(
+ errors.InvalidGeneration,
+ self.db.validate_gen_and_trans_id, gen + 1, trans_id)
+
+
+class LocalDatabaseValidateSourceGenTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_validate_source_gen_and_trans_id_same(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.db._validate_source('other', 1, 'T-sid')
+
+ def test_validate_source_gen_newer(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.db._validate_source('other', 2, 'T-whatevs')
+
+ def test_validate_source_wrong_txid(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db._validate_source, 'other', 1, 'T-sad')
+
+
+class LocalDatabaseWithConflictsTests(tests.DatabaseBaseTests):
+ # test supporting/functionality around storing conflicts
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_get_docs_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual([doc2], list(self.db.get_docs([doc1.doc_id])))
+
+ def test_get_docs_conflicts_ignored(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ no_conflict_doc = self.make_document(doc1.doc_id, 'alternate:1',
+ nested_doc)
+ self.assertEqual([no_conflict_doc, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id],
+ check_for_conflicts=False)))
+
+ def test_get_doc_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual([alt_doc, doc],
+ self.db.get_doc_conflicts(doc.doc_id))
+
+ def test_get_all_docs_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ _, docs = self.db.get_all_docs()
+ self.assertTrue(list(docs)[0].has_conflicts)
+
+ def test_get_doc_conflicts_unconflicted(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([], self.db.get_doc_conflicts(doc.doc_id))
+
+ def test_get_doc_conflicts_no_such_id(self):
+ self.assertEqual([], self.db.get_doc_conflicts('doc-id'))
+
+ def test_resolve_doc(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc.doc_id,
+ [('alternate:1', nested_doc), (doc.rev, simple_doc)])
+ orig_rev = doc.rev
+ self.db.resolve_doc(doc, [alt_doc.rev, doc.rev])
+ self.assertNotEqual(orig_rev, doc.rev)
+ self.assertFalse(doc.has_conflicts)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertGetDocConflicts(self.db, doc.doc_id, [])
+
+ def test_resolve_doc_picks_biggest_vcr(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, simple_doc)])
+ orig_doc1_rev = doc1.rev
+ self.db.resolve_doc(doc1, [doc2.rev, doc1.rev])
+ self.assertFalse(doc1.has_conflicts)
+ self.assertNotEqual(orig_doc1_rev, doc1.rev)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ vcr_1 = vectorclock.VectorClockRev(orig_doc1_rev)
+ vcr_2 = vectorclock.VectorClockRev(doc2.rev)
+ vcr_new = vectorclock.VectorClockRev(doc1.rev)
+ self.assertTrue(vcr_new.is_newer(vcr_1))
+ self.assertTrue(vcr_new.is_newer(vcr_2))
+
+ def test_resolve_doc_partial_not_winning(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, simple_doc)])
+ content3 = '{"key": "valin3"}'
+ doc3 = self.make_document(doc1.doc_id, 'third:1', content3)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+ self.db.resolve_doc(doc1, [doc2.rev, doc1.rev])
+ self.assertTrue(doc1.has_conflicts)
+ self.assertGetDoc(self.db, doc1.doc_id, doc3.rev, content3, True)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc)])
+
+ def test_resolve_doc_partial_winning(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ content3 = '{"key": "valin3"}'
+ doc3 = self.make_document(doc1.doc_id, 'third:1', content3)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+ self.db.resolve_doc(doc1, [doc3.rev, doc1.rev])
+ self.assertTrue(doc1.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+
+ def test_resolve_doc_with_delete_conflict(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, None)])
+ self.db.resolve_doc(doc2, [doc1.rev, doc2.rev])
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, False)
+
+ def test_resolve_doc_with_delete_to_delete(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, None)])
+ self.db.resolve_doc(doc1, [doc1.rev, doc2.rev])
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ self.assertGetDocIncludeDeleted(
+ self.db, doc1.doc_id, doc1.rev, None, False)
+
+ def test_put_doc_if_newer_save_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ # Document is inserted as a conflict
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('conflicted', state)
+ # The database was updated
+ self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, nested_doc, True)
+
+ def test_force_doc_conflict_supersedes_properly(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', '{"b": 1}')
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ doc3 = self.make_document(doc1.doc_id, 'altalt:1', '{"c": 1}')
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ doc22 = self.make_document(doc1.doc_id, 'alternate:2', '{"b": 2}')
+ self.db._put_doc_if_newer(
+ doc22, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='zed')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:2', doc22.get_json()),
+ ('altalt:1', doc3.get_json()),
+ (doc1.rev, simple_doc)])
+
+ def test_put_doc_if_newer_save_conflict_was_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertTrue(doc2.has_conflicts)
+ self.assertGetDoc(
+ self.db, doc1.doc_id, 'alternate:1', nested_doc, True)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:1', nested_doc), (doc1.rev, None)])
+
+ def test_put_doc_if_newer_propagates_full_resolution(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ resolved_vcr = vectorclock.VectorClockRev(doc1.rev)
+ vcr_2 = vectorclock.VectorClockRev(doc2.rev)
+ resolved_vcr.maximize(vcr_2)
+ resolved_vcr.increment('alternate')
+ doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(),
+ '{"good": 1}')
+ state, _ = self.db._put_doc_if_newer(
+ doc_resolved, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual('inserted', state)
+ self.assertFalse(doc_resolved.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ doc3 = self.db.get_doc(doc1.doc_id)
+ self.assertFalse(doc3.has_conflicts)
+
+ def test_put_doc_if_newer_propagates_partial_resolution(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'altalt:1', '{}')
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ doc3 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:1', nested_doc), ('test:1', simple_doc),
+ ('altalt:1', '{}')])
+ resolved_vcr = vectorclock.VectorClockRev(doc1.rev)
+ vcr_3 = vectorclock.VectorClockRev(doc3.rev)
+ resolved_vcr.maximize(vcr_3)
+ resolved_vcr.increment('alternate')
+ doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(),
+ '{"good": 1}')
+ state, _ = self.db._put_doc_if_newer(
+ doc_resolved, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual('inserted', state)
+ self.assertTrue(doc_resolved.has_conflicts)
+ doc4 = self.db.get_doc(doc1.doc_id)
+ self.assertTrue(doc4.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:2|test:1', '{"good": 1}'), ('altalt:1', '{}')])
+
+ def test_put_doc_if_newer_replica_uid(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-id')
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1',
+ nested_doc)
+ self.db._put_doc_if_newer(doc2, save_conflict=True,
+ replica_uid='other', replica_gen=2,
+ replica_trans_id='T-id2')
+ # Conflict vs the current update
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|third:3',
+ '{}')
+ self.assertEqual('conflicted',
+ self.db._put_doc_if_newer(doc2, save_conflict=True,
+ replica_uid='other', replica_gen=3,
+ replica_trans_id='T-id3')[0])
+ self.assertEqual(
+ (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other'))
+
+ def test_put_doc_if_newer_autoresolve_2(self):
+ # this is an ordering variant of _3, but that already works
+ # adding the test explicitly to catch the regression easily
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', "{}")
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1',
+ '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:2|other:1', "{}")
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'inserted')
+ self.assertFalse(self.db.get_doc(doc_a1.doc_id).has_conflicts)
+
+ def test_put_doc_if_newer_autoresolve_3(self):
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', "{}")
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', "{}")
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'superseded')
+ doc = self.db.get_doc(doc_a1.doc_id, True)
+ self.assertFalse(doc.has_conflicts)
+ rev = vectorclock.VectorClockRev(doc.rev)
+ rev_a3 = vectorclock.VectorClockRev('test:3')
+ rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1')
+ self.assertTrue(rev.is_newer(rev_a3))
+ self.assertTrue('test:4' in doc.rev) # locally increased
+ self.assertTrue(rev.is_newer(rev_a1b1))
+
+ def test_put_doc_if_newer_autoresolve_4(self):
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', None)
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', None)
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'superseded')
+ doc = self.db.get_doc(doc_a1.doc_id, True)
+ self.assertFalse(doc.has_conflicts)
+ rev = vectorclock.VectorClockRev(doc.rev)
+ rev_a3 = vectorclock.VectorClockRev('test:3')
+ rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1')
+ self.assertTrue(rev.is_newer(rev_a3))
+ self.assertTrue('test:4' in doc.rev) # locally increased
+ self.assertTrue(rev.is_newer(rev_a1b1))
+
+ def test_put_refuses_to_update_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ content2 = '{"key": "altval"}'
+ doc2 = self.make_document(doc1.doc_id, 'altrev:1', content2)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, content2, True)
+ content3 = '{"key": "local"}'
+ doc2.set_json(content3)
+ self.assertRaises(errors.ConflictedDoc, self.db.put_doc, doc2)
+
+ def test_delete_refuses_for_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'altrev:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, True)
+ self.assertRaises(errors.ConflictedDoc, self.db.delete_doc, doc2)
+
+
+class DatabaseIndexTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def assertParseError(self, definition):
+ self.db.create_doc_from_json(nested_doc)
+ self.assertRaises(
+ errors.IndexDefinitionParseError, self.db.create_index, 'idx',
+ definition)
+
+ def assertIndexCreatable(self, definition):
+ name = "idx"
+ self.db.create_doc_from_json(nested_doc)
+ self.db.create_index(name, definition)
+ self.assertEqual(
+ [(name, [definition])], self.db.list_indexes())
+
+ def test_create_index(self):
+ self.db.create_index('test-idx', 'name')
+ self.assertEqual([('test-idx', ['name'])],
+ self.db.list_indexes())
+
+ def test_create_index_on_non_ascii_field_name(self):
+ doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'}))
+ self.db.create_index('test-idx', u'\xe5')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_list_indexes_with_non_ascii_field_names(self):
+ self.db.create_index('test-idx', u'\xe5')
+ self.assertEqual(
+ [('test-idx', [u'\xe5'])], self.db.list_indexes())
+
+ def test_create_index_evaluates_it(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_wildcard_matches_unicode_value(self):
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_retrieve_unicode_value_from_index(self):
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', u"valu\xe5"))
+
+ def test_create_index_fails_if_name_taken(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertRaises(errors.IndexNameTakenError,
+ self.db.create_index,
+ 'test-idx', 'stuff')
+
+ def test_create_index_does_not_fail_if_name_taken_with_same_index(self):
+ self.db.create_index('test-idx', 'key')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+
+ def test_create_index_does_not_duplicate_indexed_fields(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.db.delete_index('test-idx')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(1, len(self.db.get_from_index('test-idx', 'value')))
+
+ def test_delete_index_does_not_remove_fields_from_other_indexes(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.db.create_index('test-idx2', 'key')
+ self.db.delete_index('test-idx')
+ self.assertEqual(1, len(self.db.get_from_index('test-idx2', 'value')))
+
+ def test_create_index_after_deleting_document(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc2)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_delete_index(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+ self.db.delete_index('test-idx')
+ self.assertEqual([], self.db.list_indexes())
+
+ def test_create_adds_to_index(self):
+ self.db.create_index('test-idx', 'key')
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_unmatched(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_from_index('test-idx', 'novalue'))
+
+ def test_create_index_multiple_exact_matches(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'value')))
+
+ def test_get_from_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_multi(self):
+ content = '{"key": "value", "key2": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2'))
+
+ def test_get_from_index_multi_list(self):
+ doc = self.db.create_doc_from_json(
+ '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-1'))
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-2'))
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-3'))
+ self.assertEqual(
+ [('value', 'value2-1'), ('value', 'value2-2'),
+ ('value', 'value2-3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_from_index_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key', 'key2')
+ alt_doc = self.make_document(
+ doc.doc_id, 'alternate:1',
+ '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ docs = self.db.get_from_index('test-idx', 'value', 'value2-1')
+ self.assertTrue(docs[0].has_conflicts)
+
+ def test_get_index_keys_multi_list_list(self):
+ self.db.create_doc_from_json(
+ '{"key": "value1-1 value1-2 value1-3", '
+ '"key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db.create_index('test-idx', 'split_words(key)', 'key2')
+ self.assertEqual(
+ [(u'value1-1', u'value2-1'), (u'value1-1', u'value2-2'),
+ (u'value1-1', u'value2-3'), (u'value1-2', u'value2-1'),
+ (u'value1-2', u'value2-2'), (u'value1-2', u'value2-3'),
+ (u'value1-3', u'value2-1'), (u'value1-3', u'value2-2'),
+ (u'value1-3', u'value2-3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_from_index_multi_ordered(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2, doc1],
+ self.db.get_from_index('test-idx', 'v*', '*'))
+
+ def test_get_range_from_index_start_end(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ self.db.create_doc_from_json('{"key": "value4"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', 'value2', 'value3'))
+
+ def test_get_range_from_index_start(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value4"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc2, doc1, doc3],
+ self.db.get_range_from_index('test-idx', 'value2'))
+
+ def test_get_range_from_index_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ alt_doc = self.make_document(
+ doc.doc_id, 'alternate:1', '{"key": "valuedepalue"}')
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ docs = self.db.get_range_from_index('test-idx', 'a')
+ self.assertTrue(docs[0].has_conflicts)
+
+ def test_get_range_from_index_end(self):
+ self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ self.db.create_doc_from_json('{"key": "value4"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc4, doc2],
+ self.db.get_range_from_index('test-idx', None, 'value2'))
+
+ def test_get_wildcard_range_from_index_start(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value4"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value23"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value22"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc3, doc4, doc2, doc1],
+ self.db.get_range_from_index('test-idx', 'value2*'))
+
+ def test_get_wildcard_range_from_index_end(self):
+ self.db.create_doc_from_json('{"key": "value4"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value23"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value22"}')
+ doc5 = self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc5, doc3, doc4, doc2],
+ self.db.get_range_from_index('test-idx', None, 'value2*'))
+
+ def test_get_wildcard_range_from_index_start_end(self):
+ self.db.create_doc_from_json('{"key": "a"}')
+ self.db.create_doc_from_json('{"key": "boo3"}')
+ doc3 = self.db.create_doc_from_json('{"key": "catalyst"}')
+ doc4 = self.db.create_doc_from_json('{"key": "whaever"}')
+ self.db.create_doc_from_json('{"key": "zerg"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc3, doc4],
+ self.db.get_range_from_index('test-idx', 'cat*', 'zap*'))
+
+ def test_get_range_from_index_multi_column_start_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', ('value2', 'value2'), ('value2', 'value3')))
+
+ def test_get_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ self.db.create_doc_from_json('{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', 'value3')))
+
+ def test_get_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', None, ('value2', 'value3')))
+
+ def test_get_wildcard_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc3, doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', 'value2*')))
+
+ def test_get_wildcard_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', None, ('value2', 'value2*')))
+
+ def test_get_glob_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', '*')))
+
+ def test_get_glob_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index('test-idx', None, ('value2', '*')))
+
+ def test_get_range_from_index_illegal_wildcard_order(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', ('*', 'v2'))
+
+ def test_get_range_from_index_illegal_glob_after_wildcard(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', ('*', 'v*'))
+
+ def test_get_range_from_index_illegal_wildcard_order_end(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', None, ('*', 'v2'))
+
+ def test_get_range_from_index_illegal_glob_after_wildcard_end(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', None, ('*', 'v*'))
+
+ def test_get_from_index_fails_if_no_index(self):
+ self.assertRaises(
+ errors.IndexDoesNotExist, self.db.get_from_index, 'foo')
+
+ def test_get_index_keys_fails_if_no_index(self):
+ self.assertRaises(errors.IndexDoesNotExist,
+ self.db.get_index_keys,
+ 'foo')
+
+ def test_get_index_keys_works_if_no_docs(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_index_keys('test-idx'))
+
+ def test_put_updates_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ new_content = '{"key": "altval"}'
+ doc.set_json(new_content)
+ self.db.put_doc(doc)
+ self.assertEqual([], self.db.get_from_index('test-idx', 'value'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'altval'))
+
+ def test_delete_updates_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'value')))
+ self.db.delete_doc(doc)
+ self.assertEqual([doc2], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_illegal_number_of_entries(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidValueForIndex, self.db.get_from_index, 'test-idx')
+ self.assertRaises(
+ errors.InvalidValueForIndex,
+ self.db.get_from_index, 'test-idx', 'v1')
+ self.assertRaises(
+ errors.InvalidValueForIndex,
+ self.db.get_from_index, 'test-idx', 'v1', 'v2', 'v3')
+
+ def test_get_from_index_illegal_wildcard_order(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', '*', 'v2')
+
+ def test_get_from_index_illegal_glob_after_wildcard(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', '*', 'v*')
+
+ def test_get_all_from_index(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ # This one should not be in the index
+ self.db.create_doc_from_json('{"no": "key"}')
+ diff_value_doc = '{"key": "diff value"}'
+ doc4 = self.db.create_doc_from_json(diff_value_doc)
+ # This is essentially a 'prefix' match, but we match every entry.
+ self.assertEqual(
+ sorted([doc1, doc2, doc4]),
+ sorted(self.db.get_from_index('test-idx', '*')))
+
+ def test_get_all_from_index_ordered(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json('{"key": "value x"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value b"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value a"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value m"}')
+ # This is essentially a 'prefix' match, but we match every entry.
+ self.assertEqual(
+ [doc3, doc2, doc4, doc1], self.db.get_from_index('test-idx', '*'))
+
+ def test_put_updates_when_adding_key(self):
+ doc = self.db.create_doc_from_json("{}")
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+ doc.set_json(simple_doc)
+ self.db.put_doc(doc)
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_get_from_index_empty_string(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ content2 = '{"key": ""}'
+ doc2 = self.db.create_doc_from_json(content2)
+ self.assertEqual([doc2], self.db.get_from_index('test-idx', ''))
+ # Empty string matches the wildcard.
+ self.assertEqual(
+ sorted([doc1, doc2]),
+ sorted(self.db.get_from_index('test-idx', '*')))
+
+ def test_get_from_index_not_null(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json('{"key": null}')
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', '*'))
+
+ def test_get_partial_from_index(self):
+ content1 = '{"k1": "v1", "k2": "v2"}'
+ content2 = '{"k1": "v1", "k2": "x2"}'
+ content3 = '{"k1": "v1", "k2": "y2"}'
+ # doc4 has a different k1 value, so it doesn't match the prefix.
+ content4 = '{"k1": "NN", "k2": "v2"}'
+ doc1 = self.db.create_doc_from_json(content1)
+ doc2 = self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertEqual(
+ sorted([doc1, doc2, doc3]),
+ sorted(self.db.get_from_index('test-idx', "v1", "*")))
+
+ def test_get_glob_match(self):
+ # Note: the exact glob syntax is probably subject to change
+ content1 = '{"k1": "v1", "k2": "v1"}'
+ content2 = '{"k1": "v1", "k2": "v2"}'
+ content3 = '{"k1": "v1", "k2": "v3"}'
+ # doc4 has a different k2 prefix value, so it doesn't match
+ content4 = '{"k1": "v1", "k2": "ZZ"}'
+ self.db.create_index('test-idx', 'k1', 'k2')
+ doc1 = self.db.create_doc_from_json(content1)
+ doc2 = self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.assertEqual(
+ sorted([doc1, doc2, doc3]),
+ sorted(self.db.get_from_index('test-idx', "v1", "v*")))
+
+ def test_nested_index(self):
+ doc = self.db.create_doc_from_json(nested_doc)
+ self.db.create_index('test-idx', 'sub.doc')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'underneath'))
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'underneath')))
+
+ def test_nested_nonexistent(self):
+ self.db.create_doc_from_json(nested_doc)
+ # sub exists, but sub.foo does not:
+ self.db.create_index('test-idx', 'sub.foo')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+
+ def test_nested_nonexistent2(self):
+ self.db.create_doc_from_json(nested_doc)
+ self.db.create_index('test-idx', 'sub.foo.bar.baz.qux.fnord')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+
+ def test_nested_traverses_lists(self):
+ # subpath finds dicts in list
+ doc = self.db.create_doc_from_json(
+ '{"foo": [{"zap": "bar"}, {"zap": "baz"}]}')
+ # subpath only finds dicts in list
+ self.db.create_doc_from_json('{"foo": ["zap", "baz"]}')
+ self.db.create_index('test-idx', 'foo.zap')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'bar'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'baz'))
+
+ def test_nested_list_traversal(self):
+ # subpath finds dicts in list
+ doc = self.db.create_doc_from_json(
+ '{"foo": [{"zap": [{"qux": "fnord"}, {"qux": "zombo"}]},'
+ '{"zap": "baz"}]}')
+ # subpath only finds dicts in list
+ self.db.create_index('test-idx', 'foo.zap.qux')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'fnord'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'zombo'))
+
+ def test_index_list1(self):
+ self.db.create_index("index", "name")
+ content = '{"name": ["foo", "bar"]}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_list2(self):
+ self.db.create_index("index", "name")
+ content = '{"name": ["foo", "bar"]}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_case_sensitive(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([], self.db.get_from_index('test-idx', 'V*'))
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', 'v*'))
+
+ def test_get_from_index_illegal_glob_before_value(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', 'v*', 'v2')
+
+ def test_get_from_index_illegal_glob_after_glob(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', 'v*', 'v*')
+
+ def test_get_from_index_with_sql_wildcards(self):
+ self.db.create_index('test-idx', 'key')
+ content1 = '{"key": "va%lue"}'
+ content2 = '{"key": "value"}'
+ content3 = '{"key": "va_lue"}'
+ doc1 = self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ # The '%' in the search should be treated literally, not as a sql
+ # globbing character.
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', 'va%*'))
+ # Same for '_'
+ self.assertEqual([doc3], self.db.get_from_index('test-idx', 'va_*'))
+
+ def test_get_from_index_with_lower(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "Foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_lower_matches_same_case(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_lower_doesnt_match_different_case(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "Foo"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "Foo")
+ self.assertEqual([], rows)
+
+ def test_index_lower_doesnt_match_other_index(self):
+ self.db.create_index("index", "lower(name)")
+ self.db.create_index("other_index", "name")
+ content = '{"name": "Foo"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "Foo")
+ self.assertEqual(0, len(rows))
+
+ def test_index_split_words_match_first(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_match_second(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_match_both(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_double_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_leading_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": " foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_trailing_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar "}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_number(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 12}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "00012")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_number_bigger_than_padding(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 123456}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "123456")
+ self.assertEqual([doc], rows)
+
+ def test_number_mapping_ignores_non_numbers(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 56}'
+ doc1 = self.db.create_doc_from_json(content)
+ content = '{"foo": "this is not a maigret painting"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "*")
+ self.assertEqual([doc1], rows)
+
+ def test_get_from_index_with_bool(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": true}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "1")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_bool_false(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": false}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "0")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_non_bool(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": 42}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "*")
+ self.assertEqual([], rows)
+
+ def test_get_from_index_with_combine(self):
+ self.db.create_index("index", "combine(foo, bar)")
+ content = '{"foo": "value1", "bar": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "value1")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "value2")
+ self.assertEqual([doc], rows)
+
+ def test_get_complex_combine(self):
+ self.db.create_index(
+ "index", "combine(number(foo, 5), lower(bar), split_words(baz))")
+ content = '{"foo": 12, "bar": "ALLCAPS", "baz": "qux nox"}'
+ doc = self.db.create_doc_from_json(content)
+ content = '{"foo": "not a number", "bar": "something"}'
+ doc2 = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "00012")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "allcaps")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "nox")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "something")
+ self.assertEqual([doc2], rows)
+
+ def test_get_index_keys_from_index(self):
+ self.db.create_index('test-idx', 'key')
+ content1 = '{"key": "value1"}'
+ content2 = '{"key": "value2"}'
+ content3 = '{"key": "value2"}'
+ self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ self.db.create_doc_from_json(content3)
+ self.assertEqual(
+ [('value1',), ('value2',)],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_index_keys_from_multicolumn_index(self):
+ self.db.create_index('test-idx', 'key1', 'key2')
+ content1 = '{"key1": "value1", "key2": "val2-1"}'
+ content2 = '{"key1": "value2", "key2": "val2-2"}'
+ content3 = '{"key1": "value2", "key2": "val2-2"}'
+ content4 = '{"key1": "value2", "key2": "val3"}'
+ self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.assertEqual([
+ ('value1', 'val2-1'),
+ ('value2', 'val2-2'),
+ ('value2', 'val3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_empty_expr(self):
+ self.assertParseError('')
+
+ def test_nested_unknown_operation(self):
+ self.assertParseError('unknown_operation(field1)')
+
+ def test_parse_missing_close_paren(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_trailing_close_paren(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab)adsf")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_top_level_commas(self):
+ self.assertParseError("a, b")
+
+ def test_invalid_field_name(self):
+ self.assertParseError("a.")
+
+ def test_invalid_inner_field_name(self):
+ self.assertParseError("lower(a.)")
+
+ def test_gobbledigook(self):
+ self.assertParseError("(@#@cc @#!*DFJSXV(()jccd")
+
+ def test_leading_space(self):
+ self.assertIndexCreatable(" lower(a)")
+
+ def test_trailing_space(self):
+ self.assertIndexCreatable("lower(a) ")
+
+ def test_spaces_before_open_paren(self):
+ self.assertIndexCreatable("lower (a)")
+
+ def test_spaces_after_open_paren(self):
+ self.assertIndexCreatable("lower( a)")
+
+ def test_spaces_before_close_paren(self):
+ self.assertIndexCreatable("lower(a )")
+
+ def test_spaces_before_comma(self):
+ self.assertIndexCreatable("combine(a , b , c)")
+
+ def test_spaces_after_comma(self):
+ self.assertIndexCreatable("combine(a, b, c)")
+
+ def test_all_together_now(self):
+ self.assertParseError(' (a) ')
+
+ def test_all_together_now2(self):
+ self.assertParseError('combine(lower(x)x,foo)')
+
+
+class PythonBackendTests(tests.DatabaseBaseTests):
+
+ def setUp(self):
+ super(PythonBackendTests, self).setUp()
+ self.simple_doc = json.loads(simple_doc)
+
+ def test_create_doc_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id')
+ self.assertTrue(isinstance(doc, TestAlternativeDocument))
+
+ def test_get_doc_after_put_with_factory(self):
+ doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id')
+ self.db.set_document_factory(TestAlternativeDocument)
+ result = self.db.get_doc('my_doc_id')
+ self.assertTrue(isinstance(result, TestAlternativeDocument))
+ self.assertEqual(doc.doc_id, result.doc_id)
+ self.assertEqual(doc.rev, result.rev)
+ self.assertEqual(doc.get_json(), result.get_json())
+ self.assertEqual(False, result.has_conflicts)
+
+ def test_get_doc_nonexisting_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.assertIs(None, self.db.get_doc('non-existing'))
+
+ def test_get_all_docs_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.db.create_doc(self.simple_doc)
+ self.assertTrue(isinstance(
+ list(self.db.get_all_docs()[1])[0], TestAlternativeDocument))
+
+ def test_get_docs_conflicted_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ doc1 = self.db.create_doc(self.simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertTrue(
+ isinstance(
+ list(self.db.get_docs([doc1.doc_id]))[0],
+ TestAlternativeDocument))
+
+ def test_get_from_index_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.db.create_doc(self.simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertTrue(
+ isinstance(
+ self.db.get_from_index('test-idx', 'value')[0],
+ TestAlternativeDocument))
+
+ def test_sync_exchange_updates_indexes(self):
+ doc = self.db.create_doc(self.simple_doc)
+ self.db.create_index('test-idx', 'key')
+ new_content = '{"key": "altval"}'
+ other_rev = 'test:1|z:2'
+ st = self.db.get_sync_target()
+
+ def ignore(doc_id, doc_rev, doc):
+ pass
+
+ doc_other = self.make_document(doc.doc_id, other_rev, new_content)
+ docs_by_gen = [(doc_other, 10, 'T-sid')]
+ st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=ignore)
+ self.assertGetDoc(self.db, doc.doc_id, other_rev, new_content, False)
+ self.assertEqual(
+ [doc_other], self.db.get_from_index('test-idx', 'altval'))
+ self.assertEqual([], self.db.get_from_index('test-idx', 'value'))
+
+
+# Use a custom loader to apply the scenarios at load time.
+load_tests = tests.load_with_scenarios
diff --git a/u1db/tests/test_c_backend.py b/u1db/tests/test_c_backend.py
new file mode 100644
index 00000000..bdd2aec7
--- /dev/null
+++ b/u1db/tests/test_c_backend.py
@@ -0,0 +1,634 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from u1db import (
+ Document,
+ errors,
+ tests,
+ )
+from u1db.tests import c_backend_wrapper, c_backend_error
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app
+ )
+
+
+class TestCDatabaseExists(tests.TestCase):
+
+ def test_c_backend_compiled(self):
+ if c_backend_wrapper is None:
+ self.fail("Could not import the c_backend_wrapper module."
+ " Was it compiled properly?\n%s" % (c_backend_error,))
+
+
+# Rather than lots of failing tests, we have the above check to test that the
+# module exists, and all these tests just get skipped
+class BackendTests(tests.TestCase):
+
+ def setUp(self):
+ super(BackendTests, self).setUp()
+ if c_backend_wrapper is None:
+ self.skipTest("The c_backend_wrapper could not be imported")
+
+
+class TestCDatabase(BackendTests):
+
+ def test_exists(self):
+ if c_backend_wrapper is None:
+ self.fail("Could not import the c_backend_wrapper module."
+ " Was it compiled properly?")
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertEqual(':memory:', db._filename)
+
+ def test__is_closed(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertTrue(db._sql_is_open())
+ db.close()
+ self.assertFalse(db._sql_is_open())
+
+ def test__run_sql(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertTrue(db._sql_is_open())
+ self.assertEqual([], db._run_sql('CREATE TABLE test (id INTEGER)'))
+ self.assertEqual([], db._run_sql('INSERT INTO test VALUES (1)'))
+ self.assertEqual([('1',)], db._run_sql('SELECT * FROM test'))
+
+ def test__get_generation(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertEqual(0, db._get_generation())
+ db.create_doc_from_json(tests.simple_doc)
+ self.assertEqual(1, db._get_generation())
+
+ def test__get_generation_info(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertEqual((0, ''), db._get_generation_info())
+ db.create_doc_from_json(tests.simple_doc)
+ info = db._get_generation_info()
+ self.assertEqual(1, info[0])
+ self.assertTrue(info[1].startswith('T-'))
+
+ def test__set_replica_uid(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertIsNot(None, db._replica_uid)
+ db._set_replica_uid('foo')
+ self.assertEqual([('foo',)], db._run_sql(
+ "SELECT value FROM u1db_config WHERE name='replica_uid'"))
+
+ def test_default_replica_uid(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertIsNot(None, self.db._replica_uid)
+ self.assertEqual(32, len(self.db._replica_uid))
+ # casting to an int from the uid *is* the check for correct behavior.
+ int(self.db._replica_uid, 16)
+
+ def test_get_conflicts_with_borked_data(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ # We add an entry to conflicts, but not to documents, which is an
+ # invalid situation
+ self.db._run_sql("INSERT INTO conflicts"
+ " VALUES ('doc-id', 'doc-rev', '{}')")
+ self.assertRaises(Exception, self.db.get_doc_conflicts, 'doc-id')
+
+ def test_create_index_list(self):
+ # We manually poke data into the DB, so that we test just the "get_doc"
+ # code, rather than also testing the index management code.
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index_list("key-idx", ["key"])
+ docs = self.db.get_from_index('key-idx', 'value')
+ self.assertEqual([doc], docs)
+
+ def test_create_index_list_on_non_ascii_field_name(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'}))
+ self.db.create_index_list('test-idx', [u'\xe5'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_list_indexes_with_non_ascii_field_names(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index_list('test-idx', [u'\xe5'])
+ self.assertEqual(
+ [('test-idx', [u'\xe5'])], self.db.list_indexes())
+
+ def test_create_index_evaluates_it(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_wildcard_matches_unicode_value(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_create_index_fails_if_name_taken(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertRaises(errors.IndexNameTakenError,
+ self.db.create_index_list,
+ 'test-idx', ['stuff'])
+
+ def test_create_index_does_not_fail_if_name_taken_with_same_index(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index_list('test-idx', ['key'])
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+
+ def test_create_index_after_deleting_document(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.delete_doc(doc2)
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index(self):
+ # We manually poke data into the DB, so that we test just the "get_doc"
+ # code, rather than also testing the index management code.
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index("key-idx", "key")
+ docs = self.db.get_from_index('key-idx', 'value')
+ self.assertEqual([doc], docs)
+
+ def test_get_from_index_list(self):
+ # We manually poke data into the DB, so that we test just the "get_doc"
+ # code, rather than also testing the index management code.
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index("key-idx", "key")
+ docs = self.db.get_from_index_list('key-idx', ['value'])
+ self.assertEqual([doc], docs)
+
+ def test_get_from_index_list_multi(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ content = '{"key": "value", "key2": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc],
+ self.db.get_from_index_list('test-idx', ['value', 'value2']))
+
+ def test_get_from_index_list_multi_ordered(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2, doc1],
+ self.db.get_from_index_list('test-idx', ['v*', '*']))
+
+ def test_get_from_index_2(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.nested_doc)
+ self.db.create_index("multi-idx", "key", "sub.doc")
+ docs = self.db.get_from_index('multi-idx', 'value', 'underneath')
+ self.assertEqual([doc], docs)
+
+ def test_get_index_keys(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index("key-idx", "key")
+ keys = self.db.get_index_keys('key-idx')
+ self.assertEqual([("value",)], keys)
+
+ def test__query_init_one_field(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index("key-idx", "key")
+ query = self.db._query_init("key-idx")
+ self.assertEqual("key-idx", query.index_name)
+ self.assertEqual(1, query.num_fields)
+ self.assertEqual(["key"], query.fields)
+
+ def test__query_init_two_fields(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index("two-idx", "key", "key2")
+ query = self.db._query_init("two-idx")
+ self.assertEqual("two-idx", query.index_name)
+ self.assertEqual(2, query.num_fields)
+ self.assertEqual(["key", "key2"], query.fields)
+
+ def assertFormatQueryEquals(self, expected, wildcards, fields):
+ val, w = c_backend_wrapper._format_query(fields)
+ self.assertEqual(expected, val)
+ self.assertEqual(wildcards, w)
+
+ def test__format_query(self):
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id FROM document_fields d0"
+ " WHERE d0.field_name = ? AND d0.value = ? ORDER BY d0.value",
+ [0], ["1"])
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id"
+ " FROM document_fields d0, document_fields d1"
+ " WHERE d0.field_name = ? AND d0.value = ?"
+ " AND d0.doc_id = d1.doc_id"
+ " AND d1.field_name = ? AND d1.value = ?"
+ " ORDER BY d0.value, d1.value",
+ [0, 0], ["1", "2"])
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id"
+ " FROM document_fields d0, document_fields d1, document_fields d2"
+ " WHERE d0.field_name = ? AND d0.value = ?"
+ " AND d0.doc_id = d1.doc_id"
+ " AND d1.field_name = ? AND d1.value = ?"
+ " AND d0.doc_id = d2.doc_id"
+ " AND d2.field_name = ? AND d2.value = ?"
+ " ORDER BY d0.value, d1.value, d2.value",
+ [0, 0, 0], ["1", "2", "3"])
+
+ def test__format_query_wildcard(self):
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id FROM document_fields d0"
+ " WHERE d0.field_name = ? AND d0.value NOT NULL ORDER BY d0.value",
+ [1], ["*"])
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id"
+ " FROM document_fields d0, document_fields d1"
+ " WHERE d0.field_name = ? AND d0.value = ?"
+ " AND d0.doc_id = d1.doc_id"
+ " AND d1.field_name = ? AND d1.value NOT NULL"
+ " ORDER BY d0.value, d1.value",
+ [0, 1], ["1", "*"])
+
+ def test__format_query_glob(self):
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id FROM document_fields d0"
+ " WHERE d0.field_name = ? AND d0.value GLOB ? ORDER BY d0.value",
+ [2], ["1*"])
+
+
+class TestCSyncTarget(BackendTests):
+
+ def setUp(self):
+ super(TestCSyncTarget, self).setUp()
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.st = self.db.get_sync_target()
+
+ def test_attached_to_db(self):
+ self.assertEqual(
+ self.db._replica_uid, self.st.get_sync_info("misc")[0])
+
+ def test_get_sync_exchange(self):
+ exc = self.st._get_sync_exchange("source-uid", 10)
+ self.assertIsNot(None, exc)
+
+ def test_sync_exchange_insert_doc_from_source(self):
+ exc = self.st._get_sync_exchange("source-uid", 5)
+ doc = c_backend_wrapper.make_document('doc-id', 'replica:1',
+ tests.simple_doc)
+ self.assertEqual([], exc.get_seen_ids())
+ exc.insert_doc_from_source(doc, 10, 'T-sid')
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', tests.simple_doc,
+ False)
+ self.assertEqual(
+ (10, 'T-sid'), self.db._get_replica_gen_and_trans_id('source-uid'))
+ self.assertEqual(['doc-id'], exc.get_seen_ids())
+
+ def test_sync_exchange_conflicted_doc(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ exc = self.st._get_sync_exchange("source-uid", 5)
+ doc2 = c_backend_wrapper.make_document(doc.doc_id, 'replica:1',
+ tests.nested_doc)
+ self.assertEqual([], exc.get_seen_ids())
+ # The insert should be rejected and the doc_id not considered 'seen'
+ exc.insert_doc_from_source(doc2, 10, 'T-sid')
+ self.assertGetDoc(
+ self.db, doc.doc_id, doc.rev, tests.simple_doc, False)
+ self.assertEqual([], exc.get_seen_ids())
+
+ def test_sync_exchange_find_doc_ids(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ exc = self.st._get_sync_exchange("source-uid", 0)
+ self.assertEqual(0, exc.target_gen)
+ exc.find_doc_ids_to_return()
+ doc_id = exc.get_doc_ids_to_return()[0]
+ self.assertEqual(
+ (doc.doc_id, 1), doc_id[:-1])
+ self.assertTrue(doc_id[-1].startswith('T-'))
+ self.assertEqual(1, exc.target_gen)
+
+ def test_sync_exchange_find_doc_ids_not_including_recently_inserted(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db.create_doc_from_json(tests.nested_doc)
+ exc = self.st._get_sync_exchange("source-uid", 0)
+ doc3 = c_backend_wrapper.make_document(doc1.doc_id,
+ doc1.rev + "|zreplica:2", tests.simple_doc)
+ exc.insert_doc_from_source(doc3, 10, 'T-sid')
+ exc.find_doc_ids_to_return()
+ self.assertEqual(
+ (doc2.doc_id, 2), exc.get_doc_ids_to_return()[0][:-1])
+ self.assertEqual(3, exc.target_gen)
+
+ def test_sync_exchange_return_docs(self):
+ returned = []
+
+ def return_doc_cb(doc, gen, trans_id):
+ returned.append((doc, gen, trans_id))
+
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ exc = self.st._get_sync_exchange("source-uid", 0)
+ exc.find_doc_ids_to_return()
+ exc.return_docs(return_doc_cb)
+ self.assertEqual((doc1, 1), returned[0][:-1])
+
+ def test_sync_exchange_doc_ids(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc, doc_id='doc-1')
+ db2 = c_backend_wrapper.CDatabase(':memory:')
+ doc2 = db2.create_doc_from_json(tests.nested_doc, doc_id='doc-2')
+ returned = []
+
+ def return_doc_cb(doc, gen, trans_id):
+ returned.append((doc, gen, trans_id))
+
+ val = self.st.sync_exchange_doc_ids(
+ db2, [(doc2.doc_id, 1, 'T-sid')], 0, None, return_doc_cb)
+ last_trans_id = self.db._get_transaction_log()[-1][1]
+ self.assertEqual(2, self.db._get_generation())
+ self.assertEqual((2, last_trans_id), val)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+ self.assertEqual((doc1, 1), returned[0][:-1])
+
+
+class TestCHTTPSyncTarget(BackendTests):
+
+ def test_format_sync_url(self):
+ target = c_backend_wrapper.create_http_sync_target("http://base_url")
+ self.assertEqual("http://base_url/sync-from/replica-uid",
+ c_backend_wrapper._format_sync_url(target, "replica-uid"))
+
+ def test_format_sync_url_escapes(self):
+ # The base_url should not get munged (we assume it is already a
+ # properly formed URL), but the replica-uid should get properly escaped
+ target = c_backend_wrapper.create_http_sync_target(
+ "http://host/base%2Ctest/")
+ self.assertEqual("http://host/base%2Ctest/sync-from/replica%2Cuid",
+ c_backend_wrapper._format_sync_url(target, "replica,uid"))
+
+ def test_format_refuses_non_http(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ target = db.get_sync_target()
+ self.assertRaises(RuntimeError,
+ c_backend_wrapper._format_sync_url, target, 'replica,uid')
+
+ def test_oauth_credentials(self):
+ target = c_backend_wrapper.create_oauth_http_sync_target(
+ "http://host/base%2Ctest/",
+ 'consumer-key', 'consumer-secret', 'token-key', 'token-secret')
+ auth = c_backend_wrapper._get_oauth_authorization(target,
+ "GET", "http://host/base%2Ctest/sync-from/abcd-efg")
+ self.assertIsNot(None, auth)
+ self.assertTrue(auth.startswith('Authorization: OAuth realm="", '))
+ self.assertNotIn('http://host/base', auth)
+ self.assertIn('oauth_nonce="', auth)
+ self.assertIn('oauth_timestamp="', auth)
+ self.assertIn('oauth_consumer_key="consumer-key"', auth)
+ self.assertIn('oauth_signature_method="HMAC-SHA1"', auth)
+ self.assertIn('oauth_version="1.0"', auth)
+ self.assertIn('oauth_token="token-key"', auth)
+ self.assertIn('oauth_signature="', auth)
+
+
+class TestSyncCtoHTTPViaC(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestSyncCtoHTTPViaC, self).setUp()
+ if c_backend_wrapper is None:
+ self.skipTest("The c_backend_wrapper could not be imported")
+ self.startServer()
+
+ def test_trivial_sync(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ url = self.getURL('test.db')
+ target = c_backend_wrapper.create_http_sync_target(url)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ c_backend_wrapper.sync_db_to_target(db, target)
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+ def test_unavailable(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_db.create_doc_from_json(tests.nested_doc)
+ tries = []
+
+ def wrapper(instance, *args, **kwargs):
+ tries.append(None)
+ raise errors.Unavailable
+
+ mem_db.whats_changed = wrapper
+ url = self.getURL('test.db')
+ target = c_backend_wrapper.create_http_sync_target(url)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ db.create_doc_from_json(tests.simple_doc)
+ self.assertRaises(
+ errors.Unavailable, c_backend_wrapper.sync_db_to_target, db,
+ target)
+ self.assertEqual(5, len(tries))
+
+ def test_unavailable_then_available(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ orig_whatschanged = mem_db.whats_changed
+ tries = []
+
+ def wrapper(instance, *args, **kwargs):
+ if len(tries) < 1:
+ tries.append(None)
+ raise errors.Unavailable
+ return orig_whatschanged(instance, *args, **kwargs)
+
+ mem_db.whats_changed = wrapper
+ url = self.getURL('test.db')
+ target = c_backend_wrapper.create_http_sync_target(url)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ c_backend_wrapper.sync_db_to_target(db, target)
+ self.assertEqual(1, len(tries))
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+ def test_db_sync(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ url = self.getURL('test.db')
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ local_gen_before_sync = db.sync(url)
+ gen, _, changes = db.whats_changed(local_gen_before_sync)
+ self.assertEqual(1, len(changes))
+ self.assertEqual(mem_doc.doc_id, changes[0][0])
+ self.assertEqual(1, gen - local_gen_before_sync)
+ self.assertEqual(1, local_gen_before_sync)
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+
+class TestSyncCtoOAuthHTTPViaC(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_oauth_http_app)
+
+ def setUp(self):
+ super(TestSyncCtoOAuthHTTPViaC, self).setUp()
+ if c_backend_wrapper is None:
+ self.skipTest("The c_backend_wrapper could not be imported")
+ self.startServer()
+
+ def test_trivial_sync(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ url = self.getURL('~/test.db')
+ target = c_backend_wrapper.create_oauth_http_sync_target(url,
+ tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ c_backend_wrapper.sync_db_to_target(db, target)
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+
+class TestVectorClock(BackendTests):
+
+ def create_vcr(self, rev):
+ return c_backend_wrapper.VectorClockRev(rev)
+
+ def test_parse_empty(self):
+ self.assertEqual('VectorClockRev()',
+ repr(self.create_vcr('')))
+
+ def test_parse_invalid(self):
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('x')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('x:a')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:a')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('x:a|y:1')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2a')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1||')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2|')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2|:')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2|m:')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:|m:3')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|:|m:3')))
+
+ def test_parse_single(self):
+ self.assertEqual('VectorClockRev(test:1)',
+ repr(self.create_vcr('test:1')))
+
+ def test_parse_multi(self):
+ self.assertEqual('VectorClockRev(test:1|z:2)',
+ repr(self.create_vcr('test:1|z:2')))
+ self.assertEqual('VectorClockRev(ab:1|bc:2|cd:3|de:4|ef:5)',
+ repr(self.create_vcr('ab:1|bc:2|cd:3|de:4|ef:5')))
+ self.assertEqual('VectorClockRev(a:2|b:1)',
+ repr(self.create_vcr('b:1|a:2')))
+
+
+class TestCDocument(BackendTests):
+
+ def make_document(self, *args, **kwargs):
+ return c_backend_wrapper.make_document(*args, **kwargs)
+
+ def test_create(self):
+ self.make_document('doc-id', 'uid:1', tests.simple_doc)
+
+ def assertPyDocEqualCDoc(self, *args, **kwargs):
+ cdoc = self.make_document(*args, **kwargs)
+ pydoc = Document(*args, **kwargs)
+ self.assertEqual(pydoc, cdoc)
+ self.assertEqual(cdoc, pydoc)
+
+ def test_cmp_to_pydoc_equal(self):
+ self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc)
+ self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=False)
+ self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+
+ def test_cmp_to_pydoc_not_equal_conflicts(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+ def test_cmp_to_pydoc_not_equal_doc_id(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc2-id', 'uid:1', tests.simple_doc)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+ def test_cmp_to_pydoc_not_equal_doc_rev(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc-id', 'uid:2', tests.simple_doc)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+ def test_cmp_to_pydoc_not_equal_content(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc-id', 'uid:1', tests.nested_doc)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+
+class TestUUID(BackendTests):
+
+ def test_uuid4_conformance(self):
+ uuids = set()
+ for i in range(20):
+ uuid = c_backend_wrapper.generate_hex_uuid()
+ self.assertIsInstance(uuid, str)
+ self.assertEqual(32, len(uuid))
+ # This will raise ValueError if it isn't a valid hex string
+ long(uuid, 16)
+ # Version 4 uuids have 2 other requirements, the high 4 bits of the
+ # seventh byte are always '0x4', and the middle bits of byte 9 are
+ # always set
+ self.assertEqual('4', uuid[12])
+ self.assertTrue(uuid[16] in '89ab')
+ self.assertTrue(uuid not in uuids)
+ uuids.add(uuid)
diff --git a/u1db/tests/test_common_backend.py b/u1db/tests/test_common_backend.py
new file mode 100644
index 00000000..8c7c7ed9
--- /dev/null
+++ b/u1db/tests/test_common_backend.py
@@ -0,0 +1,33 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test common backend bits."""
+
+from u1db import (
+ backends,
+ tests,
+ )
+
+
+class TestCommonBackendImpl(tests.TestCase):
+
+ def test__allocate_doc_id(self):
+ db = backends.CommonBackend()
+ doc_id1 = db._allocate_doc_id()
+ self.assertTrue(doc_id1.startswith('D-'))
+ self.assertEqual(34, len(doc_id1))
+ int(doc_id1[len('D-'):], 16)
+ self.assertNotEqual(doc_id1, db._allocate_doc_id())
diff --git a/u1db/tests/test_document.py b/u1db/tests/test_document.py
new file mode 100644
index 00000000..20f254b9
--- /dev/null
+++ b/u1db/tests/test_document.py
@@ -0,0 +1,148 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+
+from u1db import errors, tests
+
+
+class TestDocument(tests.TestCase):
+
+ scenarios = ([(
+ 'py', {'make_document_for_test': tests.make_document_for_test})] +
+ tests.C_DATABASE_SCENARIOS)
+
+ def test_create_doc(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ self.assertEqual('doc-id', doc.doc_id)
+ self.assertEqual('uid:1', doc.rev)
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+
+ def test__repr__(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ self.assertEqual(
+ '%s(doc-id, uid:1, \'{"key": "value"}\')'
+ % (doc.__class__.__name__,),
+ repr(doc))
+
+ def test__repr__conflicted(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+ self.assertEqual(
+ '%s(doc-id, uid:1, conflicted, \'{"key": "value"}\')'
+ % (doc.__class__.__name__,),
+ repr(doc))
+
+ def test__lt__(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('b', 'b', '{}')
+ self.assertTrue(doc_a < doc_b)
+ self.assertTrue(doc_b > doc_a)
+ doc_aa = self.make_document('a', 'a', '{}')
+ self.assertTrue(doc_aa < doc_a)
+
+ def test__eq__(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('a', 'b', '{}')
+ self.assertTrue(doc_a == doc_b)
+ doc_b = self.make_document('a', 'b', '{}', has_conflicts=True)
+ self.assertFalse(doc_a == doc_b)
+
+ def test_non_json_dict(self):
+ self.assertRaises(
+ errors.InvalidJSON, self.make_document, 'id', 'uid:1',
+ '"not a json dictionary"')
+
+ def test_non_json(self):
+ self.assertRaises(
+ errors.InvalidJSON, self.make_document, 'id', 'uid:1',
+ 'not a json dictionary')
+
+ def test_get_size(self):
+ doc_a = self.make_document('a', 'b', '{"some": "content"}')
+ self.assertEqual(
+ len('a' + 'b' + '{"some": "content"}'), doc_a.get_size())
+
+ def test_get_size_empty_document(self):
+ doc_a = self.make_document('a', 'b', None)
+ self.assertEqual(len('a' + 'b'), doc_a.get_size())
+
+
+class TestPyDocument(tests.TestCase):
+
+ scenarios = ([(
+ 'py', {'make_document_for_test': tests.make_document_for_test})])
+
+ def test_get_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertEqual({"content": ""}, doc.content)
+ doc.set_json('{"content": "new"}')
+ self.assertEqual({"content": "new"}, doc.content)
+
+ def test_set_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ doc.content = {"content": "new"}
+ self.assertEqual('{"content": "new"}', doc.get_json())
+
+ def test_set_bad_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(
+ errors.InvalidContent, setattr, doc, 'content',
+ '{"content": "new"}')
+
+ def test_is_tombstone(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ self.assertFalse(doc_a.is_tombstone())
+ doc_a.set_json(None)
+ self.assertTrue(doc_a.is_tombstone())
+
+ def test_make_tombstone(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ self.assertFalse(doc_a.is_tombstone())
+ doc_a.make_tombstone()
+ self.assertTrue(doc_a.is_tombstone())
+
+ def test_same_content_as(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('d', 'e', '{}')
+ self.assertTrue(doc_a.same_content_as(doc_b))
+ doc_b = self.make_document('p', 'q', '{}', has_conflicts=True)
+ self.assertTrue(doc_a.same_content_as(doc_b))
+ doc_b.content['key'] = 'value'
+ self.assertFalse(doc_a.same_content_as(doc_b))
+
+ def test_same_content_as_json_order(self):
+ doc_a = self.make_document(
+ 'a', 'b', '{"key1": "val1", "key2": "val2"}')
+ doc_b = self.make_document(
+ 'c', 'd', '{"key2": "val2", "key1": "val1"}')
+ self.assertTrue(doc_a.same_content_as(doc_b))
+
+ def test_set_json(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ doc.set_json('{"content": "new"}')
+ self.assertEqual('{"content": "new"}', doc.get_json())
+
+ def test_set_json_non_dict(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(errors.InvalidJSON, doc.set_json, '"is not a dict"')
+
+ def test_set_json_error(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(errors.InvalidJSON, doc.set_json, 'is not json')
+
+
+load_tests = tests.load_with_scenarios
diff --git a/u1db/tests/test_errors.py b/u1db/tests/test_errors.py
new file mode 100644
index 00000000..0e089ede
--- /dev/null
+++ b/u1db/tests/test_errors.py
@@ -0,0 +1,61 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests error infrastructure."""
+
+from u1db import (
+ errors,
+ tests,
+ )
+
+
+class TestError(tests.TestCase):
+
+ def test_error_base(self):
+ err = errors.U1DBError()
+ self.assertEqual("error", err.wire_description)
+ self.assertIs(None, err.message)
+
+ err = errors.U1DBError("Message.")
+ self.assertEqual("error", err.wire_description)
+ self.assertEqual("Message.", err.message)
+
+ def test_HTTPError(self):
+ err = errors.HTTPError(500)
+ self.assertEqual(500, err.status)
+ self.assertIs(None, err.wire_description)
+ self.assertIs(None, err.message)
+
+ err = errors.HTTPError(500, "Crash.")
+ self.assertEqual(500, err.status)
+ self.assertIs(None, err.wire_description)
+ self.assertEqual("Crash.", err.message)
+
+ def test_HTTPError_str(self):
+ err = errors.HTTPError(500)
+ self.assertEqual("HTTPError(500)", str(err))
+
+ err = errors.HTTPError(500, "ERROR")
+ self.assertEqual("HTTPError(500, 'ERROR')", str(err))
+
+ def test_Unvailable(self):
+ err = errors.Unavailable()
+ self.assertEqual(503, err.status)
+ self.assertEqual("Unavailable()", str(err))
+
+ err = errors.Unavailable("DOWN")
+ self.assertEqual("DOWN", err.message)
+ self.assertEqual("Unavailable('DOWN')", str(err))
diff --git a/u1db/tests/test_http_app.py b/u1db/tests/test_http_app.py
new file mode 100644
index 00000000..13522693
--- /dev/null
+++ b/u1db/tests/test_http_app.py
@@ -0,0 +1,1133 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test the WSGI app."""
+
+import paste.fixture
+import sys
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import StringIO
+
+from u1db import (
+ __version__ as _u1db_version,
+ errors,
+ sync,
+ tests,
+ )
+
+from u1db.remote import (
+ http_app,
+ http_errors,
+ )
+
+
+class TestFencedReader(tests.TestCase):
+
+ def test_init(self):
+ reader = http_app._FencedReader(StringIO.StringIO(""), 25, 100)
+ self.assertEqual(25, reader.remaining)
+
+ def test_read_chunk(self):
+ inp = StringIO.StringIO("abcdef")
+ reader = http_app._FencedReader(inp, 5, 10)
+ data = reader.read_chunk(2)
+ self.assertEqual("ab", data)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(3, reader.remaining)
+
+ def test_read_chunk_remaining(self):
+ inp = StringIO.StringIO("abcdef")
+ reader = http_app._FencedReader(inp, 4, 10)
+ data = reader.read_chunk(9999)
+ self.assertEqual("abcd", data)
+ self.assertEqual(4, inp.tell())
+ self.assertEqual(0, reader.remaining)
+
+ def test_read_chunk_nothing_left(self):
+ inp = StringIO.StringIO("abc")
+ reader = http_app._FencedReader(inp, 2, 10)
+ reader.read_chunk(2)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(0, reader.remaining)
+ data = reader.read_chunk(2)
+ self.assertEqual("", data)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(0, reader.remaining)
+
+ def test_read_chunk_kept(self):
+ inp = StringIO.StringIO("abcde")
+ reader = http_app._FencedReader(inp, 4, 10)
+ reader._kept = "xyz"
+ data = reader.read_chunk(2) # atmost ignored
+ self.assertEqual("xyz", data)
+ self.assertEqual(0, inp.tell())
+ self.assertEqual(4, reader.remaining)
+ self.assertIsNone(reader._kept)
+
+ def test_getline(self):
+ inp = StringIO.StringIO("abc\r\nde")
+ reader = http_app._FencedReader(inp, 6, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abc\r\n", line)
+ self.assertEqual("d", reader._kept)
+
+ def test_getline_exact(self):
+ inp = StringIO.StringIO("abcd\r\nef")
+ reader = http_app._FencedReader(inp, 6, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abcd\r\n", line)
+ self.assertIs(None, reader._kept)
+
+ def test_getline_no_newline(self):
+ inp = StringIO.StringIO("abcd")
+ reader = http_app._FencedReader(inp, 4, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abcd", line)
+
+ def test_getline_many_chunks(self):
+ inp = StringIO.StringIO("abcde\r\nf")
+ reader = http_app._FencedReader(inp, 8, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("abcde\r\n", line)
+ self.assertEqual("f", reader._kept)
+ line = reader.getline()
+ self.assertEqual("f", line)
+
+ def test_getline_empty(self):
+ inp = StringIO.StringIO("")
+ reader = http_app._FencedReader(inp, 0, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("", line)
+ line = reader.getline()
+ self.assertEqual("", line)
+
+ def test_getline_just_newline(self):
+ inp = StringIO.StringIO("\r\n")
+ reader = http_app._FencedReader(inp, 2, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("\r\n", line)
+ line = reader.getline()
+ self.assertEqual("", line)
+
+ def test_getline_too_large(self):
+ inp = StringIO.StringIO("x" * 50)
+ reader = http_app._FencedReader(inp, 50, 25)
+ reader.MAXCHUNK = 4
+ self.assertRaises(http_app.BadRequest, reader.getline)
+
+ def test_getline_too_large_complete(self):
+ inp = StringIO.StringIO("x" * 25 + "\r\n")
+ reader = http_app._FencedReader(inp, 50, 25)
+ reader.MAXCHUNK = 4
+ self.assertRaises(http_app.BadRequest, reader.getline)
+
+
+class TestHTTPMethodDecorator(tests.TestCase):
+
+ def test_args(self):
+ @http_app.http_method()
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x", "b": "y"}, None)
+ self.assertEqual(("self", "x", "y"), res)
+
+ def test_args_missing(self):
+ @http_app.http_method()
+ def f(self, a, b):
+ return a, b
+ self.assertRaises(http_app.BadRequest, f, "self", {"a": "x"}, None)
+
+ def test_args_unexpected(self):
+ @http_app.http_method()
+ def f(self, a):
+ return a
+ self.assertRaises(http_app.BadRequest, f, "self",
+ {"a": "x", "c": "z"}, None)
+
+ def test_args_default(self):
+ @http_app.http_method()
+ def f(self, a, b="z"):
+ return a, b
+ res = f("self", {"a": "x"}, None)
+ self.assertEqual(("x", "z"), res)
+
+ def test_args_conversion(self):
+ @http_app.http_method(b=int)
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x", "b": "2"}, None)
+ self.assertEqual(("self", "x", 2), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self",
+ {"a": "x", "b": "foo"}, None)
+
+ def test_args_conversion_with_default(self):
+ @http_app.http_method(b=str)
+ def f(self, a, b=None):
+ return self, a, b
+ res = f("self", {"a": "x"}, None)
+ self.assertEqual(("self", "x", None), res)
+
+ def test_args_content(self):
+ @http_app.http_method()
+ def f(self, a, content):
+ return a, content
+ res = f(self, {"a": "x"}, "CONTENT")
+ self.assertEqual(("x", "CONTENT"), res)
+
+ def test_args_content_as_args(self):
+ @http_app.http_method(b=int, content_as_args=True)
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x"}, '{"b": "2"}')
+ self.assertEqual(("self", "x", 2), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self", {}, 'not-json')
+
+ def test_args_content_no_query(self):
+ @http_app.http_method(no_query=True,
+ content_as_args=True)
+ def f(self, a='a', b='b'):
+ return a, b
+ res = f("self", {}, '{"b": "y"}')
+ self.assertEqual(('a', 'y'), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self", {'a': 'x'},
+ '{"b": "y"}')
+
+
+class TestResource(object):
+
+ @http_app.http_method()
+ def get(self, a, b):
+ self.args = dict(a=a, b=b)
+ return 'Get'
+
+ @http_app.http_method()
+ def put(self, a, content):
+ self.args = dict(a=a)
+ self.content = content
+ return 'Put'
+
+ @http_app.http_method(content_as_args=True)
+ def put_args(self, a, b):
+ self.args = dict(a=a, b=b)
+ self.order = ['a']
+ self.entries = []
+
+ @http_app.http_method()
+ def put_stream_entry(self, content):
+ self.entries.append(content)
+ self.order.append('s')
+
+ def put_end(self):
+ self.order.append('e')
+ return "Put/end"
+
+
+class parameters:
+ max_request_size = 200000
+ max_entry_size = 100000
+
+
+class TestHTTPInvocationByMethodWithBody(tests.TestCase):
+
+ def test_get(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'GET'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Get', res)
+ self.assertEqual({'a': '1', 'b': '2'}, resource.args)
+
+ def test_put_json(self):
+ resource = TestResource()
+ body = '{"body": true}'
+ environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Put', res)
+ self.assertEqual({'a': '1'}, resource.args)
+ self.assertEqual('{"body": true}', resource.content)
+
+ def test_put_sync_stream(self):
+ resource = TestResource()
+ body = (
+ '[\r\n'
+ '{"b": 2},\r\n' # args
+ '{"entry": "x"},\r\n' # stream entry
+ '{"entry": "y"}\r\n' # stream entry
+ ']'
+ )
+ environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-sync-stream'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Put/end', res)
+ self.assertEqual({'a': '1', 'b': 2}, resource.args)
+ self.assertEqual(
+ ['{"entry": "x"}', '{"entry": "y"}'], resource.entries)
+ self.assertEqual(['a', 's', 's', 'e'], resource.order)
+
+ def _put_sync_stream(self, body):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-sync-stream'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ invoke()
+
+ def test_put_sync_stream_wrong_start(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "{}\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "\r\n{}\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "")
+
+ def test_put_sync_stream_wrong_end(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}\r\n]\r\n...")
+
+ def test_put_sync_stream_missing_comma(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}\r\n{}\r\n]")
+
+ def test_put_sync_stream_extra_comma(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{},\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{},\r\n{},\r\n]")
+
+ def test_bad_request_decode_failure(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=\xff', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_content_type(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'text/plain'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_content_length_too_large(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '10000',
+ 'CONTENT_TYPE': 'text/plain'}
+
+ resource.max_request_size = 5000
+ resource.max_entry_size = sys.maxint # we don't get to use this
+
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_no_content_length(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('a'),
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_invalid_content_length(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('abc'),
+ 'CONTENT_LENGTH': '1unk',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_empty_body(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(''),
+ 'CONTENT_LENGTH': '0',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_get_like(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'DELETE'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_put_like(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_put_like_multi_json(self):
+ resource = TestResource()
+ body = '{}\r\n{}\r\n'
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'POST',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-multi-json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+
+class TestHTTPResponder(tests.TestCase):
+
+ def start_response(self, status, headers):
+ self.status = status
+ self.headers = dict(headers)
+ self.response_body = []
+
+ def write(data):
+ self.response_body.append(data)
+
+ return write
+
+ def test_send_response_content_w_headers(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_content('foo', headers={'x-a': '1'})
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/json',
+ 'cache-control': 'no-cache',
+ 'x-a': '1', 'content-length': '3'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual(['foo'], responder.content)
+
+ def test_send_response_json(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_json(value='success')
+ self.assertEqual('200 OK', self.status)
+ expected_body = '{"value": "success"}\r\n'
+ self.assertEqual({'content-type': 'application/json',
+ 'content-length': str(len(expected_body)),
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual([expected_body], responder.content)
+
+ def test_send_response_json_status_fail(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_json(400)
+ self.assertEqual('400 Bad Request', self.status)
+ expected_body = '{}\r\n'
+ self.assertEqual({'content-type': 'application/json',
+ 'content-length': str(len(expected_body)),
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual([expected_body], responder.content)
+
+ def test_start_finish_response_status_fail(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.start_response(404, {'error': 'not found'})
+ responder.finish_response()
+ self.assertEqual('404 Not Found', self.status)
+ self.assertEqual({'content-type': 'application/json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['{"error": "not found"}\r\n'], self.response_body)
+ self.assertEqual([], responder.content)
+
+ def test_send_stream_entry(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.content_type = "application/x-u1db-multi-json"
+ responder.start_response(200)
+ responder.start_stream()
+ responder.stream_entry({'entry': 1})
+ responder.stream_entry({'entry': 2})
+ responder.end_stream()
+ responder.finish_response()
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/x-u1db-multi-json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['[',
+ '\r\n', '{"entry": 1}',
+ ',\r\n', '{"entry": 2}',
+ '\r\n]\r\n'], self.response_body)
+ self.assertEqual([], responder.content)
+
+ def test_send_stream_w_error(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.content_type = "application/x-u1db-multi-json"
+ responder.start_response(200)
+ responder.start_stream()
+ responder.stream_entry({'entry': 1})
+ responder.send_response_json(503, error="unavailable")
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/x-u1db-multi-json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['[',
+ '\r\n', '{"entry": 1}'], self.response_body)
+ self.assertEqual([',\r\n', '{"error": "unavailable"}\r\n'],
+ responder.content)
+
+
+class TestHTTPApp(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPApp, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.http_app = http_app.HTTPApp(self.state)
+ self.app = paste.fixture.TestApp(self.http_app)
+ self.db0 = self.state._create_database('db0')
+
+ def test_bad_request_broken(self):
+ resp = self.app.put('/db0/doc/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/foo'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_bad_request_dispatch(self):
+ resp = self.app.put('/db0/foo/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_version(self):
+ resp = self.app.get('/')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"version": _u1db_version}, json.loads(resp.body))
+
+ def test_create_database(self):
+ resp = self.app.put('/db1', params='{}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+
+ resp = self.app.put('/db1', params='{}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+
+ def test_delete_database(self):
+ resp = self.app.delete('/db0')
+ self.assertEqual(200, resp.status)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.check_database, 'db0')
+
+ def test_get_database(self):
+ resp = self.app.get('/db0')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({}, json.loads(resp.body))
+
+ def test_valid_database_names(self):
+ resp = self.app.get('/a-database', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/db1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/0', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/0-0', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/org.future', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ def test_invalid_database_names(self):
+ resp = self.app.get('/.a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ resp = self.app.get('/-a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ resp = self.app.get('/_a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_put_doc_create(self):
+ resp = self.app.put('/db0/doc/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'})
+ doc = self.db0.get_doc('doc1')
+ self.assertEqual(201, resp.status) # created
+ self.assertEqual('{"x": 1}', doc.get_json())
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_put_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev,
+ params='{"x": 2}',
+ headers={'content-type': 'application/json'})
+ doc = self.db0.get_doc('doc1')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('{"x": 2}', doc.get_json())
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_put_doc_too_large(self):
+ self.http_app.max_request_size = 15000
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev,
+ params='{"%s": 2}' % ('z' * 16000),
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_delete_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.delete('/db0/doc/doc1?old_rev=%s' % doc.rev)
+ doc = self.db0.get_doc('doc1', include_deleted=True)
+ self.assertEqual(None, doc.content)
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_get_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.get('/db0/doc/%s' % doc.doc_id)
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"x": 1}', resp.body)
+ self.assertEqual(doc.rev, resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_non_existing(self):
+ resp = self.app.get('/db0/doc/not-there', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "document does not exist"}, json.loads(resp.body))
+ self.assertEqual('', resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_deleted(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get('/db0/doc/doc1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DocumentDoesNotExist.wire_description},
+ json.loads(resp.body))
+
+ def test_get_doc_deleted_explicit_exclude(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get(
+ '/db0/doc/doc1?include_deleted=false', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DocumentDoesNotExist.wire_description},
+ json.loads(resp.body))
+
+ def test_get_deleted_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get(
+ '/db0/doc/doc1?include_deleted=true', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DOCUMENT_DELETED}, json.loads(resp.body))
+ self.assertEqual(doc.rev, resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_non_existing_dabase(self):
+ resp = self.app.get('/not-there/doc/doc1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "database does not exist"}, json.loads(resp.body))
+
+ def test_get_docs(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False},
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_missing_doc_ids(self):
+ resp = self.app.get('/db0/docs', expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "missing document ids"}, json.loads(resp.body))
+
+ def test_get_docs_empty_doc_ids(self):
+ resp = self.app.get('/db0/docs?doc_ids=', expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "missing document ids"}, json.loads(resp.body))
+
+ def test_get_docs_percent(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc%1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc%1",
+ "has_conflicts": False},
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_deleted(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ self.db0.delete_doc(doc2)
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_include_deleted(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ self.db0.delete_doc(doc2)
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s&include_deleted=true' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False},
+ {"content": None, "doc_rev": "db0:2", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_sync_info(self):
+ self.db0._set_replica_gen_and_trans_id('other-id', 1, 'T-transid')
+ resp = self.app.get('/db0/sync-from/other-id')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(dict(target_replica_uid='db0',
+ target_replica_generation=0,
+ target_replica_transaction_id='',
+ source_replica_uid='other-id',
+ source_replica_generation=1,
+ source_transaction_id='T-transid'),
+ json.loads(resp.body))
+
+ def test_record_sync_info(self):
+ resp = self.app.put('/db0/sync-from/other-id',
+ params='{"generation": 2, "transaction_id": "T-transid"}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+ self.assertEqual(
+ (2, 'T-transid'),
+ self.db0._get_replica_gen_and_trans_id('other-id'))
+
+ def test_sync_exchange_send(self):
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'},
+ 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content':
+ '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'}
+ }
+
+ gens = []
+ _do_set_replica_gen_and_trans_id = \
+ self.db0._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness(other_uid, other_gen, other_trans_id):
+ gens.append((other_uid, other_gen))
+ _do_set_replica_gen_and_trans_id(
+ other_uid, other_gen, other_trans_id)
+ self.assertGetDoc(self.db0, entries[other_gen]['id'],
+ entries[other_gen]['rev'],
+ entries[other_gen]['content'], False)
+
+ self.patch(
+ self.db0, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness)
+
+ args = dict(last_known_generation=0)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s,\r\n" % json.dumps(entries[10]) +
+ "%s\r\n" % json.dumps(entries[11]) +
+ "]\r\n")
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ bits = resp.body.split('\r\n')
+ self.assertEqual('[', bits[0])
+ last_trans_id = self.db0._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id},
+ json.loads(bits[1]))
+ self.assertEqual(']', bits[2])
+ self.assertEqual('', bits[3])
+ self.assertEqual([('replica', 10), ('replica', 11)], gens)
+
+ def test_sync_exchange_send_ensure(self):
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'},
+ 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content':
+ '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'}
+ }
+
+ args = dict(last_known_generation=0, ensure=True)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s,\r\n" % json.dumps(entries[10]) +
+ "%s\r\n" % json.dumps(entries[11]) +
+ "]\r\n")
+ resp = self.app.post('/dbnew/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ bits = resp.body.split('\r\n')
+ self.assertEqual('[', bits[0])
+ dbnew = self.state.open_database("dbnew")
+ last_trans_id = dbnew._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id,
+ 'replica_uid': dbnew._replica_uid},
+ json.loads(bits[1]))
+ self.assertEqual(']', bits[2])
+ self.assertEqual('', bits[3])
+
+ def test_sync_exchange_send_entry_too_large(self):
+ self.patch(http_app.SyncResource, 'max_request_size', 20000)
+ self.patch(http_app.SyncResource, 'max_entry_size', 10000)
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "%s"}' % ('H' * 11000), 'gen': 10},
+ }
+ args = dict(last_known_generation=0)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s\r\n" % json.dumps(entries[10]) +
+ "]\r\n")
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_sync_exchange_receive(self):
+ doc = self.db0.create_doc_from_json('{"value": "there"}')
+ doc2 = self.db0.create_doc_from_json('{"value": "there2"}')
+ args = dict(last_known_generation=0)
+ body = "[\r\n%s\r\n]" % json.dumps(args)
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ parts = resp.body.splitlines()
+ self.assertEqual(5, len(parts))
+ self.assertEqual('[', parts[0])
+ last_trans_id = self.db0._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id},
+ json.loads(parts[1].rstrip(",")))
+ part2 = json.loads(parts[2].rstrip(","))
+ self.assertTrue(part2['trans_id'].startswith('T-'))
+ self.assertEqual('{"value": "there"}', part2['content'])
+ self.assertEqual(doc.rev, part2['rev'])
+ self.assertEqual(doc.doc_id, part2['id'])
+ self.assertEqual(1, part2['gen'])
+ part3 = json.loads(parts[3].rstrip(","))
+ self.assertTrue(part3['trans_id'].startswith('T-'))
+ self.assertEqual('{"value": "there2"}', part3['content'])
+ self.assertEqual(doc2.rev, part3['rev'])
+ self.assertEqual(doc2.doc_id, part3['id'])
+ self.assertEqual(2, part3['gen'])
+ self.assertEqual(']', parts[4])
+
+ def test_sync_exchange_error_in_stream(self):
+ args = dict(last_known_generation=0)
+ body = "[\r\n%s\r\n]" % json.dumps(args)
+
+ def boom(self, return_doc_cb):
+ raise errors.Unavailable
+
+ self.patch(sync.SyncExchange, 'return_docs',
+ boom)
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ parts = resp.body.splitlines()
+ self.assertEqual(3, len(parts))
+ self.assertEqual('[', parts[0])
+ self.assertEqual({'new_generation': 0, 'new_transaction_id': ''},
+ json.loads(parts[1].rstrip(",")))
+ self.assertEqual({'error': 'unavailable'}, json.loads(parts[2]))
+
+
+class TestRequestHooks(tests.TestCase):
+
+ def setUp(self):
+ super(TestRequestHooks, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.http_app = http_app.HTTPApp(self.state)
+ self.app = paste.fixture.TestApp(self.http_app)
+ self.db0 = self.state._create_database('db0')
+
+ def test_begin_and_done(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def done(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('done')
+
+ self.http_app.request_begin = begin
+ self.http_app.request_done = done
+
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.app.get('/db0/doc/%s' % doc.doc_id)
+
+ self.assertEqual(['begin', 'done'], calls)
+
+ def test_bad_request(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def bad_request(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('bad-request')
+
+ self.http_app.request_begin = begin
+ self.http_app.request_bad_request = bad_request
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ resp = self.app.put('/db0/foo/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual(['begin', 'bad-request'], calls)
+
+
+class TestHTTPErrors(tests.TestCase):
+
+ def test_wire_description_to_status(self):
+ self.assertNotIn("error", http_errors.wire_description_to_status)
+
+
+class TestHTTPAppErrorHandling(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPAppErrorHandling, self).setUp()
+ self.exc = None
+ self.state = tests.ServerStateForTests()
+
+ class ErroringResource(object):
+
+ def post(_, args, content):
+ raise self.exc
+
+ def lookup_resource(environ, responder):
+ return ErroringResource()
+
+ self.http_app = http_app.HTTPApp(self.state)
+ self.http_app._lookup_resource = lookup_resource
+ self.app = paste.fixture.TestApp(self.http_app)
+
+ def test_RevisionConflict_etc(self):
+ self.exc = errors.RevisionConflict()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(409, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "revision conflict"},
+ json.loads(resp.body))
+
+ def test_Unavailable(self):
+ self.exc = errors.Unavailable
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(503, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "unavailable"},
+ json.loads(resp.body))
+
+ def test_generic_u1db_errors(self):
+ self.exc = errors.U1DBError()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(500, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "error"},
+ json.loads(resp.body))
+
+ def test_generic_u1db_errors_hooks(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def u1db_error(environ, exc):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append(('error', exc))
+
+ self.http_app.request_begin = begin
+ self.http_app.request_u1db_error = u1db_error
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ self.exc = errors.U1DBError()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(500, resp.status)
+ self.assertEqual(['begin', ('error', self.exc)], calls)
+
+ def test_failure(self):
+ class Failure(Exception):
+ pass
+ self.exc = Failure()
+ self.assertRaises(Failure, self.app.post, '/req', params='{}',
+ headers={'content-type': 'application/json'})
+
+ def test_failure_hooks(self):
+ class Failure(Exception):
+ pass
+ calls = []
+
+ def begin(environ):
+ calls.append('begin')
+
+ def failed(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append(('failed', sys.exc_info()))
+
+ self.http_app.request_begin = begin
+ self.http_app.request_failed = failed
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ self.exc = Failure()
+ self.assertRaises(Failure, self.app.post, '/req', params='{}',
+ headers={'content-type': 'application/json'})
+
+ self.assertEqual(2, len(calls))
+ self.assertEqual('begin', calls[0])
+ marker, (exc_type, exc, tb) = calls[1]
+ self.assertEqual('failed', marker)
+ self.assertEqual(self.exc, exc)
+
+
+class TestPluggableSyncExchange(tests.TestCase):
+
+ def setUp(self):
+ super(TestPluggableSyncExchange, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.state.ensure_database('foo')
+
+ def test_plugging(self):
+
+ class MySyncExchange(object):
+ def __init__(self, db, source_replica_uid, last_known_generation):
+ pass
+
+ class MySyncResource(http_app.SyncResource):
+ sync_exchange_class = MySyncExchange
+
+ sync_res = MySyncResource('foo', 'src', self.state, None)
+ sync_res.post_args(
+ {'last_known_generation': 0, 'last_known_trans_id': None}, '{}')
+ self.assertIsInstance(sync_res.sync_exch, MySyncExchange)
diff --git a/u1db/tests/test_http_client.py b/u1db/tests/test_http_client.py
new file mode 100644
index 00000000..115c8aaa
--- /dev/null
+++ b/u1db/tests/test_http_client.py
@@ -0,0 +1,361 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for HTTPDatabase"""
+
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.remote import (
+ http_client,
+ )
+
+
+class TestEncoder(tests.TestCase):
+
+ def test_encode_string(self):
+ self.assertEqual("foo", http_client._encode_query_parameter("foo"))
+
+ def test_encode_true(self):
+ self.assertEqual("true", http_client._encode_query_parameter(True))
+
+ def test_encode_false(self):
+ self.assertEqual("false", http_client._encode_query_parameter(False))
+
+
+class TestHTTPClientBase(tests.TestCaseWithServer):
+
+ def setUp(self):
+ super(TestHTTPClientBase, self).setUp()
+ self.errors = 0
+
+ def app(self, environ, start_response):
+ if environ['PATH_INFO'].endswith('echo'):
+ start_response("200 OK", [('Content-Type', 'application/json')])
+ ret = {}
+ for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
+ ret[name] = environ[name]
+ if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
+ ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
+ content_length = int(environ['CONTENT_LENGTH'])
+ ret['body'] = environ['wsgi.input'].read(content_length)
+ return [json.dumps(ret)]
+ elif environ['PATH_INFO'].endswith('error_then_accept'):
+ if self.errors >= 3:
+ start_response(
+ "200 OK", [('Content-Type', 'application/json')])
+ ret = {}
+ for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
+ ret[name] = environ[name]
+ if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
+ ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
+ content_length = int(environ['CONTENT_LENGTH'])
+ ret['body'] = '{"oki": "doki"}'
+ return [json.dumps(ret)]
+ self.errors += 1
+ content_length = int(environ['CONTENT_LENGTH'])
+ error = json.loads(
+ environ['wsgi.input'].read(content_length))
+ response = error['response']
+ # In debug mode, wsgiref has an assertion that the status parameter
+ # is a 'str' object. However error['status'] returns a unicode
+ # object.
+ status = str(error['status'])
+ if isinstance(response, unicode):
+ response = str(response)
+ if isinstance(response, str):
+ start_response(status, [('Content-Type', 'text/plain')])
+ return [str(response)]
+ else:
+ start_response(status, [('Content-Type', 'application/json')])
+ return [json.dumps(response)]
+ elif environ['PATH_INFO'].endswith('error'):
+ self.errors += 1
+ content_length = int(environ['CONTENT_LENGTH'])
+ error = json.loads(
+ environ['wsgi.input'].read(content_length))
+ response = error['response']
+ # In debug mode, wsgiref has an assertion that the status parameter
+ # is a 'str' object. However error['status'] returns a unicode
+ # object.
+ status = str(error['status'])
+ if isinstance(response, unicode):
+ response = str(response)
+ if isinstance(response, str):
+ start_response(status, [('Content-Type', 'text/plain')])
+ return [str(response)]
+ else:
+ start_response(status, [('Content-Type', 'application/json')])
+ return [json.dumps(response)]
+ elif '/oauth' in environ['PATH_INFO']:
+ base_url = self.getURL('').rstrip('/')
+ oauth_req = oauth.OAuthRequest.from_request(
+ http_method=environ['REQUEST_METHOD'],
+ http_url=base_url + environ['PATH_INFO'],
+ headers={'Authorization': environ['HTTP_AUTHORIZATION']},
+ query_string=environ['QUERY_STRING']
+ )
+ oauth_server = oauth.OAuthServer(tests.testingOAuthStore)
+ oauth_server.add_signature_method(tests.sign_meth_HMAC_SHA1)
+ try:
+ consumer, token, params = oauth_server.verify_request(
+ oauth_req)
+ except oauth.OAuthError, e:
+ start_response("401 Unauthorized",
+ [('Content-Type', 'application/json')])
+ return [json.dumps({"error": "unauthorized",
+ "message": e.message})]
+ start_response("200 OK", [('Content-Type', 'application/json')])
+ return [json.dumps([environ['PATH_INFO'], token.key, params])]
+
+ def make_app(self):
+ return self.app
+
+ def getClient(self, **kwds):
+ self.startServer()
+ return http_client.HTTPClientBase(self.getURL('dbase'), **kwds)
+
+ def test_construct(self):
+ self.startServer()
+ url = self.getURL()
+ cli = http_client.HTTPClientBase(url)
+ self.assertEqual(url, cli._url.geturl())
+ self.assertIs(None, cli._conn)
+
+ def test_parse_url(self):
+ cli = http_client.HTTPClientBase(
+ '%s://127.0.0.1:12345/' % self.url_scheme)
+ self.assertEqual(self.url_scheme, cli._url.scheme)
+ self.assertEqual('127.0.0.1', cli._url.hostname)
+ self.assertEqual(12345, cli._url.port)
+ self.assertEqual('/', cli._url.path)
+
+ def test__ensure_connection(self):
+ cli = self.getClient()
+ self.assertIs(None, cli._conn)
+ cli._ensure_connection()
+ self.assertIsNot(None, cli._conn)
+ conn = cli._conn
+ cli._ensure_connection()
+ self.assertIs(conn, cli._conn)
+
+ def test_close(self):
+ cli = self.getClient()
+ cli._ensure_connection()
+ cli.close()
+ self.assertIs(None, cli._conn)
+
+ def test__request(self):
+ cli = self.getClient()
+ res, headers = cli._request('PUT', ['echo'], {}, {})
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': '',
+ 'body': '{}',
+ 'REQUEST_METHOD': 'PUT'}, json.loads(res))
+
+ res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1})
+ self.assertEqual({'PATH_INFO': '/dbase/doc/echo',
+ 'QUERY_STRING': 'a=1',
+ 'REQUEST_METHOD': 'GET'}, json.loads(res))
+
+ res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1})
+ self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo',
+ 'QUERY_STRING': 'a=1',
+ 'REQUEST_METHOD': 'GET'}, json.loads(res))
+
+ res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body',
+ 'application/x-test')
+ self.assertEqual({'CONTENT_TYPE': 'application/x-test',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': 'b=2',
+ 'body': 'Body',
+ 'REQUEST_METHOD': 'POST'}, json.loads(res))
+
+ def test__request_json(self):
+ cli = self.getClient()
+ res, headers = cli._request_json(
+ 'POST', ['echo'], {'b': 2}, {'a': 'x'})
+ self.assertEqual('application/json', headers['content-type'])
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': 'b=2',
+ 'body': '{"a": "x"}',
+ 'REQUEST_METHOD': 'POST'}, res)
+
+ def test_unspecified_http_error(self):
+ cli = self.getClient()
+ self.assertRaises(errors.HTTPError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "500 Internal Error",
+ 'response': "Crash."})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "500 Internal Error",
+ 'response': "Fail."})
+ except errors.HTTPError, e:
+ pass
+
+ self.assertEqual(500, e.status)
+ self.assertEqual("Fail.", e.message)
+ self.assertTrue("content-type" in e.headers)
+
+ def test_revision_conflict(self):
+ cli = self.getClient()
+ self.assertRaises(errors.RevisionConflict,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "409 Conflict",
+ 'response': {"error": "revision conflict"}})
+
+ def test_unavailable_proper(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ self.assertRaises(errors.Unavailable,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "503 Service Unavailable",
+ 'response': {"error": "unavailable"}})
+ self.assertEqual(5, self.errors)
+
+ def test_unavailable_then_available(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ res, headers = cli._request_json(
+ 'POST', ['error_then_accept'], {'b': 2},
+ {'status': "503 Service Unavailable",
+ 'response': {"error": "unavailable"}})
+ self.assertEqual('application/json', headers['content-type'])
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/error_then_accept',
+ 'QUERY_STRING': 'b=2',
+ 'body': '{"oki": "doki"}',
+ 'REQUEST_METHOD': 'POST'}, res)
+ self.assertEqual(3, self.errors)
+
+ def test_unavailable_random_source(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "503 Service Unavailable",
+ 'response': "random unavailable."})
+ except errors.Unavailable, e:
+ pass
+
+ self.assertEqual(503, e.status)
+ self.assertEqual("random unavailable.", e.message)
+ self.assertTrue("content-type" in e.headers)
+ self.assertEqual(5, self.errors)
+
+ def test_document_too_big(self):
+ cli = self.getClient()
+ self.assertRaises(errors.DocumentTooBig,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "document too big"}})
+
+ def test_user_quota_exceeded(self):
+ cli = self.getClient()
+ self.assertRaises(errors.UserQuotaExceeded,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "user quota exceeded"}})
+
+ def test_user_needs_subscription(self):
+ cli = self.getClient()
+ self.assertRaises(errors.SubscriptionNeeded,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "user needs subscription"}})
+
+ def test_generic_u1db_error(self):
+ cli = self.getClient()
+ self.assertRaises(errors.U1DBError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': {"error": "error"}})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': {"error": "error"}})
+ except errors.U1DBError, e:
+ pass
+ self.assertIs(e.__class__, errors.U1DBError)
+
+ def test_unspecified_bad_request(self):
+ cli = self.getClient()
+ self.assertRaises(errors.HTTPError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': "<Bad Request>"})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': "<Bad Request>"})
+ except errors.HTTPError, e:
+ pass
+
+ self.assertEqual(400, e.status)
+ self.assertEqual("<Bad Request>", e.message)
+ self.assertTrue("content-type" in e.headers)
+
+ def test_oauth(self):
+ cli = self.getClient()
+ cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res))
+
+ # oauth does its own internal quoting
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth', 'foo bar'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth/foo bar', tests.token1.key, params],
+ json.loads(res))
+
+ def test_oauth_ctr_creds(self):
+ cli = self.getClient(creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret,
+ }})
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res))
+
+ def test_unknown_creds(self):
+ self.assertRaises(errors.UnknownAuthMethod,
+ self.getClient, creds={'foo': {}})
+ self.assertRaises(errors.UnknownAuthMethod,
+ self.getClient, creds={})
+
+ def test_oauth_Unauthorized(self):
+ cli = self.getClient()
+ cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, "WRONG")
+ params = {'y': 'foo'}
+ self.assertRaises(errors.Unauthorized, cli._request, 'GET',
+ ['doc', 'oauth'], params)
diff --git a/u1db/tests/test_http_database.py b/u1db/tests/test_http_database.py
new file mode 100644
index 00000000..c8e7eb76
--- /dev/null
+++ b/u1db/tests/test_http_database.py
@@ -0,0 +1,256 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for HTTPDatabase"""
+
+import inspect
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ errors,
+ Document,
+ tests,
+ )
+from u1db.remote import (
+ http_database,
+ http_target,
+ )
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+)
+
+
+class TestHTTPDatabaseSimpleOperations(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPDatabaseSimpleOperations, self).setUp()
+ self.db = http_database.HTTPDatabase('dbase')
+ self.db._conn = object() # crash if used
+ self.got = None
+ self.response_val = None
+
+ def _request(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ def _request_json(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ self.db._request = _request
+ self.db._request_json = _request_json
+
+ def test__sanity_same_signature(self):
+ my_request_sig = inspect.getargspec(self.db._request)
+ my_request_sig = (['self'] + my_request_sig[0],) + my_request_sig[1:]
+ self.assertEqual(my_request_sig,
+ inspect.getargspec(http_database.HTTPDatabase._request))
+ my_request_json_sig = inspect.getargspec(self.db._request_json)
+ my_request_json_sig = ((['self'] + my_request_json_sig[0],) +
+ my_request_json_sig[1:])
+ self.assertEqual(my_request_json_sig,
+ inspect.getargspec(http_database.HTTPDatabase._request_json))
+
+ def test__ensure(self):
+ self.response_val = {'ok': True}, {}
+ self.db._ensure()
+ self.assertEqual(('PUT', [], {}, {}, None), self.got)
+
+ def test__delete(self):
+ self.response_val = {'ok': True}, {}
+ self.db._delete()
+ self.assertEqual(('DELETE', [], {}, {}, None), self.got)
+
+ def test__check(self):
+ self.response_val = {}, {}
+ res = self.db._check()
+ self.assertEqual({}, res)
+ self.assertEqual(('GET', [], None, None, None), self.got)
+
+ def test_put_doc(self):
+ self.response_val = {'rev': 'doc-rev'}, {}
+ doc = Document('doc-id', None, '{"v": 1}')
+ res = self.db.put_doc(doc)
+ self.assertEqual('doc-rev', res)
+ self.assertEqual('doc-rev', doc.rev)
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {},
+ '{"v": 1}', 'application/json'), self.got)
+
+ self.response_val = {'rev': 'doc-rev-2'}, {}
+ doc.content = {"v": 2}
+ res = self.db.put_doc(doc)
+ self.assertEqual('doc-rev-2', res)
+ self.assertEqual('doc-rev-2', doc.rev)
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {'old_rev': 'doc-rev'},
+ '{"v": 2}', 'application/json'), self.got)
+
+ def test_get_doc(self):
+ self.response_val = '{"v": 2}', {'x-u1db-rev': 'doc-rev',
+ 'x-u1db-has-conflicts': 'false'}
+ self.assertGetDoc(self.db, 'doc-id', 'doc-rev', '{"v": 2}', False)
+ self.assertEqual(
+ ('GET', ['doc', 'doc-id'], {'include_deleted': False}, None, None),
+ self.got)
+
+ def test_get_doc_non_existing(self):
+ self.response_val = errors.DocumentDoesNotExist()
+ self.assertIs(None, self.db.get_doc('not-there'))
+ self.assertEqual(
+ ('GET', ['doc', 'not-there'], {'include_deleted': False}, None,
+ None), self.got)
+
+ def test_get_doc_deleted(self):
+ self.response_val = errors.DocumentDoesNotExist()
+ self.assertIs(None, self.db.get_doc('deleted'))
+ self.assertEqual(
+ ('GET', ['doc', 'deleted'], {'include_deleted': False}, None,
+ None), self.got)
+
+ def test_get_doc_deleted_include_deleted(self):
+ self.response_val = errors.HTTPError(404,
+ json.dumps(
+ {"error": errors.DOCUMENT_DELETED}
+ ),
+ {'x-u1db-rev': 'doc-rev-gone',
+ 'x-u1db-has-conflicts': 'false'})
+ doc = self.db.get_doc('deleted', include_deleted=True)
+ self.assertEqual('deleted', doc.doc_id)
+ self.assertEqual('doc-rev-gone', doc.rev)
+ self.assertIs(None, doc.content)
+ self.assertEqual(
+ ('GET', ['doc', 'deleted'], {'include_deleted': True}, None, None),
+ self.got)
+
+ def test_get_doc_pass_through_errors(self):
+ self.response_val = errors.HTTPError(500, 'Crash.')
+ self.assertRaises(errors.HTTPError,
+ self.db.get_doc, 'something-something')
+
+ def test_create_doc_with_id(self):
+ self.response_val = {'rev': 'doc-rev'}, {}
+ new_doc = self.db.create_doc_from_json('{"v": 1}', doc_id='doc-id')
+ self.assertEqual('doc-rev', new_doc.rev)
+ self.assertEqual('doc-id', new_doc.doc_id)
+ self.assertEqual('{"v": 1}', new_doc.get_json())
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {},
+ '{"v": 1}', 'application/json'), self.got)
+
+ def test_create_doc_without_id(self):
+ self.response_val = {'rev': 'doc-rev-2'}, {}
+ new_doc = self.db.create_doc_from_json('{"v": 3}')
+ self.assertEqual('D-', new_doc.doc_id[:2])
+ self.assertEqual('doc-rev-2', new_doc.rev)
+ self.assertEqual('{"v": 3}', new_doc.get_json())
+ self.assertEqual(('PUT', ['doc', new_doc.doc_id], {},
+ '{"v": 3}', 'application/json'), self.got)
+
+ def test_delete_doc(self):
+ self.response_val = {'rev': 'doc-rev-gone'}, {}
+ doc = Document('doc-id', 'doc-rev', None)
+ self.db.delete_doc(doc)
+ self.assertEqual('doc-rev-gone', doc.rev)
+ self.assertEqual(('DELETE', ['doc', 'doc-id'], {'old_rev': 'doc-rev'},
+ None, None), self.got)
+
+ def test_get_sync_target(self):
+ st = self.db.get_sync_target()
+ self.assertIsInstance(st, http_target.HTTPSyncTarget)
+ self.assertEqual(st._url, self.db._url)
+
+ def test_get_sync_target_inherits_oauth_credentials(self):
+ self.db.set_oauth_credentials(tests.consumer1.key,
+ tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ st = self.db.get_sync_target()
+ self.assertEqual(self.db._creds, st._creds)
+
+
+class TestHTTPDatabaseCtrWithCreds(tests.TestCase):
+
+ def test_ctr_with_creds(self):
+ db1 = http_database.HTTPDatabase('http://dbs/db', creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret
+ }})
+ self.assertIn('oauth', db1._creds)
+
+
+class TestHTTPDatabaseIntegration(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestHTTPDatabaseIntegration, self).setUp()
+ self.startServer()
+
+ def test_non_existing_db(self):
+ db = http_database.HTTPDatabase(self.getURL('not-there'))
+ self.assertRaises(errors.DatabaseDoesNotExist, db.get_doc, 'doc1')
+
+ def test__ensure(self):
+ db = http_database.HTTPDatabase(self.getURL('new'))
+ db._ensure()
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test__delete(self):
+ self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase(self.getURL('db0'))
+ db._delete()
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_open_database_existing(self):
+ self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase.open_database(self.getURL('db0'),
+ create=False)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_open_database_non_existing(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ http_database.HTTPDatabase.open_database,
+ self.getURL('not-there'),
+ create=False)
+
+ def test_open_database_create(self):
+ db = http_database.HTTPDatabase.open_database(self.getURL('new'),
+ create=True)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_delete_database_existing(self):
+ self.request_state._create_database('db0')
+ http_database.HTTPDatabase.delete_database(self.getURL('db0'))
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_doc_ids_needing_quoting(self):
+ db0 = self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase.open_database(self.getURL('db0'),
+ create=False)
+ doc = Document('%fff', None, '{}')
+ db.put_doc(doc)
+ self.assertGetDoc(db0, '%fff', doc.rev, '{}', False)
+ self.assertGetDoc(db, '%fff', doc.rev, '{}', False)
diff --git a/u1db/tests/test_https.py b/u1db/tests/test_https.py
new file mode 100644
index 00000000..67681c8a
--- /dev/null
+++ b/u1db/tests/test_https.py
@@ -0,0 +1,117 @@
+"""Test support for client-side https support."""
+
+import os
+import ssl
+import sys
+
+from paste import httpserver
+
+from u1db import (
+ tests,
+ )
+from u1db.remote import (
+ http_client,
+ http_target,
+ )
+
+from u1db.tests.test_remote_sync_target import (
+ make_oauth_http_app,
+ )
+
+
+def https_server_def():
+ def make_server(host_port, application):
+ from OpenSSL import SSL
+ cert_file = os.path.join(os.path.dirname(__file__), 'testing-certs',
+ 'testing.cert')
+ key_file = os.path.join(os.path.dirname(__file__), 'testing-certs',
+ 'testing.key')
+ ssl_context = SSL.Context(SSL.SSLv23_METHOD)
+ ssl_context.use_privatekey_file(key_file)
+ ssl_context.use_certificate_chain_file(cert_file)
+ srv = httpserver.WSGIServerBase(application, host_port,
+ httpserver.WSGIHandler,
+ ssl_context=ssl_context
+ )
+
+ def shutdown_request(req):
+ req.shutdown()
+ srv.close_request(req)
+
+ srv.shutdown_request = shutdown_request
+ application.base_url = "https://localhost:%s" % srv.server_address[1]
+ return srv
+ return make_server, "shutdown", "https"
+
+
+def oauth_https_sync_target(test, host, path):
+ _, port = test.server.server_address
+ st = http_target.HTTPSyncTarget('https://%s:%d/~/%s' % (host, port, path))
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestHttpSyncTargetHttpsSupport(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('oauth_https', {'server_def': https_server_def,
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': oauth_https_sync_target
+ }),
+ ]
+
+ def setUp(self):
+ try:
+ import OpenSSL # noqa
+ except ImportError:
+ self.skipTest("Requires pyOpenSSL")
+ self.cacert_pem = os.path.join(os.path.dirname(__file__),
+ 'testing-certs', 'cacert.pem')
+ super(TestHttpSyncTargetHttpsSupport, self).setUp()
+
+ def getSyncTarget(self, host, path=None):
+ if self.server is None:
+ self.startServer()
+ return self.sync_target(self, host, path)
+
+ def test_working(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ self.patch(http_client, 'CA_CERTS', self.cacert_pem)
+ remote_target = self.getSyncTarget('localhost', 'test')
+ remote_target.record_sync_info('other-id', 2, 'T-id')
+ self.assertEqual(
+ (2, 'T-id'), db._get_replica_gen_and_trans_id('other-id'))
+
+ def test_cannot_verify_cert(self):
+ if not sys.platform.startswith('linux'):
+ self.skipTest(
+ "XXX certificate verification happens on linux only for now")
+ self.startServer()
+ # don't print expected traceback server-side
+ self.server.handle_error = lambda req, cli_addr: None
+ self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('localhost', 'test')
+ try:
+ remote_target.record_sync_info('other-id', 2, 'T-id')
+ except ssl.SSLError, e:
+ self.assertIn("certificate verify failed", str(e))
+ else:
+ self.fail("certificate verification should have failed.")
+
+ def test_host_mismatch(self):
+ if not sys.platform.startswith('linux'):
+ self.skipTest(
+ "XXX certificate verification happens on linux only for now")
+ self.startServer()
+ self.request_state._create_database('test')
+ self.patch(http_client, 'CA_CERTS', self.cacert_pem)
+ remote_target = self.getSyncTarget('127.0.0.1', 'test')
+ self.assertRaises(
+ http_client.CertificateError, remote_target.record_sync_info,
+ 'other-id', 2, 'T-id')
+
+
+load_tests = tests.load_with_scenarios
diff --git a/u1db/tests/test_inmemory.py b/u1db/tests/test_inmemory.py
new file mode 100644
index 00000000..255a1e08
--- /dev/null
+++ b/u1db/tests/test_inmemory.py
@@ -0,0 +1,128 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test in-memory backend internals."""
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.backends import inmemory
+
+
+simple_doc = '{"key": "value"}'
+
+
+class TestInMemoryDatabaseInternals(tests.TestCase):
+
+ def setUp(self):
+ super(TestInMemoryDatabaseInternals, self).setUp()
+ self.db = inmemory.InMemoryDatabase('test')
+
+ def test__allocate_doc_rev_from_None(self):
+ self.assertEqual('test:1', self.db._allocate_doc_rev(None))
+
+ def test__allocate_doc_rev_incremental(self):
+ self.assertEqual('test:2', self.db._allocate_doc_rev('test:1'))
+
+ def test__allocate_doc_rev_other(self):
+ self.assertEqual('replica:1|test:1',
+ self.db._allocate_doc_rev('replica:1'))
+
+ def test__get_replica_uid(self):
+ self.assertEqual('test', self.db._replica_uid)
+
+
+class TestInMemoryIndex(tests.TestCase):
+
+ def test_has_name_and_definition(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ self.assertEqual('idx-name', idx._name)
+ self.assertEqual(['key'], idx._definition)
+
+ def test_evaluate_json(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ self.assertEqual(['value'], idx.evaluate_json(simple_doc))
+
+ def test_evaluate_json_field_None(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['missing'])
+ self.assertEqual([], idx.evaluate_json(simple_doc))
+
+ def test_evaluate_json_subfield_None(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key', 'missing'])
+ self.assertEqual([], idx.evaluate_json(simple_doc))
+
+ def test_evaluate_multi_index(self):
+ doc = '{"key": "value", "key2": "value2"}'
+ idx = inmemory.InMemoryIndex('idx-name', ['key', 'key2'])
+ self.assertEqual(['value\x01value2'],
+ idx.evaluate_json(doc))
+
+ def test_update_ignores_None(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['nokey'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual({}, idx._values)
+
+ def test_update_adds_entry(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual({'value': ['doc-id']}, idx._values)
+
+ def test_remove_json(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual({'value': ['doc-id']}, idx._values)
+ idx.remove_json('doc-id', simple_doc)
+ self.assertEqual({}, idx._values)
+
+ def test_remove_json_multiple(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ idx.add_json('doc2-id', simple_doc)
+ self.assertEqual({'value': ['doc-id', 'doc2-id']}, idx._values)
+ idx.remove_json('doc-id', simple_doc)
+ self.assertEqual({'value': ['doc2-id']}, idx._values)
+
+ def test_keys(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual(['value'], idx.keys())
+
+ def test_lookup(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual(['doc-id'], idx.lookup(['value']))
+
+ def test_lookup_multi(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ idx.add_json('doc2-id', simple_doc)
+ self.assertEqual(['doc-id', 'doc2-id'], idx.lookup(['value']))
+
+ def test__find_non_wildcards(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['k1', 'k2', 'k3'])
+ self.assertEqual(-1, idx._find_non_wildcards(('a', 'b', 'c')))
+ self.assertEqual(2, idx._find_non_wildcards(('a', 'b', '*')))
+ self.assertEqual(3, idx._find_non_wildcards(('a', 'b', 'c*')))
+ self.assertEqual(2, idx._find_non_wildcards(('a', 'b*', '*')))
+ self.assertEqual(0, idx._find_non_wildcards(('*', '*', '*')))
+ self.assertEqual(1, idx._find_non_wildcards(('a*', '*', '*')))
+ self.assertRaises(errors.InvalidValueForIndex,
+ idx._find_non_wildcards, ('a', 'b'))
+ self.assertRaises(errors.InvalidValueForIndex,
+ idx._find_non_wildcards, ('a', 'b', 'c', 'd'))
+ self.assertRaises(errors.InvalidGlobbing,
+ idx._find_non_wildcards, ('*', 'b', 'c'))
diff --git a/u1db/tests/test_open.py b/u1db/tests/test_open.py
new file mode 100644
index 00000000..fbeb0cfd
--- /dev/null
+++ b/u1db/tests/test_open.py
@@ -0,0 +1,69 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test u1db.open"""
+
+import os
+
+from u1db import (
+ errors,
+ open as u1db_open,
+ tests,
+ )
+from u1db.backends import sqlite_backend
+from u1db.tests.test_backends import TestAlternativeDocument
+
+
+class TestU1DBOpen(tests.TestCase):
+
+ def setUp(self):
+ super(TestU1DBOpen, self).setUp()
+ tmpdir = self.createTempDir()
+ self.db_path = tmpdir + '/test.db'
+
+ def test_open_no_create(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ u1db_open, self.db_path, create=False)
+ self.assertFalse(os.path.exists(self.db_path))
+
+ def test_open_create(self):
+ db = u1db_open(self.db_path, create=True)
+ self.addCleanup(db.close)
+ self.assertTrue(os.path.exists(self.db_path))
+ self.assertIsInstance(db, sqlite_backend.SQLiteDatabase)
+
+ def test_open_with_factory(self):
+ db = u1db_open(self.db_path, create=True,
+ document_factory=TestAlternativeDocument)
+ self.addCleanup(db.close)
+ self.assertEqual(TestAlternativeDocument, db._factory)
+
+ def test_open_existing(self):
+ db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path)
+ self.addCleanup(db.close)
+ doc = db.create_doc_from_json(tests.simple_doc)
+ # Even though create=True, we shouldn't wipe the db
+ db2 = u1db_open(self.db_path, create=True)
+ self.addCleanup(db2.close)
+ doc2 = db2.get_doc(doc.doc_id)
+ self.assertEqual(doc, doc2)
+
+ def test_open_existing_no_create(self):
+ db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path)
+ self.addCleanup(db.close)
+ db2 = u1db_open(self.db_path, create=False)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
diff --git a/u1db/tests/test_query_parser.py b/u1db/tests/test_query_parser.py
new file mode 100644
index 00000000..ee374267
--- /dev/null
+++ b/u1db/tests/test_query_parser.py
@@ -0,0 +1,443 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+from u1db import (
+ errors,
+ query_parser,
+ tests,
+ )
+
+
+trivial_raw_doc = {}
+
+
+class TestFieldName(tests.TestCase):
+
+ def test_check_fieldname_valid(self):
+ self.assertIsNone(query_parser.check_fieldname("foo"))
+
+ def test_check_fieldname_invalid(self):
+ self.assertRaises(
+ errors.IndexDefinitionParseError, query_parser.check_fieldname,
+ "foo.")
+
+
+class TestMakeTree(tests.TestCase):
+
+ def setUp(self):
+ super(TestMakeTree, self).setUp()
+ self.parser = query_parser.Parser()
+
+ def assertParseError(self, definition):
+ self.assertRaises(
+ errors.IndexDefinitionParseError, self.parser.parse,
+ definition)
+
+ def test_single_field(self):
+ self.assertIsInstance(
+ self.parser.parse('f'), query_parser.ExtractField)
+
+ def test_single_mapping(self):
+ self.assertIsInstance(
+ self.parser.parse('bool(field1)'), query_parser.Bool)
+
+ def test_nested_mapping(self):
+ self.assertIsInstance(
+ self.parser.parse('lower(split_words(field1))'),
+ query_parser.Lower)
+
+ def test_nested_branching_mapping(self):
+ self.assertIsInstance(
+ self.parser.parse(
+ 'combine(lower(field1), split_words(field2), '
+ 'number(field3, 5))'), query_parser.Combine)
+
+ def test_single_mapping_multiple_fields(self):
+ self.assertIsInstance(
+ self.parser.parse('number(field1, 5)'), query_parser.Number)
+
+ def test_unknown_mapping(self):
+ self.assertParseError('mapping(whatever)')
+
+ def test_parse_missing_close_paren(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_top_level_commas(self):
+ self.assertParseError("a, b")
+
+ def test_invalid_field_name(self):
+ self.assertParseError("a.")
+
+ def test_invalid_inner_field_name(self):
+ self.assertParseError("lower(a.)")
+
+ def test_gobbledigook(self):
+ self.assertParseError("(@#@cc @#!*DFJSXV(()jccd")
+
+ def test_leading_space(self):
+ self.assertIsInstance(
+ self.parser.parse(" lower(a)"), query_parser.Lower)
+
+ def test_trailing_space(self):
+ self.assertIsInstance(
+ self.parser.parse("lower(a) "), query_parser.Lower)
+
+ def test_spaces_before_open_paren(self):
+ self.assertIsInstance(
+ self.parser.parse("lower (a)"), query_parser.Lower)
+
+ def test_spaces_after_open_paren(self):
+ self.assertIsInstance(
+ self.parser.parse("lower( a)"), query_parser.Lower)
+
+ def test_spaces_before_close_paren(self):
+ self.assertIsInstance(
+ self.parser.parse("lower(a )"), query_parser.Lower)
+
+ def test_spaces_before_comma(self):
+ self.assertIsInstance(
+ self.parser.parse("number(a , 5)"), query_parser.Number)
+
+ def test_spaces_after_comma(self):
+ self.assertIsInstance(
+ self.parser.parse("number(a, 5)"), query_parser.Number)
+
+
+class TestStaticGetter(tests.TestCase):
+
+ def test_returns_string(self):
+ getter = query_parser.StaticGetter('foo')
+ self.assertEqual(['foo'], getter.get(trivial_raw_doc))
+
+ def test_returns_int(self):
+ getter = query_parser.StaticGetter(9)
+ self.assertEqual([9], getter.get(trivial_raw_doc))
+
+ def test_returns_float(self):
+ getter = query_parser.StaticGetter(9.2)
+ self.assertEqual([9.2], getter.get(trivial_raw_doc))
+
+ def test_returns_None(self):
+ getter = query_parser.StaticGetter(None)
+ self.assertEqual([], getter.get(trivial_raw_doc))
+
+ def test_returns_list(self):
+ getter = query_parser.StaticGetter(['a', 'b'])
+ self.assertEqual(['a', 'b'], getter.get(trivial_raw_doc))
+
+
+class TestExtractField(tests.TestCase):
+
+ def assertExtractField(self, expected, field_name, raw_doc):
+ getter = query_parser.ExtractField(field_name)
+ self.assertEqual(expected, getter.get(raw_doc))
+
+ def test_get_value(self):
+ self.assertExtractField(['bar'], 'foo', {'foo': 'bar'})
+
+ def test_get_value_None(self):
+ self.assertExtractField([], 'foo', {'foo': None})
+
+ def test_get_value_missing_key(self):
+ self.assertExtractField([], 'foo', {})
+
+ def test_get_value_subfield(self):
+ self.assertExtractField(['bar'], 'foo.baz', {'foo': {'baz': 'bar'}})
+
+ def test_get_value_subfield_missing(self):
+ self.assertExtractField([], 'foo.baz', {'foo': 'bar'})
+
+ def test_get_value_dict(self):
+ self.assertExtractField([], 'foo', {'foo': {'baz': 'bar'}})
+
+ def test_get_value_list(self):
+ self.assertExtractField(['bar', 'zap'], 'foo', {'foo': ['bar', 'zap']})
+
+ def test_get_value_mixed_list(self):
+ self.assertExtractField(['bar', 'zap'], 'foo',
+ {'foo': ['bar', ['baa'], 'zap', {'bing': 9}]})
+
+ def test_get_value_list_of_dicts(self):
+ self.assertExtractField([], 'foo', {'foo': [{'zap': 'bar'}]})
+
+ def test_get_value_list_of_dicts2(self):
+ self.assertExtractField(
+ ['bar', 'baz'], 'foo.zap',
+ {'foo': [{'zap': 'bar'}, {'zap': 'baz'}]})
+
+ def test_get_value_int(self):
+ self.assertExtractField([9], 'foo', {'foo': 9})
+
+ def test_get_value_float(self):
+ self.assertExtractField([9.2], 'foo', {'foo': 9.2})
+
+ def test_get_value_bool(self):
+ self.assertExtractField([True], 'foo', {'foo': True})
+ self.assertExtractField([False], 'foo', {'foo': False})
+
+
+class TestLower(tests.TestCase):
+
+ def assertLowerGets(self, expected, input_val):
+ getter = query_parser.Lower(query_parser.StaticGetter(input_val))
+ out_val = getter.get(trivial_raw_doc)
+ self.assertEqual(sorted(expected), sorted(out_val))
+
+ def test_inner_returns_None(self):
+ self.assertLowerGets([], None)
+
+ def test_inner_returns_string(self):
+ self.assertLowerGets(['foo'], 'fOo')
+
+ def test_inner_returns_list(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', 'bAr'])
+
+ def test_inner_returns_int(self):
+ self.assertLowerGets([], 9)
+
+ def test_inner_returns_float(self):
+ self.assertLowerGets([], 9.0)
+
+ def test_inner_returns_bool(self):
+ self.assertLowerGets([], True)
+
+ def test_inner_returns_list_containing_int(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', 9, 'bAr'])
+
+ def test_inner_returns_list_containing_float(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', 9.2, 'bAr'])
+
+ def test_inner_returns_list_containing_bool(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', True, 'bAr'])
+
+ def test_inner_returns_list_containing_list(self):
+ # TODO: Should this be unfolding the inner list?
+ self.assertLowerGets(['foo', 'bar'], ['fOo', ['bAa'], 'bAr'])
+
+ def test_inner_returns_list_containing_dict(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', {'baa': 'xam'}, 'bAr'])
+
+
+class TestSplitWords(tests.TestCase):
+
+ def assertSplitWords(self, expected, value):
+ getter = query_parser.SplitWords(query_parser.StaticGetter(value))
+ self.assertEqual(sorted(expected), sorted(getter.get(trivial_raw_doc)))
+
+ def test_inner_returns_None(self):
+ self.assertSplitWords([], None)
+
+ def test_inner_returns_string(self):
+ self.assertSplitWords(['foo', 'bar'], 'foo bar')
+
+ def test_inner_returns_list(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', 'bar sux'])
+
+ def test_deduplicates(self):
+ self.assertSplitWords(['bar'], ['bar', 'bar', 'bar'])
+
+ def test_inner_returns_int(self):
+ self.assertSplitWords([], 9)
+
+ def test_inner_returns_float(self):
+ self.assertSplitWords([], 9.2)
+
+ def test_inner_returns_bool(self):
+ self.assertSplitWords([], True)
+
+ def test_inner_returns_list_containing_int(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', 9, 'bar sux'])
+
+ def test_inner_returns_list_containing_float(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', 9.2, 'bar sux'])
+
+ def test_inner_returns_list_containing_bool(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', True, 'bar sux'])
+
+ def test_inner_returns_list_containing_list(self):
+ # TODO: Expand sub-lists?
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', ['baa'], 'bar sux'])
+
+ def test_inner_returns_list_containing_dict(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', {'baa': 'xam'}, 'bar sux'])
+
+
+class TestNumber(tests.TestCase):
+
+ def assertNumber(self, expected, value, padding=5):
+ """Assert number transformation produced expected values."""
+ getter = query_parser.Number(query_parser.StaticGetter(value), padding)
+ self.assertEqual(expected, getter.get(trivial_raw_doc))
+
+ def test_inner_returns_None(self):
+ """None is thrown away."""
+ self.assertNumber([], None)
+
+ def test_inner_returns_int(self):
+ """A single integer is converted to zero padded strings."""
+ self.assertNumber(['00009'], 9)
+
+ def test_inner_returns_list(self):
+ """Integers are converted to zero padded strings."""
+ self.assertNumber(['00009', '00235'], [9, 235])
+
+ def test_inner_returns_string(self):
+ """A string is thrown away."""
+ self.assertNumber([], 'foo bar')
+
+ def test_inner_returns_float(self):
+ """A float is thrown away."""
+ self.assertNumber([], 9.2)
+
+ def test_inner_returns_bool(self):
+ """A boolean is thrown away."""
+ self.assertNumber([], True)
+
+ def test_inner_returns_list_containing_strings(self):
+ """Strings in a list are thrown away."""
+ self.assertNumber(['00009'], ['foo baz', 9, 'bar sux'])
+
+ def test_inner_returns_list_containing_float(self):
+ """Floats in a list are thrown away."""
+ self.assertNumber(
+ ['00083', '00073'], [83, 9.2, 73])
+
+ def test_inner_returns_list_containing_bool(self):
+ """Booleans in a list are thrown away."""
+ self.assertNumber(
+ ['00083', '00073'], [83, True, 73])
+
+ def test_inner_returns_list_containing_list(self):
+ """Lists in a list are thrown away."""
+ # TODO: Expand sub-lists?
+ self.assertNumber(
+ ['00012', '03333'], [12, [29], 3333])
+
+ def test_inner_returns_list_containing_dict(self):
+ """Dicts in a list are thrown away."""
+ self.assertNumber(
+ ['00012', '00001'], [12, {54: 89}, 1])
+
+
+class TestIsNull(tests.TestCase):
+
+ def assertIsNull(self, value):
+ getter = query_parser.IsNull(query_parser.StaticGetter(value))
+ self.assertEqual([True], getter.get(trivial_raw_doc))
+
+ def assertIsNotNull(self, value):
+ getter = query_parser.IsNull(query_parser.StaticGetter(value))
+ self.assertEqual([False], getter.get(trivial_raw_doc))
+
+ def test_inner_returns_None(self):
+ self.assertIsNull(None)
+
+ def test_inner_returns_string(self):
+ self.assertIsNotNull('foo')
+
+ def test_inner_returns_list(self):
+ self.assertIsNotNull(['foo', 'bar'])
+
+ def test_inner_returns_empty_list(self):
+ # TODO: is this the behavior we want?
+ self.assertIsNull([])
+
+ def test_inner_returns_int(self):
+ self.assertIsNotNull(9)
+
+ def test_inner_returns_float(self):
+ self.assertIsNotNull(9.2)
+
+ def test_inner_returns_bool(self):
+ self.assertIsNotNull(True)
+
+ # TODO: What about a dict? Inner is likely to return None, even though the
+ # attribute does exist...
+
+
+class TestParser(tests.TestCase):
+
+ def parse(self, spec):
+ parser = query_parser.Parser()
+ return parser.parse(spec)
+
+ def parse_all(self, specs):
+ parser = query_parser.Parser()
+ return parser.parse_all(specs)
+
+ def assertParseError(self, definition):
+ self.assertRaises(errors.IndexDefinitionParseError, self.parse,
+ definition)
+
+ def test_parse_empty_string(self):
+ self.assertRaises(errors.IndexDefinitionParseError, self.parse, "")
+
+ def test_parse_field(self):
+ getter = self.parse("a")
+ self.assertIsInstance(getter, query_parser.ExtractField)
+ self.assertEqual(["a"], getter.field)
+
+ def test_parse_dotted_field(self):
+ getter = self.parse("a.b")
+ self.assertIsInstance(getter, query_parser.ExtractField)
+ self.assertEqual(["a", "b"], getter.field)
+
+ def test_parse_dotted_field_nothing_after_dot(self):
+ self.assertParseError("a.")
+
+ def test_parse_missing_close_on_transformation(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_missing_field_in_transformation(self):
+ self.assertParseError("lower()")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_unknown_op(self):
+ self.assertParseError("no_such_operation(field)")
+
+ def test_parse_wrong_arg_type(self):
+ self.assertParseError("number(field, fnord)")
+
+ def test_parse_transformation(self):
+ getter = self.parse("lower(a)")
+ self.assertIsInstance(getter, query_parser.Lower)
+ self.assertIsInstance(getter.inner, query_parser.ExtractField)
+ self.assertEqual(["a"], getter.inner.field)
+
+ def test_parse_all(self):
+ getters = self.parse_all(["a", "b"])
+ self.assertEqual(2, len(getters))
+ self.assertIsInstance(getters[0], query_parser.ExtractField)
+ self.assertEqual(["a"], getters[0].field)
+ self.assertIsInstance(getters[1], query_parser.ExtractField)
+ self.assertEqual(["b"], getters[1].field)
diff --git a/u1db/tests/test_remote_sync_target.py b/u1db/tests/test_remote_sync_target.py
new file mode 100644
index 00000000..3e0d8995
--- /dev/null
+++ b/u1db/tests/test_remote_sync_target.py
@@ -0,0 +1,314 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for the remote sync targets"""
+
+import cStringIO
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.remote import (
+ http_app,
+ http_target,
+ oauth_middleware,
+ )
+
+
+class TestHTTPSyncTargetBasics(tests.TestCase):
+
+ def test_parse_url(self):
+ remote_target = http_target.HTTPSyncTarget('http://127.0.0.1:12345/')
+ self.assertEqual('http', remote_target._url.scheme)
+ self.assertEqual('127.0.0.1', remote_target._url.hostname)
+ self.assertEqual(12345, remote_target._url.port)
+ self.assertEqual('/', remote_target._url.path)
+
+
+class TestParsingSyncStream(tests.TestCase):
+
+ def test_wrong_start(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "{}\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "\r\n{}\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "", None)
+
+ def test_wrong_end(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{}", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n", None)
+
+ def test_missing_comma(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{}\r\n{"id": "i", "rev": "r", '
+ '"content": "c", "gen": 3}\r\n]', None)
+
+ def test_no_entries(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n]", None)
+
+ def test_extra_comma(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{},\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{},\r\n{"id": "i", "rev": "r", '
+ '"content": "{}", "gen": 3, "trans_id": "T-sid"}'
+ ',\r\n]',
+ lambda doc, gen, trans_id: None)
+
+ def test_error_in_stream(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"new_generation": 0},'
+ '\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "?"}\r\n', None)
+
+
+def make_http_app(state):
+ return http_app.HTTPApp(state)
+
+
+def http_sync_target(test, path):
+ return http_target.HTTPSyncTarget(test.getURL(path))
+
+
+def make_oauth_http_app(state):
+ app = http_app.HTTPApp(state)
+ application = oauth_middleware.OAuthMiddleware(app, None, prefix='/~/')
+ application.get_oauth_data_store = lambda: tests.testingOAuthStore
+ return application
+
+
+def oauth_http_sync_target(test, path):
+ st = http_sync_target(test, '~/' + path)
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestRemoteSyncTargets(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('http', {'make_app_with_state': make_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': http_sync_target}),
+ ('oauth_http', {'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': oauth_http_sync_target}),
+ ]
+
+ def getSyncTarget(self, path=None):
+ if self.server is None:
+ self.startServer()
+ return self.sync_target(self, path)
+
+ def test_get_sync_info(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ db._set_replica_gen_and_trans_id('other-id', 1, 'T-transid')
+ remote_target = self.getSyncTarget('test')
+ self.assertEqual(('test', 0, '', 1, 'T-transid'),
+ remote_target.get_sync_info('other-id'))
+
+ def test_record_sync_info(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('test')
+ remote_target.record_sync_info('other-id', 2, 'T-transid')
+ self.assertEqual(
+ (2, 'T-transid'), db._get_replica_gen_and_trans_id('other-id'))
+
+ def test_sync_exchange_send(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('test')
+ other_docs = []
+
+ def receive_doc(doc):
+ other_docs.append((doc.doc_id, doc.rev, doc.get_json()))
+
+ doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc, 10, 'T-sid')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc)
+ self.assertEqual(1, new_gen)
+ self.assertGetDoc(
+ db, 'doc-here', 'replica:1', '{"value": "here"}', False)
+
+ def test_sync_exchange_send_failure_and_retry_scenario(self):
+ self.startServer()
+
+ def blackhole_getstderr(inst):
+ return cStringIO.StringIO()
+
+ self.patch(self.server.RequestHandlerClass, 'get_stderr',
+ blackhole_getstderr)
+ db = self.request_state._create_database('test')
+ _put_doc_if_newer = db._put_doc_if_newer
+ trigger_ids = ['doc-here2']
+
+ def bomb_put_doc_if_newer(doc, save_conflict,
+ replica_uid=None, replica_gen=None,
+ replica_trans_id=None):
+ if doc.doc_id in trigger_ids:
+ raise Exception
+ return _put_doc_if_newer(doc, save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+ self.patch(db, '_put_doc_if_newer', bomb_put_doc_if_newer)
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ doc1 = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ doc2 = self.make_document('doc-here2', 'replica:1',
+ '{"value": "here2"}')
+ self.assertRaises(
+ errors.HTTPError,
+ remote_target.sync_exchange,
+ [(doc1, 10, 'T-sid'), (doc2, 11, 'T-sud')],
+ 'replica', last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertGetDoc(db, 'doc-here', 'replica:1', '{"value": "here"}',
+ False)
+ self.assertEqual(
+ (10, 'T-sid'), db._get_replica_gen_and_trans_id('replica'))
+ self.assertEqual([], other_changes)
+ # retry
+ trigger_ids = []
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc2, 11, 'T-sud')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc)
+ self.assertGetDoc(db, 'doc-here2', 'replica:1', '{"value": "here2"}',
+ False)
+ self.assertEqual(
+ (11, 'T-sud'), db._get_replica_gen_and_trans_id('replica'))
+ self.assertEqual(2, new_gen)
+ # bounced back to us
+ self.assertEqual(
+ ('doc-here', 'replica:1', '{"value": "here"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_in_stream_error(self):
+ self.startServer()
+
+ def blackhole_getstderr(inst):
+ return cStringIO.StringIO()
+
+ self.patch(self.server.RequestHandlerClass, 'get_stderr',
+ blackhole_getstderr)
+ db = self.request_state._create_database('test')
+ doc = db.create_doc_from_json('{"value": "there"}')
+
+ def bomb_get_docs(doc_ids, check_for_conflicts=None,
+ include_deleted=False):
+ yield doc
+ # delayed failure case
+ raise errors.Unavailable
+
+ self.patch(db, 'get_docs', bomb_get_docs)
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ self.assertRaises(
+ errors.Unavailable, remote_target.sync_exchange, [], 'replica',
+ last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, '{"value": "there"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_receive(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ doc = db.create_doc_from_json('{"value": "there"}')
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ new_gen, trans_id = remote_target.sync_exchange(
+ [], 'replica', last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertEqual(1, new_gen)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, '{"value": "there"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_send_ensure_callback(self):
+ self.startServer()
+ remote_target = self.getSyncTarget('test')
+ other_docs = []
+ replica_uid_box = []
+
+ def receive_doc(doc):
+ other_docs.append((doc.doc_id, doc.rev, doc.get_json()))
+
+ def ensure_cb(replica_uid):
+ replica_uid_box.append(replica_uid)
+
+ doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc, 10, 'T-sid')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc,
+ ensure_callback=ensure_cb)
+ self.assertEqual(1, new_gen)
+ db = self.request_state.open_database('test')
+ self.assertEqual(1, len(replica_uid_box))
+ self.assertEqual(db._replica_uid, replica_uid_box[0])
+ self.assertGetDoc(
+ db, 'doc-here', 'replica:1', '{"value": "here"}', False)
+
+
+load_tests = tests.load_with_scenarios
diff --git a/u1db/tests/test_remote_utils.py b/u1db/tests/test_remote_utils.py
new file mode 100644
index 00000000..959cd882
--- /dev/null
+++ b/u1db/tests/test_remote_utils.py
@@ -0,0 +1,36 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for protocol details utils."""
+
+from u1db.tests import TestCase
+from u1db.remote import utils
+
+
+class TestUtils(TestCase):
+
+ def test_check_and_strip_comma(self):
+ line, comma = utils.check_and_strip_comma("abc,")
+ self.assertTrue(comma)
+ self.assertEqual("abc", line)
+
+ line, comma = utils.check_and_strip_comma("abc")
+ self.assertFalse(comma)
+ self.assertEqual("abc", line)
+
+ line, comma = utils.check_and_strip_comma("")
+ self.assertFalse(comma)
+ self.assertEqual("", line)
diff --git a/u1db/tests/test_server_state.py b/u1db/tests/test_server_state.py
new file mode 100644
index 00000000..fc3f1282
--- /dev/null
+++ b/u1db/tests/test_server_state.py
@@ -0,0 +1,93 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for server state object."""
+
+import os
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.remote import (
+ server_state,
+ )
+from u1db.backends import sqlite_backend
+
+
+class TestServerState(tests.TestCase):
+
+ def setUp(self):
+ super(TestServerState, self).setUp()
+ self.state = server_state.ServerState()
+
+ def test_set_workingdir(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ self.assertTrue(self.state._relpath('path').startswith(tempdir))
+
+ def test_open_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ self.assertFalse(os.path.exists(path))
+ # Create the db, but don't do anything with it
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db = self.state.open_database('test.db')
+ self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_check_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ self.assertFalse(os.path.exists(path))
+
+ # doesn't exist => raises
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.check_database, 'test.db')
+
+ # Create the db, but don't do anything with it
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ # exists => returns
+ res = self.state.check_database('test.db')
+ self.assertIsNone(res)
+
+ def test_ensure_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ self.assertFalse(os.path.exists(path))
+ db, replica_uid = self.state.ensure_database('test.db')
+ self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase)
+ self.assertEqual(db._replica_uid, replica_uid)
+ self.assertTrue(os.path.exists(path))
+ db2 = self.state.open_database('test.db')
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_delete_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ db, _ = self.state.ensure_database('test.db')
+ db.close()
+ self.state.delete_database('test.db')
+ self.assertFalse(os.path.exists(path))
+
+ def test_delete_database_DoesNotExist(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.delete_database, 'test.db')
diff --git a/u1db/tests/test_sqlite_backend.py b/u1db/tests/test_sqlite_backend.py
new file mode 100644
index 00000000..73330789
--- /dev/null
+++ b/u1db/tests/test_sqlite_backend.py
@@ -0,0 +1,493 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test sqlite backend internals."""
+
+import os
+import time
+import threading
+
+from sqlite3 import dbapi2
+
+from u1db import (
+ errors,
+ tests,
+ query_parser,
+ )
+from u1db.backends import sqlite_backend
+from u1db.tests.test_backends import TestAlternativeDocument
+
+
+simple_doc = '{"key": "value"}'
+nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
+
+
+class TestSQLiteDatabase(tests.TestCase):
+
+ def test_atomic_initialize(self):
+ tmpdir = self.createTempDir()
+ dbname = os.path.join(tmpdir, 'atomic.db')
+
+ t2 = None # will be a thread
+
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ _index_storage_value = "testing"
+
+ def __init__(self, dbname, ntry):
+ self._try = ntry
+ self._is_initialized_invocations = 0
+ super(SQLiteDatabaseTesting, self).__init__(dbname)
+
+ def _is_initialized(self, c):
+ res = super(SQLiteDatabaseTesting, self)._is_initialized(c)
+ if self._try == 1:
+ self._is_initialized_invocations += 1
+ if self._is_initialized_invocations == 2:
+ t2.start()
+ # hard to do better and have a generic test
+ time.sleep(0.05)
+ return res
+
+ outcome2 = []
+
+ def second_try():
+ try:
+ db2 = SQLiteDatabaseTesting(dbname, 2)
+ except Exception, e:
+ outcome2.append(e)
+ else:
+ outcome2.append(db2)
+
+ t2 = threading.Thread(target=second_try)
+ db1 = SQLiteDatabaseTesting(dbname, 1)
+ t2.join()
+
+ self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting)
+ db2 = outcome2[0]
+ self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor()))
+
+
+class TestSQLitePartialExpandDatabase(tests.TestCase):
+
+ def setUp(self):
+ super(TestSQLitePartialExpandDatabase, self).setUp()
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db._set_replica_uid('test')
+
+ def test_create_database(self):
+ raw_db = self.db._get_sqlite_handle()
+ self.assertNotEqual(None, raw_db)
+
+ def test_default_replica_uid(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.assertIsNot(None, self.db._replica_uid)
+ self.assertEqual(32, len(self.db._replica_uid))
+ int(self.db._replica_uid, 16)
+
+ def test__close_sqlite_handle(self):
+ raw_db = self.db._get_sqlite_handle()
+ self.db._close_sqlite_handle()
+ self.assertRaises(dbapi2.ProgrammingError,
+ raw_db.cursor)
+
+ def test_create_database_initializes_schema(self):
+ raw_db = self.db._get_sqlite_handle()
+ c = raw_db.cursor()
+ c.execute("SELECT * FROM u1db_config")
+ config = dict([(r[0], r[1]) for r in c.fetchall()])
+ self.assertEqual({'sql_schema': '0', 'replica_uid': 'test',
+ 'index_storage': 'expand referenced'}, config)
+
+ # These tables must exist, though we don't care what is in them yet
+ c.execute("SELECT * FROM transaction_log")
+ c.execute("SELECT * FROM document")
+ c.execute("SELECT * FROM document_fields")
+ c.execute("SELECT * FROM sync_log")
+ c.execute("SELECT * FROM conflicts")
+ c.execute("SELECT * FROM index_definitions")
+
+ def test__parse_index(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ g = self.db._parse_index_definition('fieldname')
+ self.assertIsInstance(g, query_parser.ExtractField)
+ self.assertEqual(['fieldname'], g.field)
+
+ def test__update_indexes(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ g = self.db._parse_index_definition('fieldname')
+ c = self.db._get_sqlite_handle().cursor()
+ self.db._update_indexes('doc-id', {'fieldname': 'val'},
+ [('fieldname', g)], c)
+ c.execute('SELECT doc_id, field_name, value FROM document_fields')
+ self.assertEqual([('doc-id', 'fieldname', 'val')],
+ c.fetchall())
+
+ def test__set_replica_uid(self):
+ # Start from scratch, so that replica_uid isn't set.
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.assertIsNot(None, self.db._real_replica_uid)
+ self.assertIsNot(None, self.db._replica_uid)
+ self.db._set_replica_uid('foo')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'")
+ self.assertEqual(('foo',), c.fetchone())
+ self.assertEqual('foo', self.db._real_replica_uid)
+ self.assertEqual('foo', self.db._replica_uid)
+ self.db._close_sqlite_handle()
+ self.assertEqual('foo', self.db._replica_uid)
+
+ def test__get_generation(self):
+ self.assertEqual(0, self.db._get_generation())
+
+ def test__get_generation_info(self):
+ self.assertEqual((0, ''), self.db._get_generation_info())
+
+ def test_create_index(self):
+ self.db.create_index('test-idx', "key")
+ self.assertEqual([('test-idx', ["key"])], self.db.list_indexes())
+
+ def test_create_index_multiple_fields(self):
+ self.db.create_index('test-idx', "key", "key2")
+ self.assertEqual([('test-idx', ["key", "key2"])],
+ self.db.list_indexes())
+
+ def test__get_index_definition(self):
+ self.db.create_index('test-idx', "key", "key2")
+ # TODO: How would you test that an index is getting used for an SQL
+ # request?
+ self.assertEqual(["key", "key2"],
+ self.db._get_index_definition('test-idx'))
+
+ def test_list_index_mixed(self):
+ # Make sure that we properly order the output
+ c = self.db._get_sqlite_handle().cursor()
+ # We intentionally insert the data in weird ordering, to make sure the
+ # query still gets it back correctly.
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ [('idx-1', 0, 'key10'),
+ ('idx-2', 2, 'key22'),
+ ('idx-1', 1, 'key11'),
+ ('idx-2', 0, 'key20'),
+ ('idx-2', 1, 'key21')])
+ self.assertEqual([('idx-1', ['key10', 'key11']),
+ ('idx-2', ['key20', 'key21', 'key22'])],
+ self.db.list_indexes())
+
+ def test_no_indexes_no_document_fields(self):
+ self.db.create_doc_from_json(
+ '{"key1": "val1", "key2": "val2"}')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([], c.fetchall())
+
+ def test_create_extracts_fields(self):
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([], c.fetchall())
+ self.db.create_index('test', 'key1', 'key2')
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual(sorted(
+ [(doc1.doc_id, "key1", "val1"),
+ (doc1.doc_id, "key2", "val2"),
+ (doc2.doc_id, "key1", "valx"),
+ (doc2.doc_id, "key2", "valy"),
+ ]), sorted(c.fetchall()))
+
+ def test_put_updates_fields(self):
+ self.db.create_index('test', 'key1', 'key2')
+ doc1 = self.db.create_doc_from_json(
+ '{"key1": "val1", "key2": "val2"}')
+ doc1.content = {"key1": "val1", "key2": "valy"}
+ self.db.put_doc(doc1)
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, "key1", "val1"),
+ (doc1.doc_id, "key2", "valy"),
+ ], c.fetchall())
+
+ def test_put_updates_nested_fields(self):
+ self.db.create_index('test', 'key', 'sub.doc')
+ doc1 = self.db.create_doc_from_json(nested_doc)
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, "key", "value"),
+ (doc1.doc_id, "sub.doc", "underneath"),
+ ], c.fetchall())
+
+ def test__ensure_schema_rollback(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/rollback.db'
+
+ class SQLitePartialExpandDbTesting(
+ sqlite_backend.SQLitePartialExpandDatabase):
+
+ def _set_replica_uid_in_transaction(self, uid):
+ super(SQLitePartialExpandDbTesting,
+ self)._set_replica_uid_in_transaction(uid)
+ if fail:
+ raise Exception()
+
+ db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting)
+ db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
+ fail = True
+ self.assertRaises(Exception, db._ensure_schema)
+ fail = False
+ db._initialize(db._db_handle.cursor())
+
+ def test__open_database(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase._open_database(path)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test__open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase._open_database(
+ path, document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test__open_database_non_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase._open_database, path)
+
+ def test__open_database_during_init(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/initialised.db'
+ db = sqlite_backend.SQLitePartialExpandDatabase.__new__(
+ sqlite_backend.SQLitePartialExpandDatabase)
+ db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
+ self.addCleanup(db.close)
+ observed = []
+
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c)
+ db._ensure_schema() # init db
+ observed.append(res[0])
+ return res
+
+ db2 = SQLiteDatabaseTesting._open_database(path)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ self.assertEqual([None,
+ sqlite_backend.SQLitePartialExpandDatabase._index_storage_value],
+ observed)
+
+ def test__open_database_invalid(self):
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path1 = temp_dir + '/invalid1.db'
+ with open(path1, 'wb') as f:
+ f.write("")
+ self.assertRaises(dbapi2.OperationalError,
+ SQLiteDatabaseTesting._open_database, path1)
+ with open(path1, 'wb') as f:
+ f.write("invalid")
+ self.assertRaises(dbapi2.DatabaseError,
+ SQLiteDatabaseTesting._open_database, path1)
+
+ def test_open_database_existing(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(
+ path, create=False, document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test_open_database_create(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/new.sqlite'
+ sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_open_database_non_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.open_database, path,
+ create=False)
+
+ def test_delete_database_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/new.sqlite'
+ db = sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db.close()
+ sqlite_backend.SQLiteDatabase.delete_database(path)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.open_database, path,
+ create=False)
+
+ def test_delete_database_nonexistent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.delete_database, path)
+
+ def test__get_indexed_fields(self):
+ self.db.create_index('idx1', 'a', 'b')
+ self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields())
+ self.db.create_index('idx2', 'b', 'c')
+ self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields())
+
+ def test_indexed_fields_expanded(self):
+ self.db.create_index('idx1', 'key1')
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ self.assertEqual(set(['key1']), self.db._get_indexed_fields())
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall())
+
+ def test_create_index_updates_fields(self):
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ self.db.create_index('idx1', 'key1')
+ self.assertEqual(set(['key1']), self.db._get_indexed_fields())
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall())
+
+ def assertFormatQueryEquals(self, exp_statement, exp_args, definition,
+ values):
+ statement, args = self.db._format_query(definition, values)
+ self.assertEqual(exp_statement, statement)
+ self.assertEqual(exp_args, args)
+
+ def test__format_query(self):
+ self.assertFormatQueryEquals(
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON "
+ "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name "
+ "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content "
+ "ORDER BY d0.value;", ["key1", "a"],
+ ["key1"], ["a"])
+
+ def test__format_query2(self):
+ self.assertFormatQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ["key1", "a", "key2", "b", "key3", "c"],
+ ["key1", "key2", "key3"], ["a", "b", "c"])
+
+ def test__format_query_wildcard(self):
+ self.assertFormatQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? '
+ 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content '
+ 'ORDER BY d0.value, d1.value, d2.value;',
+ ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"],
+ ["a", "b*", "*"])
+
+ def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition,
+ start_value, end_value):
+ statement, args = self.db._format_range_query(
+ definition, start_value, end_value)
+ self.assertEqual(exp_statement, statement)
+ self.assertEqual(exp_args, args)
+
+ def test__format_range_query(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q',
+ 'key3', 'r'],
+ ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"])
+
+ def test__format_range_query_no_start(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c'],
+ ["key1", "key2", "key3"], None, ["a", "b", "c"])
+
+ def test__format_range_query_no_end(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c'],
+ ["key1", "key2", "key3"], ["a", "b", "c"], None)
+
+ def test__format_range_query_wildcard(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? '
+ 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? '
+ 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id '
+ 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, '
+ 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*',
+ 'key3'],
+ ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"])
diff --git a/u1db/tests/test_sync.py b/u1db/tests/test_sync.py
new file mode 100644
index 00000000..f2a925f0
--- /dev/null
+++ b/u1db/tests/test_sync.py
@@ -0,0 +1,1285 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The Synchronization class for U1DB."""
+
+import os
+from wsgiref import simple_server
+
+from u1db import (
+ errors,
+ sync,
+ tests,
+ vectorclock,
+ SyncTarget,
+ )
+from u1db.backends import (
+ inmemory,
+ )
+from u1db.remote import (
+ http_target,
+ )
+
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+ )
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+
+def _make_local_db_and_target(test):
+ db = test.create_database('test')
+ st = db.get_sync_target()
+ return db, st
+
+
+def _make_local_db_and_http_target(test, path='test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ st = http_target.HTTPSyncTarget.connect(test.getURL(path))
+ return db, st
+
+
+def _make_c_db_and_c_http_target(test, path='test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ url = test.getURL(path)
+ st = tests.c_backend_wrapper.create_http_sync_target(url)
+ return db, st
+
+
+def _make_local_db_and_oauth_http_target(test):
+ db, st = _make_local_db_and_http_target(test, '~/test')
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return db, st
+
+
+def _make_c_db_and_oauth_http_target(test, path='~/test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ url = test.getURL(path)
+ st = tests.c_backend_wrapper.create_oauth_http_sync_target(url,
+ tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return db, st
+
+
+target_scenarios = [
+ ('local', {'create_db_and_target': _make_local_db_and_target}),
+ ('http', {'create_db_and_target': _make_local_db_and_http_target,
+ 'make_app_with_state': make_http_app}),
+ ('oauth_http', {'create_db_and_target':
+ _make_local_db_and_oauth_http_target,
+ 'make_app_with_state': make_oauth_http_app}),
+ ]
+
+c_db_scenarios = [
+ ('local,c', {'create_db_and_target': _make_local_db_and_target,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'copy_database_for_test': tests.copy_c_database_for_test,
+ 'make_document_for_test': tests.make_c_document_for_test,
+ 'whitebox': False}),
+ ('http,c', {'create_db_and_target': _make_c_db_and_c_http_target,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'copy_database_for_test': tests.copy_c_database_for_test,
+ 'make_document_for_test': tests.make_c_document_for_test,
+ 'make_app_with_state': make_http_app,
+ 'whitebox': False}),
+ ('oauth_http,c', {'create_db_and_target': _make_c_db_and_oauth_http_target,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'copy_database_for_test': tests.copy_c_database_for_test,
+ 'make_document_for_test': tests.make_c_document_for_test,
+ 'make_app_with_state': make_oauth_http_app,
+ 'whitebox': False}),
+ ]
+
+
+class DatabaseSyncTargetTests(tests.DatabaseBaseTests,
+ tests.TestCaseWithServer):
+
+ scenarios = (tests.multiply_scenarios(tests.DatabaseBaseTests.scenarios,
+ target_scenarios)
+ + c_db_scenarios)
+ # whitebox true means self.db is the actual local db object
+ # against which the sync is performed
+ whitebox = True
+
+ def setUp(self):
+ super(DatabaseSyncTargetTests, self).setUp()
+ self.db, self.st = self.create_db_and_target(self)
+ self.other_changes = []
+
+ def tearDown(self):
+ # We delete them explicitly, so that connections are cleanly closed
+ del self.st
+ self.db.close()
+ del self.db
+ super(DatabaseSyncTargetTests, self).tearDown()
+
+ def receive_doc(self, doc, gen, trans_id):
+ self.other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ def set_trace_hook(self, callback, shallow=False):
+ setter = (self.st._set_trace_hook if not shallow else
+ self.st._set_trace_hook_shallow)
+ try:
+ setter(callback)
+ except NotImplementedError:
+ self.skipTest("%s does not implement _set_trace_hook"
+ % (self.st.__class__.__name__,))
+
+ def test_get_sync_target(self):
+ self.assertIsNot(None, self.st)
+
+ def test_get_sync_info(self):
+ self.assertEqual(
+ ('test', 0, '', 0, ''), self.st.get_sync_info('other'))
+
+ def test_create_doc_updates_sync_info(self):
+ self.assertEqual(
+ ('test', 0, '', 0, ''), self.st.get_sync_info('other'))
+ self.db.create_doc_from_json(simple_doc)
+ self.assertEqual(1, self.st.get_sync_info('other')[1])
+
+ def test_record_sync_info(self):
+ self.st.record_sync_info('replica', 10, 'T-transid')
+ self.assertEqual(
+ ('test', 0, '', 10, 'T-transid'), self.st.get_sync_info('replica'))
+
+ def test_sync_exchange(self):
+ docs_by_gen = [
+ (self.make_document('doc-id', 'replica:1', simple_doc), 10,
+ 'T-sid')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False)
+ self.assertTransactionLog(['doc-id'], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 1, last_trans_id),
+ (self.other_changes, new_gen, last_trans_id))
+ self.assertEqual(10, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_deleted(self):
+ doc = self.db.create_doc_from_json('{}')
+ edit_rev = 'replica:1|' + doc.rev
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, edit_rev, None), 10, 'T-sid')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, edit_rev, None, False)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 2, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(10, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_push_many(self):
+ docs_by_gen = [
+ (self.make_document('doc-id', 'replica:1', simple_doc), 10, 'T-1'),
+ (self.make_document('doc-id2', 'replica:1', nested_doc), 11,
+ 'T-2')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False)
+ self.assertGetDoc(self.db, 'doc-id2', 'replica:1', nested_doc, False)
+ self.assertTransactionLog(['doc-id', 'doc-id2'], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 2, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(11, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_refuses_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'replica:1', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1])
+ self.assertEqual(1, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_ignores_convergence(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ gen, txid = self.db._get_generation_info()
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, doc.rev, simple_doc), 10, 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=gen,
+ last_known_trans_id=txid, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(([], 1), (self.other_changes, new_gen))
+
+ def test_sync_exchange_returns_new_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1])
+ self.assertEqual(1, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_returns_deleted_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, None, 2), self.other_changes[0][:-1])
+ self.assertEqual(2, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_returns_many_new_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ self.assertEqual(2, new_gen)
+ self.assertEqual(
+ [(doc.doc_id, doc.rev, simple_doc, 1),
+ (doc2.doc_id, doc2.rev, nested_doc, 2)],
+ [c[:-1] for c in self.other_changes])
+ if self.whitebox:
+ self.assertEqual(
+ self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs':
+ [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]})
+
+ def test_sync_exchange_getting_newer_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_with_concurrent_updates_of_synced_doc(self):
+ expected = []
+
+ def before_whatschanged_cb(state):
+ if state != 'before whats_changed':
+ return
+ cont = '{"key": "cuncurrent"}'
+ conc_rev = self.db.put_doc(
+ self.make_document(doc.doc_id, 'test:1|z:2', cont))
+ expected.append((doc.doc_id, conc_rev, cont, 3))
+
+ self.set_trace_hook(before_whatschanged_cb)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(expected, [c[:-1] for c in self.other_changes])
+ self.assertEqual(3, new_gen)
+
+ def test_sync_exchange_with_concurrent_updates(self):
+
+ def after_whatschanged_cb(state):
+ if state != 'after whats_changed':
+ return
+ self.db.create_doc_from_json('{"new": "doc"}')
+
+ self.set_trace_hook(after_whatschanged_cb)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_converged_handling(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ docs_by_gen = [
+ (self.make_document('new', 'other:1', '{}'), 4, 'T-foo'),
+ (self.make_document(doc.doc_id, doc.rev, doc.get_json()), 5,
+ 'T-bar')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_detect_incomplete_exchange(self):
+ def before_get_docs_explode(state):
+ if state != 'before get_docs':
+ return
+ raise errors.U1DBError("fail")
+ self.set_trace_hook(before_get_docs_explode)
+ # suppress traceback printing in the wsgiref server
+ self.patch(simple_server.ServerHandler,
+ 'log_exception', lambda h, exc_info: None)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertRaises(
+ (errors.U1DBError, errors.BrokenSyncStream),
+ self.st.sync_exchange, [], 'other-replica',
+ last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=self.receive_doc)
+
+ def test_sync_exchange_doc_ids(self):
+ sync_exchange_doc_ids = getattr(self.st, 'sync_exchange_doc_ids', None)
+ if sync_exchange_doc_ids is None:
+ self.skipTest("sync_exchange_doc_ids not implemented")
+ db2 = self.create_database('test2')
+ doc = db2.create_doc_from_json(simple_doc)
+ new_gen, trans_id = sync_exchange_doc_ids(
+ db2, [(doc.doc_id, 10, 'T-sid')], 0, None,
+ return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 1, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(10, self.st.get_sync_info(db2._replica_uid)[3])
+
+ def test__set_trace_hook(self):
+ called = []
+
+ def cb(state):
+ called.append(state)
+
+ self.set_trace_hook(cb)
+ self.st.sync_exchange([], 'replica', 0, None, self.receive_doc)
+ self.st.record_sync_info('replica', 0, 'T-sid')
+ self.assertEqual(['before whats_changed',
+ 'after whats_changed',
+ 'before get_docs',
+ 'record_sync_info',
+ ],
+ called)
+
+ def test__set_trace_hook_shallow(self):
+ if (self.st._set_trace_hook_shallow == self.st._set_trace_hook
+ or self.st._set_trace_hook_shallow.im_func ==
+ SyncTarget._set_trace_hook_shallow.im_func):
+ # shallow same as full
+ expected = ['before whats_changed',
+ 'after whats_changed',
+ 'before get_docs',
+ 'record_sync_info',
+ ]
+ else:
+ expected = ['sync_exchange', 'record_sync_info']
+
+ called = []
+
+ def cb(state):
+ called.append(state)
+
+ self.set_trace_hook(cb, shallow=True)
+ self.st.sync_exchange([], 'replica', 0, None, self.receive_doc)
+ self.st.record_sync_info('replica', 0, 'T-sid')
+ self.assertEqual(expected, called)
+
+
+def sync_via_synchronizer(test, db_source, db_target, trace_hook=None,
+ trace_hook_shallow=None):
+ target = db_target.get_sync_target()
+ trace_hook = trace_hook or trace_hook_shallow
+ if trace_hook:
+ target._set_trace_hook(trace_hook)
+ return sync.Synchronizer(db_source, target).sync()
+
+
+sync_scenarios = []
+for name, scenario in tests.LOCAL_DATABASES_SCENARIOS:
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_synchronizer
+ sync_scenarios.append((name, scenario))
+ scenario = dict(scenario)
+
+
+def make_database_for_http_test(test, replica_uid):
+ if test.server is None:
+ test.startServer()
+ db = test.request_state._create_database(replica_uid)
+ try:
+ http_at = test._http_at
+ except AttributeError:
+ http_at = test._http_at = {}
+ http_at[db] = replica_uid
+ return db
+
+
+def copy_database_for_http_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR HOUSE.
+ if test.server is None:
+ test.startServer()
+ new_db = test.request_state._copy_database(db)
+ try:
+ http_at = test._http_at
+ except AttributeError:
+ http_at = test._http_at = {}
+ path = db._replica_uid
+ while path in http_at.values():
+ path += 'copy'
+ http_at[new_db] = path
+ return new_db
+
+
+def sync_via_synchronizer_and_http(test, db_source, db_target,
+ trace_hook=None, trace_hook_shallow=None):
+ if trace_hook:
+ test.skipTest("full trace hook unsupported over http")
+ path = test._http_at[db_target]
+ target = http_target.HTTPSyncTarget.connect(test.getURL(path))
+ if trace_hook_shallow:
+ target._set_trace_hook_shallow(trace_hook_shallow)
+ return sync.Synchronizer(db_source, target).sync()
+
+
+sync_scenarios.append(('pyhttp', {
+ 'make_database_for_test': make_database_for_http_test,
+ 'copy_database_for_test': copy_database_for_http_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_http_app,
+ 'do_sync': sync_via_synchronizer_and_http
+ }))
+
+
+if tests.c_backend_wrapper is not None:
+ # TODO: We should hook up sync tests with an HTTP target
+ def sync_via_c_sync(test, db_source, db_target, trace_hook=None,
+ trace_hook_shallow=None):
+ target = db_target.get_sync_target()
+ trace_hook = trace_hook or trace_hook_shallow
+ if trace_hook:
+ target._set_trace_hook(trace_hook)
+ return tests.c_backend_wrapper.sync_db_to_target(db_source, target)
+
+ for name, scenario in tests.C_DATABASE_SCENARIOS:
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_synchronizer
+ sync_scenarios.append((name + ',pysync', scenario))
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_c_sync
+ sync_scenarios.append((name + ',csync', scenario))
+
+
+class DatabaseSyncTests(tests.DatabaseBaseTests,
+ tests.TestCaseWithServer):
+
+ scenarios = sync_scenarios
+ do_sync = None # set by scenarios
+
+ def create_database(self, replica_uid, sync_role=None):
+ if replica_uid == 'test' and sync_role is None:
+ # created up the chain by base class but unused
+ return None
+ db = self.create_database_for_role(replica_uid, sync_role)
+ if sync_role:
+ self._use_tracking[db] = (replica_uid, sync_role)
+ return db
+
+ def create_database_for_role(self, replica_uid, sync_role):
+ # hook point for reuse
+ return super(DatabaseSyncTests, self).create_database(replica_uid)
+
+ def copy_database(self, db, sync_role=None):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ db_copy = super(DatabaseSyncTests, self).copy_database(db)
+ name, orig_sync_role = self._use_tracking[db]
+ self._use_tracking[db_copy] = (name + '(copy)', sync_role
+ or orig_sync_role)
+ return db_copy
+
+ def sync(self, db_from, db_to, trace_hook=None,
+ trace_hook_shallow=None):
+ from_name, from_sync_role = self._use_tracking[db_from]
+ to_name, to_sync_role = self._use_tracking[db_to]
+ if from_sync_role not in ('source', 'both'):
+ raise Exception("%s marked for %s use but used as source" %
+ (from_name, from_sync_role))
+ if to_sync_role not in ('target', 'both'):
+ raise Exception("%s marked for %s use but used as target" %
+ (to_name, to_sync_role))
+ return self.do_sync(self, db_from, db_to, trace_hook,
+ trace_hook_shallow)
+
+ def setUp(self):
+ self._use_tracking = {}
+ super(DatabaseSyncTests, self).setUp()
+
+ def assertLastExchangeLog(self, db, expected):
+ log = getattr(db, '_last_exchange_log', None)
+ if log is None:
+ return
+ self.assertEqual(expected, log)
+
+ def test_sync_tracks_db_generation_of_other(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.assertEqual(0, self.sync(self.db1, self.db2))
+ self.assertEqual(
+ (0, ''), self.db1._get_replica_gen_and_trans_id('test2'))
+ self.assertEqual(
+ (0, ''), self.db2._get_replica_gen_and_trans_id('test1'))
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [], 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 0}})
+
+ def test_sync_autoresolves(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ rev1 = doc1.rev
+ doc2 = self.db2.create_doc_from_json(simple_doc, doc_id='doc')
+ rev2 = doc2.rev
+ self.sync(self.db1, self.db2)
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual(doc.rev, self.db2.get_doc('doc').rev)
+ v = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev1)))
+ self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev2)))
+
+ def test_sync_autoresolves_moar(self):
+ # here we test that when a database that has a conflicted document is
+ # the source of a sync, and the target database has a revision of the
+ # conflicted document that is newer than the source database's, and
+ # that target's database's document's content is the same as the
+ # source's document's conflict's, the source's document's conflict gets
+ # autoresolved, and the source's document's revision bumped.
+ #
+ # idea is as follows:
+ # A B
+ # a1 -
+ # `------->
+ # a1 a1
+ # v v
+ # a2 a1b1
+ # `------->
+ # a1b1+a2 a1b1
+ # v
+ # a1b1+a2 a1b2 (a1b2 has same content as a2)
+ # `------->
+ # a3b2 a1b2 (autoresolved)
+ # `------->
+ # a3b2 a3b2
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db2)
+ # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertTrue(doc.has_conflicts)
+ # set db2 to have a doc of {} (same as db1 before the conflict)
+ doc = self.db2.get_doc('doc')
+ doc.set_json('{}')
+ self.db2.put_doc(doc)
+ rev2 = doc.rev
+ # sync it across
+ self.sync(self.db1, self.db2)
+ # tadaa!
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ vec1 = vectorclock.VectorClockRev(rev1)
+ vec2 = vectorclock.VectorClockRev(rev2)
+ vec3 = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(vec3.is_newer(vec1))
+ self.assertTrue(vec3.is_newer(vec2))
+ # because the conflict is on the source, sync it another time
+ self.sync(self.db1, self.db2)
+ # make sure db2 now has the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_autoresolves_moar_backwards(self):
+ # here we test that when a database that has a conflicted document is
+ # the target of a sync, and the source database has a revision of the
+ # conflicted document that is newer than the target database's, and
+ # that source's database's document's content is the same as the
+ # target's document's conflict's, the target's document's conflict gets
+ # autoresolved, and the document's revision bumped.
+ #
+ # idea is as follows:
+ # A B
+ # a1 -
+ # `------->
+ # a1 a1
+ # v v
+ # a2 a1b1
+ # `------->
+ # a1b1+a2 a1b1
+ # v
+ # a1b1+a2 a1b2 (a1b2 has same content as a2)
+ # <-------'
+ # a3b2 a3b2 (autoresolved and propagated)
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db2)
+ # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertTrue(doc.has_conflicts)
+ revc = self.db1.get_doc_conflicts('doc')[-1].rev
+ # set db2 to have a doc of {} (same as db1 before the conflict)
+ doc = self.db2.get_doc('doc')
+ doc.set_json('{}')
+ self.db2.put_doc(doc)
+ rev2 = doc.rev
+ # sync it across
+ self.sync(self.db2, self.db1)
+ # tadaa!
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ vec1 = vectorclock.VectorClockRev(rev1)
+ vec2 = vectorclock.VectorClockRev(rev2)
+ vec3 = vectorclock.VectorClockRev(doc.rev)
+ vecc = vectorclock.VectorClockRev(revc)
+ self.assertTrue(vec3.is_newer(vec1))
+ self.assertTrue(vec3.is_newer(vec2))
+ self.assertTrue(vec3.is_newer(vecc))
+ # make sure db2 now has the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_autoresolves_moar_backwards_three(self):
+ # same as autoresolves_moar_backwards, but with three databases (note
+ # all the syncs go in the same direction -- this is a more natural
+ # scenario):
+ #
+ # A B C
+ # a1 - -
+ # `------->
+ # a1 a1 -
+ # `------->
+ # a1 a1 a1
+ # v v
+ # a2 a1b1 a1
+ # `------------------->
+ # a2 a1b1 a2
+ # `------->
+ # a2+a1b1 a2
+ # v
+ # a2 a2+a1b1 a2c1 (same as a1b1)
+ # `------------------->
+ # a2c1 a2+a1b1 a2c1
+ # `------->
+ # a2b2c1 a2b2c1 a2c1
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ self.db3 = self.create_database('test3', 'target')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ self.sync(self.db2, self.db3)
+ for db, content in [(self.db2, '{"hi": 42}'),
+ (self.db1, '{}'),
+ ]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db3)
+ self.sync(self.db2, self.db3)
+ # db2 and db3 now both have a doc of {}, but db2 has a
+ # conflict
+ doc = self.db2.get_doc('doc')
+ self.assertTrue(doc.has_conflicts)
+ revc = self.db2.get_doc_conflicts('doc')[-1].rev
+ self.assertEqual('{}', doc.get_json())
+ self.assertEqual(self.db3.get_doc('doc').get_json(), doc.get_json())
+ self.assertEqual(self.db3.get_doc('doc').rev, doc.rev)
+ # set db3 to have a doc of {hi:42} (same as db2 before the conflict)
+ doc = self.db3.get_doc('doc')
+ doc.set_json('{"hi": 42}')
+ self.db3.put_doc(doc)
+ rev3 = doc.rev
+ # sync it across to db1
+ self.sync(self.db1, self.db3)
+ # db1 now has hi:42, with a rev that is newer than db2's doc
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual('{"hi": 42}', doc.get_json())
+ VCR = vectorclock.VectorClockRev
+ self.assertTrue(VCR(rev1).is_newer(VCR(self.db2.get_doc('doc').rev)))
+ # so sync it to db2
+ self.sync(self.db1, self.db2)
+ # tadaa!
+ doc = self.db2.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ # db2's revision of the document is strictly newer than db1's before
+ # the sync, and db3's before that sync way back when
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(rev1)))
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(rev3)))
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(revc)))
+ # make sure both dbs now have the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_puts_changes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ self.assertEqual(1, self.sync(self.db1, self.db2))
+ self.assertGetDoc(self.db2, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0])
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc.doc_id, doc.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 1}})
+
+ def test_sync_pulls_changes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db2.create_doc_from_json(simple_doc)
+ self.db1.create_index('test-idx', 'key')
+ self.assertEqual(0, self.sync(self.db1, self.db2))
+ self.assertGetDoc(self.db1, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0])
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [], 'last_known_gen': 0},
+ 'return': {'docs': [(doc.doc_id, doc.rev)],
+ 'last_gen': 1}})
+ self.assertEqual([doc], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_pulling_doesnt_update_other_if_changed(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db2.create_doc_from_json(simple_doc)
+ # After the local side has sent its list of docs, before we start
+ # receiving the "targets" response, we update the local database with a
+ # new record.
+ # When we finish synchronizing, we can notice that something locally
+ # was updated, and we cannot tell c2 our new updated generation
+
+ def before_get_docs(state):
+ if state != 'before get_docs':
+ return
+ self.db1.create_doc_from_json(simple_doc)
+
+ self.assertEqual(0, self.sync(self.db1, self.db2,
+ trace_hook=before_get_docs))
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [], 'last_known_gen': 0},
+ 'return': {'docs': [(doc.doc_id, doc.rev)],
+ 'last_gen': 1}})
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ # c2 should not have gotten a '_record_sync_info' call, because the
+ # local database had been updated more than just by the messages
+ # returned from c2.
+ self.assertEqual(
+ (0, ''), self.db2._get_replica_gen_and_trans_id('test1'))
+
+ def test_sync_doesnt_update_other_if_nothing_pulled(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(simple_doc)
+
+ def no_record_sync_info(state):
+ if state != 'record_sync_info':
+ return
+ self.fail('SyncTarget.record_sync_info was called')
+ self.assertEqual(1, self.sync(self.db1, self.db2,
+ trace_hook_shallow=no_record_sync_info))
+ self.assertEqual(
+ 1,
+ self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)[0])
+
+ def test_sync_ignores_convergence(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ self.db3 = self.create_database('test3', 'target')
+ self.assertEqual(1, self.sync(self.db1, self.db3))
+ self.assertEqual(0, self.sync(self.db2, self.db3))
+ self.assertEqual(1, self.sync(self.db1, self.db2))
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc.doc_id, doc.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 1}})
+
+ def test_sync_ignores_superseded(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ doc_rev1 = doc.rev
+ self.db3 = self.create_database('test3', 'target')
+ self.sync(self.db1, self.db3)
+ self.sync(self.db2, self.db3)
+ new_content = '{"key": "altval"}'
+ doc.set_json(new_content)
+ self.db1.put_doc(doc)
+ doc_rev2 = doc.rev
+ self.sync(self.db2, self.db1)
+ self.assertLastExchangeLog(self.db1,
+ {'receive': {'docs': [(doc.doc_id, doc_rev1)],
+ 'source_uid': 'test2',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [(doc.doc_id, doc_rev2)],
+ 'last_gen': 2}})
+ self.assertGetDoc(self.db1, doc.doc_id, doc_rev2, new_content, False)
+
+ def test_sync_sees_remote_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ doc1_rev = doc1.rev
+ self.db1.create_index('test-idx', 'key')
+ new_doc = '{"key": "altval"}'
+ doc2 = self.db2.create_doc_from_json(new_doc, doc_id=doc_id)
+ doc2_rev = doc2.rev
+ self.assertTransactionLog([doc1.doc_id], self.db1)
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc_id, doc1_rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [(doc_id, doc2_rev)],
+ 'last_gen': 1}})
+ self.assertTransactionLog([doc_id, doc_id], self.db1)
+ self.assertGetDoc(self.db1, doc_id, doc2_rev, new_doc, True)
+ self.assertGetDoc(self.db2, doc_id, doc2_rev, new_doc, False)
+ from_idx = self.db1.get_from_index('test-idx', 'altval')[0]
+ self.assertEqual(doc2.doc_id, from_idx.doc_id)
+ self.assertEqual(doc2.rev, from_idx.rev)
+ self.assertTrue(from_idx.has_conflicts)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_sees_remote_delete_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ doc2 = self.make_document(doc1.doc_id, doc1.rev, doc1.get_json())
+ new_doc = '{"key": "altval"}'
+ doc1.set_json(new_doc)
+ self.db1.put_doc(doc1)
+ self.db2.delete_doc(doc2)
+ self.assertTransactionLog([doc_id, doc_id], self.db1)
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc_id, doc1.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 2, 'last_known_gen': 1},
+ 'return': {'docs': [(doc_id, doc2.rev)],
+ 'last_gen': 2}})
+ self.assertTransactionLog([doc_id, doc_id, doc_id], self.db1)
+ self.assertGetDocIncludeDeleted(self.db1, doc_id, doc2.rev, None, True)
+ self.assertGetDocIncludeDeleted(
+ self.db2, doc_id, doc2.rev, None, False)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_local_race_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc.doc_id
+ doc1_rev = doc.rev
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ content1 = '{"key": "localval"}'
+ content2 = '{"key": "altval"}'
+ doc.set_json(content2)
+ self.db2.put_doc(doc)
+ doc2_rev2 = doc.rev
+ triggered = []
+
+ def after_whatschanged(state):
+ if state != 'after whats_changed':
+ return
+ triggered.append(True)
+ doc = self.make_document(doc_id, doc1_rev, content1)
+ self.db1.put_doc(doc)
+
+ self.sync(self.db1, self.db2, trace_hook=after_whatschanged)
+ self.assertEqual([True], triggered)
+ self.assertGetDoc(self.db1, doc_id, doc2_rev2, content2, True)
+ from_idx = self.db1.get_from_index('test-idx', 'altval')[0]
+ self.assertEqual(doc.doc_id, from_idx.doc_id)
+ self.assertEqual(doc.rev, from_idx.rev)
+ self.assertTrue(from_idx.has_conflicts)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'localval'))
+
+ def test_sync_propagates_deletes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ self.db2.create_index('test-idx', 'key')
+ self.db3 = self.create_database('test3', 'target')
+ self.sync(self.db1, self.db3)
+ self.db1.delete_doc(doc1)
+ deleted_rev = doc1.rev
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc_id, deleted_rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 2, 'last_known_gen': 1},
+ 'return': {'docs': [], 'last_gen': 2}})
+ self.assertGetDocIncludeDeleted(
+ self.db1, doc_id, deleted_rev, None, False)
+ self.assertGetDocIncludeDeleted(
+ self.db2, doc_id, deleted_rev, None, False)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+ self.assertEqual([], self.db2.get_from_index('test-idx', 'value'))
+ self.sync(self.db2, self.db3)
+ self.assertLastExchangeLog(self.db3,
+ {'receive': {'docs': [(doc_id, deleted_rev)],
+ 'source_uid': 'test2',
+ 'source_gen': 2, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 2}})
+ self.assertGetDocIncludeDeleted(
+ self.db3, doc_id, deleted_rev, None, False)
+
+ def test_sync_propagates_resolution(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc')
+ db3 = self.create_database('test3', 'both')
+ self.sync(self.db2, self.db1)
+ self.assertEqual(
+ self.db1._get_generation_info(),
+ self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid))
+ self.assertEqual(
+ self.db2._get_generation_info(),
+ self.db1._get_replica_gen_and_trans_id(self.db2._replica_uid))
+ self.sync(db3, self.db1)
+ # update on 2
+ doc2 = self.make_document('the-doc', doc1.rev, '{"a": 2}')
+ self.db2.put_doc(doc2)
+ self.sync(self.db2, db3)
+ self.assertEqual(db3.get_doc('the-doc').rev, doc2.rev)
+ # update on 1
+ doc1.set_json('{"a": 3}')
+ self.db1.put_doc(doc1)
+ # conflicts
+ self.sync(self.db2, self.db1)
+ self.sync(db3, self.db1)
+ self.assertTrue(self.db2.get_doc('the-doc').has_conflicts)
+ self.assertTrue(db3.get_doc('the-doc').has_conflicts)
+ # resolve
+ conflicts = self.db2.get_doc_conflicts('the-doc')
+ doc4 = self.make_document('the-doc', None, '{"a": 4}')
+ revs = [doc.rev for doc in conflicts]
+ self.db2.resolve_doc(doc4, revs)
+ doc2 = self.db2.get_doc('the-doc')
+ self.assertEqual(doc4.get_json(), doc2.get_json())
+ self.assertFalse(doc2.has_conflicts)
+ self.sync(self.db2, db3)
+ doc3 = db3.get_doc('the-doc')
+ self.assertEqual(doc4.get_json(), doc3.get_json())
+ self.assertFalse(doc3.has_conflicts)
+
+ def test_sync_supersedes_conflicts(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.create_database('test3', 'both')
+ doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc')
+ self.db2.create_doc_from_json('{"b": 1}', doc_id='the-doc')
+ db3.create_doc_from_json('{"c": 1}', doc_id='the-doc')
+ self.sync(db3, self.db1)
+ self.assertEqual(
+ self.db1._get_generation_info(),
+ db3._get_replica_gen_and_trans_id(self.db1._replica_uid))
+ self.assertEqual(
+ db3._get_generation_info(),
+ self.db1._get_replica_gen_and_trans_id(db3._replica_uid))
+ self.sync(db3, self.db2)
+ self.assertEqual(
+ self.db2._get_generation_info(),
+ db3._get_replica_gen_and_trans_id(self.db2._replica_uid))
+ self.assertEqual(
+ db3._get_generation_info(),
+ self.db2._get_replica_gen_and_trans_id(db3._replica_uid))
+ self.assertEqual(3, len(db3.get_doc_conflicts('the-doc')))
+ doc1.set_json('{"a": 2}')
+ self.db1.put_doc(doc1)
+ self.sync(db3, self.db1)
+ # original doc1 should have been removed from conflicts
+ self.assertEqual(3, len(db3.get_doc_conflicts('the-doc')))
+
+ def test_sync_stops_after_get_sync_info(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc)
+ self.sync(self.db1, self.db2)
+
+ def put_hook(state):
+ self.fail("Tracehook triggered for %s" % (state,))
+
+ self.sync(self.db1, self.db2, trace_hook_shallow=put_hook)
+
+ def test_sync_detects_rollback_in_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1')
+ self.sync(self.db1, self.db2)
+ db1_copy = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidGeneration, self.sync, db1_copy, self.db2)
+
+ def test_sync_detects_rollback_in_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ db2_copy = self.copy_database(self.db2)
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidGeneration, self.sync, self.db1, db2_copy)
+
+ def test_sync_detects_diverged_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ db3.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, db3, self.db2)
+
+ def test_sync_detects_diverged_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.copy_database(self.db2)
+ db3.create_doc_from_json(tests.nested_doc, doc_id="divergent")
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, self.db1, db3)
+
+ def test_sync_detects_rollback_and_divergence_in_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1')
+ self.sync(self.db1, self.db2)
+ db1_copy = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.sync(self.db1, self.db2)
+ db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, db1_copy, self.db2)
+
+ def test_sync_detects_rollback_and_divergence_in_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ db2_copy = self.copy_database(self.db2)
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.sync(self.db1, self.db2)
+ db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, self.db1, db2_copy)
+
+
+class TestDbSync(tests.TestCaseWithServer):
+ """Test db.sync remote sync shortcut"""
+
+ scenarios = [
+ ('py-http', {
+ 'make_app_with_state': make_http_app,
+ 'make_database_for_test': tests.make_memory_database_for_test,
+ }),
+ ('c-http', {
+ 'make_app_with_state': make_http_app,
+ 'make_database_for_test': tests.make_c_database_for_test
+ }),
+ ('py-oauth-http', {
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_database_for_test': tests.make_memory_database_for_test,
+ 'oauth': True
+ }),
+ ('c-oauth-http', {
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'oauth': True
+ }),
+ ]
+
+ oauth = False
+
+ def do_sync(self, target_name):
+ if self.oauth:
+ path = '~/' + target_name
+ extra = dict(creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret
+ }})
+ else:
+ path = target_name
+ extra = {}
+ target_url = self.getURL(path)
+ return self.db.sync(target_url, **extra)
+
+ def setUp(self):
+ super(TestDbSync, self).setUp()
+ self.startServer()
+ self.db = self.make_database_for_test(self, 'test1')
+ self.db2 = self.request_state._create_database('test2.db')
+
+ def test_db_sync(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db2.create_doc_from_json(tests.nested_doc)
+ local_gen_before_sync = self.do_sync('test2.db')
+ gen, _, changes = self.db.whats_changed(local_gen_before_sync)
+ self.assertEqual(1, len(changes))
+ self.assertEqual(doc2.doc_id, changes[0][0])
+ self.assertEqual(1, gen - local_gen_before_sync)
+ self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+
+ def test_db_sync_autocreate(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ local_gen_before_sync = self.do_sync('test3.db')
+ gen, _, changes = self.db.whats_changed(local_gen_before_sync)
+ self.assertEqual(0, gen - local_gen_before_sync)
+ db3 = self.request_state.open_database('test3.db')
+ gen, _, changes = db3.whats_changed()
+ self.assertEqual(1, len(changes))
+ self.assertEqual(doc1.doc_id, changes[0][0])
+ self.assertGetDoc(db3, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ t_gen, _ = self.db._get_replica_gen_and_trans_id('test3.db')
+ s_gen, _ = db3._get_replica_gen_and_trans_id('test1')
+ self.assertEqual(1, t_gen)
+ self.assertEqual(1, s_gen)
+
+
+class TestRemoteSyncIntegration(tests.TestCaseWithServer):
+ """Integration tests for the most common sync scenario local -> remote"""
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestRemoteSyncIntegration, self).setUp()
+ self.startServer()
+ self.db1 = inmemory.InMemoryDatabase('test1')
+ self.db2 = self.request_state._create_database('test2')
+
+ def test_sync_tracks_generations_incrementally(self):
+ doc11 = self.db1.create_doc_from_json('{"a": 1}')
+ doc12 = self.db1.create_doc_from_json('{"a": 2}')
+ doc21 = self.db2.create_doc_from_json('{"b": 1}')
+ doc22 = self.db2.create_doc_from_json('{"b": 2}')
+ #sanity
+ self.assertEqual(2, len(self.db1._get_transaction_log()))
+ self.assertEqual(2, len(self.db2._get_transaction_log()))
+ progress1 = []
+ progress2 = []
+ _do_set_replica_gen_and_trans_id = \
+ self.db1._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness1(other_uid, other_gen, trans_id):
+ progress1.append((other_uid, other_gen,
+ [d for d, t in self.db1._get_transaction_log()[2:]]))
+ _do_set_replica_gen_and_trans_id(other_uid, other_gen, trans_id)
+ self.patch(self.db1, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness1)
+ _do_set_replica_gen_and_trans_id2 = \
+ self.db2._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness2(other_uid, other_gen, trans_id):
+ progress2.append((other_uid, other_gen,
+ [d for d, t in self.db2._get_transaction_log()[2:]]))
+ _do_set_replica_gen_and_trans_id2(other_uid, other_gen, trans_id)
+ self.patch(self.db2, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness2)
+
+ db2_url = self.getURL('test2')
+ self.db1.sync(db2_url)
+
+ self.assertEqual([('test2', 1, [doc21.doc_id]),
+ ('test2', 2, [doc21.doc_id, doc22.doc_id]),
+ ('test2', 4, [doc21.doc_id, doc22.doc_id])],
+ progress1)
+ self.assertEqual([('test1', 1, [doc11.doc_id]),
+ ('test1', 2, [doc11.doc_id, doc12.doc_id]),
+ ('test1', 4, [doc11.doc_id, doc12.doc_id])],
+ progress2)
+
+
+load_tests = tests.load_with_scenarios
diff --git a/u1db/tests/test_test_infrastructure.py b/u1db/tests/test_test_infrastructure.py
new file mode 100644
index 00000000..b79e0516
--- /dev/null
+++ b/u1db/tests/test_test_infrastructure.py
@@ -0,0 +1,41 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for test infrastructure bits"""
+
+from wsgiref import simple_server
+
+from u1db import (
+ tests,
+ )
+
+
+class TestTestCaseWithServer(tests.TestCaseWithServer):
+
+ def make_app(self):
+ return "app"
+
+ @staticmethod
+ def server_def():
+ def make_server(host_port, application):
+ assert application == "app"
+ return simple_server.WSGIServer(host_port, None)
+ return (make_server, "shutdown", "http")
+
+ def test_getURL(self):
+ self.startServer()
+ url = self.getURL()
+ self.assertTrue(url.startswith('http://127.0.0.1:'))
diff --git a/u1db/tests/test_vectorclock.py b/u1db/tests/test_vectorclock.py
new file mode 100644
index 00000000..72baf246
--- /dev/null
+++ b/u1db/tests/test_vectorclock.py
@@ -0,0 +1,121 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""VectorClockRev helper class tests."""
+
+from u1db import tests, vectorclock
+
+try:
+ from u1db.tests import c_backend_wrapper
+except ImportError:
+ c_backend_wrapper = None
+
+
+c_vectorclock_scenarios = []
+if c_backend_wrapper is not None:
+ c_vectorclock_scenarios.append(
+ ('c', {'create_vcr': c_backend_wrapper.VectorClockRev}))
+
+
+class TestVectorClockRev(tests.TestCase):
+
+ scenarios = [('py', {'create_vcr': vectorclock.VectorClockRev})
+ ] + c_vectorclock_scenarios
+
+ def assertIsNewer(self, newer_rev, older_rev):
+ new_vcr = self.create_vcr(newer_rev)
+ old_vcr = self.create_vcr(older_rev)
+ self.assertTrue(new_vcr.is_newer(old_vcr))
+ self.assertFalse(old_vcr.is_newer(new_vcr))
+
+ def assertIsConflicted(self, rev_a, rev_b):
+ vcr_a = self.create_vcr(rev_a)
+ vcr_b = self.create_vcr(rev_b)
+ self.assertFalse(vcr_a.is_newer(vcr_b))
+ self.assertFalse(vcr_b.is_newer(vcr_a))
+
+ def assertRoundTrips(self, rev):
+ self.assertEqual(rev, self.create_vcr(rev).as_str())
+
+ def test__is_newer_doc_rev(self):
+ self.assertIsNewer('test:1', None)
+ self.assertIsNewer('test:2', 'test:1')
+ self.assertIsNewer('other:2|test:1', 'other:1|test:1')
+ self.assertIsNewer('other:1|test:1', 'other:1')
+ self.assertIsNewer('a:2|b:1', 'b:1')
+ self.assertIsNewer('a:1|b:2', 'a:1')
+ self.assertIsConflicted('other:2|test:1', 'other:1|test:2')
+ self.assertIsConflicted('other:1|test:1', 'other:2')
+ self.assertIsConflicted('test:1', 'test:1')
+
+ def test_None(self):
+ vcr = self.create_vcr(None)
+ self.assertEqual('', vcr.as_str())
+
+ def test_round_trips(self):
+ self.assertRoundTrips('test:1')
+ self.assertRoundTrips('a:1|b:2')
+ self.assertRoundTrips('alternate:2|test:1')
+
+ def test_handles_sort_order(self):
+ self.assertEqual('a:1|b:2', self.create_vcr('b:2|a:1').as_str())
+ # Last one out of place
+ self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6',
+ self.create_vcr('f:6|a:1|b:2|c:3|d:4|e:5').as_str())
+ # Fully reversed
+ self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6',
+ self.create_vcr('f:6|e:5|d:4|c:3|b:2|a:1').as_str())
+
+ def assertIncrement(self, original, replica_uid, after_increment):
+ vcr = self.create_vcr(original)
+ vcr.increment(replica_uid)
+ self.assertEqual(after_increment, vcr.as_str())
+
+ def test_increment(self):
+ self.assertIncrement(None, 'test', 'test:1')
+ self.assertIncrement('test:1', 'test', 'test:2')
+
+ def test_increment_adds_uid(self):
+ self.assertIncrement('other:1', 'test', 'other:1|test:1')
+ self.assertIncrement('a:1|ab:2', 'aa', 'a:1|aa:1|ab:2')
+
+ def test_increment_update_partial(self):
+ self.assertIncrement('a:1|ab:2', 'a', 'a:2|ab:2')
+ self.assertIncrement('a:2|ab:2', 'ab', 'a:2|ab:3')
+
+ def test_increment_appends_uid(self):
+ self.assertIncrement('b:2', 'c', 'b:2|c:1')
+
+ def assertMaximize(self, rev1, rev2, maximized):
+ vcr1 = self.create_vcr(rev1)
+ vcr2 = self.create_vcr(rev2)
+ vcr1.maximize(vcr2)
+ self.assertEqual(maximized, vcr1.as_str())
+ # reset vcr1 to maximize the other way
+ vcr1 = self.create_vcr(rev1)
+ vcr2.maximize(vcr1)
+ self.assertEqual(maximized, vcr2.as_str())
+
+ def test_maximize(self):
+ self.assertMaximize(None, None, '')
+ self.assertMaximize(None, 'x:1', 'x:1')
+ self.assertMaximize('x:1', 'y:1', 'x:1|y:1')
+ self.assertMaximize('x:2', 'x:1', 'x:2')
+ self.assertMaximize('x:2', 'x:1|y:2', 'x:2|y:2')
+ self.assertMaximize('a:1|c:2|e:3', 'b:3|d:4|f:5',
+ 'a:1|b:3|c:2|d:4|e:3|f:5')
+
+load_tests = tests.load_with_scenarios
diff --git a/u1db/tests/testing-certs/Makefile b/u1db/tests/testing-certs/Makefile
new file mode 100644
index 00000000..2385e75b
--- /dev/null
+++ b/u1db/tests/testing-certs/Makefile
@@ -0,0 +1,35 @@
+CATOP=./demoCA
+ORIG_CONF=/usr/lib/ssl/openssl.cnf
+ELEVEN_YEARS=-days 4015
+
+init:
+ cp $(ORIG_CONF) ca.conf
+ install -d $(CATOP)
+ install -d $(CATOP)/certs
+ install -d $(CATOP)/crl
+ install -d $(CATOP)/newcerts
+ install -d $(CATOP)/private
+ touch $(CATOP)/index.txt
+ echo 01>$(CATOP)/crlnumber
+ @echo '**** Making CA certificate ...'
+ openssl req -nodes -new \
+ -newkey rsa -keyout $(CATOP)/private/cakey.pem \
+ -out $(CATOP)/careq.pem \
+ -multivalue-rdn \
+ -subj "/C=UK/ST=-/O=u1db LOCAL TESTING ONLY, DO NO TRUST/CN=u1db testing CA"
+ openssl ca -config ./ca.conf -create_serial \
+ -out $(CATOP)/cacert.pem $(ELEVEN_YEARS) -batch \
+ -keyfile $(CATOP)/private/cakey.pem -selfsign \
+ -extensions v3_ca -infiles $(CATOP)/careq.pem
+
+pems:
+ cp ./demoCA/cacert.pem .
+ openssl req -new -config ca.conf \
+ -multivalue-rdn \
+ -subj "/O=u1db LOCAL TESTING ONLY, DO NOT TRUST/CN=localhost" \
+ -nodes -keyout testing.key -out newreq.pem $(ELEVEN_YEARS)
+ openssl ca -batch -config ./ca.conf $(ELEVEN_YEARS) \
+ -policy policy_anything \
+ -out testing.cert -infiles newreq.pem
+
+.PHONY: init pems
diff --git a/u1db/tests/testing-certs/cacert.pem b/u1db/tests/testing-certs/cacert.pem
new file mode 100644
index 00000000..c019a730
--- /dev/null
+++ b/u1db/tests/testing-certs/cacert.pem
@@ -0,0 +1,58 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number:
+ e4:de:01:76:c4:78:78:7e
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA
+ Validity
+ Not Before: May 3 11:11:11 2012 GMT
+ Not After : May 1 11:11:11 2023 GMT
+ Subject: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:bc:91:a5:7f:7d:37:f7:06:c7:db:5b:83:6a:6b:
+ 63:c3:8b:5c:f7:84:4d:97:6d:d4:be:bf:e7:79:a8:
+ c1:03:57:ec:90:d4:20:e7:02:95:d9:a6:49:e3:f9:
+ 9a:ea:37:b9:b2:02:62:ab:40:d3:42:bb:4a:4e:a2:
+ 47:71:0f:1d:a2:c5:94:a1:cf:35:d3:23:32:42:c0:
+ 1e:8d:cb:08:58:fb:8a:5c:3e:ea:eb:d5:2c:ed:d6:
+ aa:09:b4:b5:7d:e3:45:c9:ae:c2:82:b2:ae:c0:81:
+ bc:24:06:65:a9:e7:e0:61:ac:25:ee:53:d3:d7:be:
+ 22:f7:00:a2:ad:c6:0e:3a:39
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Subject Key Identifier:
+ DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D
+ X509v3 Authority Key Identifier:
+ keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D
+
+ X509v3 Basic Constraints:
+ CA:TRUE
+ Signature Algorithm: sha1WithRSAEncryption
+ 72:9b:c1:f7:07:65:83:36:25:4e:01:2f:b7:4a:f2:a4:00:28:
+ 80:c7:56:2c:32:39:90:13:61:4b:bb:12:c5:44:9d:42:57:85:
+ 28:19:70:69:e1:43:c8:bd:11:f6:94:df:91:2d:c3:ea:82:8d:
+ b4:8f:5d:47:a3:00:99:53:29:93:27:6c:c5:da:c1:20:6f:ab:
+ ec:4a:be:34:f3:8f:02:e5:0c:c0:03:ac:2b:33:41:71:4f:0a:
+ 72:5a:b4:26:1a:7f:81:bc:c0:95:8a:06:87:a8:11:9f:5c:73:
+ 38:df:5a:69:40:21:29:ad:46:23:56:75:e1:e9:8b:10:18:4c:
+ 7b:54
+-----BEGIN CERTIFICATE-----
+MIICkjCCAfugAwIBAgIJAOTeAXbEeHh+MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV
+BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg
+T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x
+MjA1MDMxMTExMTFaFw0yMzA1MDExMTExMTFaMGIxCzAJBgNVBAYTAlVLMQowCAYD
+VQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcgT05MWSwgRE8gTk8g
+VFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTCBnzANBgkqhkiG9w0BAQEF
+AAOBjQAwgYkCgYEAvJGlf3039wbH21uDamtjw4tc94RNl23Uvr/neajBA1fskNQg
+5wKV2aZJ4/ma6je5sgJiq0DTQrtKTqJHcQ8dosWUoc810yMyQsAejcsIWPuKXD7q
+69Us7daqCbS1feNFya7CgrKuwIG8JAZlqefgYawl7lPT174i9wCircYOOjkCAwEA
+AaNQME4wHQYDVR0OBBYEFNs9k1FsMhVUjxBQ/ElPNhUou5VtMB8GA1UdIwQYMBaA
+FNs9k1FsMhVUjxBQ/ElPNhUou5VtMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF
+BQADgYEAcpvB9wdlgzYlTgEvt0rypAAogMdWLDI5kBNhS7sSxUSdQleFKBlwaeFD
+yL0R9pTfkS3D6oKNtI9dR6MAmVMpkydsxdrBIG+r7Eq+NPOPAuUMwAOsKzNBcU8K
+clq0Jhp/gbzAlYoGh6gRn1xzON9aaUAhKa1GI1Z14emLEBhMe1Q=
+-----END CERTIFICATE-----
diff --git a/u1db/tests/testing-certs/testing.cert b/u1db/tests/testing-certs/testing.cert
new file mode 100644
index 00000000..985684fb
--- /dev/null
+++ b/u1db/tests/testing-certs/testing.cert
@@ -0,0 +1,61 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number:
+ e4:de:01:76:c4:78:78:7f
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA
+ Validity
+ Not Before: May 3 11:11:14 2012 GMT
+ Not After : May 1 11:11:14 2023 GMT
+ Subject: O=u1db LOCAL TESTING ONLY, DO NOT TRUST, CN=localhost
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:c6:1d:72:d3:c5:e4:fc:d1:4c:d9:e4:08:3e:90:
+ 10:ce:3f:1f:87:4a:1d:4f:7f:2a:5a:52:c9:65:4f:
+ d9:2c:bf:69:75:18:1a:b5:c9:09:32:00:47:f5:60:
+ aa:c6:dd:3a:87:37:5f:16:be:de:29:b5:ea:fc:41:
+ 7e:eb:77:bb:df:63:c3:06:1e:ed:e9:a0:67:1a:f1:
+ ec:e1:9d:f7:9c:8f:1c:fa:c3:66:7b:39:dc:70:ae:
+ 09:1b:9c:c0:9a:c4:90:77:45:8e:39:95:a9:2f:92:
+ 43:bd:27:07:5a:99:51:6e:76:a0:af:dd:b1:2c:8f:
+ ca:8b:8c:47:0d:f6:6e:fc:69
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Basic Constraints:
+ CA:FALSE
+ Netscape Comment:
+ OpenSSL Generated Certificate
+ X509v3 Subject Key Identifier:
+ 1C:63:85:E1:1D:F3:89:2E:6C:4E:3F:FB:D0:10:64:5A:C1:22:6A:2A
+ X509v3 Authority Key Identifier:
+ keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D
+
+ Signature Algorithm: sha1WithRSAEncryption
+ 1d:6d:3e:bd:93:fd:bd:3e:17:b8:9f:f0:99:7f:db:50:5c:b2:
+ 01:42:03:b5:d5:94:05:d3:f6:8e:80:82:55:47:1f:58:f2:18:
+ 6c:ab:ef:43:2c:2f:10:e1:7c:c4:5c:cc:ac:50:50:22:42:aa:
+ 35:33:f5:b9:f3:a6:66:55:d9:36:f4:f2:e4:d4:d9:b5:2c:52:
+ 66:d4:21:17:97:22:b8:9b:d7:0e:7c:3d:ce:85:19:ca:c4:d2:
+ 58:62:31:c6:18:3e:44:fc:f4:30:b6:95:87:ee:21:4a:08:f0:
+ af:3c:8f:c4:ba:5e:a1:5c:37:1a:7d:7b:fe:66:ae:62:50:17:
+ 31:ca
+-----BEGIN CERTIFICATE-----
+MIICnzCCAgigAwIBAgIJAOTeAXbEeHh/MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV
+BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg
+T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x
+MjA1MDMxMTExMTRaFw0yMzA1MDExMTExMTRaMEQxLjAsBgNVBAoMJXUxZGIgTE9D
+QUwgVEVTVElORyBPTkxZLCBETyBOT1QgVFJVU1QxEjAQBgNVBAMMCWxvY2FsaG9z
+dDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAxh1y08Xk/NFM2eQIPpAQzj8f
+h0odT38qWlLJZU/ZLL9pdRgatckJMgBH9WCqxt06hzdfFr7eKbXq/EF+63e732PD
+Bh7t6aBnGvHs4Z33nI8c+sNmeznccK4JG5zAmsSQd0WOOZWpL5JDvScHWplRbnag
+r92xLI/Ki4xHDfZu/GkCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0E
+HxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFBxjheEd
+84kubE4/+9AQZFrBImoqMB8GA1UdIwQYMBaAFNs9k1FsMhVUjxBQ/ElPNhUou5Vt
+MA0GCSqGSIb3DQEBBQUAA4GBAB1tPr2T/b0+F7if8Jl/21BcsgFCA7XVlAXT9o6A
+glVHH1jyGGyr70MsLxDhfMRczKxQUCJCqjUz9bnzpmZV2Tb08uTU2bUsUmbUIReX
+Irib1w58Pc6FGcrE0lhiMcYYPkT89DC2lYfuIUoI8K88j8S6XqFcNxp9e/5mrmJQ
+FzHK
+-----END CERTIFICATE-----
diff --git a/u1db/tests/testing-certs/testing.key b/u1db/tests/testing-certs/testing.key
new file mode 100644
index 00000000..d83d4920
--- /dev/null
+++ b/u1db/tests/testing-certs/testing.key
@@ -0,0 +1,16 @@
+-----BEGIN PRIVATE KEY-----
+MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMYdctPF5PzRTNnk
+CD6QEM4/H4dKHU9/KlpSyWVP2Sy/aXUYGrXJCTIAR/VgqsbdOoc3Xxa+3im16vxB
+fut3u99jwwYe7emgZxrx7OGd95yPHPrDZns53HCuCRucwJrEkHdFjjmVqS+SQ70n
+B1qZUW52oK/dsSyPyouMRw32bvxpAgMBAAECgYBs3lXxhjg1rhabTjIxnx19GTcM
+M3Az9V+izweZQu3HJ1CeZiaXauhAr+LbNsniCkRVddotN6oCJdQB10QVxXBZc9Jz
+HPJ4zxtZfRZlNMTMmG7eLWrfxpgWnb/BUjDb40yy1nhr9yhDUnI/8RoHDRHnAEHZ
+/CnHGUrqcVcrY5zJAQJBAPLhBJg9W88JVmcOKdWxRgs7dLHnZb999Kv1V5mczmAi
+jvGvbUmucqOqke6pTUHNYyNHqU6pySzGUi2cH+BAkFECQQDQ0VoAOysg6FVoT15v
+tGh57t5sTiCZZ7PS8jwvtThsgA+vcf6c16XWzXgjGXSap4r2QDOY2rI5lsWLaQ8T
++fyZAkAfyFJRmbXp4c7srW3MCOahkaYzoZQu+syJtBFCiMJ40gzik5I5khpuUGPI
+V19EvRu8AiSlppIsycb3MPb64XgBAkEAy7DrUf5le5wmc7G4NM6OeyJ+5LbxJbL6
+vnJ8My1a9LuWkVVpQCU7J+UVo2dZTuLPspW9vwTVhUeFOxAoHRxlQQJAFem93f7m
+el2BkB2EFqU3onPejkZ5UrDmfmeOQR1axMQNSXqSxcJxqa16Ru1BWV2gcWRbwajQ
+oc+kuJThu/r/Ug==
+-----END PRIVATE KEY-----
diff --git a/u1db/vectorclock.py b/u1db/vectorclock.py
new file mode 100644
index 00000000..42bceaa8
--- /dev/null
+++ b/u1db/vectorclock.py
@@ -0,0 +1,89 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""VectorClockRev helper class."""
+
+
+class VectorClockRev(object):
+ """Track vector clocks for multiple replica ids.
+
+ This allows simple comparison to determine if one VectorClockRev is
+ newer/older/in-conflict-with another VectorClockRev without having to
+ examine history. Every replica has a strictly increasing revision. When
+ creating a new revision, they include all revisions for all other replicas
+ which the new revision dominates, and increment their own revision to
+ something greater than the current value.
+ """
+
+ def __init__(self, value):
+ self._values = self._expand(value)
+
+ def __repr__(self):
+ s = self.as_str()
+ return '%s(%s)' % (self.__class__.__name__, s)
+
+ def as_str(self):
+ s = '|'.join(['%s:%d' % (m, r) for m, r
+ in sorted(self._values.items())])
+ return s
+
+ def _expand(self, value):
+ result = {}
+ if value is None:
+ return result
+ for replica_info in value.split('|'):
+ replica_uid, counter = replica_info.split(':')
+ counter = int(counter)
+ result[replica_uid] = counter
+ return result
+
+ def is_newer(self, other):
+ """Is this VectorClockRev strictly newer than other.
+ """
+ if not self._values:
+ return False
+ if not other._values:
+ return True
+ this_is_newer = False
+ other_expand = dict(other._values)
+ for key, value in self._values.iteritems():
+ if key in other_expand:
+ other_value = other_expand.pop(key)
+ if other_value > value:
+ return False
+ elif other_value < value:
+ this_is_newer = True
+ else:
+ this_is_newer = True
+ if other_expand:
+ return False
+ return this_is_newer
+
+ def increment(self, replica_uid):
+ """Increase the 'replica_uid' section of this vector clock.
+
+ :return: A string representing the new vector clock value
+ """
+ self._values[replica_uid] = self._values.get(replica_uid, 0) + 1
+
+ def maximize(self, other_vcr):
+ for replica_uid, counter in other_vcr._values.iteritems():
+ if replica_uid not in self._values:
+ self._values[replica_uid] = counter
+ else:
+ this_counter = self._values[replica_uid]
+ if this_counter < counter:
+ self._values[replica_uid] = counter