summaryrefslogtreecommitdiff
path: root/tests/u1db_tests/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/u1db_tests/__init__.py')
-rw-r--r--tests/u1db_tests/__init__.py421
1 files changed, 0 insertions, 421 deletions
diff --git a/tests/u1db_tests/__init__.py b/tests/u1db_tests/__init__.py
deleted file mode 100644
index 43304b43..00000000
--- a/tests/u1db_tests/__init__.py
+++ /dev/null
@@ -1,421 +0,0 @@
-# Copyright 2011-2012 Canonical Ltd.
-#
-# This file is part of u1db.
-#
-# u1db is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License version 3
-# as published by the Free Software Foundation.
-#
-# u1db is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Lesser General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with u1db. If not, see <http://www.gnu.org/licenses/>.
-
-"""Test infrastructure for U1DB"""
-
-import copy
-import shutil
-import socket
-import tempfile
-import threading
-
-try:
- import simplejson as json
-except ImportError:
- import json # noqa
-
-from wsgiref import simple_server
-
-from oauth import oauth
-from pysqlcipher import dbapi2
-from StringIO import StringIO
-
-import testscenarios
-import testtools
-
-from u1db import (
- errors,
- Document,
-)
-from u1db.backends import (
- inmemory,
- sqlite_backend,
-)
-from u1db.remote import (
- server_state,
-)
-
-
-class TestCase(testtools.TestCase):
-
- def createTempDir(self, prefix='u1db-tmp-'):
- """Create a temporary directory to do some work in.
-
- This directory will be scheduled for cleanup when the test ends.
- """
- tempdir = tempfile.mkdtemp(prefix=prefix)
- self.addCleanup(shutil.rmtree, tempdir)
- return tempdir
-
- def make_document(self, doc_id, doc_rev, content, has_conflicts=False):
- return self.make_document_for_test(
- self, doc_id, doc_rev, content, has_conflicts)
-
- def make_document_for_test(self, test, doc_id, doc_rev, content,
- has_conflicts):
- return make_document_for_test(
- test, doc_id, doc_rev, content, has_conflicts)
-
- def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts):
- """Assert that the document in the database looks correct."""
- exp_doc = self.make_document(doc_id, doc_rev, content,
- has_conflicts=has_conflicts)
- self.assertEqual(exp_doc, db.get_doc(doc_id))
-
- def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content,
- has_conflicts):
- """Assert that the document in the database looks correct."""
- exp_doc = self.make_document(doc_id, doc_rev, content,
- has_conflicts=has_conflicts)
- self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True))
-
- def assertGetDocConflicts(self, db, doc_id, conflicts):
- """Assert what conflicts are stored for a given doc_id.
-
- :param conflicts: A list of (doc_rev, content) pairs.
- The first item must match the first item returned from the
- database, however the rest can be returned in any order.
- """
- if conflicts:
- conflicts = [(rev,
- (json.loads(cont) if isinstance(cont, basestring)
- else cont)) for (rev, cont) in conflicts]
- conflicts = conflicts[:1] + sorted(conflicts[1:])
- actual = db.get_doc_conflicts(doc_id)
- if actual:
- actual = [
- (doc.rev, (json.loads(doc.get_json())
- if doc.get_json() is not None else None))
- for doc in actual]
- actual = actual[:1] + sorted(actual[1:])
- self.assertEqual(conflicts, actual)
-
-
-def multiply_scenarios(a_scenarios, b_scenarios):
- """Create the cross-product of scenarios."""
-
- all_scenarios = []
- for a_name, a_attrs in a_scenarios:
- for b_name, b_attrs in b_scenarios:
- name = '%s,%s' % (a_name, b_name)
- attrs = dict(a_attrs)
- attrs.update(b_attrs)
- all_scenarios.append((name, attrs))
- return all_scenarios
-
-
-simple_doc = '{"key": "value"}'
-nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
-
-
-def make_memory_database_for_test(test, replica_uid):
- return inmemory.InMemoryDatabase(replica_uid)
-
-
-def copy_memory_database_for_test(test, db):
- # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
- # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
- # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
- # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
- # HOUSE.
- new_db = inmemory.InMemoryDatabase(db._replica_uid)
- new_db._transaction_log = db._transaction_log[:]
- new_db._docs = copy.deepcopy(db._docs)
- new_db._conflicts = copy.deepcopy(db._conflicts)
- new_db._indexes = copy.deepcopy(db._indexes)
- new_db._factory = db._factory
- return new_db
-
-
-def make_sqlite_partial_expanded_for_test(test, replica_uid):
- db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
- db._set_replica_uid(replica_uid)
- return db
-
-
-def copy_sqlite_partial_expanded_for_test(test, db):
- # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
- # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
- # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
- # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
- # HOUSE.
- new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
- tmpfile = StringIO()
- for line in db._db_handle.iterdump():
- if not 'sqlite_sequence' in line: # work around bug in iterdump
- tmpfile.write('%s\n' % line)
- tmpfile.seek(0)
- new_db._db_handle = dbapi2.connect(':memory:')
- new_db._db_handle.cursor().executescript(tmpfile.read())
- new_db._db_handle.commit()
- new_db._set_replica_uid(db._replica_uid)
- new_db._factory = db._factory
- return new_db
-
-
-def make_document_for_test(test, doc_id, rev, content, has_conflicts=False):
- return Document(doc_id, rev, content, has_conflicts=has_conflicts)
-
-
-LOCAL_DATABASES_SCENARIOS = [
- ('mem', {'make_database_for_test': make_memory_database_for_test,
- 'copy_database_for_test': copy_memory_database_for_test,
- 'make_document_for_test': make_document_for_test}),
- ('sql', {'make_database_for_test':
- make_sqlite_partial_expanded_for_test,
- 'copy_database_for_test':
- copy_sqlite_partial_expanded_for_test,
- 'make_document_for_test': make_document_for_test}),
-]
-
-
-class DatabaseBaseTests(TestCase):
-
- accept_fixed_trans_id = False # set to True assertTransactionLog
- # is happy with all trans ids = ''
-
- scenarios = LOCAL_DATABASES_SCENARIOS
-
- def create_database(self, replica_uid):
- return self.make_database_for_test(self, replica_uid)
-
- def copy_database(self, db):
- # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
- # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
- # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
- # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
- # NINJA TO YOUR HOUSE.
- return self.copy_database_for_test(self, db)
-
- def setUp(self):
- super(DatabaseBaseTests, self).setUp()
- self.db = self.create_database('test')
-
- def tearDown(self):
- # TODO: Add close_database parameterization
- # self.close_database(self.db)
- super(DatabaseBaseTests, self).tearDown()
-
- def assertTransactionLog(self, doc_ids, db):
- """Assert that the given docs are in the transaction log."""
- log = db._get_transaction_log()
- just_ids = []
- seen_transactions = set()
- for doc_id, transaction_id in log:
- just_ids.append(doc_id)
- self.assertIsNot(None, transaction_id,
- "Transaction id should not be None")
- if transaction_id == '' and self.accept_fixed_trans_id:
- continue
- self.assertNotEqual('', transaction_id,
- "Transaction id should be a unique string")
- self.assertTrue(transaction_id.startswith('T-'))
- self.assertNotIn(transaction_id, seen_transactions)
- seen_transactions.add(transaction_id)
- self.assertEqual(doc_ids, just_ids)
-
- def getLastTransId(self, db):
- """Return the transaction id for the last database update."""
- return self.db._get_transaction_log()[-1][-1]
-
-
-class ServerStateForTests(server_state.ServerState):
- """Used in the test suite, so we don't have to touch disk, etc."""
-
- def __init__(self):
- super(ServerStateForTests, self).__init__()
- self._dbs = {}
-
- def open_database(self, path):
- try:
- return self._dbs[path]
- except KeyError:
- raise errors.DatabaseDoesNotExist
-
- def check_database(self, path):
- # cares only about the possible exception
- self.open_database(path)
-
- def ensure_database(self, path):
- try:
- db = self.open_database(path)
- except errors.DatabaseDoesNotExist:
- db = self._create_database(path)
- return db, db._replica_uid
-
- def _copy_database(self, db):
- # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
- # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
- # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
- # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
- # NINJA TO YOUR HOUSE.
- new_db = copy_memory_database_for_test(None, db)
- path = db._replica_uid
- while path in self._dbs:
- path += 'copy'
- self._dbs[path] = new_db
- return new_db
-
- def _create_database(self, path):
- db = inmemory.InMemoryDatabase(path)
- self._dbs[path] = db
- return db
-
- def delete_database(self, path):
- del self._dbs[path]
-
-
-class ResponderForTests(object):
- """Responder for tests."""
- _started = False
- sent_response = False
- status = None
-
- def start_response(self, status='success', **kwargs):
- self._started = True
- self.status = status
- self.kwargs = kwargs
-
- def send_response(self, status='success', **kwargs):
- self.start_response(status, **kwargs)
- self.finish_response()
-
- def finish_response(self):
- self.sent_response = True
-
-
-class TestCaseWithServer(TestCase):
-
- @staticmethod
- def server_def():
- # hook point
- # should return (ServerClass, "shutdown method name", "url_scheme")
- class _RequestHandler(simple_server.WSGIRequestHandler):
- def log_request(*args):
- pass # suppress
-
- def make_server(host_port, application):
- assert application, "forgot to override make_app(_with_state)?"
- srv = simple_server.WSGIServer(host_port, _RequestHandler)
- # patch the value in if it's None
- if getattr(application, 'base_url', 1) is None:
- application.base_url = "http://%s:%s" % srv.server_address
- srv.set_app(application)
- return srv
-
- return make_server, "shutdown", "http"
-
- @staticmethod
- def make_app_with_state(state):
- # hook point
- return None
-
- def make_app(self):
- # potential hook point
- self.request_state = ServerStateForTests()
- return self.make_app_with_state(self.request_state)
-
- def setUp(self):
- super(TestCaseWithServer, self).setUp()
- self.server = self.server_thread = None
-
- @property
- def url_scheme(self):
- return self.server_def()[-1]
-
- def startServer(self):
- server_def = self.server_def()
- server_class, shutdown_meth, _ = server_def
- application = self.make_app()
- self.server = server_class(('127.0.0.1', 0), application)
- self.server_thread = threading.Thread(target=self.server.serve_forever,
- kwargs=dict(poll_interval=0.01))
- self.server_thread.start()
- self.addCleanup(self.server_thread.join)
- self.addCleanup(getattr(self.server, shutdown_meth))
-
- def getURL(self, path=None):
- host, port = self.server.server_address
- if path is None:
- path = ''
- return '%s://%s:%s/%s' % (self.url_scheme, host, port, path)
-
-
-def socket_pair():
- """Return a pair of TCP sockets connected to each other.
-
- Unlike socket.socketpair, this should work on Windows.
- """
- sock_pair = getattr(socket, 'socket_pair', None)
- if sock_pair:
- return sock_pair(socket.AF_INET, socket.SOCK_STREAM)
- listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- listen_sock.bind(('127.0.0.1', 0))
- listen_sock.listen(1)
- client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- client_sock.connect(listen_sock.getsockname())
- server_sock, addr = listen_sock.accept()
- listen_sock.close()
- return server_sock, client_sock
-
-
-# OAuth related testing
-
-consumer1 = oauth.OAuthConsumer('K1', 'S1')
-token1 = oauth.OAuthToken('kkkk1', 'XYZ')
-consumer2 = oauth.OAuthConsumer('K2', 'S2')
-token2 = oauth.OAuthToken('kkkk2', 'ZYX')
-token3 = oauth.OAuthToken('kkkk3', 'ZYX')
-
-
-class TestingOAuthDataStore(oauth.OAuthDataStore):
- """In memory predefined OAuthDataStore for testing."""
-
- consumers = {
- consumer1.key: consumer1,
- consumer2.key: consumer2,
- }
-
- tokens = {
- token1.key: token1,
- token2.key: token2
- }
-
- def lookup_consumer(self, key):
- return self.consumers.get(key)
-
- def lookup_token(self, token_type, token_token):
- return self.tokens.get(token_token)
-
- def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
- return None
-
-testingOAuthStore = TestingOAuthDataStore()
-
-sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
-sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
-
-
-def load_with_scenarios(loader, standard_tests, pattern):
- """Load the tests in a given module.
-
- This just applies testscenarios.generate_scenarios to all the tests that
- are present. We do it at load time rather than at run time, because it
- plays nicer with various tools.
- """
- suite = loader.suiteClass()
- suite.addTests(testscenarios.generate_scenarios(standard_tests))
- return suite