# Copyright 2011-2012 Canonical Ltd. # # This file is part of u1db. # # u1db is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 # as published by the Free Software Foundation. # # u1db is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with u1db. If not, see <http://www.gnu.org/licenses/>. """Test infrastructure for U1DB""" import copy import shutil import socket import tempfile import threading try: import simplejson as json except ImportError: import json # noqa from wsgiref import simple_server from oauth import oauth from pysqlcipher import dbapi2 from StringIO import StringIO import testscenarios import testtools from u1db import ( errors, Document, ) from u1db.backends import ( inmemory, sqlite_backend, ) from u1db.remote import ( server_state, ) class TestCase(testtools.TestCase): def createTempDir(self, prefix='u1db-tmp-'): """Create a temporary directory to do some work in. This directory will be scheduled for cleanup when the test ends. """ tempdir = tempfile.mkdtemp(prefix=prefix) self.addCleanup(shutil.rmtree, tempdir) return tempdir def make_document(self, doc_id, doc_rev, content, has_conflicts=False): return self.make_document_for_test( self, doc_id, doc_rev, content, has_conflicts) def make_document_for_test(self, test, doc_id, doc_rev, content, has_conflicts): return make_document_for_test( test, doc_id, doc_rev, content, has_conflicts) def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): """Assert that the document in the database looks correct.""" exp_doc = self.make_document(doc_id, doc_rev, content, has_conflicts=has_conflicts) self.assertEqual(exp_doc, db.get_doc(doc_id)) def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, has_conflicts): """Assert that the document in the database looks correct.""" exp_doc = self.make_document(doc_id, doc_rev, content, has_conflicts=has_conflicts) self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) def assertGetDocConflicts(self, db, doc_id, conflicts): """Assert what conflicts are stored for a given doc_id. :param conflicts: A list of (doc_rev, content) pairs. The first item must match the first item returned from the database, however the rest can be returned in any order. """ if conflicts: conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) else cont)) for (rev, cont) in conflicts] conflicts = conflicts[:1] + sorted(conflicts[1:]) actual = db.get_doc_conflicts(doc_id) if actual: actual = [ (doc.rev, (json.loads(doc.get_json()) if doc.get_json() is not None else None)) for doc in actual] actual = actual[:1] + sorted(actual[1:]) self.assertEqual(conflicts, actual) def multiply_scenarios(a_scenarios, b_scenarios): """Create the cross-product of scenarios.""" all_scenarios = [] for a_name, a_attrs in a_scenarios: for b_name, b_attrs in b_scenarios: name = '%s,%s' % (a_name, b_name) attrs = dict(a_attrs) attrs.update(b_attrs) all_scenarios.append((name, attrs)) return all_scenarios simple_doc = '{"key": "value"}' nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' def make_memory_database_for_test(test, replica_uid): return inmemory.InMemoryDatabase(replica_uid) def copy_memory_database_for_test(test, db): # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR # HOUSE. new_db = inmemory.InMemoryDatabase(db._replica_uid) new_db._transaction_log = db._transaction_log[:] new_db._docs = copy.deepcopy(db._docs) new_db._conflicts = copy.deepcopy(db._conflicts) new_db._indexes = copy.deepcopy(db._indexes) new_db._factory = db._factory return new_db def make_sqlite_partial_expanded_for_test(test, replica_uid): db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') db._set_replica_uid(replica_uid) return db def copy_sqlite_partial_expanded_for_test(test, db): # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR # HOUSE. new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') tmpfile = StringIO() for line in db._db_handle.iterdump(): if not 'sqlite_sequence' in line: # work around bug in iterdump tmpfile.write('%s\n' % line) tmpfile.seek(0) new_db._db_handle = dbapi2.connect(':memory:') new_db._db_handle.cursor().executescript(tmpfile.read()) new_db._db_handle.commit() new_db._set_replica_uid(db._replica_uid) new_db._factory = db._factory return new_db def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): return Document(doc_id, rev, content, has_conflicts=has_conflicts) LOCAL_DATABASES_SCENARIOS = [ ('mem', {'make_database_for_test': make_memory_database_for_test, 'copy_database_for_test': copy_memory_database_for_test, 'make_document_for_test': make_document_for_test}), ('sql', {'make_database_for_test': make_sqlite_partial_expanded_for_test, 'copy_database_for_test': copy_sqlite_partial_expanded_for_test, 'make_document_for_test': make_document_for_test}), ] class DatabaseBaseTests(TestCase): accept_fixed_trans_id = False # set to True assertTransactionLog # is happy with all trans ids = '' scenarios = LOCAL_DATABASES_SCENARIOS def create_database(self, replica_uid): return self.make_database_for_test(self, replica_uid) def copy_database(self, db): # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND # NINJA TO YOUR HOUSE. return self.copy_database_for_test(self, db) def setUp(self): super(DatabaseBaseTests, self).setUp() self.db = self.create_database('test') def tearDown(self): # TODO: Add close_database parameterization # self.close_database(self.db) super(DatabaseBaseTests, self).tearDown() def assertTransactionLog(self, doc_ids, db): """Assert that the given docs are in the transaction log.""" log = db._get_transaction_log() just_ids = [] seen_transactions = set() for doc_id, transaction_id in log: just_ids.append(doc_id) self.assertIsNot(None, transaction_id, "Transaction id should not be None") if transaction_id == '' and self.accept_fixed_trans_id: continue self.assertNotEqual('', transaction_id, "Transaction id should be a unique string") self.assertTrue(transaction_id.startswith('T-')) self.assertNotIn(transaction_id, seen_transactions) seen_transactions.add(transaction_id) self.assertEqual(doc_ids, just_ids) def getLastTransId(self, db): """Return the transaction id for the last database update.""" return self.db._get_transaction_log()[-1][-1] class ServerStateForTests(server_state.ServerState): """Used in the test suite, so we don't have to touch disk, etc.""" def __init__(self): super(ServerStateForTests, self).__init__() self._dbs = {} def open_database(self, path): try: return self._dbs[path] except KeyError: raise errors.DatabaseDoesNotExist def check_database(self, path): # cares only about the possible exception self.open_database(path) def ensure_database(self, path): try: db = self.open_database(path) except errors.DatabaseDoesNotExist: db = self._create_database(path) return db, db._replica_uid def _copy_database(self, db): # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND # NINJA TO YOUR HOUSE. new_db = copy_memory_database_for_test(None, db) path = db._replica_uid while path in self._dbs: path += 'copy' self._dbs[path] = new_db return new_db def _create_database(self, path): db = inmemory.InMemoryDatabase(path) self._dbs[path] = db return db def delete_database(self, path): del self._dbs[path] class ResponderForTests(object): """Responder for tests.""" _started = False sent_response = False status = None def start_response(self, status='success', **kwargs): self._started = True self.status = status self.kwargs = kwargs def send_response(self, status='success', **kwargs): self.start_response(status, **kwargs) self.finish_response() def finish_response(self): self.sent_response = True class TestCaseWithServer(TestCase): @staticmethod def server_def(): # hook point # should return (ServerClass, "shutdown method name", "url_scheme") class _RequestHandler(simple_server.WSGIRequestHandler): def log_request(*args): pass # suppress def make_server(host_port, application): assert application, "forgot to override make_app(_with_state)?" srv = simple_server.WSGIServer(host_port, _RequestHandler) # patch the value in if it's None if getattr(application, 'base_url', 1) is None: application.base_url = "http://%s:%s" % srv.server_address srv.set_app(application) return srv return make_server, "shutdown", "http" @staticmethod def make_app_with_state(state): # hook point return None def make_app(self): # potential hook point self.request_state = ServerStateForTests() return self.make_app_with_state(self.request_state) def setUp(self): super(TestCaseWithServer, self).setUp() self.server = self.server_thread = None @property def url_scheme(self): return self.server_def()[-1] def startServer(self): server_def = self.server_def() server_class, shutdown_meth, _ = server_def application = self.make_app() self.server = server_class(('127.0.0.1', 0), application) self.server_thread = threading.Thread(target=self.server.serve_forever, kwargs=dict(poll_interval=0.01)) self.server_thread.start() self.addCleanup(self.server_thread.join) self.addCleanup(getattr(self.server, shutdown_meth)) def getURL(self, path=None): host, port = self.server.server_address if path is None: path = '' return '%s://%s:%s/%s' % (self.url_scheme, host, port, path) def socket_pair(): """Return a pair of TCP sockets connected to each other. Unlike socket.socketpair, this should work on Windows. """ sock_pair = getattr(socket, 'socket_pair', None) if sock_pair: return sock_pair(socket.AF_INET, socket.SOCK_STREAM) listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(1) client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client_sock.connect(listen_sock.getsockname()) server_sock, addr = listen_sock.accept() listen_sock.close() return server_sock, client_sock # OAuth related testing consumer1 = oauth.OAuthConsumer('K1', 'S1') token1 = oauth.OAuthToken('kkkk1', 'XYZ') consumer2 = oauth.OAuthConsumer('K2', 'S2') token2 = oauth.OAuthToken('kkkk2', 'ZYX') token3 = oauth.OAuthToken('kkkk3', 'ZYX') class TestingOAuthDataStore(oauth.OAuthDataStore): """In memory predefined OAuthDataStore for testing.""" consumers = { consumer1.key: consumer1, consumer2.key: consumer2, } tokens = { token1.key: token1, token2.key: token2 } def lookup_consumer(self, key): return self.consumers.get(key) def lookup_token(self, token_type, token_token): return self.tokens.get(token_token) def lookup_nonce(self, oauth_consumer, oauth_token, nonce): return None testingOAuthStore = TestingOAuthDataStore() sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() def load_with_scenarios(loader, standard_tests, pattern): """Load the tests in a given module. This just applies testscenarios.generate_scenarios to all the tests that are present. We do it at load time rather than at run time, because it plays nicer with various tools. """ suite = loader.suiteClass() suite.addTests(testscenarios.generate_scenarios(standard_tests)) return suite