summaryrefslogtreecommitdiff
path: root/client/src/leap/soledad/client/adbapi.py
diff options
context:
space:
mode:
Diffstat (limited to 'client/src/leap/soledad/client/adbapi.py')
-rw-r--r--client/src/leap/soledad/client/adbapi.py146
1 files changed, 110 insertions, 36 deletions
diff --git a/client/src/leap/soledad/client/adbapi.py b/client/src/leap/soledad/client/adbapi.py
index 730999a3..3b15509b 100644
--- a/client/src/leap/soledad/client/adbapi.py
+++ b/client/src/leap/soledad/client/adbapi.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# sqlcipher.py
+# adbapi.py
# Copyright (C) 2013, 2014 LEAP
#
# This program is free software: you can redistribute it and/or modify
@@ -17,61 +17,135 @@
"""
An asyncrhonous interface to soledad using sqlcipher backend.
It uses twisted.enterprise.adbapi.
-
"""
+import re
import os
import sys
+from functools import partial
+
+import u1db
+from u1db.backends import sqlite_backend
+
from twisted.enterprise import adbapi
from twisted.python import log
+from leap.soledad.client.sqlcipher import set_init_pragmas
+
+
DEBUG_SQL = os.environ.get("LEAP_DEBUG_SQL")
if DEBUG_SQL:
log.startLogging(sys.stdout)
-def getConnectionPool(db=None, key=None):
- return SQLCipherConnectionPool(
- "pysqlcipher.dbapi2", database=db, key=key, check_same_thread=False)
+def getConnectionPool(opts, openfun=None, driver="pysqlcipher"):
+ if openfun is None and driver == "pysqlcipher":
+ openfun = partial(set_init_pragmas, opts=opts)
+ return U1DBConnectionPool(
+ "%s.dbapi2" % driver, database=opts.path,
+ check_same_thread=False, cp_openfun=openfun)
-class SQLCipherConnectionPool(adbapi.ConnectionPool):
+# XXX work in progress --------------------------------------------
- key = None
- def connect(self):
- """
- Return a database connection when one becomes available.
+class U1DBSqliteWrapper(sqlite_backend.SQLitePartialExpandDatabase):
+ """
+ A very simple wrapper around sqlcipher backend.
- This method blocks and should be run in a thread from the internal
- threadpool. Don't call this method directly from non-threaded code.
- Using this method outside the external threadpool may exceed the
- maximum number of connections in the pool.
+ Instead of initializing the database on the fly, it just uses an existing
+ connection that is passed to it in the initializer.
+ """
- :return: a database connection from the pool.
- """
- self.noisy = DEBUG_SQL
+ def __init__(self, conn):
+ self._db_handle = conn
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = u1db.Document
- tid = self.threadID()
- conn = self.connections.get(tid)
- if self.key is None:
- self.key = self.connkw.pop('key', None)
+class U1DBConnection(adbapi.Connection):
- if conn is None:
- if self.noisy:
- log.msg('adbapi connecting: %s %s%s' % (self.dbapiName,
- self.connargs or '',
- self.connkw or ''))
- conn = self.dbapi.connect(*self.connargs, **self.connkw)
+ u1db_wrapper = U1DBSqliteWrapper
+
+ def __init__(self, pool, init_u1db=False):
+ self.init_u1db = init_u1db
+ adbapi.Connection.__init__(self, pool)
+
+ def reconnect(self):
+ if self._connection is not None:
+ self._pool.disconnect(self._connection)
+ self._connection = self._pool.connect()
+
+ if self.init_u1db:
+ self._u1db = self.u1db_wrapper(self._connection)
+
+ def __getattr__(self, name):
+ if name.startswith('u1db_'):
+ meth = re.sub('^u1db_', '', name)
+ return getattr(self._u1db, meth)
+ else:
+ return getattr(self._connection, name)
- # XXX we should hook here all OUR SOLEDAD pragmas -----
- conn.cursor().execute("PRAGMA key=%s" % self.key)
- conn.commit()
- # -----------------------------------------------------
- # XXX profit of openfun isntead???
- if self.openfun is not None:
- self.openfun(conn)
- self.connections[tid] = conn
- return conn
+class U1DBTransaction(adbapi.Transaction):
+
+ def __getattr__(self, name):
+ if name.startswith('u1db_'):
+ meth = re.sub('^u1db_', '', name)
+ return getattr(self._connection._u1db, meth)
+ else:
+ return getattr(self._cursor, name)
+
+
+class U1DBConnectionPool(adbapi.ConnectionPool):
+
+ connectionFactory = U1DBConnection
+ transactionFactory = U1DBTransaction
+
+ def __init__(self, *args, **kwargs):
+ adbapi.ConnectionPool.__init__(self, *args, **kwargs)
+ # all u1db connections, hashed by thread-id
+ self.u1dbconnections = {}
+
+ def runU1DBQuery(self, meth, *args, **kw):
+ meth = "u1db_%s" % meth
+ return self.runInteraction(self._runU1DBQuery, meth, *args, **kw)
+
+ def _runU1DBQuery(self, trans, meth, *args, **kw):
+ meth = getattr(trans, meth)
+ return meth(*args, **kw)
+
+ def _runInteraction(self, interaction, *args, **kw):
+ tid = self.threadID()
+ u1db = self.u1dbconnections.get(tid)
+ conn = self.connectionFactory(self, init_u1db=not bool(u1db))
+
+ if u1db is None:
+ self.u1dbconnections[tid] = conn._u1db
+ else:
+ conn._u1db = u1db
+
+ trans = self.transactionFactory(self, conn)
+ try:
+ result = interaction(trans, *args, **kw)
+ trans.close()
+ conn.commit()
+ return result
+ except:
+ excType, excValue, excTraceback = sys.exc_info()
+ try:
+ conn.rollback()
+ except:
+ log.err(None, "Rollback failed")
+ raise excType, excValue, excTraceback
+
+ def finalClose(self):
+ self.shutdownID = None
+ self.threadpool.stop()
+ self.running = False
+ for conn in self.connections.values():
+ self._close(conn)
+ for u1db in self.u1dbconnections.values():
+ self._close(u1db)
+ self.connections.clear()