summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordrebs <drebs@leap.se>2012-12-11 15:03:12 -0200
committerdrebs <drebs@leap.se>2012-12-11 15:03:12 -0200
commit435728098255a8bdc89a9191d2b774a911fa59a0 (patch)
tree9b36f3d623f1b26e46c0c004372e76ba8ba97233
parentd94469a11b3dbfd4ed73c38be50434095683c8bc (diff)
Fix SQLCipherDatabase and add tests.
-rw-r--r--backends/sqlcipher.py5
-rw-r--r--tests/__init__.py55
-rw-r--r--tests/test_sqlcipher.py84
3 files changed, 102 insertions, 42 deletions
diff --git a/backends/sqlcipher.py b/backends/sqlcipher.py
index fcdab251..301d4a7f 100644
--- a/backends/sqlcipher.py
+++ b/backends/sqlcipher.py
@@ -60,7 +60,8 @@ def open(path, create, document_factory=None, password=None):
class SQLCipherDatabase(SQLitePartialExpandDatabase):
"""A U1DB implementation that uses SQLCipher as its persistence layer."""
- _sqlite_registry = {}
+ _index_storage_value = 'expand referenced encrypted'
+
@classmethod
def set_pragma_key(cls, db_handle, key):
@@ -113,7 +114,7 @@ class SQLCipherDatabase(SQLitePartialExpandDatabase):
raise
if backend_cls is None:
# default is SQLCipherPartialExpandDatabase
- backend_cls = SQLCipherPartialExpandDatabase
+ backend_cls = SQLCipherDatabase
return backend_cls(sqlite_file, document_factory=document_factory,
password=password)
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29b..7918b265 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,55 @@
+import unittest2 as unittest
+import tempfile
+import shutil
+
+class TestCase(unittest.TestCase):
+
+ def createTempDir(self, prefix='u1db-tmp-'):
+ """Create a temporary directory to do some work in.
+
+ This directory will be scheduled for cleanup when the test ends.
+ """
+ tempdir = tempfile.mkdtemp(prefix=prefix)
+ self.addCleanup(shutil.rmtree, tempdir)
+ return tempdir
+
+ def make_document(self, doc_id, doc_rev, content, has_conflicts=False):
+ return self.make_document_for_test(
+ self, doc_id, doc_rev, content, has_conflicts)
+
+ def make_document_for_test(self, test, doc_id, doc_rev, content,
+ has_conflicts):
+ return make_document_for_test(
+ test, doc_id, doc_rev, content, has_conflicts)
+
+ def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id))
+
+ def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content,
+ has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True))
+
+ def assertGetDocConflicts(self, db, doc_id, conflicts):
+ """Assert what conflicts are stored for a given doc_id.
+
+ :param conflicts: A list of (doc_rev, content) pairs.
+ The first item must match the first item returned from the
+ database, however the rest can be returned in any order.
+ """
+ if conflicts:
+ conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring)
+ else cont)) for (rev, cont) in conflicts]
+ conflicts = conflicts[:1] + sorted(conflicts[1:])
+ actual = db.get_doc_conflicts(doc_id)
+ if actual:
+ actual = [(doc.rev, (json.loads(doc.get_json())
+ if doc.get_json() is not None else None)) for doc in actual]
+ actual = actual[:1] + sorted(actual[1:])
+ self.assertEqual(conflicts, actual)
+
diff --git a/tests/test_sqlcipher.py b/tests/test_sqlcipher.py
index 46f27f73..e35a6d90 100644
--- a/tests/test_sqlcipher.py
+++ b/tests/test_sqlcipher.py
@@ -19,16 +19,17 @@
import os
import time
import threading
+import unittest2 as unittest
from sqlite3 import dbapi2
from u1db import (
errors,
- tests,
query_parser,
)
-from u1db.backends import sqlite_backend
-from u1db.tests.test_backends import TestAlternativeDocument
+from soledad.backends import sqlcipher
+from soledad.backends.leap import LeapDocument
+from soledad import tests
simple_doc = '{"key": "value"}'
@@ -43,7 +44,7 @@ class TestSQLiteDatabase(tests.TestCase):
t2 = None # will be a thread
- class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase):
_index_storage_value = "testing"
def __init__(self, dbname, ntry):
@@ -84,7 +85,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
def setUp(self):
super(TestSQLitePartialExpandDatabase, self).setUp()
- self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db = sqlcipher.SQLCipherDatabase(':memory:')
self.db._set_replica_uid('test')
def test_create_database(self):
@@ -92,7 +93,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
self.assertNotEqual(None, raw_db)
def test_default_replica_uid(self):
- self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db = sqlcipher.SQLCipherDatabase(':memory:')
self.assertIsNot(None, self.db._replica_uid)
self.assertEqual(32, len(self.db._replica_uid))
int(self.db._replica_uid, 16)
@@ -109,7 +110,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
c.execute("SELECT * FROM u1db_config")
config = dict([(r[0], r[1]) for r in c.fetchall()])
self.assertEqual({'sql_schema': '0', 'replica_uid': 'test',
- 'index_storage': 'expand referenced'}, config)
+ 'index_storage': 'expand referenced encrypted'}, config)
# These tables must exist, though we don't care what is in them yet
c.execute("SELECT * FROM transaction_log")
@@ -120,13 +121,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
c.execute("SELECT * FROM index_definitions")
def test__parse_index(self):
- self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db = sqlcipher.SQLCipherDatabase(':memory:')
g = self.db._parse_index_definition('fieldname')
self.assertIsInstance(g, query_parser.ExtractField)
self.assertEqual(['fieldname'], g.field)
def test__update_indexes(self):
- self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db = sqlcipher.SQLCipherDatabase(':memory:')
g = self.db._parse_index_definition('fieldname')
c = self.db._get_sqlite_handle().cursor()
self.db._update_indexes('doc-id', {'fieldname': 'val'},
@@ -137,7 +138,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
def test__set_replica_uid(self):
# Start from scratch, so that replica_uid isn't set.
- self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db = sqlcipher.SQLCipherDatabase(':memory:')
self.assertIsNot(None, self.db._real_replica_uid)
self.assertIsNot(None, self.db._replica_uid)
self.db._set_replica_uid('foo')
@@ -239,7 +240,7 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
path = temp_dir + '/rollback.db'
class SQLitePartialExpandDbTesting(
- sqlite_backend.SQLitePartialExpandDatabase):
+ sqlcipher.SQLCipherDatabase):
def _set_replica_uid_in_transaction(self, uid):
super(SQLitePartialExpandDbTesting,
@@ -257,34 +258,34 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
def test__open_database(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/test.sqlite'
- sqlite_backend.SQLitePartialExpandDatabase(path)
- db2 = sqlite_backend.SQLiteDatabase._open_database(path)
- self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ sqlcipher.SQLCipherDatabase(path)
+ db2 = sqlcipher.SQLCipherDatabase._open_database(path)
+ self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)
def test__open_database_with_factory(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/test.sqlite'
- sqlite_backend.SQLitePartialExpandDatabase(path)
- db2 = sqlite_backend.SQLiteDatabase._open_database(
- path, document_factory=TestAlternativeDocument)
- self.assertEqual(TestAlternativeDocument, db2._factory)
+ sqlcipher.SQLCipherDatabase(path)
+ db2 = sqlcipher.SQLCipherDatabase._open_database(
+ path, document_factory=LeapDocument)
+ self.assertEqual(LeapDocument, db2._factory)
def test__open_database_non_existent(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/non-existent.sqlite'
self.assertRaises(errors.DatabaseDoesNotExist,
- sqlite_backend.SQLiteDatabase._open_database, path)
+ sqlcipher.SQLCipherDatabase._open_database, path)
def test__open_database_during_init(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/initialised.db'
- db = sqlite_backend.SQLitePartialExpandDatabase.__new__(
- sqlite_backend.SQLitePartialExpandDatabase)
+ db = sqlcipher.SQLCipherDatabase.__new__(
+ sqlcipher.SQLCipherDatabase)
db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
self.addCleanup(db.close)
observed = []
- class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase):
WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
@classmethod
@@ -296,13 +297,13 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
db2 = SQLiteDatabaseTesting._open_database(path)
self.addCleanup(db2.close)
- self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)
self.assertEqual([None,
- sqlite_backend.SQLitePartialExpandDatabase._index_storage_value],
+ sqlcipher.SQLCipherDatabase._index_storage_value],
observed)
def test__open_database_invalid(self):
- class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ class SQLiteDatabaseTesting(sqlcipher.SQLCipherDatabase):
WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
temp_dir = self.createTempDir(prefix='u1db-test-')
path1 = temp_dir + '/invalid1.db'
@@ -318,47 +319,47 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
def test_open_database_existing(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/existing.sqlite'
- sqlite_backend.SQLitePartialExpandDatabase(path)
- db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
- self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ sqlcipher.SQLCipherDatabase(path)
+ db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)
def test_open_database_with_factory(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/existing.sqlite'
- sqlite_backend.SQLitePartialExpandDatabase(path)
- db2 = sqlite_backend.SQLiteDatabase.open_database(
- path, create=False, document_factory=TestAlternativeDocument)
- self.assertEqual(TestAlternativeDocument, db2._factory)
+ sqlcipher.SQLCipherDatabase(path)
+ db2 = sqlcipher.SQLCipherDatabase.open_database(
+ path, create=False, document_factory=LeapDocument)
+ self.assertEqual(LeapDocument, db2._factory)
def test_open_database_create(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/new.sqlite'
- sqlite_backend.SQLiteDatabase.open_database(path, create=True)
- db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
- self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ sqlcipher.SQLCipherDatabase.open_database(path, create=True)
+ db2 = sqlcipher.SQLCipherDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlcipher.SQLCipherDatabase)
def test_open_database_non_existent(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/non-existent.sqlite'
self.assertRaises(errors.DatabaseDoesNotExist,
- sqlite_backend.SQLiteDatabase.open_database, path,
+ sqlcipher.SQLCipherDatabase.open_database, path,
create=False)
def test_delete_database_existent(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/new.sqlite'
- db = sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db = sqlcipher.SQLCipherDatabase.open_database(path, create=True)
db.close()
- sqlite_backend.SQLiteDatabase.delete_database(path)
+ sqlcipher.SQLCipherDatabase.delete_database(path)
self.assertRaises(errors.DatabaseDoesNotExist,
- sqlite_backend.SQLiteDatabase.open_database, path,
+ sqlcipher.SQLCipherDatabase.open_database, path,
create=False)
def test_delete_database_nonexistent(self):
temp_dir = self.createTempDir(prefix='u1db-test-')
path = temp_dir + '/non-existent.sqlite'
self.assertRaises(errors.DatabaseDoesNotExist,
- sqlite_backend.SQLiteDatabase.delete_database, path)
+ sqlcipher.SQLCipherDatabase.delete_database, path)
def test__get_indexed_fields(self):
self.db.create_index('idx1', 'a', 'b')
@@ -492,3 +493,6 @@ class TestSQLitePartialExpandDatabase(tests.TestCase):
'key3'],
["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"])
+
+if __name__ == '__main__':
+ unittest.main()