summaryrefslogtreecommitdiff
path: root/client/src
diff options
context:
space:
mode:
authorKali Kaneko <kali@leap.se>2013-12-13 05:16:20 -0400
committerKali Kaneko <kali@leap.se>2013-12-13 14:00:19 -0400
commitb1a5a88a5f53ac9ff5a56625620b45e949404a99 (patch)
tree2631693f7b308487a8e72c3add8fc48446280af2 /client/src
parentbc56fd56528feb0e5417ed943f07a49d79ddb960 (diff)
get_count_from_index
Diffstat (limited to 'client/src')
-rw-r--r--client/src/leap/soledad/client/__init__.py17
-rw-r--r--client/src/leap/soledad/client/sqlcipher.py53
2 files changed, 69 insertions, 1 deletions
diff --git a/client/src/leap/soledad/client/__init__.py b/client/src/leap/soledad/client/__init__.py
index 62f93b3d..0d36c247 100644
--- a/client/src/leap/soledad/client/__init__.py
+++ b/client/src/leap/soledad/client/__init__.py
@@ -955,6 +955,23 @@ class Soledad(object):
if self._db:
return self._db.get_from_index(index_name, *key_values)
+ def get_count_from_index(self, index_name, *key_values):
+ """
+ Return the count of the documents that match the keys and
+ values supplied.
+
+ :param index_name: The index to query
+ :type index_name: str
+ :param key_values: values to match. eg, if you have
+ an index with 3 fields then you would have:
+ get_from_index(index_name, val1, val2, val3)
+ :type key_values: tuple
+ :return: count.
+ :rtype: int
+ """
+ if self._db:
+ return self._db.get_count_from_index(index_name, *key_values)
+
def get_range_from_index(self, index_name, start_value, end_value):
"""
Return documents that fall within the specified range.
diff --git a/client/src/leap/soledad/client/sqlcipher.py b/client/src/leap/soledad/client/sqlcipher.py
index 894c6f97..c7aebbe6 100644
--- a/client/src/leap/soledad/client/sqlcipher.py
+++ b/client/src/leap/soledad/client/sqlcipher.py
@@ -49,8 +49,8 @@ import time
import string
import threading
-
from u1db.backends import sqlite_backend
+from u1db import errors
from pysqlcipher import dbapi2
from u1db import errors as u1db_errors
from leap.soledad.common.document import SoledadDocument
@@ -697,6 +697,57 @@ class SQLCipherDatabase(sqlite_backend.SQLitePartialExpandDatabase):
# XXX change passphrase param!
db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % passphrase)
+ # Extra query methods: extensions to the base sqlite implmentation.
+
+ def get_count_from_index(self, index_name, *key_values):
+ """
+ Returns the count for a given combination of index_name
+ and key values.
+
+ Extension method made from similar methods in u1db version 13.09
+
+ :param index_name: The index to query
+ :type index_name: str
+ :param key_values: values to match. eg, if you have
+ an index with 3 fields then you would have:
+ get_from_index(index_name, val1, val2, val3)
+ :type key_values: tuple
+ :return: count.
+ :rtype: int
+ """
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ exact_where = [novalue_where[i]
+ + (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ where.append(exact_where[idx])
+ args.append(value)
+
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ statement = (
+ "SELECT COUNT(*) FROM document d, %s WHERE %s " % (
+ ', '.join(tables),
+ ' AND '.join(where),
+ ))
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(
+ str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ return res[0][0]
+
def __del__(self):
"""
Closes db_handle upon object destruction.