From 83e4e5e65aca6307923d92a7915f3851f1206501 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Wed, 25 Oct 2017 12:49:06 -0300 Subject: [refactor] split blobs into modules So we can have manager, sync, sql and errors in its own places. --Related: #8970 --- src/leap/soledad/client/_db/blobs/sql.py | 266 +++++++++++++++++++++++++++++++ 1 file changed, 266 insertions(+) create mode 100644 src/leap/soledad/client/_db/blobs/sql.py (limited to 'src/leap/soledad/client/_db/blobs/sql.py') diff --git a/src/leap/soledad/client/_db/blobs/sql.py b/src/leap/soledad/client/_db/blobs/sql.py new file mode 100644 index 00000000..dcca1671 --- /dev/null +++ b/src/leap/soledad/client/_db/blobs/sql.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +# _blobs.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Local blobs backend on SQLCipher +""" +import os +import binascii +from functools import partial +from twisted.internet import defer +from twisted.logger import Logger +from twisted.enterprise import adbapi +from leap.common.files import mkdir_p +from .. import sqlcipher +from .. import pragmas +from io import BytesIO +logger = Logger() + + +class SyncStatus: + SYNCED = 1 + PENDING_UPLOAD = 2 + PENDING_DOWNLOAD = 3 + FAILED_UPLOAD = 4 + FAILED_DOWNLOAD = 5 + UNAVAILABLE_STATUSES = (3, 5) + + +class SQLiteBlobBackend(object): + + concurrency_limit = 10 + + def __init__(self, path, key=None, user=None): + dbname = '%s_blobs.db' % (user or 'soledad') + self.path = os.path.abspath( + os.path.join(path, dbname)) + mkdir_p(os.path.dirname(self.path)) + if not key: + raise ValueError('key cannot be None') + backend = 'pysqlcipher.dbapi2' + opts = sqlcipher.SQLCipherOptions( + '/tmp/ignored', binascii.b2a_hex(key), + is_raw_key=True, create=True) + openfun = partial(pragmas.set_init_pragmas, opts=opts, + schema_func=_init_blob_table) + + self.dbpool = ConnectionPool( + backend, self.path, check_same_thread=False, timeout=5, + cp_openfun=openfun, cp_min=2, cp_max=2, cp_name='blob_pool') + + def close(self): + from twisted._threads import AlreadyQuit + try: + self.dbpool.close() + except AlreadyQuit: + pass + + @defer.inlineCallbacks + def put(self, blob_id, blob_fd, size=None, + namespace='', status=SyncStatus.PENDING_UPLOAD): + previous_state = yield self.get_sync_status(blob_id) + unavailable = SyncStatus.UNAVAILABLE_STATUSES + if previous_state and previous_state[0] in unavailable: + yield self.delete(blob_id, namespace=namespace) + status = SyncStatus.SYNCED + logger.info("Saving blob in local database...") + insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' + insert += ' VALUES (?, ?, zeroblob(?), ?)' + values = (blob_id, namespace, size, status) + irow = yield self.dbpool.insertAndGetLastRowid(insert, values) + yield self.dbpool.write_blob('blobs', 'payload', irow, blob_fd) + logger.info("Finished saving blob in local database.") + + @defer.inlineCallbacks + def get(self, blob_id, namespace=''): + # TODO we can also stream the blob value using sqlite + # incremental interface for blobs - and just return the raw fd instead + select = 'SELECT payload FROM blobs WHERE blob_id = ? AND namespace= ?' + values = (blob_id, namespace,) + avoid_values = SyncStatus.UNAVAILABLE_STATUSES + select += ' AND sync_status NOT IN (%s)' + select %= ','.join(['?' for _ in avoid_values]) + values += avoid_values + result = yield self.dbpool.runQuery(select, values) + if result: + defer.returnValue(BytesIO(str(result[0][0]))) + + @defer.inlineCallbacks + def get_sync_status(self, blob_id): + select = 'SELECT sync_status, retries FROM blobs WHERE blob_id = ?' + result = yield self.dbpool.runQuery(select, (blob_id,)) + if result: + defer.returnValue((result[0][0], result[0][1])) + + @defer.inlineCallbacks + def list(self, namespace='', sync_status=False): + query = 'select blob_id from blobs where namespace = ?' + values = (namespace,) + if sync_status: + query += ' and sync_status = ?' + values += (sync_status,) + else: + avoid_values = SyncStatus.UNAVAILABLE_STATUSES + query += ' AND sync_status NOT IN (%s)' + query %= ','.join(['?' for _ in avoid_values]) + values += avoid_values + result = yield self.dbpool.runQuery(query, values) + if result: + defer.returnValue([b_id[0] for b_id in result]) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + def update_sync_status(self, blob_id, sync_status, namespace=""): + query = 'SELECT sync_status FROM blobs WHERE blob_id = ?' + result = yield self.dbpool.runQuery(query, (blob_id,)) + + if not result: + insert = 'INSERT INTO blobs' + insert += ' (blob_id, namespace, payload, sync_status)' + insert += ' VALUES (?, ?, zeroblob(0), ?)' + values = (blob_id, namespace, sync_status) + yield self.dbpool.runOperation(insert, values) + return + + update = 'UPDATE blobs SET sync_status = ? WHERE blob_id = ?' + values = (sync_status, blob_id,) + result = yield self.dbpool.runOperation(update, values) + + def update_batch_sync_status(self, blob_id_list, sync_status, + namespace=''): + insert = 'INSERT INTO blobs (blob_id, namespace, payload, sync_status)' + first_blob_id, blob_id_list = blob_id_list[0], blob_id_list[1:] + insert += ' VALUES (?, ?, zeroblob(0), ?)' + values = (first_blob_id, namespace, sync_status) + for blob_id in blob_id_list: + insert += ', (?, ?, zeroblob(0), ?)' + values += (blob_id, namespace, sync_status) + return self.dbpool.runQuery(insert, values) + + def increment_retries(self, blob_id): + query = 'update blobs set retries = retries + 1 where blob_id = ?' + return self.dbpool.runQuery(query, (blob_id,)) + + @defer.inlineCallbacks + def list_namespaces(self): + query = 'select namespace from blobs' + result = yield self.dbpool.runQuery(query) + if result: + defer.returnValue([namespace[0] for namespace in result]) + else: + defer.returnValue([]) + + @defer.inlineCallbacks + def exists(self, blob_id, namespace=''): + query = 'SELECT blob_id from blobs WHERE blob_id = ? AND namespace= ?' + result = yield self.dbpool.runQuery(query, (blob_id, namespace,)) + defer.returnValue(bool(len(result))) + + def delete(self, blob_id, namespace=''): + query = 'DELETE FROM blobs WHERE blob_id = ? AND namespace = ?' + return self.dbpool.runQuery(query, (blob_id, namespace,)) + + +def _init_sync_table(conn): + maybe_create = """ + CREATE TABLE IF NOT EXISTS + sync_state ( + blob_id PRIMARY KEY, + sync_status INT default %s, + retries INT default 0)""" + default_status = SyncStatus.PENDING_UPLOAD + maybe_create %= default_status + conn.execute(maybe_create) + + +def _init_blob_table(conn): + maybe_create = ( + "CREATE TABLE IF NOT EXISTS " + "blobs (" + "blob_id PRIMARY KEY, " + "payload BLOB)") + conn.execute(maybe_create) + columns = [row[1] for row in conn.execute("pragma" + " table_info(blobs)").fetchall()] + if 'namespace' not in columns: + # namespace migration + conn.execute('ALTER TABLE blobs ADD COLUMN namespace TEXT') + if 'sync_status' not in columns: + # sync status migration + default_status = SyncStatus.PENDING_UPLOAD + sync_column = 'ALTER TABLE blobs ADD COLUMN sync_status INT default %s' + sync_column %= default_status + conn.execute(sync_column) + conn.execute('ALTER TABLE blobs ADD COLUMN retries INT default 0') + + +class ConnectionPool(adbapi.ConnectionPool): + + def insertAndGetLastRowid(self, *args, **kwargs): + """ + Execute an SQL query and return the last rowid. + + See: https://sqlite.org/c3ref/last_insert_rowid.html + """ + return self.runInteraction( + self._insertAndGetLastRowid, *args, **kwargs) + + def _insertAndGetLastRowid(self, trans, *args, **kw): + trans.execute(*args, **kw) + return trans.lastrowid + + def blob(self, table, column, irow, flags): + """ + Open a BLOB for incremental I/O. + + Return a handle to the BLOB that would be selected by: + + SELECT column FROM table WHERE rowid = irow; + + See: https://sqlite.org/c3ref/blob_open.html + + :param table: The table in which to lookup the blob. + :type table: str + :param column: The column where the BLOB is located. + :type column: str + :param rowid: The rowid of the BLOB. + :type rowid: int + :param flags: If zero, BLOB is opened for read-only. If non-zero, + BLOB is opened for RW. + :type flags: int + + :return: A BLOB handle. + :rtype: pysqlcipher.dbapi.Blob + """ + return self.runInteraction(self._blob, table, column, irow, flags) + + def write_blob(self, table, column, irow, blob_fd): + return self.runInteraction(self._write_blob, table, column, irow, + blob_fd) + + def _write_blob(self, trans, table, column, irow, blob_fd): + blob_fd.seek(0) + with trans._connection.blob(table, column, irow, 1) as handle: + data = blob_fd.read(2**12) + while data: + handle.write(data) + data = blob_fd.read(2**12) + + def _blob(self, trans, table, column, irow, flags): + # TODO: should not use transaction private variable here + handle = trans._connection.blob(table, column, irow, flags) + return handle -- cgit v1.2.3