diff options
| author | drebs <drebs@leap.se> | 2012-12-11 15:03:12 -0200 | 
|---|---|---|
| committer | drebs <drebs@leap.se> | 2012-12-11 15:03:12 -0200 | 
| commit | a12b80b23695dd1db8ac5edeb4b79e6ff8e527c2 (patch) | |
| tree | 24369aacaf3c305fcbe9745907579b5493a1c498 | |
| parent | 7823990656ac65982a1322ea049298350fb2185e (diff) | |
Fix SQLCipherDatabase and add tests.
| -rw-r--r-- | src/leap/soledad/backends/sqlcipher.py | 5 | ||||
| -rw-r--r-- | src/leap/soledad/tests/__init__.py | 55 | ||||
| -rw-r--r-- | src/leap/soledad/tests/test_sqlcipher.py | 84 | 
3 files changed, 102 insertions, 42 deletions
| diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py index fcdab251..301d4a7f 100644 --- a/src/leap/soledad/backends/sqlcipher.py +++ b/src/leap/soledad/backends/sqlcipher.py @@ -60,7 +60,8 @@ def open(path, create, document_factory=None, password=None):  class SQLCipherDatabase(SQLitePartialExpandDatabase):      """A U1DB implementation that uses SQLCipher as its persistence layer.""" -    _sqlite_registry = {} +    _index_storage_value = 'expand referenced encrypted' +      @classmethod      def set_pragma_key(cls, db_handle, key): @@ -113,7 +114,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase):                  raise              if backend_cls is None:                  # default is SQLCipherPartialExpandDatabase -                backend_cls = SQLCipherPartialExpandDatabase +                backend_cls = SQLCipherDatabase              return backend_cls(sqlite_file, document_factory=document_factory,                                 password=password) diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py index e69de29b..7918b265 100644 --- a/src/leap/soledad/tests/__init__.py +++ b/src/leap/soledad/tests/__init__.py @@ -0,0 +1,55 @@ +import unittest2 as unittest +import tempfile +import shutil + +class TestCase(unittest.TestCase): + +    def createTempDir(self, prefix='u1db-tmp-'): +        """Create a temporary directory to do some work in. + +        This directory will be scheduled for cleanup when the test ends. +        """ +        tempdir = tempfile.mkdtemp(prefix=prefix) +        self.addCleanup(shutil.rmtree, tempdir) +        return tempdir + +    def make_document(self, doc_id, doc_rev, content, has_conflicts=False): +        return self.make_document_for_test( +            self, doc_id, doc_rev, content, has_conflicts) + +    def make_document_for_test(self, test, doc_id, doc_rev, content, +                               has_conflicts): +        return make_document_for_test( +            test, doc_id, doc_rev, content, has_conflicts) + +    def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): +        """Assert that the document in the database looks correct.""" +        exp_doc = self.make_document(doc_id, doc_rev, content, +                                     has_conflicts=has_conflicts) +        self.assertEqual(exp_doc, db.get_doc(doc_id)) + +    def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, +                                   has_conflicts): +        """Assert that the document in the database looks correct.""" +        exp_doc = self.make_document(doc_id, doc_rev, content, +                                     has_conflicts=has_conflicts) +        self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + +    def assertGetDocConflicts(self, db, doc_id, conflicts): +        """Assert what conflicts are stored for a given doc_id. + +        :param conflicts: A list of (doc_rev, content) pairs. +            The first item must match the first item returned from the +            database, however the rest can be returned in any order. +        """ +        if conflicts: +            conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) +                           else cont)) for (rev, cont) in conflicts] +            conflicts = conflicts[:1] + sorted(conflicts[1:]) +        actual = db.get_doc_conflicts(doc_id) +        if actual: +            actual = [(doc.rev, (json.loads(doc.get_json()) +                   if doc.get_json() is not None else None)) for doc in actual] +            actual = actual[:1] + sorted(actual[1:]) +        self.assertEqual(conflicts, actual) + diff --git a/src/leap/soledad/tests/test_sqlcipher.py b/src/leap/soledad/tests/test_sqlcipher.py index 46f27f73..e35a6d90 100644 --- a/src/leap/soledad/tests/test_sqlcipher.py +++ b/src/leap/soledad/tests/test_sqlcipher.py @@ -19,16 +19,17 @@  import os  import time  import threading +import unittest2 as unittest  from sqlite3 import dbapi2  from u1db import (      errors, -    tests,      query_parser,      ) -from u1db.backends import sqlite_backend -from u1db.tests.test_backends import TestAlternativeDocument +from soledad.backends import sqlcipher +from soledad.backends.leap import LeapDocument +from soledad import tests  simple_doc = '{"key": "value"}' @@ -43,7 +44,7 @@ class TestSQLiteDatabase(tests.TestCase):          t2 = None  # will be a thread -        class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): +        class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase):              _index_storage_value = "testing"              def __init__(self, dbname, ntry): @@ -84,7 +85,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):      def setUp(self):          super(TestSQLitePartialExpandDatabase, self).setUp() -        self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') +        self.db = sqlcipher.SQLCipherDatabase(':memory:')          self.db._set_replica_uid('test')      def test_create_database(self): @@ -92,7 +93,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):          self.assertNotEqual(None, raw_db)      def test_default_replica_uid(self): -        self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') +        self.db = sqlcipher.SQLCipherDatabase(':memory:')          self.assertIsNot(None, self.db._replica_uid)          self.assertEqual(32, len(self.db._replica_uid))          int(self.db._replica_uid, 16) @@ -109,7 +110,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):          c.execute("SELECT * FROM u1db_config")          config = dict([(r[0], r[1]) for r in c.fetchall()])          self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', -                          'index_storage': 'expand referenced'}, config) +                          'index_storage': 'expand referenced encrypted'}, config)          # These tables must exist, though we don't care what is in them yet          c.execute("SELECT * FROM transaction_log") @@ -120,13 +121,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):          c.execute("SELECT * FROM index_definitions")      def test__parse_index(self): -        self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') +        self.db = sqlcipher.SQLCipherDatabase(':memory:')          g = self.db._parse_index_definition('fieldname')          self.assertIsInstance(g, query_parser.ExtractField)          self.assertEqual(['fieldname'], g.field)      def test__update_indexes(self): -        self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') +        self.db = sqlcipher.SQLCipherDatabase(':memory:')          g = self.db._parse_index_definition('fieldname')          c = self.db._get_sqlite_handle().cursor()          self.db._update_indexes('doc-id', {'fieldname': 'val'}, @@ -137,7 +138,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):      def test__set_replica_uid(self):          # Start from scratch, so that replica_uid isn't set. -        self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') +        self.db = sqlcipher.SQLCipherDatabase(':memory:')          self.assertIsNot(None, self.db._real_replica_uid)          self.assertIsNot(None, self.db._replica_uid)          self.db._set_replica_uid('foo') @@ -239,7 +240,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):          path = temp_dir + '/rollback.db'          class SQLitePartialExpandDbTesting( -            sqlite_backend.SQLitePartialExpandDatabase): +            sqlcipher.SQLCipherDatabase):              def _set_replica_uid_in_transaction(self, uid):                  super(SQLitePartialExpandDbTesting, @@ -257,34 +258,34 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):      def test__open_database(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/test.sqlite' -        sqlite_backend.SQLitePartialExpandDatabase(path) -        db2 = sqlite_backend.SQLiteDatabase._open_database(path) -        self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) +        sqlcipher.SQLCipherDatabase(path) +        db2 = sqlcipher.SQLCipherDatabase._open_database(path) +        self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)      def test__open_database_with_factory(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/test.sqlite' -        sqlite_backend.SQLitePartialExpandDatabase(path) -        db2 = sqlite_backend.SQLiteDatabase._open_database( -            path, document_factory=TestAlternativeDocument) -        self.assertEqual(TestAlternativeDocument, db2._factory) +        sqlcipher.SQLCipherDatabase(path) +        db2 = sqlcipher.SQLCipherDatabase._open_database( +            path, document_factory=LeapDocument) +        self.assertEqual(LeapDocument, db2._factory)      def test__open_database_non_existent(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/non-existent.sqlite'          self.assertRaises(errors.DatabaseDoesNotExist, -                         sqlite_backend.SQLiteDatabase._open_database, path) +                         sqlcipher.SQLCipherDatabase._open_database, path)      def test__open_database_during_init(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/initialised.db' -        db = sqlite_backend.SQLitePartialExpandDatabase.__new__( -                                    sqlite_backend.SQLitePartialExpandDatabase) +        db = sqlcipher.SQLCipherDatabase.__new__( +                                    sqlcipher.SQLCipherDatabase)          db._db_handle = dbapi2.connect(path)  # db is there but not yet init-ed          self.addCleanup(db.close)          observed = [] -        class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): +        class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase):              WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1              @classmethod @@ -296,13 +297,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):          db2 = SQLiteDatabaseTesting._open_database(path)          self.addCleanup(db2.close) -        self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) +        self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)          self.assertEqual([None, -              sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], +              sqlcipher.SQLCipherDatabase._index_storage_value],                           observed)      def test__open_database_invalid(self): -        class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): +        class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase):              WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1          temp_dir = self.createTempDir(prefix='u1db-test-')          path1 = temp_dir + '/invalid1.db' @@ -318,47 +319,47 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):      def test_open_database_existing(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/existing.sqlite' -        sqlite_backend.SQLitePartialExpandDatabase(path) -        db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) -        self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) +        sqlcipher.SQLCipherDatabase(path) +        db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) +        self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)      def test_open_database_with_factory(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/existing.sqlite' -        sqlite_backend.SQLitePartialExpandDatabase(path) -        db2 = sqlite_backend.SQLiteDatabase.open_database( -            path, create=False, document_factory=TestAlternativeDocument) -        self.assertEqual(TestAlternativeDocument, db2._factory) +        sqlcipher.SQLCipherDatabase(path) +        db2 = sqlcipher.SQLCipherDatabase.open_database( +            path, create=False, document_factory=LeapDocument) +        self.assertEqual(LeapDocument, db2._factory)      def test_open_database_create(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/new.sqlite' -        sqlite_backend.SQLiteDatabase.open_database(path, create=True) -        db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) -        self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) +        sqlcipher.SQLCipherDatabase.open_database(path, create=True) +        db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False) +        self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)      def test_open_database_non_existent(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/non-existent.sqlite'          self.assertRaises(errors.DatabaseDoesNotExist, -                          sqlite_backend.SQLiteDatabase.open_database, path, +                          sqlcipher.SQLCipherDatabase.open_database, path,                            create=False)      def test_delete_database_existent(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/new.sqlite' -        db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) +        db = sqlcipher.SQLCipherDatabase.open_database(path, create=True)          db.close() -        sqlite_backend.SQLiteDatabase.delete_database(path) +        sqlcipher.SQLCipherDatabase.delete_database(path)          self.assertRaises(errors.DatabaseDoesNotExist, -                          sqlite_backend.SQLiteDatabase.open_database, path, +                          sqlcipher.SQLCipherDatabase.open_database, path,                            create=False)      def test_delete_database_nonexistent(self):          temp_dir = self.createTempDir(prefix='u1db-test-')          path = temp_dir + '/non-existent.sqlite'          self.assertRaises(errors.DatabaseDoesNotExist, -                          sqlite_backend.SQLiteDatabase.delete_database, path) +                          sqlcipher.SQLCipherDatabase.delete_database, path)      def test__get_indexed_fields(self):          self.db.create_index('idx1', 'a', 'b') @@ -492,3 +493,6 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):               'key3'],              ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) + +if __name__ == '__main__': +    unittest.main() | 
