summaryrefslogtreecommitdiff
path: root/src/leap/soledad/common/l2db/query_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/soledad/common/l2db/query_parser.py')
-rw-r--r--src/leap/soledad/common/l2db/query_parser.py371
1 files changed, 371 insertions, 0 deletions
diff --git a/src/leap/soledad/common/l2db/query_parser.py b/src/leap/soledad/common/l2db/query_parser.py
new file mode 100644
index 00000000..15a9ac80
--- /dev/null
+++ b/src/leap/soledad/common/l2db/query_parser.py
@@ -0,0 +1,371 @@
+# Copyright 2011 Canonical Ltd.
+# Copyright 2016 LEAP Encryption Access Project
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+"""
+Code for parsing Index definitions.
+"""
+
+import re
+
+from leap.soledad.common.l2db import errors
+
+
+class Getter(object):
+ """Get values from a document based on a specification."""
+
+ def get(self, raw_doc):
+ """Get a value from the document.
+
+ :param raw_doc: a python dictionary to get the value from.
+ :return: A list of values that match the description.
+ """
+ raise NotImplementedError(self.get)
+
+
+class StaticGetter(Getter):
+ """A getter that returns a defined value (independent of the doc)."""
+
+ def __init__(self, value):
+ """Create a StaticGetter.
+
+ :param value: the value to return when get is called.
+ """
+ if value is None:
+ self.value = []
+ elif isinstance(value, list):
+ self.value = value
+ else:
+ self.value = [value]
+
+ def get(self, raw_doc):
+ return self.value
+
+
+def extract_field(raw_doc, subfields, index=0):
+ if not isinstance(raw_doc, dict):
+ return []
+ val = raw_doc.get(subfields[index])
+ if val is None:
+ return []
+ if index < len(subfields) - 1:
+ if isinstance(val, list):
+ results = []
+ for item in val:
+ results.extend(extract_field(item, subfields, index + 1))
+ return results
+ if isinstance(val, dict):
+ return extract_field(val, subfields, index + 1)
+ return []
+ if isinstance(val, dict):
+ return []
+ if isinstance(val, list):
+ # Strip anything in the list that isn't a simple type
+ return [v for v in val if not isinstance(v, (dict, list))]
+ return [val]
+
+
+class ExtractField(Getter):
+ """Extract a field from the document."""
+
+ def __init__(self, field):
+ """Create an ExtractField object.
+
+ When a document is passed to get() this will return a value
+ from the document based on the field specifier passed to
+ the constructor.
+
+ None will be returned if the field is nonexistant, or refers to an
+ object, rather than a simple type or list of simple types.
+
+ :param field: a specifier for the field to return.
+ This is either a field name, or a dotted field name.
+ """
+ self.field = field.split('.')
+
+ def get(self, raw_doc):
+ return extract_field(raw_doc, self.field)
+
+
+class Transformation(Getter):
+ """A transformation on a value from another Getter."""
+
+ name = None
+ arity = 1
+ args = ['expression']
+
+ def __init__(self, inner):
+ """Create a transformation.
+
+ :param inner: the argument(s) to the transformation.
+ """
+ self.inner = inner
+
+ def get(self, raw_doc):
+ inner_values = self.inner.get(raw_doc)
+ assert isinstance(inner_values, list),\
+ 'get() should always return a list'
+ return self.transform(inner_values)
+
+ def transform(self, values):
+ """Transform the values.
+
+ This should be implemented by subclasses to transform the
+ value when get() is called.
+
+ :param values: the values from the other Getter
+ :return: the transformed values.
+ """
+ raise NotImplementedError(self.transform)
+
+
+class Lower(Transformation):
+ """Lowercase a string.
+
+ This transformation will return None for non-string inputs. However,
+ it will lowercase any strings in a list, dropping any elements
+ that are not strings.
+ """
+
+ name = "lower"
+
+ def _can_transform(self, val):
+ return isinstance(val, basestring)
+
+ def transform(self, values):
+ if not values:
+ return []
+ return [val.lower() for val in values if self._can_transform(val)]
+
+
+class Number(Transformation):
+ """Convert an integer to a zero padded string.
+
+ This transformation will return None for non-integer inputs. However, it
+ will transform any integers in a list, dropping any elements that are not
+ integers.
+ """
+
+ name = 'number'
+ arity = 2
+ args = ['expression', int]
+
+ def __init__(self, inner, number):
+ super(Number, self).__init__(inner)
+ self.padding = "%%0%sd" % number
+
+ def _can_transform(self, val):
+ return isinstance(val, int) and not isinstance(val, bool)
+
+ def transform(self, values):
+ """Transform any integers in values into zero padded strings."""
+ if not values:
+ return []
+ return [self.padding % (v,) for v in values if self._can_transform(v)]
+
+
+class Bool(Transformation):
+ """Convert bool to string."""
+
+ name = "bool"
+ args = ['expression']
+
+ def _can_transform(self, val):
+ return isinstance(val, bool)
+
+ def transform(self, values):
+ """Transform any booleans in values into strings."""
+ if not values:
+ return []
+ return [('1' if v else '0') for v in values if self._can_transform(v)]
+
+
+class SplitWords(Transformation):
+ """Split a string on whitespace.
+
+ This Getter will return [] for non-string inputs. It will however
+ split any strings in an input list, discarding any elements that
+ are not strings.
+ """
+
+ name = "split_words"
+
+ def _can_transform(self, val):
+ return isinstance(val, basestring)
+
+ def transform(self, values):
+ if not values:
+ return []
+ result = set()
+ for value in values:
+ if self._can_transform(value):
+ for word in value.split():
+ result.add(word)
+ return list(result)
+
+
+class Combine(Transformation):
+ """Combine multiple expressions into a single index."""
+
+ name = "combine"
+ # variable number of args
+ arity = -1
+
+ def __init__(self, *inner):
+ super(Combine, self).__init__(inner)
+
+ def get(self, raw_doc):
+ inner_values = []
+ for inner in self.inner:
+ inner_values.extend(inner.get(raw_doc))
+ return self.transform(inner_values)
+
+ def transform(self, values):
+ return values
+
+
+class IsNull(Transformation):
+ """Indicate whether the input is None.
+
+ This Getter returns a bool indicating whether the input is nil.
+ """
+
+ name = "is_null"
+
+ def transform(self, values):
+ return [len(values) == 0]
+
+
+def check_fieldname(fieldname):
+ if fieldname.endswith('.'):
+ raise errors.IndexDefinitionParseError(
+ "Fieldname cannot end in '.':%s^" % (fieldname,))
+
+
+class Parser(object):
+ """Parse an index expression into a sequence of transformations."""
+
+ _transformations = {}
+ _delimiters = re.compile("\(|\)|,")
+
+ def __init__(self):
+ self._tokens = []
+
+ def _set_expression(self, expression):
+ self._open_parens = 0
+ self._tokens = []
+ expression = expression.strip()
+ while expression:
+ delimiter = self._delimiters.search(expression)
+ if delimiter:
+ idx = delimiter.start()
+ if idx == 0:
+ result, expression = (expression[:1], expression[1:])
+ self._tokens.append(result)
+ else:
+ result, expression = (expression[:idx], expression[idx:])
+ result = result.strip()
+ if result:
+ self._tokens.append(result)
+ else:
+ expression = expression.strip()
+ if expression:
+ self._tokens.append(expression)
+ expression = None
+
+ def _get_token(self):
+ if self._tokens:
+ return self._tokens.pop(0)
+
+ def _peek_token(self):
+ if self._tokens:
+ return self._tokens[0]
+
+ @staticmethod
+ def _to_getter(term):
+ if isinstance(term, Getter):
+ return term
+ check_fieldname(term)
+ return ExtractField(term)
+
+ def _parse_op(self, op_name):
+ self._get_token() # '('
+ op = self._transformations.get(op_name, None)
+ if op is None:
+ raise errors.IndexDefinitionParseError(
+ "Unknown operation: %s" % op_name)
+ args = []
+ while True:
+ args.append(self._parse_term())
+ sep = self._get_token()
+ if sep == ')':
+ break
+ if sep != ',':
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' in parentheses." % (sep,))
+ parsed = []
+ for i, arg in enumerate(args):
+ arg_type = op.args[i % len(op.args)]
+ if arg_type == 'expression':
+ inner = self._to_getter(arg)
+ else:
+ try:
+ inner = arg_type(arg)
+ except ValueError as e:
+ raise errors.IndexDefinitionParseError(
+ "Invalid value %r for argument type %r "
+ "(%r)." % (arg, arg_type, e))
+ parsed.append(inner)
+ return op(*parsed)
+
+ def _parse_term(self):
+ term = self._get_token()
+ if term is None:
+ raise errors.IndexDefinitionParseError(
+ "Unexpected end of index definition.")
+ if term in (',', ')', '('):
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' at start of expression." % (term,))
+ next_token = self._peek_token()
+ if next_token == '(':
+ return self._parse_op(term)
+ return term
+
+ def parse(self, expression):
+ self._set_expression(expression)
+ term = self._to_getter(self._parse_term())
+ if self._peek_token():
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' after end of expression."
+ % (self._peek_token(),))
+ return term
+
+ def parse_all(self, fields):
+ return [self.parse(field) for field in fields]
+
+ @classmethod
+ def register_transormation(cls, transform):
+ assert transform.name not in cls._transformations, (
+ "Transform %s already registered for %s"
+ % (transform.name, cls._transformations[transform.name]))
+ cls._transformations[transform.name] = transform
+
+
+Parser.register_transormation(SplitWords)
+Parser.register_transormation(Lower)
+Parser.register_transormation(Number)
+Parser.register_transormation(Bool)
+Parser.register_transormation(IsNull)
+Parser.register_transormation(Combine)