diff options
Diffstat (limited to 'client')
| -rw-r--r-- | client/changes/feature_4616_sqlite_count_by_index | 1 | ||||
| -rw-r--r-- | client/src/leap/soledad/client/__init__.py | 17 | ||||
| -rw-r--r-- | client/src/leap/soledad/client/sqlcipher.py | 53 | 
3 files changed, 70 insertions, 1 deletions
| diff --git a/client/changes/feature_4616_sqlite_count_by_index b/client/changes/feature_4616_sqlite_count_by_index new file mode 100644 index 00000000..c7819d38 --- /dev/null +++ b/client/changes/feature_4616_sqlite_count_by_index @@ -0,0 +1 @@ +  o Adds a get_count_by_index to sqlcipher u1db backend. Related to: #4616 diff --git a/client/src/leap/soledad/client/__init__.py b/client/src/leap/soledad/client/__init__.py index 62f93b3d..0d36c247 100644 --- a/client/src/leap/soledad/client/__init__.py +++ b/client/src/leap/soledad/client/__init__.py @@ -955,6 +955,23 @@ class Soledad(object):          if self._db:              return self._db.get_from_index(index_name, *key_values) +    def get_count_from_index(self, index_name, *key_values): +        """ +        Return the count of the documents that match the keys and +        values supplied. + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +                           an index with 3 fields then you would have: +                           get_from_index(index_name, val1, val2, val3) +        :type key_values: tuple +        :return: count. +        :rtype: int +        """ +        if self._db: +            return self._db.get_count_from_index(index_name, *key_values) +      def get_range_from_index(self, index_name, start_value, end_value):          """          Return documents that fall within the specified range. diff --git a/client/src/leap/soledad/client/sqlcipher.py b/client/src/leap/soledad/client/sqlcipher.py index 894c6f97..c7aebbe6 100644 --- a/client/src/leap/soledad/client/sqlcipher.py +++ b/client/src/leap/soledad/client/sqlcipher.py @@ -49,8 +49,8 @@ import time  import string  import threading -  from u1db.backends import sqlite_backend +from u1db import errors  from pysqlcipher import dbapi2  from u1db import errors as u1db_errors  from leap.soledad.common.document import SoledadDocument @@ -697,6 +697,57 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase):          # XXX change passphrase param!          db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % passphrase) +    # Extra query methods: extensions to the base sqlite implmentation. + +    def get_count_from_index(self, index_name, *key_values): +        """ +        Returns the count for a given combination of index_name +        and key values. + +        Extension method made from similar methods in u1db version 13.09 + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +                           an index with 3 fields then you would have: +                           get_from_index(index_name, val1, val2, val3) +        :type key_values: tuple +        :return: count. +        :rtype: int +        """ +        c = self._db_handle.cursor() +        definition = self._get_index_definition(index_name) + +        if len(key_values) != len(definition): +            raise errors.InvalidValueForIndex() +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        novalue_where = ["d.doc_id = d%d.doc_id" +                         " AND d%d.field_name = ?" +                         % (i, i) for i in range(len(definition))] +        exact_where = [novalue_where[i] +                       + (" AND d%d.value = ?" % (i,)) +                       for i in range(len(definition))] +        args = [] +        where = [] +        for idx, (field, value) in enumerate(zip(definition, key_values)): +            args.append(field) +            where.append(exact_where[idx]) +            args.append(value) + +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        statement = ( +            "SELECT COUNT(*) FROM document d, %s WHERE %s " % ( +                ', '.join(tables), +                ' AND '.join(where), +            )) +        try: +            c.execute(statement, tuple(args)) +        except dbapi2.OperationalError, e: +            raise dbapi2.OperationalError( +                str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args)) +        res = c.fetchall() +        return res[0][0] +      def __del__(self):          """          Closes db_handle upon object destruction. | 
