diff options
Diffstat (limited to 'client/src/leap/soledad/client/adbapi.py')
-rw-r--r-- | client/src/leap/soledad/client/adbapi.py | 146 |
1 files changed, 110 insertions, 36 deletions
diff --git a/client/src/leap/soledad/client/adbapi.py b/client/src/leap/soledad/client/adbapi.py index 730999a3..3b15509b 100644 --- a/client/src/leap/soledad/client/adbapi.py +++ b/client/src/leap/soledad/client/adbapi.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# sqlcipher.py +# adbapi.py # Copyright (C) 2013, 2014 LEAP # # This program is free software: you can redistribute it and/or modify @@ -17,61 +17,135 @@ """ An asyncrhonous interface to soledad using sqlcipher backend. It uses twisted.enterprise.adbapi. - """ +import re import os import sys +from functools import partial + +import u1db +from u1db.backends import sqlite_backend + from twisted.enterprise import adbapi from twisted.python import log +from leap.soledad.client.sqlcipher import set_init_pragmas + + DEBUG_SQL = os.environ.get("LEAP_DEBUG_SQL") if DEBUG_SQL: log.startLogging(sys.stdout) -def getConnectionPool(db=None, key=None): - return SQLCipherConnectionPool( - "pysqlcipher.dbapi2", database=db, key=key, check_same_thread=False) +def getConnectionPool(opts, openfun=None, driver="pysqlcipher"): + if openfun is None and driver == "pysqlcipher": + openfun = partial(set_init_pragmas, opts=opts) + return U1DBConnectionPool( + "%s.dbapi2" % driver, database=opts.path, + check_same_thread=False, cp_openfun=openfun) -class SQLCipherConnectionPool(adbapi.ConnectionPool): +# XXX work in progress -------------------------------------------- - key = None - def connect(self): - """ - Return a database connection when one becomes available. +class U1DBSqliteWrapper(sqlite_backend.SQLitePartialExpandDatabase): + """ + A very simple wrapper around sqlcipher backend. - This method blocks and should be run in a thread from the internal - threadpool. Don't call this method directly from non-threaded code. - Using this method outside the external threadpool may exceed the - maximum number of connections in the pool. + Instead of initializing the database on the fly, it just uses an existing + connection that is passed to it in the initializer. + """ - :return: a database connection from the pool. - """ - self.noisy = DEBUG_SQL + def __init__(self, conn): + self._db_handle = conn + self._real_replica_uid = None + self._ensure_schema() + self._factory = u1db.Document - tid = self.threadID() - conn = self.connections.get(tid) - if self.key is None: - self.key = self.connkw.pop('key', None) +class U1DBConnection(adbapi.Connection): - if conn is None: - if self.noisy: - log.msg('adbapi connecting: %s %s%s' % (self.dbapiName, - self.connargs or '', - self.connkw or '')) - conn = self.dbapi.connect(*self.connargs, **self.connkw) + u1db_wrapper = U1DBSqliteWrapper + + def __init__(self, pool, init_u1db=False): + self.init_u1db = init_u1db + adbapi.Connection.__init__(self, pool) + + def reconnect(self): + if self._connection is not None: + self._pool.disconnect(self._connection) + self._connection = self._pool.connect() + + if self.init_u1db: + self._u1db = self.u1db_wrapper(self._connection) + + def __getattr__(self, name): + if name.startswith('u1db_'): + meth = re.sub('^u1db_', '', name) + return getattr(self._u1db, meth) + else: + return getattr(self._connection, name) - # XXX we should hook here all OUR SOLEDAD pragmas ----- - conn.cursor().execute("PRAGMA key=%s" % self.key) - conn.commit() - # ----------------------------------------------------- - # XXX profit of openfun isntead??? - if self.openfun is not None: - self.openfun(conn) - self.connections[tid] = conn - return conn +class U1DBTransaction(adbapi.Transaction): + + def __getattr__(self, name): + if name.startswith('u1db_'): + meth = re.sub('^u1db_', '', name) + return getattr(self._connection._u1db, meth) + else: + return getattr(self._cursor, name) + + +class U1DBConnectionPool(adbapi.ConnectionPool): + + connectionFactory = U1DBConnection + transactionFactory = U1DBTransaction + + def __init__(self, *args, **kwargs): + adbapi.ConnectionPool.__init__(self, *args, **kwargs) + # all u1db connections, hashed by thread-id + self.u1dbconnections = {} + + def runU1DBQuery(self, meth, *args, **kw): + meth = "u1db_%s" % meth + return self.runInteraction(self._runU1DBQuery, meth, *args, **kw) + + def _runU1DBQuery(self, trans, meth, *args, **kw): + meth = getattr(trans, meth) + return meth(*args, **kw) + + def _runInteraction(self, interaction, *args, **kw): + tid = self.threadID() + u1db = self.u1dbconnections.get(tid) + conn = self.connectionFactory(self, init_u1db=not bool(u1db)) + + if u1db is None: + self.u1dbconnections[tid] = conn._u1db + else: + conn._u1db = u1db + + trans = self.transactionFactory(self, conn) + try: + result = interaction(trans, *args, **kw) + trans.close() + conn.commit() + return result + except: + excType, excValue, excTraceback = sys.exc_info() + try: + conn.rollback() + except: + log.err(None, "Rollback failed") + raise excType, excValue, excTraceback + + def finalClose(self): + self.shutdownID = None + self.threadpool.stop() + self.running = False + for conn in self.connections.values(): + self._close(conn) + for u1db in self.u1dbconnections.values(): + self._close(u1db) + self.connections.clear() |