diff options
-rw-r--r-- | client/changes/feature_4616_sqlite_count_by_index | 1 | ||||
-rw-r--r-- | client/src/leap/soledad/client/__init__.py | 17 | ||||
-rw-r--r-- | client/src/leap/soledad/client/sqlcipher.py | 53 |
3 files changed, 70 insertions, 1 deletions
diff --git a/client/changes/feature_4616_sqlite_count_by_index b/client/changes/feature_4616_sqlite_count_by_index new file mode 100644 index 00000000..c7819d38 --- /dev/null +++ b/client/changes/feature_4616_sqlite_count_by_index @@ -0,0 +1 @@ + o Adds a get_count_by_index to sqlcipher u1db backend. Related to: #4616 diff --git a/client/src/leap/soledad/client/__init__.py b/client/src/leap/soledad/client/__init__.py index 4c1a4195..61337680 100644 --- a/client/src/leap/soledad/client/__init__.py +++ b/client/src/leap/soledad/client/__init__.py @@ -973,6 +973,23 @@ class Soledad(object): if self._db: return self._db.get_from_index(index_name, *key_values) + def get_count_from_index(self, index_name, *key_values): + """ + Return the count of the documents that match the keys and + values supplied. + + :param index_name: The index to query + :type index_name: str + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :type key_values: tuple + :return: count. + :rtype: int + """ + if self._db: + return self._db.get_count_from_index(index_name, *key_values) + def get_range_from_index(self, index_name, start_value, end_value): """ Return documents that fall within the specified range. diff --git a/client/src/leap/soledad/client/sqlcipher.py b/client/src/leap/soledad/client/sqlcipher.py index 894c6f97..c7aebbe6 100644 --- a/client/src/leap/soledad/client/sqlcipher.py +++ b/client/src/leap/soledad/client/sqlcipher.py @@ -49,8 +49,8 @@ import time import string import threading - from u1db.backends import sqlite_backend +from u1db import errors from pysqlcipher import dbapi2 from u1db import errors as u1db_errors from leap.soledad.common.document import SoledadDocument @@ -697,6 +697,57 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase): # XXX change passphrase param! db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % passphrase) + # Extra query methods: extensions to the base sqlite implmentation. + + def get_count_from_index(self, index_name, *key_values): + """ + Returns the count for a given combination of index_name + and key values. + + Extension method made from similar methods in u1db version 13.09 + + :param index_name: The index to query + :type index_name: str + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :type key_values: tuple + :return: count. + :rtype: int + """ + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + where.append(exact_where[idx]) + args.append(value) + + tables = ["document_fields d%d" % i for i in range(len(definition))] + statement = ( + "SELECT COUNT(*) FROM document d, %s WHERE %s " % ( + ', '.join(tables), + ' AND '.join(where), + )) + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError( + str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + return res[0][0] + def __del__(self): """ Closes db_handle upon object destruction. |