From 435728098255a8bdc89a9191d2b774a911fa59a0 Mon Sep 17 00:00:00 2001 From: drebs Date: Tue, 11 Dec 2012 15:03:12 -0200 Subject: Fix SQLCipherDatabase and add tests. --- backends/sqlcipher.py | 5 +-- tests/__init__.py | 55 ++++++++++++++++++++++++++++++++ tests/test_sqlcipher.py | 84 ++++++++++++++++++++++++++----------------------- 3 files changed, 102 insertions(+), 42 deletions(-) diff --git a/backends/sqlcipher.py b/backends/sqlcipher.py index fcdab251..301d4a7f 100644 --- a/backends/sqlcipher.py +++ b/backends/sqlcipher.py @@ -60,7 +60,8 @@ def open(path, create, document_factory=None, password=None): class SQLCipherDatabase(SQLitePartialExpandDatabase): """A U1DB implementation that uses SQLCipher as its persistence layer.""" - _sqlite_registry = {} + _index_storage_value = 'expand referenced encrypted' + @classmethod def set_pragma_key(cls, db_handle, key): @@ -113,7 +114,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase): raise if backend_cls is None: # default is SQLCipherPartialExpandDatabase - backend_cls = SQLCipherPartialExpandDatabase + backend_cls = SQLCipherDatabase return backend_cls(sqlite_file, document_factory=document_factory, password=password) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..7918b265 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,55 @@ +import unittest2 as unittest +import tempfile +import shutil + +class TestCase(unittest.TestCase): + + def createTempDir(self, prefix='u1db-tmp-'): + """Create a temporary directory to do some work in. + + This directory will be scheduled for cleanup when the test ends. + """ + tempdir = tempfile.mkdtemp(prefix=prefix) + self.addCleanup(shutil.rmtree, tempdir) + return tempdir + + def make_document(self, doc_id, doc_rev, content, has_conflicts=False): + return self.make_document_for_test( + self, doc_id, doc_rev, content, has_conflicts) + + def make_document_for_test(self, test, doc_id, doc_rev, content, + has_conflicts): + return make_document_for_test( + test, doc_id, doc_rev, content, has_conflicts) + + def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id)) + + def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, + has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + + def assertGetDocConflicts(self, db, doc_id, conflicts): + """Assert what conflicts are stored for a given doc_id. + + :param conflicts: A list of (doc_rev, content) pairs. + The first item must match the first item returned from the + database, however the rest can be returned in any order. + """ + if conflicts: + conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) + else cont)) for (rev, cont) in conflicts] + conflicts = conflicts[:1] + sorted(conflicts[1:]) + actual = db.get_doc_conflicts(doc_id) + if actual: + actual = [(doc.rev, (json.loads(doc.get_json()) + if doc.get_json() is not None else None)) for doc in actual] + actual = actual[:1] + sorted(actual[1:]) + self.assertEqual(conflicts, actual) + diff --git a/tests/test_sqlcipher.py b/tests/test_sqlcipher.py index 46f27f73..e35a6d90 100644 --- a/tests/test_sqlcipher.py +++ b/tests/test_sqlcipher.py @@ -19,16 +19,17 @@ import os import time import threading +import unittest2 as unittest from sqlite3 import dbapi2 from u1db import ( errors, - tests, query_parser, ) -from u1db.backends import sqlite_backend -from u1db.tests.test_backends import TestAlternativeDocument +from soledad.backends import sqlcipher +from soledad.backends.leap import LeapDocument +from soledad import tests simple_doc = '{"key": "value"}' @@ -43,7 +44,7 @@ class TestSQLiteDatabase(tests.TestCase): t2 = None # will be a thread - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): _index_storage_value = "testing" def __init__(self, dbname, ntry): @@ -84,7 +85,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def setUp(self): super(TestSQLitePartialExpandDatabase, self).setUp() - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') self.db._set_replica_uid('test') def test_create_database(self): @@ -92,7 +93,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): self.assertNotEqual(None, raw_db) def test_default_replica_uid(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') self.assertIsNot(None, self.db._replica_uid) self.assertEqual(32, len(self.db._replica_uid)) int(self.db._replica_uid, 16) @@ -109,7 +110,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): c.execute("SELECT * FROM u1db_config") config = dict([(r[0], r[1]) for r in c.fetchall()]) self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', - 'index_storage': 'expand referenced'}, config) + 'index_storage': 'expand referenced encrypted'}, config) # These tables must exist, though we don't care what is in them yet c.execute("SELECT * FROM transaction_log") @@ -120,13 +121,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): c.execute("SELECT * FROM index_definitions") def test__parse_index(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') g = self.db._parse_index_definition('fieldname') self.assertIsInstance(g, query_parser.ExtractField) self.assertEqual(['fieldname'], g.field) def test__update_indexes(self): - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') g = self.db._parse_index_definition('fieldname') c = self.db._get_sqlite_handle().cursor() self.db._update_indexes('doc-id', {'fieldname': 'val'}, @@ -137,7 +138,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test__set_replica_uid(self): # Start from scratch, so that replica_uid isn't set. - self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db = sqlcipher.SQLCipherDatabase(':memory:') self.assertIsNot(None, self.db._real_replica_uid) self.assertIsNot(None, self.db._replica_uid) self.db._set_replica_uid('foo') @@ -239,7 +240,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): path = temp_dir + '/rollback.db' class SQLitePartialExpandDbTesting( - sqlite_backend.SQLitePartialExpandDatabase): + sqlcipher.SQLCipherDatabase): def _set_replica_uid_in_transaction(self, uid): super(SQLitePartialExpandDbTesting, @@ -257,34 +258,34 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test__open_database(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/test.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase._open_database(path) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase._open_database(path) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test__open_database_with_factory(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/test.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase._open_database( - path, document_factory=TestAlternativeDocument) - self.assertEqual(TestAlternativeDocument, db2._factory) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase._open_database( + path, document_factory=LeapDocument) + self.assertEqual(LeapDocument, db2._factory) def test__open_database_non_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase._open_database, path) + sqlcipher.SQLCipherDatabase._open_database, path) def test__open_database_during_init(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/initialised.db' - db = sqlite_backend.SQLitePartialExpandDatabase.__new__( - sqlite_backend.SQLitePartialExpandDatabase) + db = sqlcipher.SQLCipherDatabase.__new__( + sqlcipher.SQLCipherDatabase) db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed self.addCleanup(db.close) observed = [] - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 @classmethod @@ -296,13 +297,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): db2 = SQLiteDatabaseTesting._open_database(path) self.addCleanup(db2.close) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) self.assertEqual([None, - sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], + sqlcipher.SQLCipherDatabase._index_storage_value], observed) def test__open_database_invalid(self): - class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase): WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 temp_dir = self.createTempDir(prefix='u1db-test-') path1 = temp_dir + '/invalid1.db' @@ -318,47 +319,47 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): def test_open_database_existing(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/existing.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test_open_database_with_factory(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/existing.sqlite' - sqlite_backend.SQLitePartialExpandDatabase(path) - db2 = sqlite_backend.SQLiteDatabase.open_database( - path, create=False, document_factory=TestAlternativeDocument) - self.assertEqual(TestAlternativeDocument, db2._factory) + sqlcipher.SQLCipherDatabase(path) + db2 = sqlcipher.SQLCipherDatabase.open_database( + path, create=False, document_factory=LeapDocument) + self.assertEqual(LeapDocument, db2._factory) def test_open_database_create(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/new.sqlite' - sqlite_backend.SQLiteDatabase.open_database(path, create=True) - db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) - self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + sqlcipher.SQLCipherDatabase.open_database(path, create=True) + db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase) def test_open_database_non_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.open_database, path, + sqlcipher.SQLCipherDatabase.open_database, path, create=False) def test_delete_database_existent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/new.sqlite' - db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db = sqlcipher.SQLCipherDatabase.open_database(path, create=True) db.close() - sqlite_backend.SQLiteDatabase.delete_database(path) + sqlcipher.SQLCipherDatabase.delete_database(path) self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.open_database, path, + sqlcipher.SQLCipherDatabase.open_database, path, create=False) def test_delete_database_nonexistent(self): temp_dir = self.createTempDir(prefix='u1db-test-') path = temp_dir + '/non-existent.sqlite' self.assertRaises(errors.DatabaseDoesNotExist, - sqlite_backend.SQLiteDatabase.delete_database, path) + sqlcipher.SQLCipherDatabase.delete_database, path) def test__get_indexed_fields(self): self.db.create_index('idx1', 'a', 'b') @@ -492,3 +493,6 @@ class TestSQLitePartialExpandDatabase(tests.TestCase): 'key3'], ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3