diff options
Diffstat (limited to 'tests/sqlcipher/test_backend.py')
-rw-r--r-- | tests/sqlcipher/test_backend.py | 677 |
1 files changed, 677 insertions, 0 deletions
diff --git a/tests/sqlcipher/test_backend.py b/tests/sqlcipher/test_backend.py new file mode 100644 index 00000000..68f8f9f2 --- /dev/null +++ b/tests/sqlcipher/test_backend.py @@ -0,0 +1,677 @@ +# -*- coding: utf-8 -*- +# test_sqlcipher.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Test sqlcipher backend internals. +""" +import os +import pytest +import time +import threading +import sys + +# l2db stuff. +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db import query_parser + +# soledad stuff. +from leap.soledad.common.document import SoledadDocument +from leap.soledad.client._db.sqlite import SQLitePartialExpandDatabase +from leap.soledad.client._db.sqlcipher import SQLCipherDatabase +from leap.soledad.client._db.sqlcipher import SQLCipherOptions +from leap.soledad.client._db.sqlcipher import DatabaseIsNotEncrypted + +# u1db tests stuff. +from test_soledad import u1db_tests as tests +from test_soledad.u1db_tests import test_backends +from test_soledad.u1db_tests import test_open +from test_soledad.util import SQLCIPHER_SCENARIOS +from test_soledad.util import PASSWORD +from test_soledad.util import BaseSoledadTest + +from testscenarios import TestWithScenarios + +if sys.version_info[0] < 3: + from pysqlcipher import dbapi2 +else: + from pysqlcipher3 import dbapi2 + + +def sqlcipher_open(path, passphrase, create=True, document_factory=None): + return SQLCipherDatabase( + SQLCipherOptions(path, passphrase, create=create)) + + +# ----------------------------------------------------------------------------- +# The following tests come from `u1db.tests.test_common_backend`. +# ----------------------------------------------------------------------------- + +class TestSQLCipherBackendImpl(tests.TestCase): + + def test__allocate_doc_id(self): + db = sqlcipher_open(':memory:', PASSWORD) + doc_id1 = db._allocate_doc_id() + self.assertTrue(doc_id1.startswith('D-')) + self.assertEqual(34, len(doc_id1)) + int(doc_id1[len('D-'):], 16) + self.assertNotEqual(doc_id1, db._allocate_doc_id()) + db.close() + + +# ----------------------------------------------------------------------------- +# The following tests come from `u1db.tests.test_backends`. +# ----------------------------------------------------------------------------- + +class SQLCipherTests(TestWithScenarios, test_backends.AllDatabaseTests): + scenarios = SQLCIPHER_SCENARIOS + + +class SQLCipherDatabaseTests(TestWithScenarios, + test_backends.LocalDatabaseTests): + scenarios = SQLCIPHER_SCENARIOS + + +class SQLCipherValidateGenNTransIdTests( + TestWithScenarios, + test_backends.LocalDatabaseValidateGenNTransIdTests): + scenarios = SQLCIPHER_SCENARIOS + + +class SQLCipherValidateSourceGenTests( + TestWithScenarios, + test_backends.LocalDatabaseValidateSourceGenTests): + scenarios = SQLCIPHER_SCENARIOS + + +class SQLCipherWithConflictsTests( + TestWithScenarios, + test_backends.LocalDatabaseWithConflictsTests): + scenarios = SQLCIPHER_SCENARIOS + + +class SQLCipherIndexTests( + TestWithScenarios, test_backends.DatabaseIndexTests): + scenarios = SQLCIPHER_SCENARIOS + + +# ----------------------------------------------------------------------------- +# The following tests come from `u1db.tests.test_sqlite_backend`. +# ----------------------------------------------------------------------------- + +@pytest.mark.usefixtures('method_tmpdir') +class TestSQLCipherDatabase(tests.TestCase): + """ + Tests from u1db.tests.test_sqlite_backend.TestSQLiteDatabase. + """ + + def test_atomic_initialize(self): + # This test was modified to ensure that db2.close() is called within + # the thread that created the database. + dbname = os.path.join(self.tempdir, 'atomic.db') + + t2 = None # will be a thread + + class SQLCipherDatabaseTesting(SQLCipherDatabase): + _index_storage_value = "testing" + + def __init__(self, dbname, ntry): + self._try = ntry + self._is_initialized_invocations = 0 + SQLCipherDatabase.__init__( + self, + SQLCipherOptions(dbname, PASSWORD)) + + def _is_initialized(self, c): + res = \ + SQLCipherDatabase._is_initialized(self, c) + if self._try == 1: + self._is_initialized_invocations += 1 + if self._is_initialized_invocations == 2: + t2.start() + # hard to do better and have a generic test + time.sleep(0.05) + return res + + class SecondTry(threading.Thread): + + outcome2 = [] + + def run(self): + try: + db2 = SQLCipherDatabaseTesting(dbname, 2) + except Exception as e: + SecondTry.outcome2.append(e) + else: + SecondTry.outcome2.append(db2) + + t2 = SecondTry() + db1 = SQLCipherDatabaseTesting(dbname, 1) + t2.join() + + self.assertIsInstance(SecondTry.outcome2[0], SQLCipherDatabaseTesting) + self.assertTrue(db1._is_initialized(db1._get_sqlite_handle().cursor())) + db1.close() + + +@pytest.mark.usefixtures('method_tmpdir') +class TestSQLCipherPartialExpandDatabase(tests.TestCase): + """ + Tests from u1db.tests.test_sqlite_backend.TestSQLitePartialExpandDatabase. + """ + + # The following tests had to be cloned from u1db because they all + # instantiate the backend directly, so we need to change that in order to + # our backend be instantiated in place. + + def setUp(self): + self.db = sqlcipher_open(':memory:', PASSWORD) + + def tearDown(self): + self.db.close() + + def test_default_replica_uid(self): + self.assertIsNot(None, self.db._replica_uid) + self.assertEqual(32, len(self.db._replica_uid)) + int(self.db._replica_uid, 16) + + def test__parse_index(self): + g = self.db._parse_index_definition('fieldname') + self.assertIsInstance(g, query_parser.ExtractField) + self.assertEqual(['fieldname'], g.field) + + def test__update_indexes(self): + g = self.db._parse_index_definition('fieldname') + c = self.db._get_sqlite_handle().cursor() + self.db._update_indexes('doc-id', {'fieldname': 'val'}, + [('fieldname', g)], c) + c.execute('SELECT doc_id, field_name, value FROM document_fields') + self.assertEqual([('doc-id', 'fieldname', 'val')], + c.fetchall()) + + def test_create_database(self): + raw_db = self.db._get_sqlite_handle() + self.assertNotEqual(None, raw_db) + + def test__set_replica_uid(self): + # Start from scratch, so that replica_uid isn't set. + self.assertIsNot(None, self.db._real_replica_uid) + self.assertIsNot(None, self.db._replica_uid) + self.db._set_replica_uid('foo') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'") + self.assertEqual(('foo',), c.fetchone()) + self.assertEqual('foo', self.db._real_replica_uid) + self.assertEqual('foo', self.db._replica_uid) + self.db._close_sqlite_handle() + self.assertEqual('foo', self.db._replica_uid) + + def test__open_database(self): + # SQLCipherDatabase has no _open_database() method, so we just pass + # (and test for the same funcionality on test_open_database_existing() + # below). + pass + + def test__open_database_with_factory(self): + # SQLCipherDatabase has no _open_database() method. + pass + + def test__open_database_non_existent(self): + path = self.tempdir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlcipher_open, + path, PASSWORD, create=False) + + def test__open_database_during_init(self): + # The purpose of this test is to ensure that _open_database() parallel + # db initialization behaviour is correct. As SQLCipherDatabase does + # not have an _open_database() method, we just do not implement this + # test. + pass + + def test__open_database_invalid(self): + # This test was modified to ensure that an empty database file will + # raise a DatabaseIsNotEncrypted exception instead of a + # dbapi2.OperationalError exception. + path1 = self.tempdir + '/invalid1.db' + with open(path1, 'wb') as f: + f.write("") + self.assertRaises(DatabaseIsNotEncrypted, + sqlcipher_open, path1, + PASSWORD) + with open(path1, 'wb') as f: + f.write("invalid") + self.assertRaises(dbapi2.DatabaseError, + sqlcipher_open, path1, + PASSWORD) + + def test_open_database_existing(self): + # In the context of SQLCipherDatabase, where no _open_database() + # method exists and thus there's no call to _which_index_storage(), + # this test tests for the same functionality as + # test_open_database_create() below. So, we just pass. + pass + + def test_open_database_with_factory(self): + # SQLCipherDatabase's constructor has no factory parameter. + pass + + def test_open_database_create(self): + # SQLCipherDatabas has no open_database() method, so we just test for + # the actual database constructor effects. + path = self.tempdir + '/new.sqlite' + db1 = sqlcipher_open(path, PASSWORD, create=True) + db2 = sqlcipher_open(path, PASSWORD, create=False) + self.assertIsInstance(db2, SQLCipherDatabase) + db1.close() + db2.close() + + def test_create_database_initializes_schema(self): + # This test had to be cloned because our implementation of SQLCipher + # backend is referenced with an index_storage_value that includes the + # word "encrypted". See u1db's sqlite_backend and our + # sqlcipher_backend for reference. + raw_db = self.db._get_sqlite_handle() + c = raw_db.cursor() + c.execute("SELECT * FROM u1db_config") + config = dict([(r[0], r[1]) for r in c.fetchall()]) + replica_uid = self.db._replica_uid + self.assertEqual({'sql_schema': '0', 'replica_uid': replica_uid, + 'index_storage': 'expand referenced encrypted'}, + config) + + def test_store_syncable(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + # assert that docs are syncable by default + self.assertEqual(True, doc.syncable) + # assert that we can store syncable = False + doc.syncable = False + self.db.put_doc(doc) + self.assertEqual(False, self.db.get_doc(doc.doc_id).syncable) + # assert that we can store syncable = True + doc.syncable = True + self.db.put_doc(doc) + self.assertEqual(True, self.db.get_doc(doc.doc_id).syncable) + + def test__close_sqlite_handle(self): + raw_db = self.db._get_sqlite_handle() + self.db._close_sqlite_handle() + self.assertRaises(dbapi2.ProgrammingError, + raw_db.cursor) + + def test__get_generation(self): + self.assertEqual(0, self.db._get_generation()) + + def test__get_generation_info(self): + self.assertEqual((0, ''), self.db._get_generation_info()) + + def test_create_index(self): + self.db.create_index('test-idx', "key") + self.assertEqual([('test-idx', ["key"])], self.db.list_indexes()) + + def test_create_index_multiple_fields(self): + self.db.create_index('test-idx', "key", "key2") + self.assertEqual([('test-idx', ["key", "key2"])], + self.db.list_indexes()) + + def test__get_index_definition(self): + self.db.create_index('test-idx', "key", "key2") + # TODO: How would you test that an index is getting used for an SQL + # request? + self.assertEqual(["key", "key2"], + self.db._get_index_definition('test-idx')) + + def test_list_index_mixed(self): + # Make sure that we properly order the output + c = self.db._get_sqlite_handle().cursor() + # We intentionally insert the data in weird ordering, to make sure the + # query still gets it back correctly. + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + [('idx-1', 0, 'key10'), + ('idx-2', 2, 'key22'), + ('idx-1', 1, 'key11'), + ('idx-2', 0, 'key20'), + ('idx-2', 1, 'key21')]) + self.assertEqual([('idx-1', ['key10', 'key11']), + ('idx-2', ['key20', 'key21', 'key22'])], + self.db.list_indexes()) + + def test_no_indexes_no_document_fields(self): + self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + + def test_create_extracts_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + self.db.create_index('test', 'key1', 'key2') + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual(sorted( + [(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "val2"), + (doc2.doc_id, "key1", "valx"), + (doc2.doc_id, "key2", "valy"), ]), sorted(c.fetchall())) + + def test_put_updates_fields(self): + self.db.create_index('test', 'key1', 'key2') + doc1 = self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + doc1.content = {"key1": "val1", "key2": "valy"} + self.db.put_doc(doc1) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "valy"), ], c.fetchall()) + + def test_put_updates_nested_fields(self): + self.db.create_index('test', 'key', 'sub.doc') + doc1 = self.db.create_doc_from_json(tests.nested_doc) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key", "value"), + (doc1.doc_id, "sub.doc", "underneath"), ], + c.fetchall()) + + def test__ensure_schema_rollback(self): + path = self.tempdir + '/rollback.db' + + class SQLitePartialExpandDbTesting(SQLCipherDatabase): + + def _set_replica_uid_in_transaction(self, uid): + super(SQLitePartialExpandDbTesting, + self)._set_replica_uid_in_transaction(uid) + if fail: + raise Exception() + + db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + fail = True + self.assertRaises(Exception, db._ensure_schema) + fail = False + db._initialize(db._db_handle.cursor()) + + def test_open_database_non_existent(self): + path = self.tempdir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlcipher_open, path, "123", + create=False) + + def test_delete_database_existent(self): + path = self.tempdir + '/new.sqlite' + db = sqlcipher_open(path, "123", create=True) + db.close() + SQLCipherDatabase.delete_database(path) + self.assertRaises(errors.DatabaseDoesNotExist, + sqlcipher_open, path, "123", + create=False) + + def test_delete_database_nonexistent(self): + path = self.tempdir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + SQLCipherDatabase.delete_database, path) + + def test__get_indexed_fields(self): + self.db.create_index('idx1', 'a', 'b') + self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields()) + self.db.create_index('idx2', 'b', 'c') + self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields()) + + def test_indexed_fields_expanded(self): + self.db.create_index('idx1', 'key1') + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def test_create_index_updates_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.db.create_index('idx1', 'key1') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def assertFormatQueryEquals(self, exp_statement, exp_args, definition, + values): + statement, args = self.db._format_query(definition, values) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_query(self): + self.assertFormatQueryEquals( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON " + "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name " + "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content " + "ORDER BY d0.value;", ["key1", "a"], + ["key1"], ["a"]) + + def test__format_query2(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b", "key3", "c"], + ["key1", "key2", "key3"], ["a", "b", "c"]) + + def test__format_query_wildcard(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? ' + 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content ' + 'ORDER BY d0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"], + ["a", "b*", "*"]) + + def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition, + start_value, end_value): + statement, args = self.db._format_range_query( + definition, start_value, end_value) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_range_query(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q', + 'key3', 'r'], + ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"]) + + def test__format_range_query_no_start(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], None, ["a", "b", "c"]) + + def test__format_range_query_no_end(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], ["a", "b", "c"], None) + + def test__format_range_query_wildcard(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? ' + 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? ' + 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id ' + 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, ' + 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*', + 'key3'], + ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) + + +# ----------------------------------------------------------------------------- +# The following tests come from `u1db.tests.test_open`. +# ----------------------------------------------------------------------------- + + +class SQLCipherOpen(test_open.TestU1DBOpen): + + def test_open_no_create(self): + self.assertRaises(errors.DatabaseDoesNotExist, + sqlcipher_open, self.db_path, + PASSWORD, + create=False) + self.assertFalse(os.path.exists(self.db_path)) + + def test_open_create(self): + db = sqlcipher_open(self.db_path, PASSWORD, create=True) + self.addCleanup(db.close) + self.assertTrue(os.path.exists(self.db_path)) + self.assertIsInstance(db, SQLCipherDatabase) + + def test_open_with_factory(self): + db = sqlcipher_open(self.db_path, PASSWORD, create=True, + document_factory=SoledadDocument) + self.addCleanup(db.close) + doc = db.create_doc({}) + self.assertTrue(isinstance(doc, SoledadDocument)) + + def test_open_existing(self): + db = sqlcipher_open(self.db_path, PASSWORD) + self.addCleanup(db.close) + doc = db.create_doc_from_json(tests.simple_doc) + # Even though create=True, we shouldn't wipe the db + db2 = sqlcipher_open(self.db_path, PASSWORD, create=True) + self.addCleanup(db2.close) + doc2 = db2.get_doc(doc.doc_id) + self.assertEqual(doc, doc2) + + def test_open_existing_no_create(self): + db = sqlcipher_open(self.db_path, PASSWORD) + self.addCleanup(db.close) + db2 = sqlcipher_open(self.db_path, PASSWORD, create=False) + self.addCleanup(db2.close) + self.assertIsInstance(db2, SQLCipherDatabase) + + +# ----------------------------------------------------------------------------- +# Tests for actual encryption of the database +# ----------------------------------------------------------------------------- + +class SQLCipherEncryptionTests(BaseSoledadTest): + + """ + Tests to guarantee SQLCipher is indeed encrypting data when storing. + """ + + def _delete_dbfiles(self): + for dbfile in [self.DB_FILE]: + if os.path.exists(dbfile): + os.unlink(dbfile) + + def setUp(self): + BaseSoledadTest.setUp(self) + self.DB_FILE = os.path.join(self.tempdir, 'test.db') + self._delete_dbfiles() + + def tearDown(self): + self._delete_dbfiles() + BaseSoledadTest.tearDown(self) + + def test_try_to_open_encrypted_db_with_sqlite_backend(self): + """ + SQLite backend should not succeed to open SQLCipher databases. + """ + db = sqlcipher_open(self.DB_FILE, PASSWORD) + doc = db.create_doc_from_json(tests.simple_doc) + db.close() + try: + # trying to open an encrypted database with the regular u1db + # backend should raise a DatabaseError exception. + SQLitePartialExpandDatabase(self.DB_FILE, + document_factory=SoledadDocument) + raise DatabaseIsNotEncrypted() + except dbapi2.DatabaseError: + # at this point we know that the regular U1DB sqlcipher backend + # did not succeed on opening the database, so it was indeed + # encrypted. + db = sqlcipher_open(self.DB_FILE, PASSWORD) + doc = db.get_doc(doc.doc_id) + self.assertEqual(tests.simple_doc, doc.get_json(), + 'decrypted content mismatch') + db.close() + + def test_try_to_open_raw_db_with_sqlcipher_backend(self): + """ + SQLCipher backend should not succeed to open unencrypted databases. + """ + db = SQLitePartialExpandDatabase(self.DB_FILE, + document_factory=SoledadDocument) + db.create_doc_from_json(tests.simple_doc) + db.close() + try: + # trying to open the a non-encrypted database with sqlcipher + # backend should raise a DatabaseIsNotEncrypted exception. + db = sqlcipher_open(self.DB_FILE, PASSWORD) + db.close() + raise dbapi2.DatabaseError( + "SQLCipher backend should not be able to open non-encrypted " + "dbs.") + except DatabaseIsNotEncrypted: + pass |