summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--openstack.py114
-rw-r--r--tests/__init__.py46
2 files changed, 151 insertions, 9 deletions
diff --git a/openstack.py b/openstack.py
index 25f1a404..22a2d067 100644
--- a/openstack.py
+++ b/openstack.py
@@ -1,5 +1,6 @@
-from u1db.backends import CommonBackend
from leap import *
+from u1db import errors
+from u1db.backends import CommonBackend
from u1db.remote.http_target import HTTPSyncTarget
from swiftclient import client
@@ -96,21 +97,26 @@ class OpenStackDatabase(CommonBackend):
raise NotImplementedError(self.close)
def _get_replica_gen_and_trans_id(self, other_replica_uid):
- raise NotImplementedError(self._get_replica_gen_and_trans_id)
+ self._update_u1db_data()
+ return self._sync_log.get_replica_gen_and_trans_id(other_replica_uid)
def _set_replica_gen_and_trans_id(self, other_replica_uid,
other_generation, other_transaction_id):
- raise NotImplementedError(self._set_replica_gen_and_trans_id)
+ self._update_u1db_data()
+ return self._sync_log.set_replica_gen_and_trans_id(other_replica_uid,
+ other_generation, other_transaction_id)
#-------------------------------------------------------------------------
# implemented methods from CommonBackend
#-------------------------------------------------------------------------
def _get_generation(self):
- raise NotImplementedError(self._get_generation)
+ self._update_u1db_data()
+ return self._transaction_log.get_generation()
def _get_generation_info(self):
- raise NotImplementedError(self._get_generation_info)
+ self._update_u1db_data()
+ return self._transaction_log.get_generation_info()
def _get_doc(self, doc_id, check_for_conflicts=False):
"""Get just the document content, without fancy handling."""
@@ -119,15 +125,16 @@ class OpenStackDatabase(CommonBackend):
def _has_conflicts(self, doc_id):
raise NotImplementedError(self._has_conflicts)
- def _get_transaction_log(self):
- raise NotImplementedError(self._get_transaction_log)
-
def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content):
raise NotImplementedError(self._put_and_update_indexes)
def _get_trans_id_for_gen(self, generation):
- raise NotImplementedError(self._get_trans_id_for_gen)
+ self._update_u1db_data()
+ trans_id = self._transaction_log.get_trans_id_for_gen(generation)
+ if trans_id is None:
+ raise errors.InvalidGeneration
+ return trans_id
#-------------------------------------------------------------------------
# OpenStack specific methods
@@ -143,6 +150,11 @@ class OpenStackDatabase(CommonBackend):
self._url, self._auth_token = self._connection.get_auth()
return self._url, self.auth_token
+ def _update_u1db_data(self):
+ data = self.get_doc('u1db_data').content
+ self._transaction_log = data['transaction_log']
+ self._sync_log = data['sync_log']
+
class OpenStackSyncTarget(HTTPSyncTarget):
@@ -152,3 +164,87 @@ class OpenStackSyncTarget(HTTPSyncTarget):
def record_sync_info(self, source_replica_uid, source_replica_generation,
source_replica_transaction_id):
raise NotImplementedError(self.record_sync_info)
+
+
+class SimpleLog(object):
+ def __init__(self, log=None):
+ self._log = []
+ if log:
+ self._log = log
+
+ def append(self, msg):
+ self._log.append(msg)
+
+ def reduce(self, func, initializer=None):
+ return reduce(func, self._log, initializer)
+
+ def map(self, func):
+ return map(func, self._log)
+
+
+class TransactionLog(SimpleLog):
+ """
+ A list of (generation, doc_id, transaction_id) tuples.
+ """
+
+ def get_generation(self):
+ """
+ Return the current generation.
+ """
+ gens = self.map(lambda x: x[0])
+ if not gens:
+ return 0
+ return max(gens)
+
+ def get_generation_info(self):
+ """
+ Return the current generation and transaction id.
+ """
+ if not self._log:
+ return(0, '')
+ info = self.map(lambda x: (x[0], x[2]))
+ return reduce(lambda x, y: x if (x[0] > y[0]) else y, info)
+
+ def get_trans_id_for_gen(self, gen):
+ """
+ Get the transaction id corresponding to a particular generation.
+ """
+ log = self.reduce(lambda x, y: y if y[0] == gen else x)
+ if log is None:
+ return None
+ return log[2]
+
+class SyncLog(SimpleLog):
+ """
+ A list of (replica_id, generation, transaction_id) tuples.
+ """
+
+ def find_by_replica_uid(self, replica_uid):
+ if not self._log:
+ return ()
+ return self.reduce(lambda x, y: y if y[0] == replica_uid else x)
+
+ def get_replica_gen_and_trans_id(self, other_replica_uid):
+ """
+ Return the last known generation and transaction id for the other db
+ replica.
+ """
+ info = self.find_by_replica_uid(other_replica_uid)
+ if not info:
+ return (0, '')
+ return (info[1], info[2])
+
+ def set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ """
+ Set the last-known generation and transaction id for the other
+ database replica.
+ """
+ old_log = self._log
+ self._log = []
+ for log in old_log:
+ if log[0] != other_replica_uid:
+ self.append(log)
+ self.append((other_replica_uid, other_generation,
+ other_transaction_id))
+
diff --git a/tests/__init__.py b/tests/__init__.py
index 0d7ae2b4..50c99dd4 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -8,6 +8,7 @@ import os
import u1db
from soledad import leap, GPGWrapper
+from soledad.openstack import SimpleLog, TransactionLog, SyncLog
class EncryptedSyncTestCase(unittest.TestCase):
@@ -43,6 +44,51 @@ class EncryptedSyncTestCase(unittest.TestCase):
self.assertEqual(res1, res2, 'incorrect document encryption')
+class LogTestCase(unittest.TestCase):
+
+
+ def test_transaction_log(self):
+ data = [
+ (2, "doc_3", "tran_3"),
+ (3, "doc_2", "tran_2"),
+ (1, "doc_1", "tran_1")
+ ]
+ log = TransactionLog(data)
+ self.assertEqual(log.get_generation(), 3, 'error getting generation')
+ self.assertEqual(log.get_generation_info(), (3, 'tran_2'),
+ 'error getting generation info')
+ self.assertEqual(log.get_trans_id_for_gen(1), 'tran_1',
+ 'error getting trans_id for gen')
+ self.assertEqual(log.get_trans_id_for_gen(2), 'tran_3',
+ 'error getting trans_id for gen')
+ self.assertEqual(log.get_trans_id_for_gen(3), 'tran_2',
+ 'error getting trans_id for gen')
+
+ def test_sync_log(self):
+ data = [
+ ("replica_3", 3, "tran_3"),
+ ("replica_2", 2, "tran_2"),
+ ("replica_1", 1, "tran_1")
+ ]
+ log = SyncLog(data)
+ # test getting
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'),
+ (3, 'tran_3'), 'error getting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'),
+ (2, 'tran_2'), 'error getting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'),
+ (1, 'tran_1'), 'error getting replica gen and trans id')
+ # test setting
+ log.set_replica_gen_and_trans_id('replica_1', 2, 'tran_12')
+ self.assertEqual(len(log._log), 3, 'error in log size after setting')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'),
+ (2, 'tran_12'), 'error setting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'),
+ (2, 'tran_2'), 'error setting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'),
+ (3, 'tran_3'), 'error setting replica gen and trans id')
+
+
# Key material for testing
KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF"
PUBLIC_KEY = """