diff options
Diffstat (limited to 'src')
90 files changed, 16623 insertions, 0 deletions
diff --git a/src/leap/__init__.py b/src/leap/__init__.py new file mode 100644 index 00000000..f48ad105 --- /dev/null +++ b/src/leap/__init__.py @@ -0,0 +1,6 @@ +# See http://peak.telecommunity.com/DevCenter/setuptools#namespace-packages +try: +    __import__('pkg_resources').declare_namespace(__name__) +except ImportError: +    from pkgutil import extend_path +    __path__ = extend_path(__path__, __name__) diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py new file mode 100644 index 00000000..f48ad105 --- /dev/null +++ b/src/leap/soledad/__init__.py @@ -0,0 +1,6 @@ +# See http://peak.telecommunity.com/DevCenter/setuptools#namespace-packages +try: +    __import__('pkg_resources').declare_namespace(__name__) +except ImportError: +    from pkgutil import extend_path +    __path__ = extend_path(__path__, __name__) diff --git a/src/leap/soledad/_version.py b/src/leap/soledad/_version.py new file mode 100644 index 00000000..4d465c10 --- /dev/null +++ b/src/leap/soledad/_version.py @@ -0,0 +1,484 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.16 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): +    """Get the keywords needed to look up the version information.""" +    # these strings will be replaced by git during git-archive. +    # setup.py/versioneer.py will grep for the variable names, so they must +    # each be defined on a line of their own. _version.py will just call +    # get_keywords(). +    git_refnames = "$Format:%d$" +    git_full = "$Format:%H$" +    keywords = {"refnames": git_refnames, "full": git_full} +    return keywords + + +class VersioneerConfig: +    """Container for Versioneer configuration parameters.""" + + +def get_config(): +    """Create, populate and return the VersioneerConfig() object.""" +    # these strings are filled in when 'setup.py versioneer' creates +    # _version.py +    cfg = VersioneerConfig() +    cfg.VCS = "git" +    cfg.style = "pep440" +    cfg.tag_prefix = "" +    cfg.parentdir_prefix = "None" +    cfg.versionfile_source = "src/leap/soledad/_version.py" +    cfg.verbose = False +    return cfg + + +class NotThisMethod(Exception): +    """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method):  # decorator +    """Decorator to mark a method as the handler for a particular VCS.""" +    def decorate(f): +        """Store f in HANDLERS[vcs][method].""" +        if vcs not in HANDLERS: +            HANDLERS[vcs] = {} +        HANDLERS[vcs][method] = f +        return f +    return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): +    """Call the given command(s).""" +    assert isinstance(commands, list) +    p = None +    for c in commands: +        try: +            dispcmd = str([c] + args) +            # remember shell=False, so use git.cmd on windows, not just git +            p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, +                                 stderr=(subprocess.PIPE if hide_stderr +                                         else None)) +            break +        except EnvironmentError: +            e = sys.exc_info()[1] +            if e.errno == errno.ENOENT: +                continue +            if verbose: +                print("unable to run %s" % dispcmd) +                print(e) +            return None +    else: +        if verbose: +            print("unable to find command, tried %s" % (commands,)) +        return None +    stdout = p.communicate()[0].strip() +    if sys.version_info[0] >= 3: +        stdout = stdout.decode() +    if p.returncode != 0: +        if verbose: +            print("unable to run %s (error)" % dispcmd) +        return None +    return stdout + + +def versions_from_parentdir(parentdir_prefix, root, verbose): +    """Try to determine the version from the parent directory name. + +    Source tarballs conventionally unpack into a directory that includes +    both the project name and a version string. +    """ +    dirname = os.path.basename(root) +    if not dirname.startswith(parentdir_prefix): +        if verbose: +            print("guessing rootdir is '%s', but '%s' doesn't start with " +                  "prefix '%s'" % (root, dirname, parentdir_prefix)) +        raise NotThisMethod("rootdir doesn't start with parentdir_prefix") +    return {"version": dirname[len(parentdir_prefix):], +            "full-revisionid": None, +            "dirty": False, "error": None} + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): +    """Extract version information from the given file.""" +    # the code embedded in _version.py can just fetch the value of these +    # keywords. When used from setup.py, we don't want to import _version.py, +    # so we do it with a regexp instead. This function is not used from +    # _version.py. +    keywords = {} +    try: +        f = open(versionfile_abs, "r") +        for line in f.readlines(): +            if line.strip().startswith("git_refnames ="): +                mo = re.search(r'=\s*"(.*)"', line) +                if mo: +                    keywords["refnames"] = mo.group(1) +            if line.strip().startswith("git_full ="): +                mo = re.search(r'=\s*"(.*)"', line) +                if mo: +                    keywords["full"] = mo.group(1) +        f.close() +    except EnvironmentError: +        pass +    return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): +    """Get version information from git keywords.""" +    if not keywords: +        raise NotThisMethod("no keywords at all, weird") +    refnames = keywords["refnames"].strip() +    if refnames.startswith("$Format"): +        if verbose: +            print("keywords are unexpanded, not using") +        raise NotThisMethod("unexpanded keywords, not a git-archive tarball") +    refs = set([r.strip() for r in refnames.strip("()").split(",")]) +    # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of +    # just "foo-1.0". If we see a "tag: " prefix, prefer those. +    TAG = "tag: " +    tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) +    if not tags: +        # Either we're using git < 1.8.3, or there really are no tags. We use +        # a heuristic: assume all version tags have a digit. The old git %d +        # expansion behaves like git log --decorate=short and strips out the +        # refs/heads/ and refs/tags/ prefixes that would let us distinguish +        # between branches and tags. By ignoring refnames without digits, we +        # filter out many common branch names like "release" and +        # "stabilization", as well as "HEAD" and "master". +        tags = set([r for r in refs if re.search(r'\d', r)]) +        if verbose: +            print("discarding '%s', no digits" % ",".join(refs-tags)) +    if verbose: +        print("likely tags: %s" % ",".join(sorted(tags))) +    for ref in sorted(tags): +        # sorting will prefer e.g. "2.0" over "2.0rc1" +        if ref.startswith(tag_prefix): +            r = ref[len(tag_prefix):] +            if verbose: +                print("picking %s" % r) +            return {"version": r, +                    "full-revisionid": keywords["full"].strip(), +                    "dirty": False, "error": None +                    } +    # no suitable tags, so version is "0+unknown", but full hex is still there +    if verbose: +        print("no suitable tags, using unknown + full revision id") +    return {"version": "0+unknown", +            "full-revisionid": keywords["full"].strip(), +            "dirty": False, "error": "no suitable tags"} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +    """Get version from 'git describe' in the root of the source tree. + +    This only gets called if the git-archive 'subst' keywords were *not* +    expanded, and _version.py hasn't already been rewritten with a short +    version string, meaning we're inside a checked out source tree. +    """ +    if not os.path.exists(os.path.join(root, ".git")): +        if verbose: +            print("no .git in %s" % root) +        raise NotThisMethod("no .git directory") + +    GITS = ["git"] +    if sys.platform == "win32": +        GITS = ["git.cmd", "git.exe"] +    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] +    # if there isn't one, this yields HEX[-dirty] (no NUM) +    describe_out = run_command(GITS, ["describe", "--tags", "--dirty", +                                      "--always", "--long", +                                      "--match", "%s*" % tag_prefix], +                               cwd=root) +    # --long was added in git-1.5.5 +    if describe_out is None: +        raise NotThisMethod("'git describe' failed") +    describe_out = describe_out.strip() +    full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) +    if full_out is None: +        raise NotThisMethod("'git rev-parse' failed") +    full_out = full_out.strip() + +    pieces = {} +    pieces["long"] = full_out +    pieces["short"] = full_out[:7]  # maybe improved later +    pieces["error"] = None + +    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] +    # TAG might have hyphens. +    git_describe = describe_out + +    # look for -dirty suffix +    dirty = git_describe.endswith("-dirty") +    pieces["dirty"] = dirty +    if dirty: +        git_describe = git_describe[:git_describe.rindex("-dirty")] + +    # now we have TAG-NUM-gHEX or HEX + +    if "-" in git_describe: +        # TAG-NUM-gHEX +        mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) +        if not mo: +            # unparseable. Maybe git-describe is misbehaving? +            pieces["error"] = ("unable to parse git-describe output: '%s'" +                               % describe_out) +            return pieces + +        # tag +        full_tag = mo.group(1) +        if not full_tag.startswith(tag_prefix): +            if verbose: +                fmt = "tag '%s' doesn't start with prefix '%s'" +                print(fmt % (full_tag, tag_prefix)) +            pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" +                               % (full_tag, tag_prefix)) +            return pieces +        pieces["closest-tag"] = full_tag[len(tag_prefix):] + +        # distance: number of commits since tag +        pieces["distance"] = int(mo.group(2)) + +        # commit: short hex revision ID +        pieces["short"] = mo.group(3) + +    else: +        # HEX: no tags +        pieces["closest-tag"] = None +        count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], +                                cwd=root) +        pieces["distance"] = int(count_out)  # total number of commits + +    return pieces + + +def plus_or_dot(pieces): +    """Return a + if we don't already have one, else return a .""" +    if "+" in pieces.get("closest-tag", ""): +        return "." +    return "+" + + +def render_pep440(pieces): +    """Build up version string, with post-release "local version identifier". + +    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you +    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + +    Exceptions: +    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"] or pieces["dirty"]: +            rendered += plus_or_dot(pieces) +            rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) +            if pieces["dirty"]: +                rendered += ".dirty" +    else: +        # exception #1 +        rendered = "0+untagged.%d.g%s" % (pieces["distance"], +                                          pieces["short"]) +        if pieces["dirty"]: +            rendered += ".dirty" +    return rendered + + +def render_pep440_pre(pieces): +    """TAG[.post.devDISTANCE] -- No -dirty. + +    Exceptions: +    1: no tags. 0.post.devDISTANCE +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"]: +            rendered += ".post.dev%d" % pieces["distance"] +    else: +        # exception #1 +        rendered = "0.post.dev%d" % pieces["distance"] +    return rendered + + +def render_pep440_post(pieces): +    """TAG[.postDISTANCE[.dev0]+gHEX] . + +    The ".dev0" means dirty. Note that .dev0 sorts backwards +    (a dirty tree will appear "older" than the corresponding clean one), +    but you shouldn't be releasing software with -dirty anyways. + +    Exceptions: +    1: no tags. 0.postDISTANCE[.dev0] +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"] or pieces["dirty"]: +            rendered += ".post%d" % pieces["distance"] +            if pieces["dirty"]: +                rendered += ".dev0" +            rendered += plus_or_dot(pieces) +            rendered += "g%s" % pieces["short"] +    else: +        # exception #1 +        rendered = "0.post%d" % pieces["distance"] +        if pieces["dirty"]: +            rendered += ".dev0" +        rendered += "+g%s" % pieces["short"] +    return rendered + + +def render_pep440_old(pieces): +    """TAG[.postDISTANCE[.dev0]] . + +    The ".dev0" means dirty. + +    Eexceptions: +    1: no tags. 0.postDISTANCE[.dev0] +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"] or pieces["dirty"]: +            rendered += ".post%d" % pieces["distance"] +            if pieces["dirty"]: +                rendered += ".dev0" +    else: +        # exception #1 +        rendered = "0.post%d" % pieces["distance"] +        if pieces["dirty"]: +            rendered += ".dev0" +    return rendered + + +def render_git_describe(pieces): +    """TAG[-DISTANCE-gHEX][-dirty]. + +    Like 'git describe --tags --dirty --always'. + +    Exceptions: +    1: no tags. HEX[-dirty]  (note: no 'g' prefix) +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"]: +            rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) +    else: +        # exception #1 +        rendered = pieces["short"] +    if pieces["dirty"]: +        rendered += "-dirty" +    return rendered + + +def render_git_describe_long(pieces): +    """TAG-DISTANCE-gHEX[-dirty]. + +    Like 'git describe --tags --dirty --always -long'. +    The distance/hash is unconditional. + +    Exceptions: +    1: no tags. HEX[-dirty]  (note: no 'g' prefix) +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) +    else: +        # exception #1 +        rendered = pieces["short"] +    if pieces["dirty"]: +        rendered += "-dirty" +    return rendered + + +def render(pieces, style): +    """Render the given version pieces into the requested style.""" +    if pieces["error"]: +        return {"version": "unknown", +                "full-revisionid": pieces.get("long"), +                "dirty": None, +                "error": pieces["error"]} + +    if not style or style == "default": +        style = "pep440"  # the default + +    if style == "pep440": +        rendered = render_pep440(pieces) +    elif style == "pep440-pre": +        rendered = render_pep440_pre(pieces) +    elif style == "pep440-post": +        rendered = render_pep440_post(pieces) +    elif style == "pep440-old": +        rendered = render_pep440_old(pieces) +    elif style == "git-describe": +        rendered = render_git_describe(pieces) +    elif style == "git-describe-long": +        rendered = render_git_describe_long(pieces) +    else: +        raise ValueError("unknown style '%s'" % style) + +    return {"version": rendered, "full-revisionid": pieces["long"], +            "dirty": pieces["dirty"], "error": None} + + +def get_versions(): +    """Get version information or return default if unable to do so.""" +    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have +    # __file__, we can work backwards from there to the root. Some +    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which +    # case we can only use expanded keywords. + +    cfg = get_config() +    verbose = cfg.verbose + +    try: +        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, +                                          verbose) +    except NotThisMethod: +        pass + +    try: +        root = os.path.realpath(__file__) +        # versionfile_source is the relative path from the top of the source +        # tree (where the .git directory might live) to this file. Invert +        # this to find the root from __file__. +        for i in cfg.versionfile_source.split('/'): +            root = os.path.dirname(root) +    except NameError: +        return {"version": "0+unknown", "full-revisionid": None, +                "dirty": None, +                "error": "unable to find root of source tree"} + +    try: +        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) +        return render(pieces, cfg.style) +    except NotThisMethod: +        pass + +    try: +        if cfg.parentdir_prefix: +            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) +    except NotThisMethod: +        pass + +    return {"version": "0+unknown", "full-revisionid": None, +            "dirty": None, +            "error": "unable to compute version"} diff --git a/src/leap/soledad/client/__init__.py b/src/leap/soledad/client/__init__.py new file mode 100644 index 00000000..bcad78db --- /dev/null +++ b/src/leap/soledad/client/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Soledad - Synchronization Of Locally Encrypted Data Among Devices. +""" +from leap.soledad.common import soledad_assert + +from .api import Soledad +from ._document import Document, AttachmentStates +from ._version import get_versions + +__version__ = get_versions()['version'] +del get_versions + +__all__ = ['soledad_assert', 'Soledad', 'Document', 'AttachmentStates', +           '__version__'] diff --git a/src/leap/soledad/client/_crypto.py b/src/leap/soledad/client/_crypto.py new file mode 100644 index 00000000..8cedf52e --- /dev/null +++ b/src/leap/soledad/client/_crypto.py @@ -0,0 +1,557 @@ +# -*- coding: utf-8 -*- +# _crypto.py +# Copyright (C) 2016 LEAP Encryption Access Project +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +""" +Cryptographic operations for the soledad client. + +This module implements streaming crypto operations. +It replaces the old client.crypto module, that will be deprecated in soledad +0.12. + +The algorithm for encrypting and decrypting is as follow: + +The KEY is a 32 bytes value. +The IV is a random 16 bytes value. +The PREAMBLE is a packed_structure with encryption metadata, such as IV. +The SEPARATOR is a space. + +Encryption +---------- + +IV = os.urandom(16) +PREAMBLE = BLOB_SIGNATURE_MAGIC, ENC_SCHEME, ENC_METHOD, time, IV, doc_id, rev, +and size. + +PREAMBLE = base64_encoded(PREAMBLE) +CIPHERTEXT = base64_encoded(AES_GCM(KEY, cleartext) + resulting_tag) if armor + +CIPHERTEXT = AES_GCM(KEY, cleartext) + resulting_tag if not armor +# "resulting_tag" came from AES-GCM encryption. It will be the last 16 bytes of +# our ciphertext. + +encrypted_payload = PREAMBLE + SEPARATOR + CIPHERTEXT + +Decryption +---------- + +Ciphertext and Tag CAN come encoded in base64 (with armor=True) or raw (with +armor=False). Preamble will always come encoded in base64. + +PREAMBLE, CIPHERTEXT = PAYLOAD.SPLIT(' ', 1) + +PREAMBLE = base64_decode(PREAMBLE) +CIPHERTEXT = base64_decode(CIPHERTEXT) if armor else CIPHERTEXT + +CIPHERTEXT, TAG = CIPHERTEXT[:-16], CIPHERTEXT[-16:] +CLEARTEXT = aes_gcm_decrypt(KEY, IV, CIPHERTEXT, TAG, associated_data=PREAMBLE) + +AES-GCM will check preamble authenticity as well, since we are using +Authenticated Encryption with Associated Data (AEAD). Ciphertext and associated +data (PREAMBLE) authenticity will both be checked together during decryption. +PREAMBLE consistency (if it matches the desired document, for instance) is +checked during PREAMBLE reading. +""" + + +import base64 +import hashlib +import warnings +import hmac +import os +import struct +import time + +from io import BytesIO +from collections import namedtuple + +from twisted.internet import defer +from twisted.internet import interfaces +from twisted.web.client import FileBodyProducer + +from leap.soledad.common import soledad_assert +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend + +from zope.interface import implementer + + +SECRET_LENGTH = 64 +SEPARATOR = ' '  # Anything that doesn't belong to base64 encoding + +CRYPTO_BACKEND = default_backend() + +PACMAN = struct.Struct('2sbbQ16s255p255pQ') +LEGACY_PACMAN = struct.Struct('2sbbQ16s255p255p') +BLOB_SIGNATURE_MAGIC = '\x13\x37' + + +ENC_SCHEME = namedtuple('SCHEME', 'symkey')(1) +ENC_METHOD = namedtuple('METHOD', 'aes_256_ctr aes_256_gcm')(1, 2) +DocInfo = namedtuple('DocInfo', 'doc_id rev') + + +class EncryptionDecryptionError(Exception): +    pass + + +class InvalidBlob(Exception): +    pass + + +class SoledadCrypto(object): +    """ +    This class provides convenient methods for document encryption and +    decryption using BlobEncryptor and BlobDecryptor classes. +    """ +    def __init__(self, secret): +        """ +        Initialize the crypto object. + +        :param secret: The Soledad remote storage secret. +        :type secret: str +        """ +        self.secret = secret + +    def encrypt_doc(self, doc): +        """ +        Creates and configures a BlobEncryptor, asking it to start encryption +        and wrapping the result as a simple JSON string with a "raw" key. + +        :param doc: the document to be encrypted. +        :type doc: Document +        :return: A deferred whose callback will be invoked with a JSON string +            containing the ciphertext as the value of "raw" key. +        :rtype: twisted.internet.defer.Deferred +        """ + +        def put_raw(blob): +            raw = blob.getvalue() +            return '{"raw": "' + raw + '"}' + +        content = BytesIO(str(doc.get_json())) +        info = DocInfo(doc.doc_id, doc.rev) +        del doc +        encryptor = BlobEncryptor(info, content, secret=self.secret) +        d = encryptor.encrypt() +        d.addCallback(put_raw) +        return d + +    def decrypt_doc(self, doc): +        """ +        Creates and configures a BlobDecryptor, asking it decrypt and returning +        the decrypted cleartext content from the encrypted document. + +        :param doc: the document to be decrypted. +        :type doc: Document +        :return: The decrypted cleartext content of the document. +        :rtype: str +        """ +        info = DocInfo(doc.doc_id, doc.rev) +        ciphertext = BytesIO() +        payload = doc.content['raw'] +        del doc +        ciphertext.write(str(payload)) +        decryptor = BlobDecryptor(info, ciphertext, secret=self.secret) +        return decryptor.decrypt() + + +def encrypt_sym(data, key, method=ENC_METHOD.aes_256_gcm): +    """ +    Encrypt data using AES-256 cipher in selected mode. + +    :param data: The data to be encrypted. +    :type data: str +    :param key: The key used to encrypt data (must be 256 bits long). +    :type key: str + +    :return: A tuple with the initialization vector and the ciphertext, both +        encoded as base64. +    :rtype: (str, str) +    """ +    mode = _mode_by_method(method) +    encryptor = AESWriter(key, mode=mode) +    encryptor.write(data) +    _, ciphertext = encryptor.end() +    iv = base64.b64encode(encryptor.iv) +    tag = encryptor.tag or '' +    return iv, ciphertext + tag + + +def decrypt_sym(data, key, iv, method=ENC_METHOD.aes_256_gcm): +    """ +    Decrypt data using AES-256 cipher in selected mode. + +    :param data: The data to be decrypted. +    :type data: str +    :param key: The symmetric key used to decrypt data (must be 256 bits +                long). +    :type key: str +    :param iv: The base64 encoded initialization vector. +    :type iv: str + +    :return: The decrypted data. +    :rtype: str +    """ +    _iv = base64.b64decode(str(iv)) +    mode = _mode_by_method(method) +    tag = None +    if mode == modes.GCM: +        data, tag = data[:-16], data[-16:] +    decryptor = AESWriter(key, _iv, tag=tag, mode=mode) +    decryptor.write(data) +    _, plaintext = decryptor.end() +    return plaintext + + +# TODO maybe rename this to Encryptor, since it will be used by blobs an non +# blobs in soledad. +class BlobEncryptor(object): +    """ +    Produces encrypted data from the cleartext data associated with a given +    Document using AES-256 cipher in GCM mode. + +    The production happens using a Twisted's FileBodyProducer, which uses a +    Cooperator to schedule calls and can be paused/resumed. Each call takes at +    most 65536 bytes from the input. + +    Both the production input and output are file descriptors, so they can be +    applied to a stream of data. +    """ +    # TODO +    # This class needs further work to allow for proper streaming. +    # Right now we HAVE TO WAIT until the end of the stream before encoding the +    # result. It should be possible to do that just encoding the chunks and +    # passing them to a sink, but for that we have to encode the chunks at +    # proper alignment (3 bytes?) with b64 if armor is defined. + +    def __init__(self, doc_info, content_fd, secret=None, armor=True, +                 sink=None): +        if not secret: +            raise EncryptionDecryptionError('no secret given') + +        self.doc_id = doc_info.doc_id +        self.rev = doc_info.rev +        self.armor = armor + +        self._content_fd = content_fd +        self._content_size = self._get_rounded_size(content_fd) +        self._producer = FileBodyProducer(content_fd, readSize=2**16) + +        self.sym_key = _get_sym_key_for_doc(doc_info.doc_id, secret) +        self._aes = AESWriter(self.sym_key, _buffer=sink) +        self._aes.authenticate(self._encode_preamble()) + +    def _get_rounded_size(self, fd): +        """ +        Returns a rounded value in order to minimize information leaks due to +        the original size being exposed. +        """ +        fd.seek(0, os.SEEK_END) +        size = _ceiling(fd.tell()) +        fd.seek(0) +        return size + +    @property +    def iv(self): +        return self._aes.iv + +    @property +    def tag(self): +        return self._aes.tag + +    def encrypt(self): +        """ +        Starts producing encrypted data from the cleartext data. + +        :return: A deferred which will be fired when encryption ends and whose +                 callback will be invoked with the resulting ciphertext. +        :rtype: twisted.internet.defer.Deferred +        """ +        # XXX pass a sink to aes? +        d = self._producer.startProducing(self._aes) +        d.addCallback(lambda _: self._end_crypto_stream_and_encode_result()) +        return d + +    def _encode_preamble(self): +        current_time = int(time.time()) + +        preamble = PACMAN.pack( +            BLOB_SIGNATURE_MAGIC, +            ENC_SCHEME.symkey, +            ENC_METHOD.aes_256_gcm, +            current_time, +            self.iv, +            str(self.doc_id), +            str(self.rev), +            self._content_size) +        return preamble + +    def _end_crypto_stream_and_encode_result(self): + +        # TODO ---- this needs to be refactored to allow PROPER streaming +        # We should write the preamble as soon as possible, +        # Is it possible to write the AES stream as soon as it is encrypted by +        # chunks? +        # FIXME also, it needs to be able to encode chunks with base64 if armor + +        preamble, encrypted = self._aes.end() +        result = BytesIO() +        result.write( +            base64.urlsafe_b64encode(preamble)) +        result.write(SEPARATOR) + +        if self.armor: +            result.write( +                base64.urlsafe_b64encode(encrypted + self.tag)) +        else: +            result.write(encrypted + self.tag) + +        result.seek(0) +        return defer.succeed(result) + + +# TODO maybe rename this to just Decryptor, since it will be used by blobs +# and non blobs in soledad. +class BlobDecryptor(object): +    """ +    Decrypts an encrypted blob associated with a given Document. + +    Will raise an exception if the blob doesn't have the expected structure, or +    if the GCM tag doesn't verify. +    """ +    def __init__(self, doc_info, ciphertext_fd, result=None, +                 secret=None, armor=True, start_stream=True, tag=None): +        if not secret: +            raise EncryptionDecryptionError('no secret given') + +        self.doc_id = doc_info.doc_id +        self.rev = doc_info.rev +        self.fd = ciphertext_fd +        self.armor = armor +        self._producer = None +        self.result = result or BytesIO() +        sym_key = _get_sym_key_for_doc(doc_info.doc_id, secret) +        self.size = None +        self.tag = None + +        preamble, iv = self._consume_preamble() +        soledad_assert(preamble) +        soledad_assert(iv) + +        self._aes = AESWriter(sym_key, iv, self.result, tag=tag or self.tag) +        self._aes.authenticate(preamble) +        if start_stream: +            self._start_stream() + +    @property +    def decrypted_content_size(self): +        return self._aes.written + +    def _start_stream(self): +        self._producer = FileBodyProducer(self.fd, readSize=2**16) + +    def _consume_preamble(self): +        """ +        Consume the preamble and write remaining bytes as ciphertext. This +        function is called during a stream and can be holding both, so we need +        to consume only preamble and store the remaining. +        """ +        self.fd.seek(0) +        try: +            parts = self.fd.getvalue().split(SEPARATOR, 1) +            preamble = base64.urlsafe_b64decode(parts[0]) +            if len(parts) == 2: +                ciphertext = parts[1] +                if self.armor: +                    ciphertext = base64.urlsafe_b64decode(ciphertext) +                self.tag, ciphertext = ciphertext[-16:], ciphertext[:-16] +                self.fd.seek(0) +                self.fd.write(ciphertext) +                self.fd.seek(len(ciphertext)) +                self.fd.truncate() +                self.fd.seek(0) + +        except (TypeError, ValueError): +            raise InvalidBlob + +        try: +            if len(preamble) == LEGACY_PACMAN.size: +                warnings.warn("Decrypting a legacy document without size. " + +                              "This will be deprecated in 0.12. Doc was: " + +                              "doc_id: %s rev: %s" % (self.doc_id, self.rev), +                              Warning) +                unpacked_data = LEGACY_PACMAN.unpack(preamble) +                magic, sch, meth, ts, iv, doc_id, rev = unpacked_data +            elif len(preamble) == PACMAN.size: +                unpacked_data = PACMAN.unpack(preamble) +                magic, sch, meth, ts, iv, doc_id, rev, doc_size = unpacked_data +                self.size = doc_size +            else: +                raise InvalidBlob("Unexpected preamble size %d", len(preamble)) +        except struct.error as e: +            raise InvalidBlob(e) + +        if magic != BLOB_SIGNATURE_MAGIC: +            raise InvalidBlob +        # TODO check timestamp. Just as a sanity check, but for instance +        # we can refuse to process something that is in the future or +        # too far in the past (1984 would be nice, hehe) +        if sch != ENC_SCHEME.symkey: +            raise InvalidBlob('Invalid scheme: %s' % sch) +        if meth != ENC_METHOD.aes_256_gcm: +            raise InvalidBlob('Invalid encryption scheme: %s' % meth) +        if rev != self.rev: +            msg = 'Invalid revision. Expected: %s, was: %s' % (self.rev, rev) +            raise InvalidBlob(msg) +        if doc_id != self.doc_id: +            msg = 'Invalid doc_id. ' +            + 'Expected: %s, was: %s' % (self.doc_id, doc_id) +            raise InvalidBlob(msg) + +        return preamble, iv + +    def _end_stream(self): +        try: +            self._aes.end() +        except InvalidTag: +            raise InvalidBlob('Invalid Tag. Blob authentication failed.') +        fd = self.result +        fd.seek(0) +        return self.result + +    def decrypt(self): +        """ +        Starts producing encrypted data from the cleartext data. + +        :return: A deferred which will be fired when encryption ends and whose +            callback will be invoked with the resulting ciphertext. +        :rtype: twisted.internet.defer.Deferred +        """ +        d = self.startProducing() +        d.addCallback(lambda _: self._end_stream()) +        return d + +    def startProducing(self): +        if not self._producer: +            self._start_stream() +        return self._producer.startProducing(self._aes) + +    def endStream(self): +        self._end_stream() + +    def write(self, data): +        self._aes.write(data) + +    def close(self): +        result = self._aes.end() +        return result + + +@implementer(interfaces.IConsumer) +class AESWriter(object): +    """ +    A Twisted's Consumer implementation that takes an input file descriptor and +    applies AES-256 cipher in GCM mode. + +    It is used both for encryption and decryption of a stream, depending of the +    value of the tag parameter. If you pass a tag, it will operate in +    decryption mode, verifying the authenticity of the preamble and ciphertext. +    If no tag is passed, encryption mode is assumed, which will generate a tag. +    """ + +    def __init__(self, key, iv=None, _buffer=None, tag=None, mode=modes.GCM): +        if len(key) != 32: +            raise EncryptionDecryptionError('key is not 256 bits') + +        if tag is not None: +            # if tag, we're decrypting +            assert iv is not None + +        self.iv = iv or os.urandom(16) +        self.buffer = _buffer or BytesIO() +        cipher = _get_aes_cipher(key, self.iv, tag, mode) +        cipher = cipher.decryptor() if tag else cipher.encryptor() +        self.cipher, self.aead = cipher, '' +        self.written = 0 + +    def authenticate(self, data): +        self.aead += data +        self.cipher.authenticate_additional_data(data) + +    @property +    def tag(self): +        return getattr(self.cipher, 'tag', None) + +    def write(self, data): +        self.written += len(data) +        self.buffer.write(self.cipher.update(data)) + +    def end(self): +        self.buffer.write(self.cipher.finalize()) +        return self.aead, self.buffer.getvalue() + + +def is_symmetrically_encrypted(content): +    """ +    Returns True if the document was symmetrically encrypted. +    'EzcB' is the base64 encoding of \x13\x37 magic number and 1 (symmetrically +    encrypted value for enc_scheme flag). + +    :param doc: The document content as string +    :type doc: str + +    :rtype: bool +    """ +    sym_signature = '{"raw": "EzcB' +    return content and content.startswith(sym_signature) + + +# utils + + +def _hmac_sha256(key, data): +    return hmac.new(key, data, hashlib.sha256).digest() + + +def _get_sym_key_for_doc(doc_id, secret): +    key = secret[SECRET_LENGTH:] +    return _hmac_sha256(key, doc_id) + + +def _get_aes_cipher(key, iv, tag, mode=modes.GCM): +    mode = mode(iv, tag) if mode == modes.GCM else mode(iv) +    return Cipher(algorithms.AES(key), mode, backend=CRYPTO_BACKEND) + + +def _mode_by_method(method): +    if method == ENC_METHOD.aes_256_gcm: +        return modes.GCM +    else: +        return modes.CTR + + +def _ceiling(size): +    """ +    Some simplistic ceiling scheme that uses powers of 2. +    We report everything below 4096 bytes as that minimum threshold. +    See #8759 for research pending for less simplistic/aggresive strategies. +    """ +    for i in xrange(12, 31): +        step = 2 ** i +        if size < step: +            return step diff --git a/src/leap/soledad/client/_db/__init__.py b/src/leap/soledad/client/_db/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/soledad/client/_db/__init__.py diff --git a/src/leap/soledad/client/_db/adbapi.py b/src/leap/soledad/client/_db/adbapi.py new file mode 100644 index 00000000..5c28d108 --- /dev/null +++ b/src/leap/soledad/client/_db/adbapi.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# adbapi.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +An asyncrhonous interface to soledad using sqlcipher backend. +It uses twisted.enterprise.adbapi. +""" +import re +import sys + +from functools import partial + +from twisted.enterprise import adbapi +from twisted.internet.defer import DeferredSemaphore +from twisted.python import compat +from zope.proxy import ProxyBase, setProxiedObject + +from leap.soledad.common.log import getLogger +from leap.soledad.common.errors import DatabaseAccessError + +from . import sqlcipher +from . import pragmas + +if sys.version_info[0] < 3: +    from pysqlcipher import dbapi2 +else: +    from pysqlcipher3 import dbapi2 + + +logger = getLogger(__name__) + + +""" +How long the SQLCipher connection should wait for the lock to go away until +raising an exception. +""" +SQLCIPHER_CONNECTION_TIMEOUT = 10 + +""" +How many times a SQLCipher query should be retried in case of timeout. +""" +SQLCIPHER_MAX_RETRIES = 20 + + +def getConnectionPool(opts, openfun=None, driver="pysqlcipher"): +    """ +    Return a connection pool. + +    :param opts: +        Options for the SQLCipher connection. +    :type opts: SQLCipherOptions +    :param openfun: +        Callback invoked after every connect() on the underlying DB-API +        object. +    :type openfun: callable +    :param driver: +        The connection driver. +    :type driver: str + +    :return: A U1DB connection pool. +    :rtype: U1DBConnectionPool +    """ +    if openfun is None and driver == "pysqlcipher": +        openfun = partial(pragmas.set_init_pragmas, opts=opts) +    return U1DBConnectionPool( +        opts, +        # the following params are relayed "as is" to twisted's +        # ConnectionPool. +        "%s.dbapi2" % driver, opts.path, timeout=SQLCIPHER_CONNECTION_TIMEOUT, +        check_same_thread=False, cp_openfun=openfun) + + +class U1DBConnection(adbapi.Connection): +    """ +    A wrapper for a U1DB connection instance. +    """ + +    u1db_wrapper = sqlcipher.SoledadSQLCipherWrapper +    """ +    The U1DB wrapper to use. +    """ + +    def __init__(self, pool, init_u1db=False): +        """ +        :param pool: The pool of connections to that owns this connection. +        :type pool: adbapi.ConnectionPool +        :param init_u1db: Wether the u1db database should be initialized. +        :type init_u1db: bool +        """ +        self.init_u1db = init_u1db +        try: +            adbapi.Connection.__init__(self, pool) +        except dbapi2.DatabaseError as e: +            raise DatabaseAccessError( +                'Error initializing connection to sqlcipher database: %s' +                % str(e)) + +    def reconnect(self): +        """ +        Reconnect to the U1DB database. +        """ +        if self._connection is not None: +            self._pool.disconnect(self._connection) +        self._connection = self._pool.connect() + +        if self.init_u1db: +            self._u1db = self.u1db_wrapper( +                self._connection, +                self._pool.opts) + +    def __getattr__(self, name): +        """ +        Route the requested attribute either to the U1DB wrapper or to the +        connection. + +        :param name: The name of the attribute. +        :type name: str +        """ +        if name.startswith('u1db_'): +            attr = re.sub('^u1db_', '', name) +            return getattr(self._u1db, attr) +        else: +            return getattr(self._connection, name) + + +class U1DBTransaction(adbapi.Transaction): +    """ +    A wrapper for a U1DB 'cursor' object. +    """ + +    def __getattr__(self, name): +        """ +        Route the requested attribute either to the U1DB wrapper of the +        connection or to the actual connection cursor. + +        :param name: The name of the attribute. +        :type name: str +        """ +        if name.startswith('u1db_'): +            attr = re.sub('^u1db_', '', name) +            return getattr(self._connection._u1db, attr) +        else: +            return getattr(self._cursor, name) + + +class U1DBConnectionPool(adbapi.ConnectionPool): +    """ +    Represent a pool of connections to an U1DB database. +    """ + +    connectionFactory = U1DBConnection +    transactionFactory = U1DBTransaction + +    def __init__(self, opts, *args, **kwargs): +        """ +        Initialize the connection pool. +        """ +        self.opts = opts +        try: +            adbapi.ConnectionPool.__init__(self, *args, **kwargs) +        except dbapi2.DatabaseError as e: +            raise DatabaseAccessError( +                'Error initializing u1db connection pool: %s' % str(e)) + +        # all u1db connections, hashed by thread-id +        self._u1dbconnections = {} + +        # The replica uid, primed by the connections on init. +        self.replica_uid = ProxyBase(None) + +        try: +            conn = self.connectionFactory( +                self, init_u1db=True) +            replica_uid = conn._u1db._real_replica_uid +            setProxiedObject(self.replica_uid, replica_uid) +        except DatabaseAccessError as e: +            self.threadpool.stop() +            raise DatabaseAccessError( +                "Error initializing connection factory: %s" % str(e)) + +    def runU1DBQuery(self, meth, *args, **kw): +        """ +        Execute a U1DB query in a thread, using a pooled connection. + +        Concurrent threads trying to update the same database may timeout +        because of other threads holding the database lock. Because of this, +        we will retry SQLCIPHER_MAX_RETRIES times and fail after that. + +        :param meth: The U1DB wrapper method name. +        :type meth: str + +        :return: a Deferred which will fire the return value of +            'self._runU1DBQuery(Transaction(...), *args, **kw)', or a Failure. +        :rtype: twisted.internet.defer.Deferred +        """ +        meth = "u1db_%s" % meth +        semaphore = DeferredSemaphore(SQLCIPHER_MAX_RETRIES) + +        def _run_interaction(): +            return self.runInteraction( +                self._runU1DBQuery, meth, *args, **kw) + +        def _errback(failure): +            failure.trap(dbapi2.OperationalError) +            if failure.getErrorMessage() == "database is locked": +                logger.warn("database operation timed out") +                should_retry = semaphore.acquire() +                if should_retry: +                    logger.warn("trying again...") +                    return _run_interaction() +                logger.warn("giving up!") +            return failure + +        d = _run_interaction() +        d.addErrback(_errback) +        return d + +    def _runU1DBQuery(self, trans, meth, *args, **kw): +        """ +        Execute a U1DB query. + +        :param trans: An U1DB transaction. +        :type trans: adbapi.Transaction +        :param meth: the U1DB wrapper method name. +        :type meth: str +        """ +        meth = getattr(trans, meth) +        return meth(*args, **kw) +        # XXX should return a fetchall? + +    # XXX add _runOperation too + +    def _runInteraction(self, interaction, *args, **kw): +        """ +        Interact with the database and return the result. + +        :param interaction: +            A callable object whose first argument is an +            L{adbapi.Transaction}. +        :type interaction: callable +        :return: a Deferred which will fire the return value of +            'interaction(Transaction(...), *args, **kw)', or a Failure. +        :rtype: twisted.internet.defer.Deferred +        """ +        tid = self.threadID() +        u1db = self._u1dbconnections.get(tid) +        conn = self.connectionFactory( +            self, init_u1db=not bool(u1db)) + +        if self.replica_uid is None: +            replica_uid = conn._u1db._real_replica_uid +            setProxiedObject(self.replica_uid, replica_uid) + +        if u1db is None: +            self._u1dbconnections[tid] = conn._u1db +        else: +            conn._u1db = u1db + +        trans = self.transactionFactory(self, conn) +        try: +            result = interaction(trans, *args, **kw) +            trans.close() +            conn.commit() +            return result +        except: +            excType, excValue, excTraceback = sys.exc_info() +            try: +                conn.rollback() +            except: +                logger.error(None, "Rollback failed") +            compat.reraise(excValue, excTraceback) + +    def finalClose(self): +        """ +        A final close, only called by the shutdown trigger. +        """ +        self.shutdownID = None +        if self.threadpool.started: +            self.threadpool.stop() +        self.running = False +        for conn in self.connections.values(): +            self._close(conn) +        for u1db in self._u1dbconnections.values(): +            self._close(u1db) +        self.connections.clear() diff --git a/src/leap/soledad/client/_db/blobs.py b/src/leap/soledad/client/_db/blobs.py new file mode 100644 index 00000000..10b90c71 --- /dev/null +++ b/src/leap/soledad/client/_db/blobs.py @@ -0,0 +1,554 @@ +# -*- coding: utf-8 -*- +# _blobs.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Clientside BlobBackend Storage. +""" + +from urlparse import urljoin + +import binascii +import os +import base64 + +from io import BytesIO +from functools import partial + +from twisted.logger import Logger +from twisted.enterprise import adbapi +from twisted.internet import defer +from twisted.web.client import FileBodyProducer + +import treq + +from leap.soledad.common.errors import SoledadError +from leap.common.files import mkdir_p + +from .._document import BlobDoc +from .._crypto import DocInfo +from .._crypto import BlobEncryptor +from .._crypto import BlobDecryptor +from .._http import HTTPClient +from .._pipes import TruncatedTailPipe +from .._pipes import PreamblePipe + +from . import pragmas +from . import sqlcipher + + +logger = Logger() +FIXED_REV = 'ImmutableRevision'  # Blob content is immutable + + +class BlobAlreadyExistsError(SoledadError): +    pass + + +class ConnectionPool(adbapi.ConnectionPool): + +    def insertAndGetLastRowid(self, *args, **kwargs): +        """ +        Execute an SQL query and return the last rowid. + +        See: https://sqlite.org/c3ref/last_insert_rowid.html +        """ +        return self.runInteraction( +            self._insertAndGetLastRowid, *args, **kwargs) + +    def _insertAndGetLastRowid(self, trans, *args, **kw): +        trans.execute(*args, **kw) +        return trans.lastrowid + +    def blob(self, table, column, irow, flags): +        """ +        Open a BLOB for incremental I/O. + +        Return a handle to the BLOB that would be selected by: + +          SELECT column FROM table WHERE rowid = irow; + +        See: https://sqlite.org/c3ref/blob_open.html + +        :param table: The table in which to lookup the blob. +        :type table: str +        :param column: The column where the BLOB is located. +        :type column: str +        :param rowid: The rowid of the BLOB. +        :type rowid: int +        :param flags: If zero, BLOB is opened for read-only. If non-zero, +                      BLOB is opened for RW. +        :type flags: int + +        :return: A BLOB handle. +        :rtype: pysqlcipher.dbapi.Blob +        """ +        return self.runInteraction(self._blob, table, column, irow, flags) + +    def _blob(self, trans, table, column, irow, flags): +        # TODO: should not use transaction private variable here +        handle = trans._connection.blob(table, column, irow, flags) +        return handle + + +def check_http_status(code): +    if code == 409: +        raise BlobAlreadyExistsError() +    elif code != 200: +        raise SoledadError("Server Error") + + +class DecrypterBuffer(object): + +    def __init__(self, blob_id, secret, tag): +        self.doc_info = DocInfo(blob_id, FIXED_REV) +        self.secret = secret +        self.tag = tag +        self.preamble_pipe = PreamblePipe(self._make_decryptor) + +    def _make_decryptor(self, preamble): +        self.decrypter = BlobDecryptor( +            self.doc_info, preamble, +            secret=self.secret, +            armor=False, +            start_stream=False, +            tag=self.tag) +        return TruncatedTailPipe(self.decrypter, tail_size=len(self.tag)) + +    def write(self, data): +        self.preamble_pipe.write(data) + +    def close(self): +        real_size = self.decrypter.decrypted_content_size +        return self.decrypter._end_stream(), real_size + + +class BlobManager(object): +    """ +    Ideally, the decrypting flow goes like this: + +    - GET a blob from remote server. +    - Decrypt the preamble +    - Allocate a zeroblob in the sqlcipher sink +    - Mark the blob as unusable (ie, not verified) +    - Decrypt the payload incrementally, and write chunks to sqlcipher +      ** Is it possible to use a small buffer for the aes writer w/o +      ** allocating all the memory in openssl? +    - Finalize the AES decryption +    - If preamble + payload verifies correctly, mark the blob as usable + +    """ + +    def __init__( +            self, local_path, remote, key, secret, user, token=None, +            cert_file=None): +        if local_path: +            mkdir_p(os.path.dirname(local_path)) +            self.local = SQLiteBlobBackend(local_path, key) +        self.remote = remote +        self.secret = secret +        self.user = user +        self._client = HTTPClient(user, token, cert_file) + +    def close(self): +        if hasattr(self, 'local') and self.local: +            return self.local.close() + +    @defer.inlineCallbacks +    def remote_list(self): +        uri = urljoin(self.remote, self.user + '/') +        data = yield self._client.get(uri) +        defer.returnValue((yield data.json())) + +    def local_list(self): +        return self.local.list() + +    @defer.inlineCallbacks +    def send_missing(self): +        our_blobs = yield self.local_list() +        server_blobs = yield self.remote_list() +        missing = [b_id for b_id in our_blobs if b_id not in server_blobs] +        logger.info("Amount of documents missing on server: %s" % len(missing)) +        # TODO: Send concurrently when we are able to stream directly from db +        for blob_id in missing: +            fd = yield self.local.get(blob_id) +            logger.info("Upload local blob: %s" % blob_id) +            yield self._encrypt_and_upload(blob_id, fd) + +    @defer.inlineCallbacks +    def fetch_missing(self): +        # TODO: Use something to prioritize user requests over general new docs +        our_blobs = yield self.local_list() +        server_blobs = yield self.remote_list() +        docs_we_want = [b_id for b_id in server_blobs if b_id not in our_blobs] +        logger.info("Fetching new docs from server: %s" % len(docs_we_want)) +        # TODO: Fetch concurrently when we are able to stream directly into db +        for blob_id in docs_we_want: +            logger.info("Fetching new doc: %s" % blob_id) +            yield self.get(blob_id) + +    @defer.inlineCallbacks +    def put(self, doc, size): +        if (yield self.local.exists(doc.blob_id)): +            error_message = "Blob already exists: %s" % doc.blob_id +            raise BlobAlreadyExistsError(error_message) +        fd = doc.blob_fd +        # TODO this is a tee really, but ok... could do db and upload +        # concurrently. not sure if we'd gain something. +        yield self.local.put(doc.blob_id, fd, size=size) +        # In fact, some kind of pipe is needed here, where each write on db +        # handle gets forwarded into a write on the connection handle +        fd = yield self.local.get(doc.blob_id) +        yield self._encrypt_and_upload(doc.blob_id, fd) + +    @defer.inlineCallbacks +    def get(self, blob_id): +        local_blob = yield self.local.get(blob_id) +        if local_blob: +            logger.info("Found blob in local database: %s" % blob_id) +            defer.returnValue(local_blob) + +        result = yield self._download_and_decrypt(blob_id) + +        if not result: +            defer.returnValue(None) +        blob, size = result + +        if blob: +            logger.info("Got decrypted blob of type: %s" % type(blob)) +            blob.seek(0) +            yield self.local.put(blob_id, blob, size=size) +            defer.returnValue((yield self.local.get(blob_id))) +        else: +            # XXX we shouldn't get here, but we will... +            # lots of ugly error handling possible: +            # 1. retry, might be network error +            # 2. try later, maybe didn't finished streaming +            # 3.. resignation, might be error while verifying +            logger.error('sorry, dunno what happened') + +    @defer.inlineCallbacks +    def _encrypt_and_upload(self, blob_id, fd): +        # TODO ------------------------------------------ +        # this is wrong, is doing 2 stages. +        # the crypto producer can be passed to +        # the uploader and react as data is written. +        # try to rewrite as a tube: pass the fd to aes and let aes writer +        # produce data to the treq request fd. +        # ------------------------------------------------ +        logger.info("Staring upload of blob: %s" % blob_id) +        doc_info = DocInfo(blob_id, FIXED_REV) +        uri = urljoin(self.remote, self.user + "/" + blob_id) +        crypter = BlobEncryptor(doc_info, fd, secret=self.secret, +                                armor=False) +        fd = yield crypter.encrypt() +        response = yield self._client.put(uri, data=fd) +        check_http_status(response.code) +        logger.info("Finished upload: %s" % (blob_id,)) + +    @defer.inlineCallbacks +    def _download_and_decrypt(self, blob_id): +        logger.info("Staring download of blob: %s" % blob_id) +        # TODO this needs to be connected in a tube +        uri = urljoin(self.remote, self.user + '/' + blob_id) +        data = yield self._client.get(uri) + +        if data.code == 404: +            logger.warn("Blob not found in server: %s" % blob_id) +            defer.returnValue(None) +        elif not data.headers.hasHeader('Tag'): +            logger.error("Server didn't send a tag header for: %s" % blob_id) +            defer.returnValue(None) +        tag = data.headers.getRawHeaders('Tag')[0] +        tag = base64.urlsafe_b64decode(tag) +        buf = DecrypterBuffer(blob_id, self.secret, tag) + +        # incrementally collect the body of the response +        yield treq.collect(data, buf.write) +        fd, size = buf.close() +        logger.info("Finished download: (%s, %d)" % (blob_id, size)) +        defer.returnValue((fd, size)) + +    @defer.inlineCallbacks +    def delete(self, blob_id): +        logger.info("Staring deletion of blob: %s" % blob_id) +        yield self._delete_from_remote(blob_id) +        if (yield self.local.exists(blob_id)): +            yield self.local.delete(blob_id) + +    def _delete_from_remote(self, blob_id): +        # TODO this needs to be connected in a tube +        uri = urljoin(self.remote, self.user + '/' + blob_id) +        return self._client.delete(uri) + + +class SQLiteBlobBackend(object): + +    def __init__(self, path, key=None): +        self.path = os.path.abspath( +            os.path.join(path, 'soledad_blob.db')) +        mkdir_p(os.path.dirname(self.path)) +        if not key: +            raise ValueError('key cannot be None') +        backend = 'pysqlcipher.dbapi2' +        opts = sqlcipher.SQLCipherOptions( +            '/tmp/ignored', binascii.b2a_hex(key), +            is_raw_key=True, create=True) +        pragmafun = partial(pragmas.set_init_pragmas, opts=opts) +        openfun = _sqlcipherInitFactory(pragmafun) + +        self.dbpool = ConnectionPool( +            backend, self.path, check_same_thread=False, timeout=5, +            cp_openfun=openfun, cp_min=1, cp_max=2, cp_name='blob_pool') + +    def close(self): +        from twisted._threads import AlreadyQuit +        try: +            self.dbpool.close() +        except AlreadyQuit: +            pass + +    @defer.inlineCallbacks +    def put(self, blob_id, blob_fd, size=None): +        logger.info("Saving blob in local database...") +        insert = 'INSERT INTO blobs (blob_id, payload) VALUES (?, zeroblob(?))' +        irow = yield self.dbpool.insertAndGetLastRowid(insert, (blob_id, size)) +        handle = yield self.dbpool.blob('blobs', 'payload', irow, 1) +        blob_fd.seek(0) +        # XXX I have to copy the buffer here so that I'm able to +        # return a non-closed file to the caller (blobmanager.get) +        # FIXME should remove this duplication! +        # have a look at how treq does cope with closing the handle +        # for uploading a file +        producer = FileBodyProducer(blob_fd) +        done = yield producer.startProducing(handle) +        logger.info("Finished saving blob in local database.") +        defer.returnValue(done) + +    @defer.inlineCallbacks +    def get(self, blob_id): +        # TODO we can also stream the blob value using sqlite +        # incremental interface for blobs - and just return the raw fd instead +        select = 'SELECT payload FROM blobs WHERE blob_id = ?' +        result = yield self.dbpool.runQuery(select, (blob_id,)) +        if result: +            defer.returnValue(BytesIO(str(result[0][0]))) + +    @defer.inlineCallbacks +    def list(self): +        query = 'select blob_id from blobs' +        result = yield self.dbpool.runQuery(query) +        if result: +            defer.returnValue([b_id[0] for b_id in result]) +        else: +            defer.returnValue([]) + +    @defer.inlineCallbacks +    def exists(self, blob_id): +        query = 'SELECT blob_id from blobs WHERE blob_id = ?' +        result = yield self.dbpool.runQuery(query, (blob_id,)) +        defer.returnValue(bool(len(result))) + +    def delete(self, blob_id): +        query = 'DELETE FROM blobs WHERE blob_id = ?' +        return self.dbpool.runQuery(query, (blob_id,)) + + +def _init_blob_table(conn): +    maybe_create = ( +        "CREATE TABLE IF NOT EXISTS " +        "blobs (" +        "blob_id PRIMARY KEY, " +        "payload BLOB)") +    conn.execute(maybe_create) + + +def _sqlcipherInitFactory(fun): +    def _initialize(conn): +        fun(conn) +        _init_blob_table(conn) +    return _initialize + + +# +# testing facilities +# + +@defer.inlineCallbacks +def testit(reactor): +    # configure logging to stdout +    from twisted.python import log +    import sys +    log.startLogging(sys.stdout) + +    # parse command line arguments +    import argparse + +    parser = argparse.ArgumentParser() +    parser.add_argument('--url', default='http://localhost:9000/') +    parser.add_argument('--path', default='/tmp/blobs') +    parser.add_argument('--secret', default='secret') +    parser.add_argument('--uuid', default='user') +    parser.add_argument('--token', default=None) +    parser.add_argument('--cert-file', default='') + +    subparsers = parser.add_subparsers(help='sub-command help', dest='action') + +    # parse upload command +    parser_upload = subparsers.add_parser( +        'upload', help='upload blob and bypass local db') +    parser_upload.add_argument('payload') +    parser_upload.add_argument('blob_id') + +    # parse download command +    parser_download = subparsers.add_parser( +        'download', help='download blob and bypass local db') +    parser_download.add_argument('blob_id') +    parser_download.add_argument('--output-file', default='/tmp/incoming-file') + +    # parse put command +    parser_put = subparsers.add_parser( +        'put', help='put blob in local db and upload') +    parser_put.add_argument('payload') +    parser_put.add_argument('blob_id') + +    # parse get command +    parser_get = subparsers.add_parser( +        'get', help='get blob from local db, get if needed') +    parser_get.add_argument('blob_id') + +    # parse delete command +    parser_get = subparsers.add_parser( +        'delete', help='delete blob from local and remote db') +    parser_get.add_argument('blob_id') + +    # parse list command +    parser_get = subparsers.add_parser( +        'list', help='list local and remote blob ids') + +    # parse send_missing command +    parser_get = subparsers.add_parser( +        'send_missing', help='send all pending upload blobs') + +    # parse send_missing command +    parser_get = subparsers.add_parser( +        'fetch_missing', help='fetch all new server blobs') + +    # parse arguments +    args = parser.parse_args() + +    # TODO convert these into proper unittests + +    def _manager(): +        mkdir_p(os.path.dirname(args.path)) +        manager = BlobManager( +            args.path, args.url, +            'A' * 32, args.secret, +            args.uuid, args.token, args.cert_file) +        return manager + +    @defer.inlineCallbacks +    def _upload(blob_id, payload): +        logger.info(":: Starting upload only: %s" % str((blob_id, payload))) +        manager = _manager() +        with open(payload, 'r') as fd: +            yield manager._encrypt_and_upload(blob_id, fd) +        logger.info(":: Finished upload only: %s" % str((blob_id, payload))) + +    @defer.inlineCallbacks +    def _download(blob_id): +        logger.info(":: Starting download only: %s" % blob_id) +        manager = _manager() +        result = yield manager._download_and_decrypt(blob_id) +        logger.info(":: Result of download: %s" % str(result)) +        if result: +            fd, _ = result +            with open(args.output_file, 'w') as f: +                logger.info(":: Writing data to %s" % args.output_file) +                f.write(fd.read()) +        logger.info(":: Finished download only: %s" % blob_id) + +    @defer.inlineCallbacks +    def _put(blob_id, payload): +        logger.info(":: Starting full put: %s" % blob_id) +        manager = _manager() +        size = os.path.getsize(payload) +        with open(payload) as fd: +            doc = BlobDoc(fd, blob_id) +            result = yield manager.put(doc, size=size) +        logger.info(":: Result of put: %s" % str(result)) +        logger.info(":: Finished full put: %s" % blob_id) + +    @defer.inlineCallbacks +    def _get(blob_id): +        logger.info(":: Starting full get: %s" % blob_id) +        manager = _manager() +        fd = yield manager.get(blob_id) +        if fd: +            logger.info(":: Result of get: " + fd.getvalue()) +        logger.info(":: Finished full get: %s" % blob_id) + +    @defer.inlineCallbacks +    def _delete(blob_id): +        logger.info(":: Starting deletion of: %s" % blob_id) +        manager = _manager() +        yield manager.delete(blob_id) +        logger.info(":: Finished deletion of: %s" % blob_id) + +    @defer.inlineCallbacks +    def _list(): +        logger.info(":: Listing local blobs") +        manager = _manager() +        local_list = yield manager.local_list() +        logger.info(":: Local list: %s" % local_list) +        logger.info(":: Listing remote blobs") +        remote_list = yield manager.remote_list() +        logger.info(":: Remote list: %s" % remote_list) + +    @defer.inlineCallbacks +    def _send_missing(): +        logger.info(":: Sending local pending upload docs") +        manager = _manager() +        yield manager.send_missing() +        logger.info(":: Finished sending missing docs") + +    @defer.inlineCallbacks +    def _fetch_missing(): +        logger.info(":: Fetching remote new docs") +        manager = _manager() +        yield manager.fetch_missing() +        logger.info(":: Finished fetching new docs") + +    if args.action == 'upload': +        yield _upload(args.blob_id, args.payload) +    elif args.action == 'download': +        yield _download(args.blob_id) +    elif args.action == 'put': +        yield _put(args.blob_id, args.payload) +    elif args.action == 'get': +        yield _get(args.blob_id) +    elif args.action == 'delete': +        yield _delete(args.blob_id) +    elif args.action == 'list': +        yield _list() +    elif args.action == 'send_missing': +        yield _send_missing() +    elif args.action == 'fetch_missing': +        yield _fetch_missing() + + +if __name__ == '__main__': +    from twisted.internet.task import react +    react(testit) diff --git a/src/leap/soledad/client/_db/dbschema.sql b/src/leap/soledad/client/_db/dbschema.sql new file mode 100644 index 00000000..ae027fc5 --- /dev/null +++ b/src/leap/soledad/client/_db/dbschema.sql @@ -0,0 +1,42 @@ +-- Database schema +CREATE TABLE transaction_log ( +    generation INTEGER PRIMARY KEY AUTOINCREMENT, +    doc_id TEXT NOT NULL, +    transaction_id TEXT NOT NULL +); +CREATE TABLE document ( +    doc_id TEXT PRIMARY KEY, +    doc_rev TEXT NOT NULL, +    content TEXT +); +CREATE TABLE document_fields ( +    doc_id TEXT NOT NULL, +    field_name TEXT NOT NULL, +    value TEXT +); +CREATE INDEX document_fields_field_value_doc_idx +    ON document_fields(field_name, value, doc_id); + +CREATE TABLE sync_log ( +    replica_uid TEXT PRIMARY KEY, +    known_generation INTEGER, +    known_transaction_id TEXT +); +CREATE TABLE conflicts ( +    doc_id TEXT, +    doc_rev TEXT, +    content TEXT, +    CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) +); +CREATE TABLE index_definitions ( +    name TEXT, +    offset INT, +    field TEXT, +    CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) +); +create index index_definitions_field on index_definitions(field); +CREATE TABLE u1db_config ( +    name TEXT PRIMARY KEY, +    value TEXT +); +INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/src/leap/soledad/client/_db/pragmas.py b/src/leap/soledad/client/_db/pragmas.py new file mode 100644 index 00000000..870ed63e --- /dev/null +++ b/src/leap/soledad/client/_db/pragmas.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- +# pragmas.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Different pragmas used in the initialization of the SQLCipher database. +""" +import string +import threading +import os + +from leap.soledad.common import soledad_assert +from leap.soledad.common.log import getLogger + + +logger = getLogger(__name__) + + +_db_init_lock = threading.Lock() + + +def set_init_pragmas(conn, opts=None, extra_queries=None): +    """ +    Set the initialization pragmas. + +    This includes the crypto pragmas, and any other options that must +    be passed early to sqlcipher db. +    """ +    soledad_assert(opts is not None) +    extra_queries = [] if extra_queries is None else extra_queries +    with _db_init_lock: +        # only one execution path should initialize the db +        _set_init_pragmas(conn, opts, extra_queries) + + +def _set_init_pragmas(conn, opts, extra_queries): + +    sync_off = os.environ.get('LEAP_SQLITE_NOSYNC') +    memstore = os.environ.get('LEAP_SQLITE_MEMSTORE') +    nowal = os.environ.get('LEAP_SQLITE_NOWAL') + +    set_crypto_pragmas(conn, opts) + +    if not nowal: +        set_write_ahead_logging(conn) +    if sync_off: +        set_synchronous_off(conn) +    else: +        set_synchronous_normal(conn) +    if memstore: +        set_mem_temp_store(conn) + +    for query in extra_queries: +        conn.cursor().execute(query) + + +def set_crypto_pragmas(db_handle, sqlcipher_opts): +    """ +    Set cryptographic params (key, cipher, KDF number of iterations and +    cipher page size). + +    :param db_handle: +    :type db_handle: +    :param sqlcipher_opts: options for the SQLCipherDatabase +    :type sqlcipher_opts: SQLCipherOpts instance +    """ +    # XXX assert CryptoOptions +    opts = sqlcipher_opts +    _set_key(db_handle, opts.key, opts.is_raw_key) +    _set_cipher(db_handle, opts.cipher) +    _set_kdf_iter(db_handle, opts.kdf_iter) +    _set_cipher_page_size(db_handle, opts.cipher_page_size) + + +def _set_key(db_handle, key, is_raw_key): +    """ +    Set the ``key`` for use with the database. + +    The process of creating a new, encrypted database is called 'keying' +    the database. SQLCipher uses just-in-time key derivation at the point +    it is first needed for an operation. This means that the key (and any +    options) must be set before the first operation on the database. As +    soon as the database is touched (e.g. SELECT, CREATE TABLE, UPDATE, +    etc.) and pages need to be read or written, the key is prepared for +    use. + +    Implementation Notes: + +    * PRAGMA key should generally be called as the first operation on a +        database. + +    :param key: The key for use with the database. +    :type key: str +    :param is_raw_key: +        Whether C{key} is a raw 64-char hex string or a passphrase that should +        be hashed to obtain the encyrption key. +    :type is_raw_key: bool +    """ +    if is_raw_key: +        _set_key_raw(db_handle, key) +    else: +        _set_key_passphrase(db_handle, key) + + +def _set_key_passphrase(db_handle, passphrase): +    """ +    Set a passphrase for encryption key derivation. + +    The key itself can be a passphrase, which is converted to a key using +    PBKDF2 key derivation. The result is used as the encryption key for +    the database. By using this method, there is no way to alter the KDF; +    if you want to do so you should use a raw key instead and derive the +    key using your own KDF. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param passphrase: The passphrase used to derive the encryption key. +    :type passphrase: str +    """ +    db_handle.cursor().execute("PRAGMA key = '%s'" % passphrase) + + +def _set_key_raw(db_handle, key): +    """ +    Set a raw hexadecimal encryption key. + +    It is possible to specify an exact byte sequence using a blob literal. +    With this method, it is the calling application's responsibility to +    ensure that the data provided is a 64 character hex string, which will +    be converted directly to 32 bytes (256 bits) of key data. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param key: A 64 character hex string. +    :type key: str +    """ +    if not all(c in string.hexdigits for c in key): +        raise NotAnHexString(key) +    db_handle.cursor().execute('PRAGMA key = "x\'%s"' % key) + + +def _set_cipher(db_handle, cipher='aes-256-cbc'): +    """ +    Set the cipher and mode to use for symmetric encryption. + +    SQLCipher uses aes-256-cbc as the default cipher and mode of +    operation. It is possible to change this, though not generally +    recommended, using PRAGMA cipher. + +    SQLCipher makes direct use of libssl, so all cipher options available +    to libssl are also available for use with SQLCipher. See `man enc` for +    OpenSSL's supported ciphers. + +    Implementation Notes: + +    * PRAGMA cipher must be called after PRAGMA key and before the first +        actual database operation or it will have no effect. + +    * If a non-default value is used PRAGMA cipher to create a database, +        it must also be called every time that database is opened. + +    * SQLCipher does not implement its own encryption. Instead it uses the +        widely available and peer-reviewed OpenSSL libcrypto for all +        cryptographic functions. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param cipher: The cipher and mode to use. +    :type cipher: str +    """ +    db_handle.cursor().execute("PRAGMA cipher = '%s'" % cipher) + + +def _set_kdf_iter(db_handle, kdf_iter=4000): +    """ +    Set the number of iterations for the key derivation function. + +    SQLCipher uses PBKDF2 key derivation to strengthen the key and make it +    resistent to brute force and dictionary attacks. The default +    configuration uses 4000 PBKDF2 iterations (effectively 16,000 SHA1 +    operations). PRAGMA kdf_iter can be used to increase or decrease the +    number of iterations used. + +    Implementation Notes: + +    * PRAGMA kdf_iter must be called after PRAGMA key and before the first +        actual database operation or it will have no effect. + +    * If a non-default value is used PRAGMA kdf_iter to create a database, +        it must also be called every time that database is opened. + +    * It is not recommended to reduce the number of iterations if a +        passphrase is in use. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param kdf_iter: The number of iterations to use. +    :type kdf_iter: int +    """ +    db_handle.cursor().execute("PRAGMA kdf_iter = '%d'" % kdf_iter) + + +def _set_cipher_page_size(db_handle, cipher_page_size=1024): +    """ +    Set the page size of the encrypted database. + +    SQLCipher 2 introduced the new PRAGMA cipher_page_size that can be +    used to adjust the page size for the encrypted database. The default +    page size is 1024 bytes, but it can be desirable for some applications +    to use a larger page size for increased performance. For instance, +    some recent testing shows that increasing the page size can noticeably +    improve performance (5-30%) for certain queries that manipulate a +    large number of pages (e.g. selects without an index, large inserts in +    a transaction, big deletes). + +    To adjust the page size, call the pragma immediately after setting the +    key for the first time and each subsequent time that you open the +    database. + +    Implementation Notes: + +    * PRAGMA cipher_page_size must be called after PRAGMA key and before +        the first actual database operation or it will have no effect. + +    * If a non-default value is used PRAGMA cipher_page_size to create a +        database, it must also be called every time that database is opened. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param cipher_page_size: The page size. +    :type cipher_page_size: int +    """ +    db_handle.cursor().execute( +        "PRAGMA cipher_page_size = '%d'" % cipher_page_size) + + +# XXX UNUSED ? +def set_rekey(db_handle, new_key, is_raw_key): +    """ +    Change the key of an existing encrypted database. + +    To change the key on an existing encrypted database, it must first be +    unlocked with the current encryption key. Once the database is +    readable and writeable, PRAGMA rekey can be used to re-encrypt every +    page in the database with a new key. + +    * PRAGMA rekey must be called after PRAGMA key. It can be called at any +        time once the database is readable. + +    * PRAGMA rekey can not be used to encrypted a standard SQLite +        database! It is only useful for changing the key on an existing +        database. + +    * Previous versions of SQLCipher provided a PRAGMA rekey_cipher and +        code>PRAGMA rekey_kdf_iter. These are deprecated and should not be +        used. Instead, use sqlcipher_export(). + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param new_key: The new key. +    :type new_key: str +    :param is_raw_key: Whether C{password} is a raw 64-char hex string or a +                    passphrase that should be hashed to obtain the encyrption +                    key. +    :type is_raw_key: bool +    """ +    if is_raw_key: +        _set_rekey_raw(db_handle, new_key) +    else: +        _set_rekey_passphrase(db_handle, new_key) + + +def _set_rekey_passphrase(db_handle, passphrase): +    """ +    Change the passphrase for encryption key derivation. + +    The key itself can be a passphrase, which is converted to a key using +    PBKDF2 key derivation. The result is used as the encryption key for +    the database. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param passphrase: The passphrase used to derive the encryption key. +    :type passphrase: str +    """ +    db_handle.cursor().execute("PRAGMA rekey = '%s'" % passphrase) + + +def _set_rekey_raw(db_handle, key): +    """ +    Change the raw hexadecimal encryption key. + +    It is possible to specify an exact byte sequence using a blob literal. +    With this method, it is the calling application's responsibility to +    ensure that the data provided is a 64 character hex string, which will +    be converted directly to 32 bytes (256 bits) of key data. + +    :param db_handle: A handle to the SQLCipher database. +    :type db_handle: pysqlcipher.Connection +    :param key: A 64 character hex string. +    :type key: str +    """ +    if not all(c in string.hexdigits for c in key): +        raise NotAnHexString(key) +    db_handle.cursor().execute('PRAGMA rekey = "x\'%s"' % key) + + +def set_synchronous_off(db_handle): +    """ +    Change the setting of the "synchronous" flag to OFF. +    """ +    logger.debug("sqlcipher: setting synchronous off") +    db_handle.cursor().execute('PRAGMA synchronous=OFF') + + +def set_synchronous_normal(db_handle): +    """ +    Change the setting of the "synchronous" flag to NORMAL. +    """ +    logger.debug("sqlcipher: setting synchronous normal") +    db_handle.cursor().execute('PRAGMA synchronous=NORMAL') + + +def set_mem_temp_store(db_handle): +    """ +    Use a in-memory store for temporary tables. +    """ +    logger.debug("sqlcipher: setting temp_store memory") +    db_handle.cursor().execute('PRAGMA temp_store=MEMORY') + + +def set_write_ahead_logging(db_handle): +    """ +    Enable write-ahead logging, and set the autocheckpoint to 50 pages. + +    Setting the autocheckpoint to a small value, we make the reads not +    suffer too much performance degradation. + +    From the sqlite docs: + +    "There is a tradeoff between average read performance and average write +    performance. To maximize the read performance, one wants to keep the +    WAL as small as possible and hence run checkpoints frequently, perhaps +    as often as every COMMIT. To maximize write performance, one wants to +    amortize the cost of each checkpoint over as many writes as possible, +    meaning that one wants to run checkpoints infrequently and let the WAL +    grow as large as possible before each checkpoint. The decision of how +    often to run checkpoints may therefore vary from one application to +    another depending on the relative read and write performance +    requirements of the application. The default strategy is to run a +    checkpoint once the WAL reaches 1000 pages" +    """ +    logger.debug("sqlcipher: setting write-ahead logging") +    db_handle.cursor().execute('PRAGMA journal_mode=WAL') + +    # The optimum value can still use a little bit of tuning, but we favor +    # small sizes of the WAL file to get fast reads, since we assume that +    # the writes will be quick enough to not block too much. + +    db_handle.cursor().execute('PRAGMA wal_autocheckpoint=50') + + +class NotAnHexString(Exception): +    """ +    Raised when trying to (raw) key the database with a non-hex string. +    """ +    pass diff --git a/src/leap/soledad/client/_db/sqlcipher.py b/src/leap/soledad/client/_db/sqlcipher.py new file mode 100644 index 00000000..d22017bd --- /dev/null +++ b/src/leap/soledad/client/_db/sqlcipher.py @@ -0,0 +1,633 @@ +# -*- coding: utf-8 -*- +# sqlcipher.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +A U1DB backend that uses SQLCipher as its persistence layer. + +The SQLCipher API (http://sqlcipher.net/sqlcipher-api/) is fully implemented, +with the exception of the following statements: + +  * PRAGMA cipher_use_hmac +  * PRAGMA cipher_default_use_mac + +SQLCipher 2.0 introduced a per-page HMAC to validate that the page data has +not be tampered with. By default, when creating or opening a database using +SQLCipher 2, SQLCipher will attempt to use an HMAC check. This change in +database format means that SQLCipher 2 can't operate on version 1.1.x +databases by default. Thus, in order to provide backward compatibility with +SQLCipher 1.1.x, PRAGMA cipher_use_hmac can be used to disable the HMAC +functionality on specific databases. + +In some very specific cases, it is not possible to call PRAGMA cipher_use_hmac +as one of the first operations on a database. An example of this is when +trying to ATTACH a 1.1.x database to the main database. In these cases PRAGMA +cipher_default_use_hmac can be used to globally alter the default use of HMAC +when opening a database. + +So, as the statements above were introduced for backwards compatibility with +SQLCipher 1.1 databases, we do not implement them as all SQLCipher databases +handled by Soledad should be created by SQLCipher >= 2.0. +""" +import os +import sys + +from functools import partial + +from twisted.internet import reactor +from twisted.internet import defer +from twisted.enterprise import adbapi + +from leap.soledad.common.log import getLogger +from leap.soledad.common.l2db import errors as u1db_errors +from leap.soledad.common.errors import DatabaseAccessError + +from leap.soledad.client.http_target import SoledadHTTPSyncTarget +from leap.soledad.client.sync import SoledadSynchronizer + +from .._document import Document +from . import sqlite +from . import pragmas + +if sys.version_info[0] < 3: +    from pysqlcipher import dbapi2 as sqlcipher_dbapi2 +else: +    from pysqlcipher3 import dbapi2 as sqlcipher_dbapi2 + +logger = getLogger(__name__) + + +# Monkey-patch u1db.backends.sqlite with pysqlcipher.dbapi2 +sqlite.dbapi2 = sqlcipher_dbapi2 + + +# we may want to collect statistics from the sync process +DO_STATS = False +if os.environ.get('SOLEDAD_STATS'): +    DO_STATS = True + + +def initialize_sqlcipher_db(opts, on_init=None, check_same_thread=True): +    """ +    Initialize a SQLCipher database. + +    :param opts: +    :type opts: SQLCipherOptions +    :param on_init: a tuple of queries to be executed on initialization +    :type on_init: tuple +    :return: pysqlcipher.dbapi2.Connection +    """ +    # Note: There seemed to be a bug in sqlite 3.5.9 (with python2.6) +    #       where without re-opening the database on Windows, it +    #       doesn't see the transaction that was just committed +    # Removing from here now, look at the pysqlite implementation if the +    # bug shows up in windows. + +    if not os.path.isfile(opts.path) and not opts.create: +        raise u1db_errors.DatabaseDoesNotExist() + +    conn = sqlcipher_dbapi2.connect( +        opts.path, check_same_thread=check_same_thread) +    pragmas.set_init_pragmas(conn, opts, extra_queries=on_init) +    return conn + + +def initialize_sqlcipher_adbapi_db(opts, extra_queries=None): +    from leap.soledad.client import sqlcipher_adbapi +    return sqlcipher_adbapi.getConnectionPool( +        opts, extra_queries=extra_queries) + + +class SQLCipherOptions(object): +    """ +    A container with options for the initialization of an SQLCipher database. +    """ + +    @classmethod +    def copy(cls, source, path=None, key=None, create=None, +             is_raw_key=None, cipher=None, kdf_iter=None, +             cipher_page_size=None, sync_db_key=None): +        """ +        Return a copy of C{source} with parameters different than None +        replaced by new values. +        """ +        local_vars = locals() +        args = [] +        kwargs = {} + +        for name in ["path", "key"]: +            val = local_vars[name] +            if val is not None: +                args.append(val) +            else: +                args.append(getattr(source, name)) + +        for name in ["create", "is_raw_key", "cipher", "kdf_iter", +                     "cipher_page_size", "sync_db_key"]: +            val = local_vars[name] +            if val is not None: +                kwargs[name] = val +            else: +                kwargs[name] = getattr(source, name) + +        return SQLCipherOptions(*args, **kwargs) + +    def __init__(self, path, key, create=True, is_raw_key=False, +                 cipher='aes-256-cbc', kdf_iter=4000, cipher_page_size=1024, +                 sync_db_key=None): +        """ +        :param path: The filesystem path for the database to open. +        :type path: str +        :param create: +            True/False, should the database be created if it doesn't +            already exist? +        :param create: bool +        :param is_raw_key: +            Whether ``password`` is a raw 64-char hex string or a passphrase +            that should be hashed to obtain the encyrption key. +        :type raw_key: bool +        :param cipher: The cipher and mode to use. +        :type cipher: str +        :param kdf_iter: The number of iterations to use. +        :type kdf_iter: int +        :param cipher_page_size: The page size. +        :type cipher_page_size: int +        """ +        self.path = path +        self.key = key +        self.is_raw_key = is_raw_key +        self.create = create +        self.cipher = cipher +        self.kdf_iter = kdf_iter +        self.cipher_page_size = cipher_page_size +        self.sync_db_key = sync_db_key + +    def __str__(self): +        """ +        Return string representation of options, for easy debugging. + +        :return: String representation of options. +        :rtype: str +        """ +        attr_names = filter(lambda a: not a.startswith('_'), dir(self)) +        attr_str = [] +        for a in attr_names: +            attr_str.append(a + "=" + str(getattr(self, a))) +        name = self.__class__.__name__ +        return "%s(%s)" % (name, ', '.join(attr_str)) + + +# +# The SQLCipher database +# + +class SQLCipherDatabase(sqlite.SQLitePartialExpandDatabase): +    """ +    A U1DB implementation that uses SQLCipher as its persistence layer. +    """ + +    # The attribute _index_storage_value will be used as the lookup key for the +    # implementation of the SQLCipher storage backend. +    _index_storage_value = 'expand referenced encrypted' + +    def __init__(self, opts): +        """ +        Connect to an existing SQLCipher database, creating a new sqlcipher +        database file if needed. + +        *** IMPORTANT *** + +        Don't forget to close the database after use by calling the close() +        method otherwise some resources might not be freed and you may +        experience several kinds of leakages. + +        *** IMPORTANT *** + +        :param opts: options for initialization of the SQLCipher database. +        :type opts: SQLCipherOptions +        """ +        # ensure the db is encrypted if the file already exists +        if os.path.isfile(opts.path): +            self._db_handle = _assert_db_is_encrypted(opts) +        else: +            # connect to the sqlcipher database +            self._db_handle = initialize_sqlcipher_db(opts) + +        # TODO --------------------------------------------------- +        # Everything else in this initialization has to be factored +        # out, so it can be used from SoledadSQLCipherWrapper.__init__ +        # too. +        # --------------------------------------------------------- + +        self._ensure_schema() +        self.set_document_factory(doc_factory) +        self._prime_replica_uid() + +    def _prime_replica_uid(self): +        """ +        In the u1db implementation, _replica_uid is a property +        that returns the value in _real_replica_uid, and does +        a db query if no value found. +        Here we prime the replica uid during initialization so +        that we don't have to wait for the query afterwards. +        """ +        self._real_replica_uid = None +        self._get_replica_uid() + +    def _extra_schema_init(self, c): +        """ +        Add any extra fields, etc to the basic table definitions. + +        This method is called by u1db.backends.sqlite_backend._initialize() +        method, which is executed when the database schema is created. Here, +        we use it to include the "syncable" property for LeapDocuments. + +        :param c: The cursor for querying the database. +        :type c: dbapi2.cursor +        """ +        c.execute( +            'ALTER TABLE document ' +            'ADD COLUMN syncable BOOL NOT NULL DEFAULT TRUE') + +    # +    # SQLCipher API methods +    # + +    # Extra query methods: extensions to the base u1db sqlite implmentation. + +    def get_count_from_index(self, index_name, *key_values): +        """ +        Return the count for a given combination of index_name +        and key values. + +        Extension method made from similar methods in u1db version 13.09 + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +                           an index with 3 fields then you would have: +                           get_from_index(index_name, val1, val2, val3) +        :type key_values: tuple +        :return: count. +        :rtype: int +        """ +        c = self._db_handle.cursor() +        definition = self._get_index_definition(index_name) + +        if len(key_values) != len(definition): +            raise u1db_errors.InvalidValueForIndex() +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        novalue_where = ["d.doc_id = d%d.doc_id" +                         " AND d%d.field_name = ?" +                         % (i, i) for i in range(len(definition))] +        exact_where = [novalue_where[i] + (" AND d%d.value = ?" % (i,)) +                       for i in range(len(definition))] +        args = [] +        where = [] +        for idx, (field, value) in enumerate(zip(definition, key_values)): +            args.append(field) +            where.append(exact_where[idx]) +            args.append(value) + +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        statement = ( +            "SELECT COUNT(*) FROM document d, %s WHERE %s " % ( +                ', '.join(tables), +                ' AND '.join(where), +            )) +        try: +            c.execute(statement, tuple(args)) +        except sqlcipher_dbapi2.OperationalError as e: +            raise sqlcipher_dbapi2.OperationalError( +                str(e) + '\nstatement: %s\nargs: %s\n' % (statement, args)) +        res = c.fetchall() +        return res[0][0] + +    def close(self): +        """ +        Close db connections. +        """ +        # TODO should be handled by adbapi instead +        # TODO syncdb should be stopped first + +        if logger is not None:  # logger might be none if called from __del__ +            logger.debug("SQLCipher backend: closing") + +        # close the actual database +        if getattr(self, '_db_handle', False): +            self._db_handle.close() +            self._db_handle = None + +    # indexes + +    def _put_and_update_indexes(self, old_doc, doc): +        """ +        Update a document and all indexes related to it. + +        :param old_doc: The old version of the document. +        :type old_doc: u1db.Document +        :param doc: The new version of the document. +        :type doc: u1db.Document +        """ +        sqlite.SQLitePartialExpandDatabase._put_and_update_indexes( +            self, old_doc, doc) +        c = self._db_handle.cursor() +        c.execute('UPDATE document SET syncable=? WHERE doc_id=?', +                  (doc.syncable, doc.doc_id)) + +    def _get_doc(self, doc_id, check_for_conflicts=False): +        """ +        Get just the document content, without fancy handling. + +        :param doc_id: The unique document identifier +        :type doc_id: str +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise asking for a deleted +            document will return None. +        :type include_deleted: bool + +        :return: a Document object. +        :type: u1db.Document +        """ +        doc = sqlite.SQLitePartialExpandDatabase._get_doc( +            self, doc_id, check_for_conflicts) +        if doc: +            c = self._db_handle.cursor() +            c.execute('SELECT syncable FROM document WHERE doc_id=?', +                      (doc.doc_id,)) +            result = c.fetchone() +            doc.syncable = bool(result[0]) +        return doc + +    def __del__(self): +        """ +        Free resources when deleting or garbage collecting the database. + +        This is only here to minimze problems if someone ever forgets to call +        the close() method after using the database; you should not rely on +        garbage collecting to free up the database resources. +        """ +        self.close() + + +class SQLCipherU1DBSync(SQLCipherDatabase): +    """ +    Soledad syncer implementation. +    """ + +    """ +    The name of the local symmetrically encrypted documents to +    sync database file. +    """ +    LOCAL_SYMMETRIC_SYNC_FILE_NAME = 'sync.u1db' + +    """ +    Period or recurrence of the Looping Call that will do the encryption to the +    syncdb (in seconds). +    """ +    ENCRYPT_LOOP_PERIOD = 1 + +    def __init__(self, opts, soledad_crypto, replica_uid, cert_file): +        self._opts = opts +        self._path = opts.path +        self._crypto = soledad_crypto +        self.__replica_uid = replica_uid +        self._cert_file = cert_file + +        # storage for the documents received during a sync +        self.received_docs = [] + +        self.running = False +        self._db_handle = None + +        # initialize the main db before scheduling a start +        self._initialize_main_db() +        self._reactor = reactor +        self._reactor.callWhenRunning(self._start) + +        if DO_STATS: +            self.sync_phase = None + +    def commit(self): +        self._db_handle.commit() + +    @property +    def _replica_uid(self): +        return str(self.__replica_uid) + +    def _start(self): +        if not self.running: +            self.running = True + +    def _initialize_main_db(self): +        try: +            self._db_handle = initialize_sqlcipher_db( +                self._opts, check_same_thread=False) +            self._real_replica_uid = None +            self._ensure_schema() +            self.set_document_factory(doc_factory) +        except sqlcipher_dbapi2.DatabaseError as e: +            raise DatabaseAccessError(str(e)) + +    @defer.inlineCallbacks +    def sync(self, url, creds=None): +        """ +        Synchronize documents with remote replica exposed at url. + +        It is not safe to initiate more than one sync process and let them run +        concurrently. It is responsibility of the caller to ensure that there +        are no concurrent sync processes running. This is currently controlled +        by the main Soledad object because it may also run post-sync hooks, +        which should be run while the lock is locked. + +        :param url: The url of the target replica to sync with. +        :type url: str +        :param creds: optional dictionary giving credentials to authorize the +                      operation with the server. +        :type creds: dict + +        :return: +            A Deferred, that will fire with the local generation (type `int`) +            before the synchronisation was performed. +        :rtype: Deferred +        """ +        syncer = self._get_syncer(url, creds=creds) +        if DO_STATS: +            self.sync_phase = syncer.sync_phase +            self.syncer = syncer +            self.sync_exchange_phase = syncer.sync_exchange_phase +        local_gen_before_sync = yield syncer.sync() +        self.received_docs = syncer.received_docs +        defer.returnValue(local_gen_before_sync) + +    def _get_syncer(self, url, creds=None): +        """ +        Get a synchronizer for ``url`` using ``creds``. + +        :param url: The url of the target replica to sync with. +        :type url: str +        :param creds: optional dictionary giving credentials. +                      to authorize the operation with the server. +        :type creds: dict + +        :return: A synchronizer. +        :rtype: Synchronizer +        """ +        return SoledadSynchronizer( +            self, +            SoledadHTTPSyncTarget( +                url, +                # XXX is the replica_uid ready? +                self._replica_uid, +                creds=creds, +                crypto=self._crypto, +                cert_file=self._cert_file)) + +    # +    # Symmetric encryption of syncing docs +    # + +    def get_generation(self): +        # FIXME +        # XXX this SHOULD BE a callback +        return self._get_generation() + + +class U1DBSQLiteBackend(sqlite.SQLitePartialExpandDatabase): +    """ +    A very simple wrapper for u1db around sqlcipher backend. + +    Instead of initializing the database on the fly, it just uses an existing +    connection that is passed to it in the initializer. + +    It can be used in tests and debug runs to initialize the adbapi with plain +    sqlite connections, decoupled from the sqlcipher layer. +    """ + +    def __init__(self, conn): +        self._db_handle = conn +        self._real_replica_uid = None +        self._ensure_schema() +        self._factory = Document + + +class SoledadSQLCipherWrapper(SQLCipherDatabase): +    """ +    A wrapper for u1db that uses the Soledad-extended sqlcipher backend. + +    Instead of initializing the database on the fly, it just uses an existing +    connection that is passed to it in the initializer. + +    It can be used from adbapi to initialize a soledad database after +    getting a regular connection to a sqlcipher database. +    """ +    def __init__(self, conn, opts): +        self._db_handle = conn +        self._real_replica_uid = None +        self._ensure_schema() +        self.set_document_factory(doc_factory) +        self._prime_replica_uid() + + +def _assert_db_is_encrypted(opts): +    """ +    Assert that the sqlcipher file contains an encrypted database. + +    When opening an existing database, PRAGMA key will not immediately +    throw an error if the key provided is incorrect. To test that the +    database can be successfully opened with the provided key, it is +    necessary to perform some operation on the database (i.e. read from +    it) and confirm it is success. + +    The easiest way to do this is select off the sqlite_master table, +    which will attempt to read the first page of the database and will +    parse the schema. + +    :param opts: +    """ +    # We try to open an encrypted database with the regular u1db +    # backend should raise a DatabaseError exception. +    # If the regular backend succeeds, then we need to stop because +    # the database was not properly initialized. +    try: +        sqlite.SQLitePartialExpandDatabase(opts.path) +    except sqlcipher_dbapi2.DatabaseError: +        # assert that we can access it using SQLCipher with the given +        # key +        dummy_query = ('SELECT count(*) FROM sqlite_master',) +        return initialize_sqlcipher_db(opts, on_init=dummy_query) +    else: +        raise DatabaseIsNotEncrypted() + +# +# Exceptions +# + + +class DatabaseIsNotEncrypted(Exception): +    """ +    Exception raised when trying to open non-encrypted databases. +    """ +    pass + + +def doc_factory(doc_id=None, rev=None, json='{}', has_conflicts=False, +                syncable=True): +    """ +    Return a default Soledad Document. +    Used in the initialization for SQLCipherDatabase +    """ +    return Document(doc_id=doc_id, rev=rev, json=json, +                    has_conflicts=has_conflicts, syncable=syncable) + + +sqlite.SQLiteDatabase.register_implementation(SQLCipherDatabase) + + +# +# twisted.enterprise.adbapi SQLCipher implementation +# + +SQLCIPHER_CONNECTION_TIMEOUT = 10 + + +def getConnectionPool(opts, extra_queries=None): +    openfun = partial( +        pragmas.set_init_pragmas, +        opts=opts, +        extra_queries=extra_queries) +    return SQLCipherConnectionPool( +        database=opts.path, +        check_same_thread=False, +        cp_openfun=openfun, +        timeout=SQLCIPHER_CONNECTION_TIMEOUT) + + +class SQLCipherConnection(adbapi.Connection): +    pass + + +class SQLCipherTransaction(adbapi.Transaction): +    pass + + +class SQLCipherConnectionPool(adbapi.ConnectionPool): + +    connectionFactory = SQLCipherConnection +    transactionFactory = SQLCipherTransaction + +    def __init__(self, *args, **kwargs): +        adbapi.ConnectionPool.__init__( +            self, "pysqlcipher.dbapi2", *args, **kwargs) diff --git a/src/leap/soledad/client/_db/sqlite.py b/src/leap/soledad/client/_db/sqlite.py new file mode 100644 index 00000000..4f7b1259 --- /dev/null +++ b/src/leap/soledad/client/_db/sqlite.py @@ -0,0 +1,930 @@ +# Copyright 2011 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +""" +A L2DB implementation that uses SQLite as its persistence layer. +""" + +import errno +import os +import json +import sys +import time +import uuid +import pkg_resources + +from sqlite3 import dbapi2 + +from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget +from leap.soledad.common.l2db import ( +    Document, errors, +    query_parser, vectorclock) + + +class SQLiteDatabase(CommonBackend): +    """A U1DB implementation that uses SQLite as its persistence layer.""" + +    _sqlite_registry = {} + +    def __init__(self, sqlite_file, document_factory=None): +        """Create a new sqlite file.""" +        self._db_handle = dbapi2.connect(sqlite_file) +        self._real_replica_uid = None +        self._ensure_schema() +        self._factory = document_factory or Document + +    def set_document_factory(self, factory): +        self._factory = factory + +    def get_sync_target(self): +        return SQLiteSyncTarget(self) + +    @classmethod +    def _which_index_storage(cls, c): +        try: +            c.execute("SELECT value FROM u1db_config" +                      " WHERE name = 'index_storage'") +        except dbapi2.OperationalError as e: +            # The table does not exist yet +            return None, e +        else: +            return c.fetchone()[0], None + +    WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + +    @classmethod +    def _open_database(cls, sqlite_file, document_factory=None): +        if not os.path.isfile(sqlite_file): +            raise errors.DatabaseDoesNotExist() +        tries = 2 +        while True: +            # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) +            #       where without re-opening the database on Windows, it +            #       doesn't see the transaction that was just committed +            db_handle = dbapi2.connect(sqlite_file) +            c = db_handle.cursor() +            v, err = cls._which_index_storage(c) +            db_handle.close() +            if v is not None: +                break +            # possibly another process is initializing it, wait for it to be +            # done +            if tries == 0: +                raise err  # go for the richest error? +            tries -= 1 +            time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) +        return SQLiteDatabase._sqlite_registry[v]( +            sqlite_file, document_factory=document_factory) + +    @classmethod +    def open_database(cls, sqlite_file, create, backend_cls=None, +                      document_factory=None): +        try: +            return cls._open_database( +                sqlite_file, document_factory=document_factory) +        except errors.DatabaseDoesNotExist: +            if not create: +                raise +            if backend_cls is None: +                # default is SQLitePartialExpandDatabase +                backend_cls = SQLitePartialExpandDatabase +            return backend_cls(sqlite_file, document_factory=document_factory) + +    @staticmethod +    def delete_database(sqlite_file): +        try: +            os.unlink(sqlite_file) +        except OSError as ex: +            if ex.errno == errno.ENOENT: +                raise errors.DatabaseDoesNotExist() +            raise + +    @staticmethod +    def register_implementation(klass): +        """Register that we implement an SQLiteDatabase. + +        The attribute _index_storage_value will be used as the lookup key. +        """ +        SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass + +    def _get_sqlite_handle(self): +        """Get access to the underlying sqlite database. + +        This should only be used by the test suite, etc, for examining the +        state of the underlying database. +        """ +        return self._db_handle + +    def _close_sqlite_handle(self): +        """Release access to the underlying sqlite database.""" +        self._db_handle.close() + +    def close(self): +        self._close_sqlite_handle() + +    def _is_initialized(self, c): +        """Check if this database has been initialized.""" +        c.execute("PRAGMA case_sensitive_like=ON") +        try: +            c.execute("SELECT value FROM u1db_config" +                      " WHERE name = 'sql_schema'") +        except dbapi2.OperationalError: +            # The table does not exist yet +            val = None +        else: +            val = c.fetchone() +        if val is not None: +            return True +        return False + +    def _initialize(self, c): +        """Create the schema in the database.""" +        # read the script with sql commands +        # TODO: Change how we set up the dependency. Most likely use something +        #   like lp:dirspec to grab the file from a common resource +        #   directory. Doesn't specifically need to be handled until we get +        #   to the point of packaging this. +        schema_content = pkg_resources.resource_string( +            __name__, 'dbschema.sql') +        # Note: We'd like to use c.executescript() here, but it seems that +        #       executescript always commits, even if you set +        #       isolation_level = None, so if we want to properly handle +        #       exclusive locking and rollbacks between processes, we need +        #       to execute it line-by-line +        for line in schema_content.split(';'): +            if not line: +                continue +            c.execute(line) +        # add extra fields +        self._extra_schema_init(c) +        # A unique identifier should be set for this replica. Implementations +        # don't have to strictly use uuid here, but we do want the uid to be +        # unique amongst all databases that will sync with each other. +        # We might extend this to using something with hostname for easier +        # debugging. +        self._set_replica_uid_in_transaction(uuid.uuid4().hex) +        c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", +                  (self._index_storage_value,)) + +    def _ensure_schema(self): +        """Ensure that the database schema has been created.""" +        old_isolation_level = self._db_handle.isolation_level +        c = self._db_handle.cursor() +        if self._is_initialized(c): +            return +        try: +            # autocommit/own mgmt of transactions +            self._db_handle.isolation_level = None +            with self._db_handle: +                # only one execution path should initialize the db +                c.execute("begin exclusive") +                if self._is_initialized(c): +                    return +                self._initialize(c) +        finally: +            self._db_handle.isolation_level = old_isolation_level + +    def _extra_schema_init(self, c): +        """Add any extra fields, etc to the basic table definitions.""" + +    def _parse_index_definition(self, index_field): +        """Parse a field definition for an index, returning a Getter.""" +        # Note: We may want to keep a Parser object around, and cache the +        #       Getter objects for a greater length of time. Specifically, if +        #       you create a bunch of indexes, and then insert 50k docs, you'll +        #       re-parse the indexes between puts. The time to insert the docs +        #       is still likely to dominate put_doc time, though. +        parser = query_parser.Parser() +        getter = parser.parse(index_field) +        return getter + +    def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): +        """Update document_fields for a single document. + +        :param doc_id: Identifier for this document +        :param raw_doc: The python dict representation of the document. +        :param getters: A list of [(field_name, Getter)]. Getter.get will be +            called to evaluate the index definition for this document, and the +            results will be inserted into the db. +        :param db_cursor: An sqlite Cursor. +        :return: None +        """ +        values = [] +        for field_name, getter in getters: +            for idx_value in getter.get(raw_doc): +                values.append((doc_id, field_name, idx_value)) +        if values: +            db_cursor.executemany( +                "INSERT INTO document_fields VALUES (?, ?, ?)", values) + +    def _set_replica_uid(self, replica_uid): +        """Force the replica_uid to be set.""" +        with self._db_handle: +            self._set_replica_uid_in_transaction(replica_uid) + +    def _set_replica_uid_in_transaction(self, replica_uid): +        """Set the replica_uid. A transaction should already be held.""" +        c = self._db_handle.cursor() +        c.execute("INSERT OR REPLACE INTO u1db_config" +                  " VALUES ('replica_uid', ?)", +                  (replica_uid,)) +        self._real_replica_uid = replica_uid + +    def _get_replica_uid(self): +        if self._real_replica_uid is not None: +            return self._real_replica_uid +        c = self._db_handle.cursor() +        c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") +        val = c.fetchone() +        if val is None: +            return None +        self._real_replica_uid = val[0] +        return self._real_replica_uid + +    _replica_uid = property(_get_replica_uid) + +    def _get_generation(self): +        c = self._db_handle.cursor() +        c.execute('SELECT max(generation) FROM transaction_log') +        val = c.fetchone()[0] +        if val is None: +            return 0 +        return val + +    def _get_generation_info(self): +        c = self._db_handle.cursor() +        c.execute( +            'SELECT max(generation), transaction_id FROM transaction_log ') +        val = c.fetchone() +        if val[0] is None: +            return(0, '') +        return val + +    def _get_trans_id_for_gen(self, generation): +        if generation == 0: +            return '' +        c = self._db_handle.cursor() +        c.execute( +            'SELECT transaction_id FROM transaction_log WHERE generation = ?', +            (generation,)) +        val = c.fetchone() +        if val is None: +            raise errors.InvalidGeneration +        return val[0] + +    def _get_transaction_log(self): +        c = self._db_handle.cursor() +        c.execute("SELECT doc_id, transaction_id FROM transaction_log" +                  " ORDER BY generation") +        return c.fetchall() + +    def _get_doc(self, doc_id, check_for_conflicts=False): +        """Get just the document content, without fancy handling.""" +        c = self._db_handle.cursor() +        if check_for_conflicts: +            c.execute( +                "SELECT document.doc_rev, document.content, " +                "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " +                "conflicts ON conflicts.doc_id = document.doc_id WHERE " +                "document.doc_id = ? GROUP BY document.doc_id, " +                "document.doc_rev, document.content;", (doc_id,)) +        else: +            c.execute( +                "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", +                (doc_id,)) +        val = c.fetchone() +        if val is None: +            return None +        doc_rev, content, conflicts = val +        doc = self._factory(doc_id, doc_rev, content) +        doc.has_conflicts = conflicts > 0 +        return doc + +    def _has_conflicts(self, doc_id): +        c = self._db_handle.cursor() +        c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", +                  (doc_id,)) +        val = c.fetchone() +        if val is None: +            return False +        else: +            return True + +    def get_doc(self, doc_id, include_deleted=False): +        doc = self._get_doc(doc_id, check_for_conflicts=True) +        if doc is None: +            return None +        if doc.is_tombstone() and not include_deleted: +            return None +        return doc + +    def get_all_docs(self, include_deleted=False): +        """Get all documents from the database.""" +        generation = self._get_generation() +        results = [] +        c = self._db_handle.cursor() +        c.execute( +            "SELECT document.doc_id, document.doc_rev, document.content, " +            "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " +            "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " +            "document.doc_rev, document.content;") +        rows = c.fetchall() +        for doc_id, doc_rev, content, conflicts in rows: +            if content is None and not include_deleted: +                continue +            doc = self._factory(doc_id, doc_rev, content) +            doc.has_conflicts = conflicts > 0 +            results.append(doc) +        return (generation, results) + +    def put_doc(self, doc): +        if doc.doc_id is None: +            raise errors.InvalidDocId() +        self._check_doc_id(doc.doc_id) +        self._check_doc_size(doc) +        with self._db_handle: +            old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +            if old_doc and old_doc.has_conflicts: +                raise errors.ConflictedDoc() +            if old_doc and doc.rev is None and old_doc.is_tombstone(): +                new_rev = self._allocate_doc_rev(old_doc.rev) +            else: +                if old_doc is not None: +                        if old_doc.rev != doc.rev: +                            raise errors.RevisionConflict() +                else: +                    if doc.rev is not None: +                        raise errors.RevisionConflict() +                new_rev = self._allocate_doc_rev(doc.rev) +            doc.rev = new_rev +            self._put_and_update_indexes(old_doc, doc) +        return new_rev + +    def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): +        """Convert a dict representation into named fields. + +        So something like: {'key1': 'val1', 'key2': 'val2'} +        gets converted into: [(doc_id, 'key1', 'val1', 0) +                              (doc_id, 'key2', 'val2', 0)] +        :param doc_id: Just added to every record. +        :param base_field: if set, these are nested keys, so each field should +            be appropriately prefixed. +        :param raw_doc: The python dictionary. +        """ +        # TODO: Handle lists +        values = [] +        for field_name, value in raw_doc.iteritems(): +            if value is None and not save_none: +                continue +            if base_field: +                full_name = base_field + '.' + field_name +            else: +                full_name = field_name +            if value is None or isinstance(value, (int, float, basestring)): +                values.append((doc_id, full_name, value, len(values))) +            else: +                subvalues = self._expand_to_fields(doc_id, full_name, value, +                                                   save_none) +                for _, subfield_name, val, _ in subvalues: +                    values.append((doc_id, subfield_name, val, len(values))) +        return values + +    def _put_and_update_indexes(self, old_doc, doc): +        """Actually insert a document into the database. + +        This both updates the existing documents content, and any indexes that +        refer to this document. +        """ +        raise NotImplementedError(self._put_and_update_indexes) + +    def whats_changed(self, old_generation=0): +        c = self._db_handle.cursor() +        c.execute("SELECT generation, doc_id, transaction_id" +                  " FROM transaction_log" +                  " WHERE generation > ? ORDER BY generation DESC", +                  (old_generation,)) +        results = c.fetchall() +        cur_gen = old_generation +        seen = set() +        changes = [] +        newest_trans_id = '' +        for generation, doc_id, trans_id in results: +            if doc_id not in seen: +                changes.append((doc_id, generation, trans_id)) +                seen.add(doc_id) +        if changes: +            cur_gen = changes[0][1]  # max generation +            newest_trans_id = changes[0][2] +            changes.reverse() +        else: +            c.execute("SELECT generation, transaction_id" +                      " FROM transaction_log ORDER BY generation DESC LIMIT 1") +            results = c.fetchone() +            if not results: +                cur_gen = 0 +                newest_trans_id = '' +            else: +                cur_gen, newest_trans_id = results + +        return cur_gen, newest_trans_id, changes + +    def delete_doc(self, doc): +        with self._db_handle: +            old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +            if old_doc is None: +                raise errors.DocumentDoesNotExist +            if old_doc.rev != doc.rev: +                raise errors.RevisionConflict() +            if old_doc.is_tombstone(): +                raise errors.DocumentAlreadyDeleted +            if old_doc.has_conflicts: +                raise errors.ConflictedDoc() +            new_rev = self._allocate_doc_rev(doc.rev) +            doc.rev = new_rev +            doc.make_tombstone() +            self._put_and_update_indexes(old_doc, doc) +        return new_rev + +    def _get_conflicts(self, doc_id): +        c = self._db_handle.cursor() +        c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", +                  (doc_id,)) +        return [self._factory(doc_id, doc_rev, content) +                for doc_rev, content in c.fetchall()] + +    def get_doc_conflicts(self, doc_id): +        with self._db_handle: +            conflict_docs = self._get_conflicts(doc_id) +            if not conflict_docs: +                return [] +            this_doc = self._get_doc(doc_id) +            this_doc.has_conflicts = True +            return [this_doc] + conflict_docs + +    def _get_replica_gen_and_trans_id(self, other_replica_uid): +        c = self._db_handle.cursor() +        c.execute("SELECT known_generation, known_transaction_id FROM sync_log" +                  " WHERE replica_uid = ?", +                  (other_replica_uid,)) +        val = c.fetchone() +        if val is None: +            other_gen = 0 +            trans_id = '' +        else: +            other_gen = val[0] +            trans_id = val[1] +        return other_gen, trans_id + +    def _set_replica_gen_and_trans_id(self, other_replica_uid, +                                      other_generation, other_transaction_id): +        with self._db_handle: +            self._do_set_replica_gen_and_trans_id( +                other_replica_uid, other_generation, other_transaction_id) + +    def _do_set_replica_gen_and_trans_id(self, other_replica_uid, +                                         other_generation, +                                         other_transaction_id): +            c = self._db_handle.cursor() +            c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", +                      (other_replica_uid, other_generation, +                       other_transaction_id)) + +    def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, +                          replica_gen=None, replica_trans_id=None): +        return super(SQLiteDatabase, self)._put_doc_if_newer( +            doc, +            save_conflict=save_conflict, +            replica_uid=replica_uid, replica_gen=replica_gen, +            replica_trans_id=replica_trans_id) + +    def _add_conflict(self, c, doc_id, my_doc_rev, my_content): +        c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", +                  (doc_id, my_doc_rev, my_content)) + +    def _delete_conflicts(self, c, doc, conflict_revs): +        deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] +        c.executemany("DELETE FROM conflicts" +                      " WHERE doc_id=? AND doc_rev=?", deleting) +        doc.has_conflicts = self._has_conflicts(doc.doc_id) + +    def _prune_conflicts(self, doc, doc_vcr): +        if self._has_conflicts(doc.doc_id): +            autoresolved = False +            c_revs_to_prune = [] +            for c_doc in self._get_conflicts(doc.doc_id): +                c_vcr = vectorclock.VectorClockRev(c_doc.rev) +                if doc_vcr.is_newer(c_vcr): +                    c_revs_to_prune.append(c_doc.rev) +                elif doc.same_content_as(c_doc): +                    c_revs_to_prune.append(c_doc.rev) +                    doc_vcr.maximize(c_vcr) +                    autoresolved = True +            if autoresolved: +                doc_vcr.increment(self._replica_uid) +                doc.rev = doc_vcr.as_str() +            c = self._db_handle.cursor() +            self._delete_conflicts(c, doc, c_revs_to_prune) + +    def _force_doc_sync_conflict(self, doc): +        my_doc = self._get_doc(doc.doc_id) +        c = self._db_handle.cursor() +        self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) +        self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) +        doc.has_conflicts = True +        self._put_and_update_indexes(my_doc, doc) + +    def resolve_doc(self, doc, conflicted_doc_revs): +        with self._db_handle: +            cur_doc = self._get_doc(doc.doc_id) +            # TODO: https://bugs.launchpad.net/u1db/+bug/928274 +            #       I think we have a logic bug in resolve_doc +            #       Specifically, cur_doc.rev is always in the final vector +            #       clock of revisions that we supersede, even if it wasn't in +            #       conflicted_doc_revs. We still add it as a conflict, but the +            #       fact that _put_doc_if_newer propagates resolutions means I +            #       think that conflict could accidentally be resolved. We need +            #       to add a test for this case first. (create a rev, create a +            #       conflict, create another conflict, resolve the first rev +            #       and first conflict, then make sure that the resolved +            #       rev doesn't supersede the second conflict rev.) It *might* +            #       not matter, because the superseding rev is in as a +            #       conflict, but it does seem incorrect +            new_rev = self._ensure_maximal_rev(cur_doc.rev, +                                               conflicted_doc_revs) +            superseded_revs = set(conflicted_doc_revs) +            c = self._db_handle.cursor() +            doc.rev = new_rev +            if cur_doc.rev in superseded_revs: +                self._put_and_update_indexes(cur_doc, doc) +            else: +                self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) +            # TODO: Is there some way that we could construct a rev that would +            #       end up in superseded_revs, such that we add a conflict, and +            #       then immediately delete it? +            self._delete_conflicts(c, doc, superseded_revs) + +    def list_indexes(self): +        """Return the list of indexes and their definitions.""" +        c = self._db_handle.cursor() +        # TODO: How do we test the ordering? +        c.execute("SELECT name, field FROM index_definitions" +                  " ORDER BY name, offset") +        definitions = [] +        cur_name = None +        for name, field in c.fetchall(): +            if cur_name != name: +                definitions.append((name, [])) +                cur_name = name +            definitions[-1][-1].append(field) +        return definitions + +    def _get_index_definition(self, index_name): +        """Return the stored definition for a given index_name.""" +        c = self._db_handle.cursor() +        c.execute("SELECT field FROM index_definitions" +                  " WHERE name = ? ORDER BY offset", (index_name,)) +        fields = [x[0] for x in c.fetchall()] +        if not fields: +            raise errors.IndexDoesNotExist +        return fields + +    @staticmethod +    def _strip_glob(value): +        """Remove the trailing * from a value.""" +        assert value[-1] == '*' +        return value[:-1] + +    def _format_query(self, definition, key_values): +        # First, build the definition. We join the document_fields table +        # against itself, as many times as the 'width' of our definition. +        # We then do a query for each key_value, one-at-a-time. +        # Note: All of these strings are static, we could cache them, etc. +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        novalue_where = ["d.doc_id = d%d.doc_id" +                         " AND d%d.field_name = ?" +                         % (i, i) for i in range(len(definition))] +        wildcard_where = [novalue_where[i] + +                          (" AND d%d.value NOT NULL" % (i,)) +                          for i in range(len(definition))] +        exact_where = [novalue_where[i] + +                       (" AND d%d.value = ?" % (i,)) +                       for i in range(len(definition))] +        like_where = [novalue_where[i] + +                      (" AND d%d.value GLOB ?" % (i,)) +                      for i in range(len(definition))] +        is_wildcard = False +        # Merge the lists together, so that: +        # [field1, field2, field3], [val1, val2, val3] +        # Becomes: +        # (field1, val1, field2, val2, field3, val3) +        args = [] +        where = [] +        for idx, (field, value) in enumerate(zip(definition, key_values)): +            args.append(field) +            if value.endswith('*'): +                if value == '*': +                    where.append(wildcard_where[idx]) +                else: +                    # This is a glob match +                    if is_wildcard: +                        # We can't have a partial wildcard following +                        # another wildcard +                        raise errors.InvalidGlobbing +                    where.append(like_where[idx]) +                    args.append(value) +                is_wildcard = True +            else: +                if is_wildcard: +                    raise errors.InvalidGlobbing +                where.append(exact_where[idx]) +                args.append(value) +        statement = ( +            "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " +            "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " +            "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " +            "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( +                ['d%d.value' % i for i in range(len(definition))]))) +        return statement, args + +    def get_from_index(self, index_name, *key_values): +        definition = self._get_index_definition(index_name) +        if len(key_values) != len(definition): +            raise errors.InvalidValueForIndex() +        statement, args = self._format_query(definition, key_values) +        c = self._db_handle.cursor() +        try: +            c.execute(statement, tuple(args)) +        except dbapi2.OperationalError as e: +            raise dbapi2.OperationalError( +                str(e) + +                '\nstatement: %s\nargs: %s\n' % (statement, args)) +        res = c.fetchall() +        results = [] +        for row in res: +            doc = self._factory(row[0], row[1], row[2]) +            doc.has_conflicts = row[3] > 0 +            results.append(doc) +        return results + +    def _format_range_query(self, definition, start_value, end_value): +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        novalue_where = [ +            "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in +            range(len(definition))] +        wildcard_where = [ +            novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in +            range(len(definition))] +        like_where = [ +            novalue_where[i] + ( +                " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in +            range(len(definition))] +        range_where_lower = [ +            novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in +            range(len(definition))] +        range_where_upper = [ +            novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in +            range(len(definition))] +        args = [] +        where = [] +        if start_value: +            if isinstance(start_value, basestring): +                start_value = (start_value,) +            if len(start_value) != len(definition): +                raise errors.InvalidValueForIndex() +            is_wildcard = False +            for idx, (field, value) in enumerate(zip(definition, start_value)): +                args.append(field) +                if value.endswith('*'): +                    if value == '*': +                        where.append(wildcard_where[idx]) +                    else: +                        # This is a glob match +                        if is_wildcard: +                            # We can't have a partial wildcard following +                            # another wildcard +                            raise errors.InvalidGlobbing +                        where.append(range_where_lower[idx]) +                        args.append(self._strip_glob(value)) +                    is_wildcard = True +                else: +                    if is_wildcard: +                        raise errors.InvalidGlobbing +                    where.append(range_where_lower[idx]) +                    args.append(value) +        if end_value: +            if isinstance(end_value, basestring): +                end_value = (end_value,) +            if len(end_value) != len(definition): +                raise errors.InvalidValueForIndex() +            is_wildcard = False +            for idx, (field, value) in enumerate(zip(definition, end_value)): +                args.append(field) +                if value.endswith('*'): +                    if value == '*': +                        where.append(wildcard_where[idx]) +                    else: +                        # This is a glob match +                        if is_wildcard: +                            # We can't have a partial wildcard following +                            # another wildcard +                            raise errors.InvalidGlobbing +                        where.append(like_where[idx]) +                        args.append(self._strip_glob(value)) +                        args.append(value) +                    is_wildcard = True +                else: +                    if is_wildcard: +                        raise errors.InvalidGlobbing +                    where.append(range_where_upper[idx]) +                    args.append(value) +        statement = ( +            "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " +            "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " +            "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " +            "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( +                ['d%d.value' % i for i in range(len(definition))]))) +        return statement, args + +    def get_range_from_index(self, index_name, start_value=None, +                             end_value=None): +        """Return all documents with key values in the specified range.""" +        definition = self._get_index_definition(index_name) +        statement, args = self._format_range_query( +            definition, start_value, end_value) +        c = self._db_handle.cursor() +        try: +            c.execute(statement, tuple(args)) +        except dbapi2.OperationalError as e: +            raise dbapi2.OperationalError( +                str(e) + +                '\nstatement: %s\nargs: %s\n' % (statement, args)) +        res = c.fetchall() +        results = [] +        for row in res: +            doc = self._factory(row[0], row[1], row[2]) +            doc.has_conflicts = row[3] > 0 +            results.append(doc) +        return results + +    def get_index_keys(self, index_name): +        c = self._db_handle.cursor() +        definition = self._get_index_definition(index_name) +        value_fields = ', '.join([ +            'd%d.value' % i for i in range(len(definition))]) +        tables = ["document_fields d%d" % i for i in range(len(definition))] +        novalue_where = [ +            "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in +            range(len(definition))] +        where = [ +            novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in +            range(len(definition))] +        statement = ( +            "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( +                value_fields, ', '.join(tables), ' AND '.join(where), +                value_fields)) +        try: +            c.execute(statement, tuple(definition)) +        except dbapi2.OperationalError as e: +            raise dbapi2.OperationalError( +                str(e) + +                '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) +        return c.fetchall() + +    def delete_index(self, index_name): +        with self._db_handle: +            c = self._db_handle.cursor() +            c.execute("DELETE FROM index_definitions WHERE name = ?", +                      (index_name,)) +            c.execute( +                "DELETE FROM document_fields WHERE document_fields.field_name " +                " NOT IN (SELECT field from index_definitions)") + + +class SQLiteSyncTarget(CommonSyncTarget): + +    def get_sync_info(self, source_replica_uid): +        source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( +            source_replica_uid) +        my_gen, my_trans_id = self._db._get_generation_info() +        return ( +            self._db._replica_uid, my_gen, my_trans_id, source_gen, +            source_trans_id) + +    def record_sync_info(self, source_replica_uid, source_replica_generation, +                         source_replica_transaction_id): +        if self._trace_hook: +            self._trace_hook('record_sync_info') +        self._db._set_replica_gen_and_trans_id( +            source_replica_uid, source_replica_generation, +            source_replica_transaction_id) + + +class SQLitePartialExpandDatabase(SQLiteDatabase): +    """An SQLite Backend that expands documents into a document_field table. + +    It stores the original document text in document.doc. For fields that are +    indexed, the data goes into document_fields. +    """ + +    _index_storage_value = 'expand referenced' + +    def _get_indexed_fields(self): +        """Determine what fields are indexed.""" +        c = self._db_handle.cursor() +        c.execute("SELECT field FROM index_definitions") +        return set([x[0] for x in c.fetchall()]) + +    def _evaluate_index(self, raw_doc, field): +        parser = query_parser.Parser() +        getter = parser.parse(field) +        return getter.get(raw_doc) + +    def _put_and_update_indexes(self, old_doc, doc): +        c = self._db_handle.cursor() +        if doc and not doc.is_tombstone(): +            raw_doc = json.loads(doc.get_json()) +        else: +            raw_doc = {} +        if old_doc is not None: +            c.execute("UPDATE document SET doc_rev=?, content=?" +                      " WHERE doc_id = ?", +                      (doc.rev, doc.get_json(), doc.doc_id)) +            c.execute("DELETE FROM document_fields WHERE doc_id = ?", +                      (doc.doc_id,)) +        else: +            c.execute("INSERT INTO document (doc_id, doc_rev, content)" +                      " VALUES (?, ?, ?)", +                      (doc.doc_id, doc.rev, doc.get_json())) +        indexed_fields = self._get_indexed_fields() +        if indexed_fields: +            # It is expected that len(indexed_fields) is shorter than +            # len(raw_doc) +            getters = [(field, self._parse_index_definition(field)) +                       for field in indexed_fields] +            self._update_indexes(doc.doc_id, raw_doc, getters, c) +        trans_id = self._allocate_transaction_id() +        c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" +                  " VALUES (?, ?)", (doc.doc_id, trans_id)) + +    def create_index(self, index_name, *index_expressions): +        with self._db_handle: +            c = self._db_handle.cursor() +            cur_fields = self._get_indexed_fields() +            definition = [(index_name, idx, field) +                          for idx, field in enumerate(index_expressions)] +            try: +                c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", +                              definition) +            except dbapi2.IntegrityError as e: +                stored_def = self._get_index_definition(index_name) +                if stored_def == [x[-1] for x in definition]: +                    return +                raise errors.IndexNameTakenError( +                    str(e) + +                    str(sys.exc_info()[2]) +                ) +            new_fields = set( +                [f for f in index_expressions if f not in cur_fields]) +            if new_fields: +                self._update_all_indexes(new_fields) + +    def _iter_all_docs(self): +        c = self._db_handle.cursor() +        c.execute("SELECT doc_id, content FROM document") +        while True: +            next_rows = c.fetchmany() +            if not next_rows: +                break +            for row in next_rows: +                yield row + +    def _update_all_indexes(self, new_fields): +        """Iterate all the documents, and add content to document_fields. + +        :param new_fields: The index definitions that need to be added. +        """ +        getters = [(field, self._parse_index_definition(field)) +                   for field in new_fields] +        c = self._db_handle.cursor() +        for doc_id, doc in self._iter_all_docs(): +            if doc is None: +                continue +            raw_doc = json.loads(doc) +            self._update_indexes(doc_id, raw_doc, getters, c) + + +SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/client/_document.py b/src/leap/soledad/client/_document.py new file mode 100644 index 00000000..9c8577cb --- /dev/null +++ b/src/leap/soledad/client/_document.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +# _document.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Public interfaces for adding extra client features to the generic +SoledadDocument. +""" + +import weakref +import uuid + +from twisted.internet import defer + +from zope.interface import Interface +from zope.interface import implementer + +from leap.soledad.common.document import SoledadDocument + + +class IDocumentWithAttachment(Interface): +    """ +    A document that can have an attachment. +    """ + +    def set_store(self, store): +        """ +        Set the store used by this file to manage attachments. + +        :param store: The store used to manage attachments. +        :type store: Soledad +        """ + +    def put_attachment(self, fd): +        """ +        Attach data to this document. + +        Add the attachment to local storage, enqueue for upload. + +        The document content will be updated with a pointer to the attachment, +        but the document has to be manually put in the database to reflect +        modifications. + +        :param fd: A file-like object whose content will be attached to this +                   document. +        :type fd: file-like + +        :return: A deferred which fires when the attachment has been added to +                 local storage. +        :rtype: Deferred +        """ + +    def get_attachment(self): +        """ +        Return the data attached to this document. + +        If document content contains a pointer to the attachment, try to get +        the attachment from local storage and, if not found, from remote +        storage. + +        :return: A deferred which fires with a file like-object whose content +                 is the attachment of this document, or None if nothing is +                 attached. +        :rtype: Deferred +        """ + +    def delete_attachment(self): +        """ +        Delete the attachment of this document. + +        The pointer to the attachment will be removed from the document +        content, but the document has to be manually put in the database to +        reflect modifications. + +        :return: A deferred which fires when the attachment has been deleted +                 from local storage. +        :rtype: Deferred +        """ + +    def get_attachment_state(self): +        """ +        Return the state of the attachment of this document. + +        The state is a member of AttachmentStates and is of one of NONE, +        LOCAL, REMOTE or SYNCED. + +        :return: A deferred which fires with The state of the attachment of +                 this document. +        :rtype: Deferred +        """ + +    def is_dirty(self): +        """ +        Return whether this document's content differs from the contents stored +        in local database. + +        :return: A deferred which fires with True or False, depending on +                 whether this document is dirty or not. +        :rtype: Deferred +        """ + +    def upload_attachment(self): +        """ +        Upload this document's attachment. + +        :return: A deferred which fires with the state of the attachment after +                 it's been uploaded, or NONE if there's no attachment for this +                 document. +        :rtype: Deferred +        """ + +    def download_attachment(self): +        """ +        Download this document's attachment. + +        :return: A deferred which fires with the state of the attachment after +                 it's been downloaded, or NONE if there's no attachment for +                 this document. +        :rtype: Deferred +        """ + + +class BlobDoc(object): + +    # TODO probably not needed, but convenient for testing for now. + +    def __init__(self, content, blob_id): + +        self.blob_id = blob_id +        self.is_blob = True +        self.blob_fd = content +        if blob_id is None: +            blob_id = uuid.uuid4().get_hex() +        self.blob_id = blob_id + + +class AttachmentStates(object): +    NONE = 0 +    LOCAL = 1 +    REMOTE = 2 +    SYNCED = 4 + + +@implementer(IDocumentWithAttachment) +class Document(SoledadDocument): + +    def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, +                 syncable=True, store=None): +        SoledadDocument.__init__(self, doc_id=doc_id, rev=rev, json=json, +                                 has_conflicts=has_conflicts, +                                 syncable=syncable) +        self.set_store(store) + +    # +    # properties +    # + +    @property +    def _manager(self): +        if not self.store or not hasattr(self.store, 'blobmanager'): +            raise Exception('No blob manager found to manage attachments.') +        return self.store.blobmanager + +    @property +    def _blob_id(self): +        if self.content and 'blob_id' in self.content: +            return self.content['blob_id'] +        return None + +    def get_store(self): +        return self._store() if self._store else None + +    def set_store(self, store): +        self._store = weakref.ref(store) if store else None + +    store = property(get_store, set_store) + +    # +    # attachment api +    # + +    def put_attachment(self, fd): +        # add pointer to content +        blob_id = self._blob_id or str(uuid.uuid4()) +        if not self.content: +            self.content = {} +        self.content['blob_id'] = blob_id +        # put using manager +        blob = BlobDoc(fd, blob_id) +        fd.seek(0, 2) +        size = fd.tell() +        fd.seek(0) +        return self._manager.put(blob, size) + +    def get_attachment(self): +        if not self._blob_id: +            return defer.succeed(None) +        return self._manager.get(self._blob_id) + +    def delete_attachment(self): +        raise NotImplementedError + +    @defer.inlineCallbacks +    def get_attachment_state(self): +        state = AttachmentStates.NONE + +        if not self._blob_id: +            defer.returnValue(state) + +        local_list = yield self._manager.local_list() +        if self._blob_id in local_list: +            state |= AttachmentStates.LOCAL + +        remote_list = yield self._manager.remote_list() +        if self._blob_id in remote_list: +            state |= AttachmentStates.REMOTE + +        defer.returnValue(state) + +    @defer.inlineCallbacks +    def is_dirty(self): +        stored = yield self.store.get_doc(self.doc_id) +        if stored.content != self.content: +            defer.returnValue(True) +        defer.returnValue(False) + +    @defer.inlineCallbacks +    def upload_attachment(self): +        if not self._blob_id: +            defer.returnValue(AttachmentStates.NONE) + +        fd = yield self._manager.get_blob(self._blob_id) +        # TODO: turn following method into a public one +        yield self._manager._encrypt_and_upload(self._blob_id, fd) +        defer.returnValue(self.get_attachment_state()) + +    @defer.inlineCallbacks +    def download_attachment(self): +        if not self._blob_id: +            defer.returnValue(None) +        yield self.get_attachment() +        defer.returnValue(self.get_attachment_state()) diff --git a/src/leap/soledad/client/_http.py b/src/leap/soledad/client/_http.py new file mode 100644 index 00000000..2a6b9e39 --- /dev/null +++ b/src/leap/soledad/client/_http.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# _http.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +A twisted-based, TLS-pinned, token-authenticated HTTP client. +""" +import base64 + +from twisted.internet import reactor +from twisted.web.iweb import IAgent +from twisted.web.client import Agent +from twisted.web.http_headers import Headers + +from treq.client import HTTPClient as _HTTPClient + +from zope.interface import implementer + +from leap.common.certs import get_compatible_ssl_context_factory + + +__all__ = ['HTTPClient', 'PinnedTokenAgent'] + + +class HTTPClient(_HTTPClient): + +    def __init__(self, uuid, token, cert_file): +        self._agent = PinnedTokenAgent(uuid, token, cert_file) +        super(self.__class__, self).__init__(self._agent) + +    def set_token(self, token): +        self._agent.set_token(token) + + +@implementer(IAgent) +class PinnedTokenAgent(Agent): + +    def __init__(self, uuid, token, cert_file): +        self._uuid = uuid +        self._token = None +        self._creds = None +        self.set_token(token) +        # pin this agent with the platform TLS certificate +        factory = get_compatible_ssl_context_factory(cert_file) +        Agent.__init__(self, reactor, contextFactory=factory) + +    def set_token(self, token): +        self._token = token +        self._creds = self._encoded_creds() + +    def _encoded_creds(self): +        creds = '%s:%s' % (self._uuid, self._token) +        encoded = base64.b64encode(creds) +        return 'Token %s' % encoded + +    def request(self, method, uri, headers=None, bodyProducer=None): +        # authenticate the request +        headers = headers or Headers() +        headers.addRawHeader('Authorization', self._creds) +        # perform the authenticated request +        return Agent.request( +            self, method, uri, headers=headers, bodyProducer=bodyProducer) diff --git a/src/leap/soledad/client/_pipes.py b/src/leap/soledad/client/_pipes.py new file mode 100644 index 00000000..eef3f1f9 --- /dev/null +++ b/src/leap/soledad/client/_pipes.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# _pipes.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Components for piping data on streams. +""" +from io import BytesIO + + +__all__ = ['TruncatedTailPipe', 'PreamblePipe'] + + +class TruncatedTailPipe(object): +    """ +    Truncate the last `tail_size` bytes from the stream. +    """ + +    def __init__(self, output=None, tail_size=16): +        self.tail_size = tail_size +        self.output = output or BytesIO() +        self.buffer = BytesIO() + +    def write(self, data): +        self.buffer.write(data) +        if self.buffer.tell() > self.tail_size: +            self._truncate_tail() + +    def _truncate_tail(self): +        overflow_size = self.buffer.tell() - self.tail_size +        self.buffer.seek(0) +        self.output.write(self.buffer.read(overflow_size)) +        remaining = self.buffer.read() +        self.buffer.seek(0) +        self.buffer.write(remaining) +        self.buffer.truncate() + +    def close(self): +        return self.output + + +class PreamblePipe(object): +    """ +    Consumes data until a space is found, then calls a callback with it and +    starts forwarding data to consumer returned by this callback. +    """ + +    def __init__(self, callback): +        self.callback = callback +        self.preamble = BytesIO() +        self.output = None + +    def write(self, data): +        if not self.output: +            self._write_preamble(data) +        else: +            self.output.write(data) + +    def _write_preamble(self, data): +        if ' ' not in data: +            self.preamble.write(data) +            return +        preamble_chunk, remaining = data.split(' ', 1) +        self.preamble.write(preamble_chunk) +        self.output = self.callback(self.preamble) +        self.output.write(remaining) diff --git a/src/leap/soledad/client/_recovery_code.py b/src/leap/soledad/client/_recovery_code.py new file mode 100644 index 00000000..04235a29 --- /dev/null +++ b/src/leap/soledad/client/_recovery_code.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# _recovery_code.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import os +import binascii + +from leap.soledad.common.log import getLogger + +logger = getLogger(__name__) + + +class RecoveryCode(object): + +    # When we turn this string to hex, it will double in size +    code_length = 6 + +    def generate(self): +        logger.info("generating new recovery code...") +        return binascii.hexlify(os.urandom(self.code_length)) diff --git a/src/leap/soledad/client/_secrets/__init__.py b/src/leap/soledad/client/_secrets/__init__.py new file mode 100644 index 00000000..b6c81cda --- /dev/null +++ b/src/leap/soledad/client/_secrets/__init__.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# _secrets/__init__.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import os +import scrypt + +from leap.soledad.common.log import getLogger + +from leap.soledad.client._secrets.storage import SecretsStorage +from leap.soledad.client._secrets.crypto import SecretsCrypto +from leap.soledad.client._secrets.util import emit, UserDataMixin + + +logger = getLogger(__name__) + + +class Secrets(UserDataMixin): + +    lengths = { +        'remote_secret': 512,  # remote_secret is used to encrypt remote data. +        'local_salt': 64,      # local_salt is used in conjunction with +        'local_secret': 448,   # local_secret to derive a local_key for storage +    } + +    def __init__(self, soledad): +        self._soledad = soledad +        self._secrets = {} +        self.crypto = SecretsCrypto(soledad) +        self.storage = SecretsStorage(soledad) +        self._bootstrap() + +    # +    # bootstrap +    # + +    def _bootstrap(self): + +        # attempt to load secrets from local storage +        encrypted = self.storage.load_local() +        if encrypted: +            self._secrets = self.crypto.decrypt(encrypted) +            # maybe update the format of storage of local secret. +            if encrypted['version'] < self.crypto.VERSION: +                self.store_secrets() +            return + +        # no secret was found in local storage, so this is a first run of +        # soledad for this user in this device. It is mandatory that we check +        # if there's a secret stored in server. +        encrypted = self.storage.load_remote() +        if encrypted: +            self._secrets = self.crypto.decrypt(encrypted) +            self.store_secrets() +            return + +        # we have *not* found a secret neither in local nor in remote storage, +        # so we have to generate a new one, and then store it. +        self._secrets = self._generate() +        self.store_secrets() + +    # +    # generation +    # + +    @emit('creating') +    def _generate(self): +        logger.info("generating new set of secrets...") +        secrets = {} +        for name, length in self.lengths.iteritems(): +            secret = os.urandom(length) +            secrets[name] = secret +        logger.info("new set of secrets successfully generated") +        return secrets + +    # +    # crypto +    # + +    def store_secrets(self): +        # TODO: we have to improve the logic here, as we want to make sure that +        # whatever is stored locally should only be used after remote storage +        # is successful. Otherwise, this soledad could start encrypting with a +        # secret while another soledad in another device could start encrypting +        # with another secret, which would lead to decryption failures during +        # sync. +        encrypted = self.crypto.encrypt(self._secrets) +        self.storage.save_local(encrypted) +        self.storage.save_remote(encrypted) + +    # +    # secrets +    # + +    @property +    def remote_secret(self): +        return self._secrets.get('remote_secret') + +    @property +    def local_salt(self): +        return self._secrets.get('local_salt') + +    @property +    def local_secret(self): +        return self._secrets.get('local_secret') + +    @property +    def local_key(self): +        # local storage key is scrypt-derived from `local_secret` and +        # `local_salt` above +        secret = scrypt.hash( +            password=self.local_secret, +            salt=self.local_salt, +            buflen=32,  # we need a key with 256 bits (32 bytes) +        ) +        return secret diff --git a/src/leap/soledad/client/_secrets/crypto.py b/src/leap/soledad/client/_secrets/crypto.py new file mode 100644 index 00000000..8148151d --- /dev/null +++ b/src/leap/soledad/client/_secrets/crypto.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# _secrets/crypto.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import binascii +import json +import os +import scrypt + +from leap.soledad.common import soledad_assert +from leap.soledad.common.log import getLogger + +from leap.soledad.client._crypto import encrypt_sym, decrypt_sym, ENC_METHOD +from leap.soledad.client._secrets.util import SecretsError + + +logger = getLogger(__name__) + + +class SecretsCrypto(object): + +    VERSION = 2 + +    def __init__(self, soledad): +        self._soledad = soledad + +    def _get_key(self, salt): +        passphrase = self._soledad.passphrase.encode('utf8') +        key = scrypt.hash(passphrase, salt, buflen=32) +        return key + +    # +    # encryption +    # + +    def encrypt(self, secrets): +        encoded = {} +        for name, value in secrets.iteritems(): +            encoded[name] = binascii.b2a_base64(value) +        plaintext = json.dumps(encoded) +        salt = os.urandom(64)  # TODO: get salt length from somewhere else +        key = self._get_key(salt) +        iv, ciphertext = encrypt_sym(plaintext, key, +                                     method=ENC_METHOD.aes_256_gcm) +        encrypted = { +            'version': self.VERSION, +            'kdf': 'scrypt', +            'kdf_salt': binascii.b2a_base64(salt), +            'kdf_length': len(key), +            'cipher': ENC_METHOD.aes_256_gcm, +            'length': len(plaintext), +            'iv': str(iv), +            'secrets': binascii.b2a_base64(ciphertext), +        } +        return encrypted + +    # +    # decryption +    # + +    def decrypt(self, data): +        version = data.setdefault('version', 1) +        method = getattr(self, '_decrypt_v%d' % version) +        try: +            return method(data) +        except Exception as e: +            logger.error('error decrypting secrets: %r' % e) +            raise SecretsError(e) + +    def _decrypt_v1(self, data): +        # get encrypted secret from dictionary: the old format allowed for +        # storage of more than one secret, but this feature was never used and +        # soledad has been using only one secret so far. As there is a corner +        # case where the old 'active_secret' key might not be set, we just +        # ignore it and pop the only secret found in the 'storage_secrets' key. +        secret_id = data['storage_secrets'].keys().pop() +        encrypted = data['storage_secrets'][secret_id] + +        # assert that we know how to decrypt the secret +        soledad_assert('cipher' in encrypted) +        cipher = encrypted['cipher'] +        if cipher == 'aes256': +            cipher = ENC_METHOD.aes_256_ctr +        soledad_assert(cipher in ENC_METHOD) + +        # decrypt +        salt = binascii.a2b_base64(encrypted['kdf_salt']) +        key = self._get_key(salt) +        separator = ':' +        iv, ciphertext = encrypted['secret'].split(separator, 1) +        ciphertext = binascii.a2b_base64(ciphertext) +        plaintext = self._decrypt(key, iv, ciphertext, encrypted, cipher) + +        # create secrets dictionary +        secrets = { +            'remote_secret': plaintext[0:512], +            'local_salt': plaintext[512:576], +            'local_secret': plaintext[576:1024], +        } +        return secrets + +    def _decrypt_v2(self, encrypted): +        cipher = encrypted['cipher'] +        soledad_assert(cipher in ENC_METHOD) + +        salt = binascii.a2b_base64(encrypted['kdf_salt']) +        key = self._get_key(salt) +        iv = encrypted['iv'] +        ciphertext = binascii.a2b_base64(encrypted['secrets']) +        plaintext = self._decrypt( +            key, iv, ciphertext, encrypted, cipher) +        encoded = json.loads(plaintext) +        secrets = {} +        for name, value in encoded.iteritems(): +            secrets[name] = binascii.a2b_base64(value) +        return secrets + +    def _decrypt(self, key, iv, ciphertext, encrypted, method): +        # assert some properties of the stored secret +        soledad_assert(encrypted['kdf'] == 'scrypt') +        soledad_assert(encrypted['kdf_length'] == len(key)) +        # decrypt +        plaintext = decrypt_sym(ciphertext, key, iv, method) +        soledad_assert(encrypted['length'] == len(plaintext)) +        return plaintext diff --git a/src/leap/soledad/client/_secrets/storage.py b/src/leap/soledad/client/_secrets/storage.py new file mode 100644 index 00000000..85713a48 --- /dev/null +++ b/src/leap/soledad/client/_secrets/storage.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# _secrets/storage.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import json +import six.moves.urllib.parse as urlparse + +from hashlib import sha256 + +from leap.soledad.common import SHARED_DB_NAME +from leap.soledad.common.log import getLogger + +from leap.soledad.client.shared_db import SoledadSharedDatabase +from leap.soledad.client._document import Document +from leap.soledad.client._secrets.util import emit, UserDataMixin + + +logger = getLogger(__name__) + + +class SecretsStorage(UserDataMixin): + +    def __init__(self, soledad): +        self._soledad = soledad +        self._shared_db = self._soledad.shared_db or self._init_shared_db() +        self.__remote_doc = None + +    @property +    def _creds(self): +        uuid = self._soledad.uuid +        token = self._soledad.token +        return {'token': {'uuid': uuid, 'token': token}} + +    # +    # local storage +    # + +    def load_local(self): +        path = self._soledad.secrets_path +        logger.info("trying to load secrets from disk: %s" % path) +        try: +            with open(path, 'r') as f: +                encrypted = json.loads(f.read()) +            logger.info("secrets loaded successfully from disk") +            return encrypted +        except IOError: +            logger.warn("secrets not found in disk") +        return None + +    def save_local(self, encrypted): +        path = self._soledad.secrets_path +        json_data = json.dumps(encrypted) +        with open(path, 'w') as f: +            f.write(json_data) + +    # +    # remote storage +    # + +    def _init_shared_db(self): +        url = urlparse.urljoin(self._soledad.server_url, SHARED_DB_NAME) +        creds = self._creds +        db = SoledadSharedDatabase.open_database(url, creds) +        return db + +    def _remote_doc_id(self): +        passphrase = self._soledad.passphrase.encode('utf8') +        uuid = self._soledad.uuid +        text = '%s%s' % (passphrase, uuid) +        digest = sha256(text).hexdigest() +        return digest + +    @property +    def _remote_doc(self): +        if not self.__remote_doc and self._shared_db: +            doc = self._get_remote_doc() +            self.__remote_doc = doc +        return self.__remote_doc + +    @emit('downloading') +    def _get_remote_doc(self): +        logger.info('trying to load secrets from server...') +        doc = self._shared_db.get_doc(self._remote_doc_id()) +        if doc: +            logger.info('secrets loaded successfully from server') +        else: +            logger.warn('secrets not found in server') +        return doc + +    def load_remote(self): +        doc = self._remote_doc +        if not doc: +            return None +        encrypted = doc.content +        return encrypted + +    @emit('uploading') +    def save_remote(self, encrypted): +        doc = self._remote_doc +        if not doc: +            doc = Document(doc_id=self._remote_doc_id()) +        doc.content = encrypted +        db = self._shared_db +        if not db: +            logger.warn('no shared db found') +            return +        db.put_doc(doc) diff --git a/src/leap/soledad/client/_secrets/util.py b/src/leap/soledad/client/_secrets/util.py new file mode 100644 index 00000000..6401889b --- /dev/null +++ b/src/leap/soledad/client/_secrets/util.py @@ -0,0 +1,63 @@ +# -*- coding:utf-8 -*- +# _secrets/util.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +from leap.soledad.client import events + + +class SecretsError(Exception): +    pass + + +class UserDataMixin(object): +    """ +    When emitting an event, we have to pass a dictionary containing user data. +    This class only defines a property so we don't have to define it in +    multiple places. +    """ + +    @property +    def _user_data(self): +        uuid = self._soledad.uuid +        userid = self._soledad.userid +        # TODO: seems that uuid and userid hold the same value! We should check +        # whether we should pass something different or if the events api +        # really needs two different values. +        return {'uuid': uuid, 'userid': userid} + + +def emit(verb): +    def _decorator(method): +        def _decorated(self, *args, **kwargs): + +            # emit starting event +            user_data = self._user_data +            name = 'SOLEDAD_' + verb.upper() + '_KEYS' +            event = getattr(events, name) +            events.emit_async(event, user_data) + +            # run the method +            result = method(self, *args, **kwargs) + +            # emit a finished event +            name = 'SOLEDAD_DONE_' + verb.upper() + '_KEYS' +            event = getattr(events, name) +            events.emit_async(event, user_data) + +            return result +        return _decorated +    return _decorator diff --git a/src/leap/soledad/client/api.py b/src/leap/soledad/client/api.py new file mode 100644 index 00000000..c62b43f0 --- /dev/null +++ b/src/leap/soledad/client/api.py @@ -0,0 +1,848 @@ +# -*- coding: utf-8 -*- +# api.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Soledad - Synchronization Of Locally Encrypted Data Among Devices. + +This module holds the public api for Soledad. + +Soledad is the part of LEAP that manages storage and synchronization of +application data. It is built on top of U1DB reference Python API and +implements (1) a SQLCipher backend for local storage in the client, (2) a +SyncTarget that encrypts data before syncing, and (3) a CouchDB backend for +remote storage in the server side. +""" +import binascii +import errno +import os +import socket +import ssl +import uuid + +from itertools import chain +import six.moves.http_client as httplib +import six.moves.urllib.parse as urlparse +from six import StringIO +from collections import defaultdict + +from twisted.internet import defer +from zope.interface import implementer + +from leap.common.config import get_path_prefix +from leap.common.plugins import collect_plugins + +from leap.soledad.common import soledad_assert +from leap.soledad.common import soledad_assert_type +from leap.soledad.common.log import getLogger +from leap.soledad.common.l2db.remote import http_client +from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname +from leap.soledad.common.errors import DatabaseAccessError + +from . import events as soledad_events +from . import interfaces as soledad_interfaces +from ._crypto import SoledadCrypto +from ._db import adbapi +from ._db import blobs +from ._db import sqlcipher +from ._recovery_code import RecoveryCode +from ._secrets import Secrets + + +logger = getLogger(__name__) + + +# we may want to collect statistics from the sync process +DO_STATS = False +if os.environ.get('SOLEDAD_STATS'): +    DO_STATS = True + + +# +# Constants +# + +""" +Path to the certificate file used to certify the SSL connection between +Soledad client and server. +""" +SOLEDAD_CERT = None + + +@implementer(soledad_interfaces.ILocalStorage, +             soledad_interfaces.ISyncableStorage, +             soledad_interfaces.ISecretsStorage) +class Soledad(object): +    """ +    Soledad provides encrypted data storage and sync. + +    A Soledad instance is used to store and retrieve data in a local encrypted +    database and synchronize this database with Soledad server. + +    This class is also responsible for bootstrapping users' account by +    creating cryptographic secrets and/or storing/fetching them on Soledad +    server. +    """ + +    local_db_file_name = 'soledad.u1db' +    secrets_file_name = "soledad.json" +    default_prefix = os.path.join(get_path_prefix(), 'leap', 'soledad') + +    """ +    A dictionary that holds locks which avoid multiple sync attempts from the +    same database replica. The dictionary indexes are the paths to each local +    db, so we guarantee that only one sync happens for a local db at a time. +    """ +    _sync_lock = defaultdict(defer.DeferredLock) + +    def __init__(self, uuid, passphrase, secrets_path, local_db_path, +                 server_url, cert_file, shared_db=None, +                 auth_token=None): +        """ +        Initialize configuration, cryptographic keys and dbs. + +        :param uuid: User's uuid. +        :type uuid: str + +        :param passphrase: +            The passphrase for locking and unlocking encryption secrets for +            local and remote storage. +        :type passphrase: unicode + +        :param secrets_path: +            Path for storing encrypted key used for symmetric encryption. +        :type secrets_path: str + +        :param local_db_path: Path for local encrypted storage db. +        :type local_db_path: str + +        :param server_url: +            URL for Soledad server. This is used either to sync with the user's +            remote db and to interact with the shared recovery database. +        :type server_url: str + +        :param cert_file: +            Path to the certificate of the ca used to validate the SSL +            certificate used by the remote soledad server. +        :type cert_file: str + +        :param shared_db: +            The shared database. +        :type shared_db: HTTPDatabase + +        :param auth_token: +            Authorization token for accessing remote databases. +        :type auth_token: str + +        :raise BootstrapSequenceError: +            Raised when the secret initialization sequence (i.e. retrieval +            from server or generation and storage on server) has failed for +            some reason. +        """ +        # store config params +        self.uuid = uuid +        self.passphrase = passphrase +        self.secrets_path = secrets_path +        self._local_db_path = local_db_path +        self.server_url = server_url +        self.shared_db = shared_db +        self.token = auth_token + +        self._dbsyncer = None + +        # configure SSL certificate +        global SOLEDAD_CERT +        SOLEDAD_CERT = cert_file + +        self._init_config_with_defaults() +        self._init_working_dirs() + +        self._recovery_code = RecoveryCode() +        self._secrets = Secrets(self) +        self._crypto = SoledadCrypto(self._secrets.remote_secret) +        self._init_blobmanager() + +        try: +            # initialize database access, trap any problems so we can shutdown +            # smoothly. +            self._init_u1db_sqlcipher_backend() +            self._init_u1db_syncer() +        except DatabaseAccessError: +            # oops! something went wrong with backend initialization. We +            # have to close any thread-related stuff we have already opened +            # here, otherwise there might be zombie threads that may clog the +            # reactor. +            if hasattr(self, '_dbpool'): +                self._dbpool.close() +            raise + +    # +    # initialization/destruction methods +    # + +    def _init_config_with_defaults(self): +        """ +        Initialize configuration using default values for missing params. +        """ +        soledad_assert_type(self.passphrase, unicode) + +        def initialize(attr, val): +            return ((getattr(self, attr, None) is None) and +                    setattr(self, attr, val)) + +        initialize("_secrets_path", os.path.join( +            self.default_prefix, self.secrets_file_name)) +        initialize("_local_db_path", os.path.join( +            self.default_prefix, self.local_db_file_name)) +        # initialize server_url +        soledad_assert(self.server_url is not None, +                       'Missing URL for Soledad server.') + +    def _init_working_dirs(self): +        """ +        Create work directories. + +        :raise OSError: in case file exists and is not a dir. +        """ +        paths = map(lambda x: os.path.dirname(x), [ +            self._local_db_path, self._secrets_path]) +        for path in paths: +            create_path_if_not_exists(path) + +    def _init_u1db_sqlcipher_backend(self): +        """ +        Initialize the U1DB SQLCipher database for local storage. + +        Instantiates a modified twisted adbapi that will maintain a threadpool +        with a u1db-sqclipher connection for each thread, and will return +        deferreds for each u1db query. + +        Currently, Soledad uses the default SQLCipher cipher, i.e. +        'aes-256-cbc'. We use scrypt to derive a 256-bit encryption key, +        and internally the SQLCipherDatabase initialization uses the 'raw +        PRAGMA key' format to handle the key to SQLCipher. +        """ +        tohex = binascii.b2a_hex +        # sqlcipher only accepts the hex version +        key = tohex(self._secrets.local_key) + +        opts = sqlcipher.SQLCipherOptions( +            self._local_db_path, key, +            is_raw_key=True, create=True) +        self._sqlcipher_opts = opts +        self._dbpool = adbapi.getConnectionPool(opts) + +    def _init_u1db_syncer(self): +        """ +        Initialize the U1DB synchronizer. +        """ +        replica_uid = self._dbpool.replica_uid +        self._dbsyncer = sqlcipher.SQLCipherU1DBSync( +            self._sqlcipher_opts, self._crypto, replica_uid, +            SOLEDAD_CERT) + +    def sync_stats(self): +        sync_phase = 0 +        if getattr(self._dbsyncer, 'sync_phase', None): +            sync_phase = self._dbsyncer.sync_phase[0] +        sync_exchange_phase = 0 +        if getattr(self._dbsyncer, 'syncer', None): +            if getattr(self._dbsyncer.syncer, 'sync_exchange_phase', None): +                _p = self._dbsyncer.syncer.sync_exchange_phase[0] +                sync_exchange_phase = _p +        return sync_phase, sync_exchange_phase + +    def _init_blobmanager(self): +        path = os.path.join(os.path.dirname(self._local_db_path), 'blobs') +        url = urlparse.urljoin(self.server_url, 'blobs/%s' % uuid) +        key = self._secrets.local_key +        self.blobmanager = blobs.BlobManager(path, url, key, self.uuid, +                                             self.token, SOLEDAD_CERT) + +    # +    # Closing methods +    # + +    def close(self): +        """ +        Close underlying U1DB database. +        """ +        logger.debug("closing soledad") +        self._dbpool.close() +        self.blobmanager.close() +        if getattr(self, '_dbsyncer', None): +            self._dbsyncer.close() + +    # +    # ILocalStorage +    # + +    def _defer(self, meth, *args, **kw): +        """ +        Defer a method to be run on a U1DB connection pool. + +        :param meth: A method to defer to the U1DB connection pool. +        :type meth: callable +        :return: A deferred. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._dbpool.runU1DBQuery(meth, *args, **kw) + +    def put_doc(self, doc): +        """ +        Update a document. + +        If the document currently has conflicts, put will fail. +        If the database specifies a maximum document size and the document +        exceeds it, put will fail and raise a DocumentTooBig exception. + +        ============================== WARNING ============================== +        This method converts the document's contents to unicode in-place. This +        means that after calling `put_doc(doc)`, the contents of the +        document, i.e. `doc.content`, might be different from before the +        call. +        ============================== WARNING ============================== + +        :param doc: A document with new content. +        :type doc: leap.soledad.common.document.Document +        :return: A deferred whose callback will be invoked with the new +            revision identifier for the document. The document object will +            also be updated. +        :rtype: twisted.internet.defer.Deferred +        """ +        d = self._defer("put_doc", doc) +        return d + +    def delete_doc(self, doc): +        """ +        Mark a document as deleted. + +        Will abort if the current revision doesn't match doc.rev. +        This will also set doc.content to None. + +        :param doc: A document to be deleted. +        :type doc: leap.soledad.common.document.Document +        :return: A deferred. +        :rtype: twisted.internet.defer.Deferred +        """ +        soledad_assert(doc is not None, "delete_doc doesn't accept None.") +        return self._defer("delete_doc", doc) + +    def get_doc(self, doc_id, include_deleted=False): +        """ +        Get the JSON string for the given document. + +        :param doc_id: The unique document identifier +        :type doc_id: str +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise asking for a deleted +            document will return None. +        :type include_deleted: bool +        :return: A deferred whose callback will be invoked with a document +            object. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer( +            "get_doc", doc_id, include_deleted=include_deleted) + +    def get_docs( +            self, doc_ids, check_for_conflicts=True, include_deleted=False): +        """ +        Get the JSON content for many documents. + +        :param doc_ids: A list of document identifiers. +        :type doc_ids: list +        :param check_for_conflicts: If set to False, then the conflict check +            will be skipped, and 'None' will be returned instead of True/False. +        :type check_for_conflicts: bool +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise deleted documents will not +            be included in the results. +        :type include_deleted: bool +        :return: A deferred whose callback will be invoked with an iterable +            giving the document object for each document id in matching +            doc_ids order. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer( +            "get_docs", doc_ids, check_for_conflicts=check_for_conflicts, +            include_deleted=include_deleted) + +    def get_all_docs(self, include_deleted=False): +        """ +        Get the JSON content for all documents in the database. + +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise deleted documents will not +            be included in the results. +        :type include_deleted: bool + +        :return: A deferred which, when fired, will pass the a tuple +            containing (generation, [Document]) to the callback, with the +            current generation of the database, followed by a list of all the +            documents in the database. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("get_all_docs", include_deleted) + +    @defer.inlineCallbacks +    def create_doc(self, content, doc_id=None): +        """ +        Create a new document. + +        You can optionally specify the document identifier, but the document +        must not already exist. See 'put_doc' if you want to override an +        existing document. +        If the database specifies a maximum document size and the document +        exceeds it, create will fail and raise a DocumentTooBig exception. + +        :param content: A Python dictionary. +        :type content: dict +        :param doc_id: An optional identifier specifying the document id. +        :type doc_id: str +        :return: A deferred whose callback will be invoked with a document. +        :rtype: twisted.internet.defer.Deferred +        """ +        # TODO we probably should pass an optional "encoding" parameter to +        # create_doc (and probably to put_doc too). There are cases (mail +        # payloads for example) in which we already have the encoding in the +        # headers, so we don't need to guess it. +        doc = yield self._defer("create_doc", content, doc_id=doc_id) +        doc.set_store(self) +        defer.returnValue(doc) + +    def create_doc_from_json(self, json, doc_id=None): +        """ +        Create a new document. + +        You can optionally specify the document identifier, but the document +        must not already exist. See 'put_doc' if you want to override an +        existing document. +        If the database specifies a maximum document size and the document +        exceeds it, create will fail and raise a DocumentTooBig exception. + +        :param json: The JSON document string +        :type json: dict +        :param doc_id: An optional identifier specifying the document id. +        :type doc_id: str +        :return: A deferred whose callback will be invoked with a document. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("create_doc_from_json", json, doc_id=doc_id) + +    def create_index(self, index_name, *index_expressions): +        """ +        Create a named index, which can then be queried for future lookups. + +        Creating an index which already exists is not an error, and is cheap. +        Creating an index which does not match the index_expressions of the +        existing index is an error. +        Creating an index will block until the expressions have been evaluated +        and the index generated. + +        :param index_name: A unique name which can be used as a key prefix +        :type index_name: str +        :param index_expressions: index expressions defining the index +            information. + +            Examples: + +            "fieldname", or "fieldname.subfieldname" to index alphabetically +            sorted on the contents of a field. + +            "number(fieldname, width)", "lower(fieldname)" +        :type index_expresions: list of str +        :return: A deferred. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("create_index", index_name, *index_expressions) + +    def delete_index(self, index_name): +        """ +        Remove a named index. + +        :param index_name: The name of the index we are removing +        :type index_name: str +        :return: A deferred. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("delete_index", index_name) + +    def list_indexes(self): +        """ +        List the definitions of all known indexes. + +        :return: A deferred whose callback will be invoked with a list of +            [('index-name', ['field', 'field2'])] definitions. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("list_indexes") + +    def get_from_index(self, index_name, *key_values): +        """ +        Return documents that match the keys supplied. + +        You must supply exactly the same number of values as have been defined +        in the index. It is possible to do a prefix match by using '*' to +        indicate a wildcard match. You can only supply '*' to trailing entries, +        (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) +        It is also possible to append a '*' to the last supplied value (eg +        'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +            an index with 3 fields then you would have: +            get_from_index(index_name, val1, val2, val3) +        :type key_values: list +        :return: A deferred whose callback will be invoked with a list of +            [Document]. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("get_from_index", index_name, *key_values) + +    def get_count_from_index(self, index_name, *key_values): +        """ +        Return the count for a given combination of index_name +        and key values. + +        Extension method made from similar methods in u1db version 13.09 + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +                           an index with 3 fields then you would have: +                           get_from_index(index_name, val1, val2, val3) +        :type key_values: tuple +        :return: A deferred whose callback will be invoked with the count. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("get_count_from_index", index_name, *key_values) + +    def get_range_from_index(self, index_name, start_value, end_value): +        """ +        Return documents that fall within the specified range. + +        Both ends of the range are inclusive. For both start_value and +        end_value, one must supply exactly the same number of values as have +        been defined in the index, or pass None. In case of a single column +        index, a string is accepted as an alternative for a tuple with a single +        value. It is possible to do a prefix match by using '*' to indicate +        a wildcard match. You can only supply '*' to trailing entries, (eg +        'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also +        possible to append a '*' to the last supplied value (eg 'val*', '*', +        '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + +        :param index_name: The index to query +        :type index_name: str +        :param start_values: tuples of values that define the lower bound of +            the range. eg, if you have an index with 3 fields then you would +            have: (val1, val2, val3) +        :type start_values: tuple +        :param end_values: tuples of values that define the upper bound of the +            range. eg, if you have an index with 3 fields then you would have: +            (val1, val2, val3) +        :type end_values: tuple +        :return: A deferred whose callback will be invoked with a list of +            [Document]. +        :rtype: twisted.internet.defer.Deferred +        """ + +        return self._defer( +            "get_range_from_index", index_name, start_value, end_value) + +    def get_index_keys(self, index_name): +        """ +        Return all keys under which documents are indexed in this index. + +        :param index_name: The index to query +        :type index_name: str +        :return: A deferred whose callback will be invoked with a list of +            tuples of indexed keys. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("get_index_keys", index_name) + +    def get_doc_conflicts(self, doc_id): +        """ +        Get the list of conflicts for the given document. + +        The order of the conflicts is such that the first entry is the value +        that would be returned by "get_doc". + +        :param doc_id: The unique document identifier +        :type doc_id: str +        :return: A deferred whose callback will be invoked with a list of the +            Document entries that are conflicted. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("get_doc_conflicts", doc_id) + +    def resolve_doc(self, doc, conflicted_doc_revs): +        """ +        Mark a document as no longer conflicted. + +        We take the list of revisions that the client knows about that it is +        superseding. This may be a different list from the actual current +        conflicts, in which case only those are removed as conflicted.  This +        may fail if the conflict list is significantly different from the +        supplied information. (sync could have happened in the background from +        the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + +        :param doc: A Document with the new content to be inserted. +        :type doc: Document +        :param conflicted_doc_revs: A list of revisions that the new content +            supersedes. +        :type conflicted_doc_revs: list(str) +        :return: A deferred. +        :rtype: twisted.internet.defer.Deferred +        """ +        return self._defer("resolve_doc", doc, conflicted_doc_revs) + +    @property +    def local_db_path(self): +        return self._local_db_path + +    @property +    def userid(self): +        return self.uuid + +    # +    # ISyncableStorage +    # + +    def sync(self): +        """ +        Synchronize documents with the server replica. + +        This method uses a lock to prevent multiple concurrent sync processes +        over the same local db file. + +        :return: A deferred lock that will run the actual sync process when +                 the lock is acquired, and which will fire with with the local +                 generation before the synchronization was performed. +        :rtype: twisted.internet.defer.Deferred +        """ +        # maybe bypass sync +        # TODO: That's because bitmask may not provide us a token, but +        # this should be handled on the caller side. Here, calling us without +        # a token is a real error. +        if not self.token: +            generation = self._dbsyncer.get_generation() +            return defer.succeed(generation) + +        d = self.sync_lock.run( +            self._sync) +        return d + +    def _sync(self): +        """ +        Synchronize documents with the server replica. + +        :return: A deferred whose callback will be invoked with the local +            generation before the synchronization was performed. +        :rtype: twisted.internet.defer.Deferred +        """ +        sync_url = urlparse.urljoin(self.server_url, 'user-%s' % self.uuid) +        if not self._dbsyncer: +            return +        creds = {'token': {'uuid': self.uuid, 'token': self.token}} +        d = self._dbsyncer.sync(sync_url, creds=creds) + +        def _sync_callback(local_gen): +            self._last_received_docs = docs = self._dbsyncer.received_docs + +            # Post-Sync Hooks +            if docs: +                iface = soledad_interfaces.ISoledadPostSyncPlugin +                suitable_plugins = collect_plugins(iface) +                for plugin in suitable_plugins: +                    watched = plugin.watched_doc_types +                    r = [filter( +                        lambda s: s.startswith(preffix), +                        docs) for preffix in watched] +                    filtered = list(chain(*r)) +                    plugin.process_received_docs(filtered) + +            return local_gen + +        def _sync_errback(failure): +            s = StringIO() +            failure.printDetailedTraceback(file=s) +            msg = "got exception when syncing!\n" + s.getvalue() +            logger.error(msg) +            return failure + +        def _emit_done_data_sync(passthrough): +            user_data = {'uuid': self.uuid, 'userid': self.userid} +            soledad_events.emit_async( +                soledad_events.SOLEDAD_DONE_DATA_SYNC, user_data) +            return passthrough + +        d.addCallbacks(_sync_callback, _sync_errback) +        d.addCallback(_emit_done_data_sync) +        return d + +    @property +    def sync_lock(self): +        """ +        Class based lock to prevent concurrent syncs using the same local db +        file. + +        :return: A shared lock based on this instance's db file path. +        :rtype: DeferredLock +        """ +        return self._sync_lock[self._local_db_path] + +    @property +    def syncing(self): +        """ +        Return wether Soledad is currently synchronizing with the server. + +        :return: Wether Soledad is currently synchronizing with the server. +        :rtype: bool +        """ +        return self.sync_lock.locked + +    # +    # ISecretsStorage +    # + +    @property +    def secrets(self): +        """ +        Return the secrets object. + +        :return: The secrets object. +        :rtype: Secrets +        """ +        return self._secrets + +    def change_passphrase(self, new_passphrase): +        """ +        Change the passphrase that encrypts the storage secret. + +        :param new_passphrase: The new passphrase. +        :type new_passphrase: unicode + +        :raise NoStorageSecret: Raised if there's no storage secret available. +        """ +        self.passphrase = new_passphrase +        self._secrets.store_secrets() + +    # +    # Raw SQLCIPHER Queries +    # + +    def raw_sqlcipher_query(self, *args, **kw): +        """ +        Run a raw sqlcipher query in the local database, and return a deferred +        that will be fired with the result. +        """ +        return self._dbpool.runQuery(*args, **kw) + +    def raw_sqlcipher_operation(self, *args, **kw): +        """ +        Run a raw sqlcipher operation in the local database, and return a +        deferred that will be fired with None. +        """ +        return self._dbpool.runOperation(*args, **kw) + +    # +    # Service authentication +    # + +    @defer.inlineCallbacks +    def get_or_create_service_token(self, service): +        """ +        Return the stored token for a given service, or generates and stores a +        random one if it does not exist. + +        These tokens can be used to authenticate services. +        """ +        # FIXME this could use the local sqlcipher database, to avoid +        # problems with different replicas creating different tokens. + +        yield self.create_index('by-servicetoken', 'type', 'service') +        docs = yield self._get_token_for_service(service) +        if docs: +            doc = docs[0] +            defer.returnValue(doc.content['token']) +        else: +            token = str(uuid.uuid4()).replace('-', '')[-24:] +            yield self._set_token_for_service(service, token) +            defer.returnValue(token) + +    def _get_token_for_service(self, service): +        return self.get_from_index('by-servicetoken', 'servicetoken', service) + +    def _set_token_for_service(self, service, token): +        doc = {'type': 'servicetoken', 'service': service, 'token': token} +        return self.create_doc(doc) + +    def create_recovery_code(self): +        return self._recovery_code.generate() + + +def create_path_if_not_exists(path): +    try: +        if not os.path.isdir(path): +            logger.info('creating directory: %s.' % path) +        os.makedirs(path) +    except OSError as exc: +        if exc.errno == errno.EEXIST and os.path.isdir(path): +            pass +        else: +            raise + +# ---------------------------------------------------------------------------- +# Monkey patching u1db to be able to provide a custom SSL cert +# ---------------------------------------------------------------------------- + + +# We need a more reasonable timeout (in seconds) +SOLEDAD_TIMEOUT = 120 + + +class VerifiedHTTPSConnection(httplib.HTTPSConnection): +    """ +    HTTPSConnection verifying server side certificates. +    """ +    # derived from httplib.py + +    def connect(self): +        """ +        Connect to a host on a given (SSL) port. +        """ +        try: +            source = self.source_address +            sock = socket.create_connection((self.host, self.port), +                                            SOLEDAD_TIMEOUT, source) +        except AttributeError: +            # source_address was introduced in 2.7 +            sock = socket.create_connection((self.host, self.port), +                                            SOLEDAD_TIMEOUT) +        if self._tunnel_host: +            self.sock = sock +            self._tunnel() + +        self.sock = ssl.wrap_socket(sock, +                                    ca_certs=SOLEDAD_CERT, +                                    cert_reqs=ssl.CERT_REQUIRED) +        match_hostname(self.sock.getpeercert(), self.host) + + +old__VerifiedHTTPSConnection = http_client._VerifiedHTTPSConnection +http_client._VerifiedHTTPSConnection = VerifiedHTTPSConnection diff --git a/src/leap/soledad/client/auth.py b/src/leap/soledad/client/auth.py new file mode 100644 index 00000000..78e9bf1b --- /dev/null +++ b/src/leap/soledad/client/auth.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# auth.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Methods for token-based authentication. + +These methods have to be included in all classes that extend HTTPClient so +they can do token-based auth requests to the Soledad server. +""" +import base64 + +from leap.soledad.common.l2db import errors + + +class TokenBasedAuth(object): +    """ +    Encapsulate token-auth methods for classes that inherit from +    u1db.remote.http_client.HTTPClient. +    """ + +    def set_token_credentials(self, uuid, token): +        """ +        Store given credentials so we can sign the request later. + +        :param uuid: The user's uuid. +        :type uuid: str +        :param token: The authentication token. +        :type token: str +        """ +        self._creds = {'token': (uuid, token)} + +    def _sign_request(self, method, url_query, params): +        """ +        Return an authorization header to be included in the HTTP request, in +        the form: + +            [('Authorization', 'Token <(base64 encoded) uuid:token>')] + +        :param method: The HTTP method. +        :type method: str +        :param url_query: The URL query string. +        :type url_query: str +        :param params: A list with encoded query parameters. +        :type param: list + +        :return: The Authorization header. +        :rtype: list of tuple +        """ +        if 'token' in self._creds: +            uuid, token = self._creds['token'] +            auth = '%s:%s' % (uuid, token) +            b64_token = base64.b64encode(auth) +            return [('Authorization', 'Token %s' % b64_token)] +        else: +            raise errors.UnknownAuthMethod( +                'Wrong credentials: %s' % self._creds) diff --git a/src/leap/soledad/client/crypto.py b/src/leap/soledad/client/crypto.py new file mode 100644 index 00000000..0f19c964 --- /dev/null +++ b/src/leap/soledad/client/crypto.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- +# crypto.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Cryptographic utilities for Soledad. +""" +import os +import binascii +import hmac +import hashlib +import json + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from leap.soledad.common import soledad_assert +from leap.soledad.common import soledad_assert_type +from leap.soledad.common import crypto +from leap.soledad.common.log import getLogger +import warnings + + +logger = getLogger(__name__) +warnings.warn("'soledad.client.crypto' MODULE DEPRECATED", +              DeprecationWarning, stacklevel=2) + + +MAC_KEY_LENGTH = 64 + +crypto_backend = default_backend() + + +def encrypt_sym(data, key): +    """ +    Encrypt data using AES-256 cipher in CTR mode. + +    :param data: The data to be encrypted. +    :type data: str +    :param key: The key used to encrypt data (must be 256 bits long). +    :type key: str + +    :return: A tuple with the initialization vector and the encrypted data. +    :rtype: (long, str) +    """ +    soledad_assert_type(key, str) +    soledad_assert( +        len(key) == 32,  # 32 x 8 = 256 bits. +        'Wrong key size: %s bits (must be 256 bits long).' % +        (len(key) * 8)) + +    iv = os.urandom(16) +    cipher = Cipher(algorithms.AES(key), modes.CTR(iv), backend=crypto_backend) +    encryptor = cipher.encryptor() +    ciphertext = encryptor.update(data) + encryptor.finalize() + +    return binascii.b2a_base64(iv), ciphertext + + +def decrypt_sym(data, key, iv): +    """ +    Decrypt some data previously encrypted using AES-256 cipher in CTR mode. + +    :param data: The data to be decrypted. +    :type data: str +    :param key: The symmetric key used to decrypt data (must be 256 bits +                long). +    :type key: str +    :param iv: The initialization vector. +    :type iv: long + +    :return: The decrypted data. +    :rtype: str +    """ +    soledad_assert_type(key, str) +    # assert params +    soledad_assert( +        len(key) == 32,  # 32 x 8 = 256 bits. +        'Wrong key size: %s (must be 256 bits long).' % len(key)) +    iv = binascii.a2b_base64(iv) +    cipher = Cipher(algorithms.AES(key), modes.CTR(iv), backend=crypto_backend) +    decryptor = cipher.decryptor() +    return decryptor.update(data) + decryptor.finalize() + + +def doc_mac_key(doc_id, secret): +    """ +    Generate a key for calculating a MAC for a document whose id is +    C{doc_id}. + +    The key is derived using HMAC having sha256 as underlying hash +    function. The key used for HMAC is the first MAC_KEY_LENGTH characters +    of Soledad's storage secret. The HMAC message is C{doc_id}. + +    :param doc_id: The id of the document. +    :type doc_id: str + +    :param secret: The Soledad storage secret +    :type secret: str + +    :return: The key. +    :rtype: str +    """ +    soledad_assert(secret is not None) +    return hmac.new( +        secret[:MAC_KEY_LENGTH], +        doc_id, +        hashlib.sha256).digest() + + +class SoledadCrypto(object): +    """ +    General cryptographic functionality encapsulated in a +    object that can be passed along. +    """ +    def __init__(self, secret): +        """ +        Initialize the crypto object. + +        :param secret: The Soledad remote storage secret. +        :type secret: str +        """ +        self._secret = secret + +    def doc_mac_key(self, doc_id): +        return doc_mac_key(doc_id, self._secret) + +    def doc_passphrase(self, doc_id): +        """ +        Generate a passphrase for symmetric encryption of document's contents. + +        The password is derived using HMAC having sha256 as underlying hash +        function. The key used for HMAC are the first +        C{soledad.REMOTE_STORAGE_SECRET_LENGTH} bytes of Soledad's storage +        secret stripped from the first MAC_KEY_LENGTH characters. The HMAC +        message is C{doc_id}. + +        :param doc_id: The id of the document that will be encrypted using +            this passphrase. +        :type doc_id: str + +        :return: The passphrase. +        :rtype: str +        """ +        soledad_assert(self._secret is not None) +        return hmac.new( +            self._secret[MAC_KEY_LENGTH:], +            doc_id, +            hashlib.sha256).digest() + +    def encrypt_doc(self, doc): +        """ +        Wrapper around encrypt_docstr that accepts the document as argument. + +        :param doc: the document. +        :type doc: Document +        """ +        key = self.doc_passphrase(doc.doc_id) + +        return encrypt_docstr( +            doc.get_json(), doc.doc_id, doc.rev, key, self._secret) + +    def decrypt_doc(self, doc): +        """ +        Wrapper around decrypt_doc_dict that accepts the document as argument. + +        :param doc: the document. +        :type doc: Document + +        :return: json string with the decrypted document +        :rtype: str +        """ +        key = self.doc_passphrase(doc.doc_id) +        return decrypt_doc_dict( +            doc.content, doc.doc_id, doc.rev, key, self._secret) + +    @property +    def secret(self): +        return self._secret + + +# +# Crypto utilities for a Document. +# + +def mac_doc(doc_id, doc_rev, ciphertext, enc_scheme, enc_method, enc_iv, +            mac_method, secret): +    """ +    Calculate a MAC for C{doc} using C{ciphertext}. + +    Current MAC method used is HMAC, with the following parameters: + +        * key: sha256(storage_secret, doc_id) +        * msg: doc_id + doc_rev + ciphertext +        * digestmod: sha256 + +    :param doc_id: The id of the document. +    :type doc_id: str +    :param doc_rev: The revision of the document. +    :type doc_rev: str +    :param ciphertext: The content of the document. +    :type ciphertext: str +    :param enc_scheme: The encryption scheme. +    :type enc_scheme: str +    :param enc_method: The encryption method. +    :type enc_method: str +    :param enc_iv: The encryption initialization vector. +    :type enc_iv: str +    :param mac_method: The MAC method to use. +    :type mac_method: str +    :param secret: The Soledad storage secret +    :type secret: str + +    :return: The calculated MAC. +    :rtype: str + +    :raise crypto.UnknownMacMethodError: Raised when C{mac_method} is unknown. +    """ +    try: +        soledad_assert(mac_method == crypto.MacMethods.HMAC) +    except AssertionError: +        raise crypto.UnknownMacMethodError +    template = "{doc_id}{doc_rev}{ciphertext}{enc_scheme}{enc_method}{enc_iv}" +    content = template.format( +        doc_id=doc_id, +        doc_rev=doc_rev, +        ciphertext=ciphertext, +        enc_scheme=enc_scheme, +        enc_method=enc_method, +        enc_iv=enc_iv) +    return hmac.new( +        doc_mac_key(doc_id, secret), +        content, +        hashlib.sha256).digest() + + +def encrypt_docstr(docstr, doc_id, doc_rev, key, secret): +    """ +    Encrypt C{doc}'s content. + +    Encrypt doc's contents using AES-256 CTR mode and return a valid JSON +    string representing the following: + +        { +            crypto.ENC_JSON_KEY: '<encrypted doc JSON string>', +            crypto.ENC_SCHEME_KEY: 'symkey', +            crypto.ENC_METHOD_KEY: crypto.EncryptionMethods.AES_256_CTR, +            crypto.ENC_IV_KEY: '<the initial value used to encrypt>', +            MAC_KEY: '<mac>' +            crypto.MAC_METHOD_KEY: 'hmac' +        } + +    :param docstr: A representation of the document to be encrypted. +    :type docstr: str or unicode. + +    :param doc_id: The document id. +    :type doc_id: str + +    :param doc_rev: The document revision. +    :type doc_rev: str + +    :param key: The key used to encrypt ``data`` (must be 256 bits long). +    :type key: str + +    :param secret: The Soledad storage secret (used for MAC auth). +    :type secret: str + +    :return: The JSON serialization of the dict representing the encrypted +             content. +    :rtype: str +    """ +    enc_scheme = crypto.EncryptionSchemes.SYMKEY +    enc_method = crypto.EncryptionMethods.AES_256_CTR +    mac_method = crypto.MacMethods.HMAC +    enc_iv, ciphertext = encrypt_sym( +        str(docstr),  # encryption/decryption routines expect str +        key) +    mac = binascii.b2a_hex(  # store the mac as hex. +        mac_doc( +            doc_id, +            doc_rev, +            ciphertext, +            enc_scheme, +            enc_method, +            enc_iv, +            mac_method, +            secret)) +    # Return a representation for the encrypted content. In the following, we +    # convert binary data to hexadecimal representation so the JSON +    # serialization does not complain about what it tries to serialize. +    hex_ciphertext = binascii.b2a_hex(ciphertext) +    logger.debug("encrypting doc: %s" % doc_id) +    return json.dumps({ +        crypto.ENC_JSON_KEY: hex_ciphertext, +        crypto.ENC_SCHEME_KEY: enc_scheme, +        crypto.ENC_METHOD_KEY: enc_method, +        crypto.ENC_IV_KEY: enc_iv, +        crypto.MAC_KEY: mac, +        crypto.MAC_METHOD_KEY: mac_method, +    }) + + +def _verify_doc_mac(doc_id, doc_rev, ciphertext, enc_scheme, enc_method, +                    enc_iv, mac_method, secret, doc_mac): +    """ +    Verify that C{doc_mac} is a correct MAC for the given document. + +    :param doc_id: The id of the document. +    :type doc_id: str +    :param doc_rev: The revision of the document. +    :type doc_rev: str +    :param ciphertext: The content of the document. +    :type ciphertext: str +    :param enc_scheme: The encryption scheme. +    :type enc_scheme: str +    :param enc_method: The encryption method. +    :type enc_method: str +    :param enc_iv: The encryption initialization vector. +    :type enc_iv: str +    :param mac_method: The MAC method to use. +    :type mac_method: str +    :param secret: The Soledad storage secret +    :type secret: str +    :param doc_mac: The MAC to be verified against. +    :type doc_mac: str + +    :raise crypto.UnknownMacMethodError: Raised when C{mac_method} is unknown. +    :raise crypto.WrongMacError: Raised when MAC could not be verified. +    """ +    calculated_mac = mac_doc( +        doc_id, +        doc_rev, +        ciphertext, +        enc_scheme, +        enc_method, +        enc_iv, +        mac_method, +        secret) +    # we compare mac's hashes to avoid possible timing attacks that might +    # exploit python's builtin comparison operator behaviour, which fails +    # immediatelly when non-matching bytes are found. +    doc_mac_hash = hashlib.sha256( +        binascii.a2b_hex(  # the mac is stored as hex +            doc_mac)).digest() +    calculated_mac_hash = hashlib.sha256(calculated_mac).digest() + +    if doc_mac_hash != calculated_mac_hash: +        logger.warn("wrong MAC while decrypting doc...") +        raise crypto.WrongMacError("Could not authenticate document's " +                                   "contents.") + + +def decrypt_doc_dict(doc_dict, doc_id, doc_rev, key, secret): +    """ +    Decrypt a symmetrically encrypted C{doc}'s content. + +    Return the JSON string representation of the document's decrypted content. + +    The passed doc_dict argument should have the following structure: + +        { +            crypto.ENC_JSON_KEY: '<enc_blob>', +            crypto.ENC_SCHEME_KEY: '<enc_scheme>', +            crypto.ENC_METHOD_KEY: '<enc_method>', +            crypto.ENC_IV_KEY: '<initial value used to encrypt>',  # (optional) +            MAC_KEY: '<mac>' +            crypto.MAC_METHOD_KEY: 'hmac' +        } + +    C{enc_blob} is the encryption of the JSON serialization of the document's +    content. For now Soledad just deals with documents whose C{enc_scheme} is +    crypto.EncryptionSchemes.SYMKEY and C{enc_method} is +    crypto.EncryptionMethods.AES_256_CTR. + +    :param doc_dict: The content of the document to be decrypted. +    :type doc_dict: dict + +    :param doc_id: The document id. +    :type doc_id: str + +    :param doc_rev: The document revision. +    :type doc_rev: str + +    :param key: The key used to encrypt ``data`` (must be 256 bits long). +    :type key: str + +    :param secret: The Soledad storage secret. +    :type secret: str + +    :return: The JSON serialization of the decrypted content. +    :rtype: str + +    :raise UnknownEncryptionMethodError: Raised when trying to decrypt from an +        unknown encryption method. +    """ +    # assert document dictionary structure +    expected_keys = set([ +        crypto.ENC_JSON_KEY, +        crypto.ENC_SCHEME_KEY, +        crypto.ENC_METHOD_KEY, +        crypto.ENC_IV_KEY, +        crypto.MAC_KEY, +        crypto.MAC_METHOD_KEY, +    ]) +    soledad_assert(expected_keys.issubset(set(doc_dict.keys()))) + +    ciphertext = binascii.a2b_hex(doc_dict[crypto.ENC_JSON_KEY]) +    enc_scheme = doc_dict[crypto.ENC_SCHEME_KEY] +    enc_method = doc_dict[crypto.ENC_METHOD_KEY] +    enc_iv = doc_dict[crypto.ENC_IV_KEY] +    doc_mac = doc_dict[crypto.MAC_KEY] +    mac_method = doc_dict[crypto.MAC_METHOD_KEY] + +    soledad_assert(enc_scheme == crypto.EncryptionSchemes.SYMKEY) + +    _verify_doc_mac( +        doc_id, doc_rev, ciphertext, enc_scheme, enc_method, +        enc_iv, mac_method, secret, doc_mac) + +    return decrypt_sym(ciphertext, key, enc_iv) + + +def is_symmetrically_encrypted(doc): +    """ +    Return True if the document was symmetrically encrypted. + +    :param doc: The document to check. +    :type doc: Document + +    :rtype: bool +    """ +    if doc.content and crypto.ENC_SCHEME_KEY in doc.content: +        if doc.content[crypto.ENC_SCHEME_KEY] \ +                == crypto.EncryptionSchemes.SYMKEY: +            return True +    return False diff --git a/src/leap/soledad/client/events.py b/src/leap/soledad/client/events.py new file mode 100644 index 00000000..058be59c --- /dev/null +++ b/src/leap/soledad/client/events.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# signal.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Signaling functions. +""" + +from leap.common.events import emit_async +from leap.common.events import catalog + + +SOLEDAD_CREATING_KEYS = catalog.SOLEDAD_CREATING_KEYS +SOLEDAD_DONE_CREATING_KEYS = catalog.SOLEDAD_DONE_CREATING_KEYS +SOLEDAD_DOWNLOADING_KEYS = catalog.SOLEDAD_DOWNLOADING_KEYS +SOLEDAD_DONE_DOWNLOADING_KEYS = \ +    catalog.SOLEDAD_DONE_DOWNLOADING_KEYS +SOLEDAD_UPLOADING_KEYS = catalog.SOLEDAD_UPLOADING_KEYS +SOLEDAD_DONE_UPLOADING_KEYS = \ +    catalog.SOLEDAD_DONE_UPLOADING_KEYS +SOLEDAD_NEW_DATA_TO_SYNC = catalog.SOLEDAD_NEW_DATA_TO_SYNC +SOLEDAD_DONE_DATA_SYNC = catalog.SOLEDAD_DONE_DATA_SYNC +SOLEDAD_SYNC_SEND_STATUS = catalog.SOLEDAD_SYNC_SEND_STATUS +SOLEDAD_SYNC_RECEIVE_STATUS = catalog.SOLEDAD_SYNC_RECEIVE_STATUS + + +__all__ = [ +    "catalog", +    "emit_async", +    "SOLEDAD_CREATING_KEYS", +    "SOLEDAD_DONE_CREATING_KEYS", +    "SOLEDAD_DOWNLOADING_KEYS", +    "SOLEDAD_DONE_DOWNLOADING_KEYS", +    "SOLEDAD_UPLOADING_KEYS", +    "SOLEDAD_DONE_UPLOADING_KEYS", +    "SOLEDAD_NEW_DATA_TO_SYNC", +    "SOLEDAD_DONE_DATA_SYNC", +    "SOLEDAD_SYNC_SEND_STATUS", +    "SOLEDAD_SYNC_RECEIVE_STATUS", +] diff --git a/src/leap/soledad/client/examples/README b/src/leap/soledad/client/examples/README new file mode 100644 index 00000000..3aed8377 --- /dev/null +++ b/src/leap/soledad/client/examples/README @@ -0,0 +1,4 @@ +Right now, you can find here both an example of use +and the benchmarking scripts. +TODO move benchmark scripts to root scripts/ folder, +and leave here only a minimal example. diff --git a/src/leap/soledad/client/examples/benchmarks/.gitignore b/src/leap/soledad/client/examples/benchmarks/.gitignore new file mode 100644 index 00000000..2211df63 --- /dev/null +++ b/src/leap/soledad/client/examples/benchmarks/.gitignore @@ -0,0 +1 @@ +*.txt diff --git a/src/leap/soledad/client/examples/benchmarks/get_sample.sh b/src/leap/soledad/client/examples/benchmarks/get_sample.sh new file mode 100755 index 00000000..1995eee1 --- /dev/null +++ b/src/leap/soledad/client/examples/benchmarks/get_sample.sh @@ -0,0 +1,3 @@ +#!/bin/sh +mkdir tmp +wget http://www.gutenberg.org/cache/epub/101/pg101.txt -O hacker_crackdown.txt diff --git a/src/leap/soledad/client/examples/benchmarks/measure_index_times.py b/src/leap/soledad/client/examples/benchmarks/measure_index_times.py new file mode 100644 index 00000000..f9349758 --- /dev/null +++ b/src/leap/soledad/client/examples/benchmarks/measure_index_times.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# measure_index_times.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Measure u1db retrieval times for different u1db index situations. +""" +from __future__ import print_function +from functools import partial +import datetime +import hashlib +import os +import sys + +from twisted.internet import defer, reactor + +from leap.soledad.common import l2db +from leap.soledad.client import adbapi +from leap.soledad.client._db.sqlcipher import SQLCipherOptions + + +folder = os.environ.get("TMPDIR", "tmp") +numdocs = int(os.environ.get("DOCS", "1000")) +silent = os.environ.get("SILENT", False) +tmpdb = os.path.join(folder, "test.soledad") + + +sample_file = os.environ.get("SAMPLE", "hacker_crackdown.txt") +sample_path = os.path.join(os.curdir, sample_file) + +try: +    with open(sample_file) as f: +        SAMPLE = f.readlines() +except Exception: +    print("[!] Problem opening sample file. Did you download " +          "the sample, or correctly set 'SAMPLE' env var?") +    sys.exit(1) + +if numdocs > len(SAMPLE): +    print("[!] Sorry! The requested DOCS number is larger than " +          "the num of lines in our sample file") +    sys.exit(1) + + +def debug(*args): +    if not silent: +        print(*args) + + +debug("[+] db path:", tmpdb) +debug("[+] num docs", numdocs) + +if os.path.isfile(tmpdb): +    debug("[+] Removing existing db file...") +    os.remove(tmpdb) + +start_time = datetime.datetime.now() + +opts = SQLCipherOptions(tmpdb, "secret", create=True) +dbpool = adbapi.getConnectionPool(opts) + + +def createDoc(doc): +    return dbpool.runU1DBQuery("create_doc", doc) + + +db_indexes = { +    'by-chash': ['chash'], +    'by-number': ['number']} + + +def create_indexes(_): +    deferreds = [] +    for index, definition in db_indexes.items(): +        d = dbpool.runU1DBQuery("create_index", index, *definition) +        deferreds.append(d) +    return defer.gatherResults(deferreds) + + +class TimeWitness(object): +    def __init__(self, init_time): +        self.init_time = init_time + +    def get_time_count(self): +        return datetime.datetime.now() - self.init_time + + +def get_from_index(_): +    init_time = datetime.datetime.now() +    debug("GETTING FROM INDEX...", init_time) + +    def printValue(res, time): +        print("RESULT->", res) +        print("Index Query Took: ", time.get_time_count()) +        return res + +    d = dbpool.runU1DBQuery( +        "get_from_index", "by-chash", +        # "1150c7f10fabce0a57ce13071349fc5064f15bdb0cc1bf2852f74ef3f103aff5") +        # XXX this is line 89 from the hacker crackdown... +        # Should accept any other optional hash as an enviroment variable. +        "57793320d4997a673fc7062652da0596c36a4e9fbe31310d2281e67d56d82469") +    d.addCallback(printValue, TimeWitness(init_time)) +    return d + + +def getAllDocs(): +    return dbpool.runU1DBQuery("get_all_docs") + + +def errBack(e): +    debug("[!] ERROR FOUND!!!") +    e.printTraceback() +    reactor.stop() + + +def countDocs(_): +    debug("counting docs...") +    d = getAllDocs() +    d.addCallbacks(printResult, errBack) +    d.addCallbacks(allDone, errBack) +    return d + + +def printResult(r, **kwargs): +    if kwargs: +        debug(*kwargs.values()) +    elif isinstance(r, l2db.Document): +        debug(r.doc_id, r.content['number']) +    else: +        len_results = len(r[1]) +        debug("GOT %s results" % len(r[1])) + +        if len_results == numdocs: +            debug("ALL GOOD") +        else: +            debug("[!] MISSING DOCS!!!!!") +            raise ValueError("We didn't expect this result len") + + +def allDone(_): +    debug("ALL DONE!") + +    end_time = datetime.datetime.now() +    print((end_time - start_time).total_seconds()) +    reactor.stop() + + +def insert_docs(_): +    deferreds = [] +    for i in range(numdocs): +        payload = SAMPLE[i] +        chash = hashlib.sha256(payload).hexdigest() +        doc = {"number": i, "payload": payload, 'chash': chash} +        d = createDoc(doc) +        d.addCallbacks(partial(printResult, i=i, chash=chash, payload=payload), +                       lambda e: e.printTraceback()) +        deferreds.append(d) +    return defer.gatherResults(deferreds, consumeErrors=True) + + +d = create_indexes(None) +d.addCallback(insert_docs) +d.addCallback(get_from_index) +d.addCallback(countDocs) + +reactor.run() diff --git a/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py b/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py new file mode 100644 index 00000000..4f273c64 --- /dev/null +++ b/src/leap/soledad/client/examples/benchmarks/measure_index_times_custom_docid.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# measure_index_times.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Measure u1db retrieval times for different u1db index situations. +""" +from __future__ import print_function +from functools import partial +import datetime +import hashlib +import os +import sys + +from twisted.internet import defer, reactor + +from leap.soledad.client import adbapi +from leap.soledad.client._db.sqlcipher import SQLCipherOptions +from leap.soledad.common import l2db + + +folder = os.environ.get("TMPDIR", "tmp") +numdocs = int(os.environ.get("DOCS", "1000")) +silent = os.environ.get("SILENT", False) +tmpdb = os.path.join(folder, "test.soledad") + + +sample_file = os.environ.get("SAMPLE", "hacker_crackdown.txt") +sample_path = os.path.join(os.curdir, sample_file) + +try: +    with open(sample_file) as f: +        SAMPLE = f.readlines() +except Exception: +    print("[!] Problem opening sample file. Did you download " +          "the sample, or correctly set 'SAMPLE' env var?") +    sys.exit(1) + +if numdocs > len(SAMPLE): +    print("[!] Sorry! The requested DOCS number is larger than " +          "the num of lines in our sample file") +    sys.exit(1) + + +def debug(*args): +    if not silent: +        print(*args) + + +debug("[+] db path:", tmpdb) +debug("[+] num docs", numdocs) + +if os.path.isfile(tmpdb): +    debug("[+] Removing existing db file...") +    os.remove(tmpdb) + +start_time = datetime.datetime.now() + +opts = SQLCipherOptions(tmpdb, "secret", create=True) +dbpool = adbapi.getConnectionPool(opts) + + +def createDoc(doc, doc_id): +    return dbpool.runU1DBQuery("create_doc", doc, doc_id=doc_id) + + +db_indexes = { +    'by-chash': ['chash'], +    'by-number': ['number']} + + +def create_indexes(_): +    deferreds = [] +    for index, definition in db_indexes.items(): +        d = dbpool.runU1DBQuery("create_index", index, *definition) +        deferreds.append(d) +    return defer.gatherResults(deferreds) + + +class TimeWitness(object): +    def __init__(self, init_time): +        self.init_time = init_time + +    def get_time_count(self): +        return datetime.datetime.now() - self.init_time + + +def get_from_index(_): +    init_time = datetime.datetime.now() +    debug("GETTING FROM INDEX...", init_time) + +    def printValue(res, time): +        print("RESULT->", res) +        print("Index Query Took: ", time.get_time_count()) +        return res + +    d = dbpool.runU1DBQuery( +        "get_doc", +        # "1150c7f10fabce0a57ce13071349fc5064f15bdb0cc1bf2852f74ef3f103aff5") +        # XXX this is line 89 from the hacker crackdown... +        # Should accept any other optional hash as an enviroment variable. +        "57793320d4997a673fc7062652da0596c36a4e9fbe31310d2281e67d56d82469") +    d.addCallback(printValue, TimeWitness(init_time)) +    return d + + +def getAllDocs(): +    return dbpool.runU1DBQuery("get_all_docs") + + +def errBack(e): +    debug("[!] ERROR FOUND!!!") +    e.printTraceback() +    reactor.stop() + + +def countDocs(_): +    debug("counting docs...") +    d = getAllDocs() +    d.addCallbacks(printResult, errBack) +    d.addCallbacks(allDone, errBack) +    return d + + +def printResult(r, **kwargs): +    if kwargs: +        debug(*kwargs.values()) +    elif isinstance(r, l2db.Document): +        debug(r.doc_id, r.content['number']) +    else: +        len_results = len(r[1]) +        debug("GOT %s results" % len(r[1])) + +        if len_results == numdocs: +            debug("ALL GOOD") +        else: +            debug("[!] MISSING DOCS!!!!!") +            raise ValueError("We didn't expect this result len") + + +def allDone(_): +    debug("ALL DONE!") + +    end_time = datetime.datetime.now() +    print((end_time - start_time).total_seconds()) +    reactor.stop() + + +def insert_docs(_): +    deferreds = [] +    for i in range(numdocs): +        payload = SAMPLE[i] +        chash = hashlib.sha256(payload).hexdigest() +        doc = {"number": i, "payload": payload, 'chash': chash} +        d = createDoc(doc, doc_id=chash) +        d.addCallbacks(partial(printResult, i=i, chash=chash, payload=payload), +                       lambda e: e.printTraceback()) +        deferreds.append(d) +    return defer.gatherResults(deferreds, consumeErrors=True) + + +d = create_indexes(None) +d.addCallback(insert_docs) +d.addCallback(get_from_index) +d.addCallback(countDocs) + +reactor.run() diff --git a/src/leap/soledad/client/examples/compare.txt b/src/leap/soledad/client/examples/compare.txt new file mode 100644 index 00000000..19a1325a --- /dev/null +++ b/src/leap/soledad/client/examples/compare.txt @@ -0,0 +1,8 @@ +TIMES=100 TMPDIR=/media/sdb5/leap python use_adbapi.py  1.34s user 0.16s system 53% cpu 2.832 total +TIMES=100 TMPDIR=/media/sdb5/leap python use_api.py  1.22s user 0.14s system 62% cpu 2.181 total + +TIMES=1000 TMPDIR=/media/sdb5/leap python use_api.py  2.18s user 0.34s system 27% cpu 9.213 total +TIMES=1000 TMPDIR=/media/sdb5/leap python use_adbapi.py  2.40s user 0.34s system 39% cpu 7.004 total + +TIMES=5000 TMPDIR=/media/sdb5/leap python use_api.py  6.63s user 1.27s system 13% cpu 57.882 total +TIMES=5000 TMPDIR=/media/sdb5/leap python use_adbapi.py  6.84s user 1.26s system 36% cpu 22.367 total diff --git a/src/leap/soledad/client/examples/manifest.phk b/src/leap/soledad/client/examples/manifest.phk new file mode 100644 index 00000000..2c86c07d --- /dev/null +++ b/src/leap/soledad/client/examples/manifest.phk @@ -0,0 +1,50 @@ +The Hacker's Manifesto + +The Hacker's Manifesto +by: The Mentor + +Another one got caught today, it's all over the papers. "Teenager  +Arrested in Computer Crime Scandal", "Hacker Arrested after Bank  +Tampering." "Damn kids. They're all alike." But did you, in your  +three-piece psychology and 1950's technobrain, ever take a look behind  +the eyes of the hacker? Did you ever wonder what made him tick, what  +forces shaped him, what may have molded him? I am a hacker, enter my  +world. Mine is a world that begins with school. I'm smarter than most of  +the other kids, this crap they teach us bores me. "Damn underachiever.  +They're all alike." I'm in junior high or high school.  I've listened to  +teachers explain for the fifteenth time how to reduce a fraction. I  +understand it. "No, Ms. Smith, I didn't show my work. I did it in +my head." "Damn kid. Probably copied it. They're all alike." I made a  +discovery today. I found a computer. Wait a second, this is cool. It does  +what I want it to. If it makes a mistake, it's because I screwed it up.  +Not because it doesn't like me, or feels threatened by me, or thinks I'm  +a smart ass, or doesn't like teaching and shouldn't be here. Damn kid.  +All he does is play games. They're all alike. And then it happened... a  +door opened to a world... rushing through the phone line like heroin  +through an addict's veins, an electronic pulse is sent out, a refuge from  +the day-to-day incompetencies is sought... a board is found. "This is  +it... this is where I belong..." I know everyone here... even if I've  +never met them, never talked to them, may never hear from them again... I  +know you all... Damn kid. Tying up the phone line again. They're all  +alike... You bet your ass we're all alike... we've been spoon-fed baby  +food at school when we hungered for steak... the bits of meat that you  +did let slip through were pre-chewed and tasteless. We've been dominated  +by sadists, or ignored by the apathetic. The few that had something to  +teach found us willing pupils, but those few are like drops of water in  +the desert. This is our world now... the world of the electron and the  +switch, the beauty of the baud. We make use of a service already existing  +without paying for what could be dirt-cheap if it wasn't run by  +profiteering gluttons, and you call us criminals. We explore... and you  +call us criminals. We seek after knowledge... and you call us criminals.  +We exist without skin color, without nationality, without religious  +bias... and you call us criminals. You build atomic bombs, you wage wars,  +you murder, cheat, and lie to us and try to make us believe it's for our  +own good, yet we're the criminals. Yes, I am a criminal. My crime is that  +of curiosity. My crime is that of judging people by what they say and  +think, not what they look like. My crime is that of outsmarting you,  +something that you will never forgive me for. I am a hacker, and this is  +my manifesto.  You may stop this individual, but you can't stop us all...  +after all, we're all alike. + +This was the last published file written by The Mentor. Shortly after  +releasing it, he was busted by the FBI. The Mentor, sadly missed. diff --git a/src/leap/soledad/client/examples/plot-async-db.py b/src/leap/soledad/client/examples/plot-async-db.py new file mode 100644 index 00000000..018a1a1d --- /dev/null +++ b/src/leap/soledad/client/examples/plot-async-db.py @@ -0,0 +1,45 @@ +import csv +from matplotlib import pyplot as plt + +FILE = "bench.csv" + +# config the plot +plt.xlabel('number of inserts') +plt.ylabel('time (seconds)') +plt.title('SQLCipher parallelization') + +kwargs = { +    'linewidth': 1.0, +    'linestyle': '-', +} + +series = (('sync', 'r'), +          ('async', 'g')) + +data = {'mark': [], +        'sync': [], +        'async': []} + +with open(FILE, 'rb') as csvfile: +    series_reader = csv.reader(csvfile, delimiter=',') +    for m, s, a in series_reader: +        data['mark'].append(int(m)) +        data['sync'].append(float(s)) +        data['async'].append(float(a)) + +xmax = max(data['mark']) +xmin = min(data['mark']) +ymax = max(data['sync'] + data['async']) +ymin = min(data['sync'] + data['async']) + +for run in series: +    name = run[0] +    color = run[1] +    plt.plot(data['mark'], data[name], label=name, color=color, **kwargs) + +plt.axes().annotate("", xy=(xmax, ymax)) +plt.axes().annotate("", xy=(xmin, ymin)) + +plt.grid() +plt.legend() +plt.show() diff --git a/src/leap/soledad/client/examples/run_benchmark.py b/src/leap/soledad/client/examples/run_benchmark.py new file mode 100644 index 00000000..ddedf433 --- /dev/null +++ b/src/leap/soledad/client/examples/run_benchmark.py @@ -0,0 +1,30 @@ +""" +Run a mini-benchmark between regular api and dbapi +""" +import commands +import os +import time + +TMPDIR = os.environ.get("TMPDIR", "/tmp") +CSVFILE = 'bench.csv' + +cmd = "SILENT=1 TIMES={times} TMPDIR={tmpdir} python ./use_{version}api.py" + + +def parse_time(r): +    return r.split('\n')[-1] + + +with open(CSVFILE, 'w') as log: + +    for times in range(0, 10000, 500): +        cmd1 = cmd.format(times=times, tmpdir=TMPDIR, version="") +        sync_time = parse_time(commands.getoutput(cmd1)) + +        cmd2 = cmd.format(times=times, tmpdir=TMPDIR, version="adb") +        async_time = parse_time(commands.getoutput(cmd2)) + +        print times, sync_time, async_time +        log.write("%s, %s, %s\n" % (times, sync_time, async_time)) +        log.flush() +        time.sleep(2) diff --git a/src/leap/soledad/client/examples/soledad_sync.py b/src/leap/soledad/client/examples/soledad_sync.py new file mode 100644 index 00000000..3aed10eb --- /dev/null +++ b/src/leap/soledad/client/examples/soledad_sync.py @@ -0,0 +1,63 @@ +from leap.bitmask.config.providerconfig import ProviderConfig +from leap.bitmask.crypto.srpauth import SRPAuth +from leap.soledad.client import Soledad +from twisted.internet import reactor +import logging +logging.basicConfig(level=logging.DEBUG) + + +# EDIT THIS -------------------------------------------- +user = u"USERNAME" +uuid = u"USERUUID" +_pass = u"USERPASS" +server_url = "https://soledad.server.example.org:2323" +# EDIT THIS -------------------------------------------- + +secrets_path = "/tmp/%s.secrets" % uuid +local_db_path = "/tmp/%s.soledad" % uuid +cert_file = "/tmp/cacert.pem" +provider_config = '/tmp/cdev.json' + + +provider = ProviderConfig() +provider.load(provider_config) + +soledad = None + + +def printStuff(r): +    print r + + +def printErr(err): +    logging.exception(err.value) + + +def init_soledad(_): +    token = srpauth.get_token() +    print "token", token + +    global soledad +    soledad = Soledad(uuid, _pass, secrets_path, local_db_path, +                      server_url, cert_file, +                      auth_token=token) + +    def getall(_): +        d = soledad.get_all_docs() +        return d + +    d1 = soledad.create_doc({"test": 42}) +    d1.addCallback(getall) +    d1.addCallbacks(printStuff, printErr) + +    d2 = soledad.sync() +    d2.addCallbacks(printStuff, printErr) +    d2.addBoth(lambda r: reactor.stop()) + + +srpauth = SRPAuth(provider) + +d = srpauth.authenticate(user, _pass) +d.addCallbacks(init_soledad, printErr) + +reactor.run() diff --git a/src/leap/soledad/client/examples/use_adbapi.py b/src/leap/soledad/client/examples/use_adbapi.py new file mode 100644 index 00000000..ddb1eaae --- /dev/null +++ b/src/leap/soledad/client/examples/use_adbapi.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# use_adbapi.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Example of use of the asynchronous soledad api. +""" +from __future__ import print_function +import datetime +import os + +from twisted.internet import defer, reactor + +from leap.soledad.client import adbapi +from leap.soledad.client._db.sqlcipher import SQLCipherOptions +from leap.soledad.common import l2db + + +folder = os.environ.get("TMPDIR", "tmp") +times = int(os.environ.get("TIMES", "1000")) +silent = os.environ.get("SILENT", False) + +tmpdb = os.path.join(folder, "test.soledad") + + +def debug(*args): +    if not silent: +        print(*args) + + +debug("[+] db path:", tmpdb) +debug("[+] times", times) + +if os.path.isfile(tmpdb): +    debug("[+] Removing existing db file...") +    os.remove(tmpdb) + +start_time = datetime.datetime.now() + +opts = SQLCipherOptions(tmpdb, "secret", create=True) +dbpool = adbapi.getConnectionPool(opts) + + +def createDoc(doc): +    return dbpool.runU1DBQuery("create_doc", doc) + + +def getAllDocs(): +    return dbpool.runU1DBQuery("get_all_docs") + + +def countDocs(_): +    debug("counting docs...") +    d = getAllDocs() +    d.addCallbacks(printResult, lambda e: e.printTraceback()) +    d.addBoth(allDone) + + +def printResult(r): +    if isinstance(r, l2db.Document): +        debug(r.doc_id, r.content['number']) +    else: +        len_results = len(r[1]) +        debug("GOT %s results" % len(r[1])) + +        if len_results == times: +            debug("ALL GOOD") +        else: +            raise ValueError("We didn't expect this result len") + + +def allDone(_): +    debug("ALL DONE!") +    if silent: +        end_time = datetime.datetime.now() +        print((end_time - start_time).total_seconds()) +    reactor.stop() + + +deferreds = [] +payload = open('manifest.phk').read() + +for i in range(times): +    doc = {"number": i, "payload": payload} +    d = createDoc(doc) +    d.addCallbacks(printResult, lambda e: e.printTraceback()) +    deferreds.append(d) + + +all_done = defer.gatherResults(deferreds, consumeErrors=True) +all_done.addCallback(countDocs) + +reactor.run() diff --git a/src/leap/soledad/client/examples/use_api.py b/src/leap/soledad/client/examples/use_api.py new file mode 100644 index 00000000..db77c4b3 --- /dev/null +++ b/src/leap/soledad/client/examples/use_api.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# use_api.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Example of use of the soledad api. +""" +from __future__ import print_function +import datetime +import os + +from leap.soledad.client import sqlcipher +from leap.soledad.client.sqlcipher import SQLCipherOptions + + +folder = os.environ.get("TMPDIR", "tmp") +times = int(os.environ.get("TIMES", "1000")) +silent = os.environ.get("SILENT", False) + +tmpdb = os.path.join(folder, "test.soledad") + + +def debug(*args): +    if not silent: +        print(*args) + + +debug("[+] db path:", tmpdb) +debug("[+] times", times) + +if os.path.isfile(tmpdb): +    debug("[+] Removing existing db file...") +    os.remove(tmpdb) + +start_time = datetime.datetime.now() + +opts = SQLCipherOptions(tmpdb, "secret", create=True) +db = sqlcipher.SQLCipherDatabase(opts) + + +def allDone(): +    debug("ALL DONE!") + + +payload = open('manifest.phk').read() + +for i in range(times): +    doc = {"number": i, "payload": payload} +    d = db.create_doc(doc) +    debug(d.doc_id, d.content['number']) + +debug("Count", len(db.get_all_docs()[1])) +if silent: +    end_time = datetime.datetime.now() +    print((end_time - start_time).total_seconds()) + +allDone() diff --git a/src/leap/soledad/client/http_target/__init__.py b/src/leap/soledad/client/http_target/__init__.py new file mode 100644 index 00000000..b67d03f6 --- /dev/null +++ b/src/leap/soledad/client/http_target/__init__.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +A U1DB backend for encrypting data before sending to server and decrypting +after receiving. +""" + + +import os + +from twisted.web.client import Agent +from twisted.internet import reactor + +from leap.common.certs import get_compatible_ssl_context_factory +from leap.soledad.common.log import getLogger +from leap.soledad.client.http_target.send import HTTPDocSender +from leap.soledad.client.http_target.api import SyncTargetAPI +from leap.soledad.client.http_target.fetch import HTTPDocFetcher +from leap.soledad.client import crypto as old_crypto + + +logger = getLogger(__name__) + + +# we may want to collect statistics from the sync process +DO_STATS = False +if os.environ.get('SOLEDAD_STATS'): +    DO_STATS = True + + +class SoledadHTTPSyncTarget(SyncTargetAPI, HTTPDocSender, HTTPDocFetcher): + +    """ +    A SyncTarget that encrypts data before sending and decrypts data after +    receiving. + +    Normally encryption will have been written to the sync database upon +    document modification. The sync database is also used to write temporarily +    the parsed documents that the remote send us, before being decrypted and +    written to the main database. +    """ +    def __init__(self, url, source_replica_uid, creds, crypto, cert_file): +        """ +        Initialize the sync target. + +        :param url: The server sync url. +        :type url: str +        :param source_replica_uid: The source replica uid which we use when +                                   deferring decryption. +        :type source_replica_uid: str +        :param creds: A dictionary containing the uuid and token. +        :type creds: creds +        :param crypto: An instance of SoledadCrypto so we can encrypt/decrypt +                        document contents when syncing. +        :type crypto: soledad._crypto.SoledadCrypto +        :param cert_file: Path to the certificate of the ca used to validate +                          the SSL certificate used by the remote soledad +                          server. +        :type cert_file: str +        """ +        if url.endswith("/"): +            url = url[:-1] +        self._url = str(url) + "/sync-from/" + str(source_replica_uid) +        self.source_replica_uid = source_replica_uid +        self._auth_header = None +        self._uuid = None +        self.set_creds(creds) +        self._crypto = crypto +        # TODO: DEPRECATED CRYPTO +        self._deprecated_crypto = old_crypto.SoledadCrypto(crypto.secret) +        self._insert_doc_cb = None + +        # Twisted default Agent with our own ssl context factory +        factory = get_compatible_ssl_context_factory(cert_file) +        self._http = Agent(reactor, factory) + +        if DO_STATS: +            self.sync_exchange_phase = [0] diff --git a/src/leap/soledad/client/http_target/api.py b/src/leap/soledad/client/http_target/api.py new file mode 100644 index 00000000..c68185c6 --- /dev/null +++ b/src/leap/soledad/client/http_target/api.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +# api.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import os +import json +import base64 + +from six import StringIO +from uuid import uuid4 + +from twisted.internet import defer +from twisted.web.http_headers import Headers +from twisted.web.client import FileBodyProducer + +from leap.soledad.client.http_target.support import readBody +from leap.soledad.common.errors import InvalidAuthTokenError +from leap.soledad.common.l2db.errors import HTTPError +from leap.soledad.common.l2db import SyncTarget + + +# we may want to collect statistics from the sync process +DO_STATS = False +if os.environ.get('SOLEDAD_STATS'): +    DO_STATS = True + + +class SyncTargetAPI(SyncTarget): +    """ +    Declares public methods and implements u1db.SyncTarget. +    """ + +    @property +    def uuid(self): +        return self._uuid + +    def set_creds(self, creds): +        """ +        Update credentials. + +        :param creds: A dictionary containing the uuid and token. +        :type creds: dict +        """ +        uuid = creds['token']['uuid'] +        token = creds['token']['token'] +        self._uuid = uuid +        auth = '%s:%s' % (uuid, token) +        b64_token = base64.b64encode(auth) +        self._auth_header = {'Authorization': ['Token %s' % b64_token]} + +    @property +    def _base_header(self): +        return self._auth_header.copy() if self._auth_header else {} + +    def _http_request(self, url, method='GET', body=None, headers=None, +                      content_type=None, body_reader=readBody, +                      body_producer=None): +        headers = headers or self._base_header +        if content_type: +            headers.update({'content-type': [content_type]}) +        if not body_producer and body: +            body = FileBodyProducer(StringIO(body)) +        elif body_producer: +            # Upload case, check send.py +            body = body_producer(body) +        d = self._http.request( +            method, url, headers=Headers(headers), bodyProducer=body) +        d.addCallback(body_reader) +        d.addErrback(_unauth_to_invalid_token_error) +        return d + +    @defer.inlineCallbacks +    def get_sync_info(self, source_replica_uid): +        """ +        Return information about known state of remote database. + +        Return the replica_uid and the current database generation of the +        remote database, and its last-seen database generation for the client +        replica. + +        :param source_replica_uid: The client-size replica uid. +        :type source_replica_uid: str + +        :return: A deferred which fires with (target_replica_uid, +                 target_replica_generation, target_trans_id, +                 source_replica_last_known_generation, +                 source_replica_last_known_transaction_id) +        :rtype: twisted.internet.defer.Deferred +        """ +        raw = yield self._http_request(self._url) +        res = json.loads(raw) +        defer.returnValue(( +            res['target_replica_uid'], +            res['target_replica_generation'], +            res['target_replica_transaction_id'], +            res['source_replica_generation'], +            res['source_transaction_id'] +        )) + +    def record_sync_info( +            self, source_replica_uid, source_replica_generation, +            source_replica_transaction_id): +        """ +        Record tip information for another replica. + +        After sync_exchange has been processed, the caller will have +        received new content from this replica. This call allows the +        source replica instigating the sync to inform us what their +        generation became after applying the documents we returned. + +        This is used to allow future sync operations to not need to repeat data +        that we just talked about. It also means that if this is called at the +        wrong time, there can be database records that will never be +        synchronized. + +        :param source_replica_uid: The identifier for the source replica. +        :type source_replica_uid: str +        :param source_replica_generation: The database generation for the +                                          source replica. +        :type source_replica_generation: int +        :param source_replica_transaction_id: The transaction id associated +                                              with the source replica +                                              generation. +        :type source_replica_transaction_id: str + +        :return: A deferred which fires with the result of the query. +        :rtype: twisted.internet.defer.Deferred +        """ +        data = json.dumps({ +            'generation': source_replica_generation, +            'transaction_id': source_replica_transaction_id +        }) +        return self._http_request( +            self._url, +            method='PUT', +            body=data, +            content_type='application/json') + +    @defer.inlineCallbacks +    def sync_exchange(self, docs_by_generation, source_replica_uid, +                      last_known_generation, last_known_trans_id, +                      insert_doc_cb, ensure_callback=None, +                      sync_id=None): +        """ +        Find out which documents the remote database does not know about, +        encrypt and send them. After that, receive documents from the remote +        database. + +        :param docs_by_generations: A list of (doc_id, generation, trans_id) +                                    of local documents that were changed since +                                    the last local generation the remote +                                    replica knows about. +        :type docs_by_generations: list of tuples + +        :param source_replica_uid: The uid of the source replica. +        :type source_replica_uid: str + +        :param last_known_generation: Target's last known generation. +        :type last_known_generation: int + +        :param last_known_trans_id: Target's last known transaction id. +        :type last_known_trans_id: str + +        :param insert_doc_cb: A callback for inserting received documents from +                              target. If not overriden, this will call u1db +                              insert_doc_from_target in synchronizer, which +                              implements the TAKE OTHER semantics. +        :type insert_doc_cb: function + +        :param ensure_callback: A callback that ensures we know the target +                                replica uid if the target replica was just +                                created. +        :type ensure_callback: function + +        :return: A deferred which fires with the new generation and +                 transaction id of the target replica. +        :rtype: twisted.internet.defer.Deferred +        """ +        # ---------- phase 1: send docs to server ---------------------------- +        if DO_STATS: +            self.sync_exchange_phase[0] += 1 +        # -------------------------------------------------------------------- + +        self._ensure_callback = ensure_callback + +        if sync_id is None: +            sync_id = str(uuid4()) +        self.source_replica_uid = source_replica_uid + +        # save a reference to the callback so we can use it after decrypting +        self._insert_doc_cb = insert_doc_cb + +        gen_after_send, trans_id_after_send = yield self._send_docs( +            docs_by_generation, +            last_known_generation, +            last_known_trans_id, +            sync_id) + +        # ---------- phase 2: receive docs ----------------------------------- +        if DO_STATS: +            self.sync_exchange_phase[0] += 1 +        # -------------------------------------------------------------------- + +        cur_target_gen, cur_target_trans_id = yield self._receive_docs( +            last_known_generation, last_known_trans_id, +            ensure_callback, sync_id) + +        # update gen and trans id info in case we just sent and did not +        # receive docs. +        if gen_after_send is not None and gen_after_send > cur_target_gen: +            cur_target_gen = gen_after_send +            cur_target_trans_id = trans_id_after_send + +        # ---------- phase 3: sync exchange is over -------------------------- +        if DO_STATS: +            self.sync_exchange_phase[0] += 1 +        # -------------------------------------------------------------------- + +        defer.returnValue([cur_target_gen, cur_target_trans_id]) + + +def _unauth_to_invalid_token_error(failure): +    """ +    An errback to translate unauthorized errors to our own invalid token +    class. + +    :param failure: The original failure. +    :type failure: twisted.python.failure.Failure + +    :return: Either the original failure or an invalid auth token error. +    :rtype: twisted.python.failure.Failure +    """ +    failure.trap(HTTPError) +    if failure.value.status == 401: +        raise InvalidAuthTokenError +    return failure diff --git a/src/leap/soledad/client/http_target/fetch.py b/src/leap/soledad/client/http_target/fetch.py new file mode 100644 index 00000000..9d456830 --- /dev/null +++ b/src/leap/soledad/client/http_target/fetch.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# fetch.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import json +from twisted.internet import defer +from twisted.internet import threads + +from leap.soledad.client.events import SOLEDAD_SYNC_RECEIVE_STATUS +from leap.soledad.client.events import emit_async +from leap.soledad.client.http_target.support import RequestBody +from leap.soledad.common.log import getLogger +from leap.soledad.client._crypto import is_symmetrically_encrypted +from leap.soledad.common.l2db import errors +from leap.soledad.client import crypto as old_crypto + +from .._document import Document +from . import fetch_protocol + +logger = getLogger(__name__) + + +class HTTPDocFetcher(object): +    """ +    Handles Document fetching from Soledad server, using HTTP as transport. +    Steps: +    * Prepares metadata by asking server for one document +    * Fetch the total on response and prepare to ask all remaining +    * (async) Documents will come encrypted. +              So we parse, decrypt and insert locally as they arrive. +    """ + +    # The uuid of the local replica. +    # Any class inheriting from this one should provide a meaningful attribute +    # if the sync status event is meant to be used somewhere else. + +    uuid = 'undefined' +    userid = 'undefined' + +    @defer.inlineCallbacks +    def _receive_docs(self, last_known_generation, last_known_trans_id, +                      ensure_callback, sync_id): +        new_generation = last_known_generation +        new_transaction_id = last_known_trans_id +        # Acts as a queue, ensuring line order on async processing +        # as `self._insert_doc_cb` cant be run concurrently or out of order. +        # DeferredSemaphore solves the concurrency and its implementation uses +        # a queue, solving the ordering. +        # FIXME: Find a proper solution to avoid surprises on Twisted changes +        self.semaphore = defer.DeferredSemaphore(1) + +        metadata = yield self._fetch_all( +            last_known_generation, last_known_trans_id, +            sync_id) +        number_of_changes, ngen, ntrans = self._parse_metadata(metadata) + +        # wait for pending inserts +        yield self.semaphore.acquire() + +        if ngen: +            new_generation = ngen +            new_transaction_id = ntrans + +        defer.returnValue([new_generation, new_transaction_id]) + +    def _fetch_all(self, last_known_generation, +                   last_known_trans_id, sync_id): +        # add remote replica metadata to the request +        body = RequestBody( +            last_known_generation=last_known_generation, +            last_known_trans_id=last_known_trans_id, +            sync_id=sync_id, +            ensure=self._ensure_callback is not None) +        self._received_docs = 0 +        # build a stream reader with _doc_parser as a callback +        body_reader = fetch_protocol.build_body_reader(self._doc_parser) +        # start download stream +        return self._http_request( +            self._url, +            method='POST', +            body=str(body), +            content_type='application/x-soledad-sync-get', +            body_reader=body_reader) + +    @defer.inlineCallbacks +    def _doc_parser(self, doc_info, content, total): +        """ +        Insert a received document into the local replica, decrypting +        if necessary. The case where it's not decrypted is when a doc gets +        inserted from Server side with a GPG encrypted content. + +        :param doc_info: Dictionary representing Document information. +        :type doc_info: dict +        :param content: The Document's content. +        :type idx: str +        :param total: The total number of operations. +        :type total: int +        """ +        yield self.semaphore.run(self.__atomic_doc_parse, doc_info, content, +                                 total) + +    @defer.inlineCallbacks +    def __atomic_doc_parse(self, doc_info, content, total): +        doc = Document(doc_info['id'], doc_info['rev'], content) +        if is_symmetrically_encrypted(content): +            content = (yield self._crypto.decrypt_doc(doc)).getvalue() +        elif old_crypto.is_symmetrically_encrypted(doc): +            content = self._deprecated_crypto.decrypt_doc(doc) +        doc.set_json(content) + +        # TODO insert blobs here on the blob backend +        # FIXME: This is wrong. Using the very same SQLite connection object +        # from multiple threads is dangerous. We should bring the dbpool here +        # or find an alternative.  Deferring to a thread only helps releasing +        # the reactor for other tasks as this is an IO intensive call. +        yield threads.deferToThread(self._insert_doc_cb, +                                    doc, doc_info['gen'], doc_info['trans_id']) +        self._received_docs += 1 +        user_data = {'uuid': self.uuid, 'userid': self.userid} +        _emit_receive_status(user_data, self._received_docs, total=total) + +    def _parse_metadata(self, metadata): +        """ +        Parse the response from the server containing the sync metadata. + +        :param response: Metadata as string +        :type response: str + +        :return: (number_of_changes, new_gen, new_trans_id) +        :rtype: tuple +        """ +        try: +            metadata = json.loads(metadata) +            # make sure we have replica_uid from fresh new dbs +            if self._ensure_callback and 'replica_uid' in metadata: +                self._ensure_callback(metadata['replica_uid']) +            return (metadata['number_of_changes'], metadata['new_generation'], +                    metadata['new_transaction_id']) +        except (ValueError, KeyError): +            raise errors.BrokenSyncStream('Metadata parsing failed') + + +def _emit_receive_status(user_data, received_docs, total): +    content = {'received': received_docs, 'total': total} +    emit_async(SOLEDAD_SYNC_RECEIVE_STATUS, user_data, content) + +    if received_docs % 20 == 0: +        msg = "%d/%d" % (received_docs, total) +        logger.debug("Sync receive status: %s" % msg) diff --git a/src/leap/soledad/client/http_target/fetch_protocol.py b/src/leap/soledad/client/http_target/fetch_protocol.py new file mode 100644 index 00000000..851eb3a1 --- /dev/null +++ b/src/leap/soledad/client/http_target/fetch_protocol.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +# fetch_protocol.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import json +from functools import partial +from six import StringIO +from twisted.web._newclient import ResponseDone +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.remote import utils +from leap.soledad.common.log import getLogger +from .support import ReadBodyProtocol +from .support import readBody + +logger = getLogger(__name__) + + +class DocStreamReceiver(ReadBodyProtocol): +    """ +    A protocol implementation that can parse incoming data from server based +    on a line format specified on u1db implementation. Except that we split doc +    attributes from content to ease parsing and increment throughput for larger +    documents. +    [\r\n +    {metadata},\r\n +    {doc_info},\r\n +    {content},\r\n +    ... +    {doc_info},\r\n +    {content},\r\n +    ] +    """ + +    def __init__(self, response, deferred, doc_reader): +        self.deferred = deferred +        self.status = response.code if response else None +        self.message = response.phrase if response else None +        self.headers = response.headers if response else {} +        self.delimiter = '\r\n' +        self.metadata = '' +        self._doc_reader = doc_reader +        self.reset() + +    def reset(self): +        self._line = 0 +        self._buffer = StringIO() +        self._properly_finished = False + +    def connectionLost(self, reason): +        """ +        Deliver the accumulated response bytes to the waiting L{Deferred}, if +        the response body has been completely received without error. +        """ +        if self.deferred.called: +            return +        try: +            if reason.check(ResponseDone): +                self.dataBuffer = self.metadata +            else: +                self.dataBuffer = self.finish() +        except errors.BrokenSyncStream as e: +            return self.deferred.errback(e) +        return ReadBodyProtocol.connectionLost(self, reason) + +    def consumeBufferLines(self): +        """ +        Consumes lines from buffer and rewind it, writing remaining data +        that didn't formed a line back into buffer. +        """ +        content = self._buffer.getvalue()[0:self._buffer.tell()] +        self._buffer.seek(0) +        lines = content.split(self.delimiter) +        self._buffer.write(lines.pop(-1)) +        return lines + +    def dataReceived(self, data): +        """ +        Buffer incoming data until a line breaks comes in. We check only +        the incoming data for efficiency. +        """ +        self._buffer.write(data) +        if '\n' not in data: +            return +        lines = self.consumeBufferLines() +        while lines: +            line, _ = utils.check_and_strip_comma(lines.pop(0)) +            self.lineReceived(line) +            self._line += 1 + +    def lineReceived(self, line): +        """ +        Protocol implementation. +        0:      [\r\n +        1:      {metadata},\r\n +        (even): {doc_info},\r\n +        (odd):  {data},\r\n +        (last): ] +        """ +        if self._properly_finished: +            raise errors.BrokenSyncStream("Reading a finished stream") +        if ']' == line: +            self._properly_finished = True +        elif self._line == 0: +            if line is not '[': +                raise errors.BrokenSyncStream("Invalid start") +        elif self._line == 1: +            self.metadata = line +            if 'error' in self.metadata: +                raise errors.BrokenSyncStream("Error from server: %s" % line) +            self.total = json.loads(line).get('number_of_changes', -1) +        elif (self._line % 2) == 0: +            self.current_doc = json.loads(line) +            if 'error' in self.current_doc: +                raise errors.BrokenSyncStream("Error from server: %s" % line) +        else: +            d = self._doc_reader( +                self.current_doc, line.strip() or None, self.total) +            d.addErrback(self.deferred.errback) + +    def finish(self): +        """ +        Checks that ']' came and stream was properly closed. +        """ +        if not self._properly_finished: +            raise errors.BrokenSyncStream('Stream not properly closed') +        content = self._buffer.getvalue()[0:self._buffer.tell()] +        self._buffer.close() +        return content + + +def build_body_reader(doc_reader): +    """ +    Get the documents from a sync stream and call doc_reader on each +    doc received. + +    @param doc_reader: Function to be called for processing an incoming doc. +        Will be called with doc metadata (dict parsed from 1st line) and doc +        content (string) +    @type doc_reader: function + +    @return: A function that can be called by the http Agent to create and +    configure the proper protocol. +    """ +    protocolClass = partial(DocStreamReceiver, doc_reader=doc_reader) +    return partial(readBody, protocolClass=protocolClass) diff --git a/src/leap/soledad/client/http_target/send.py b/src/leap/soledad/client/http_target/send.py new file mode 100644 index 00000000..2b286ec5 --- /dev/null +++ b/src/leap/soledad/client/http_target/send.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# send.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import json + +from twisted.internet import defer + +from leap.soledad.common.log import getLogger +from leap.soledad.client.events import emit_async +from leap.soledad.client.events import SOLEDAD_SYNC_SEND_STATUS +from leap.soledad.client.http_target.support import RequestBody +from .send_protocol import DocStreamProducer + +logger = getLogger(__name__) + + +class HTTPDocSender(object): +    """ +    Handles Document uploading from Soledad server, using HTTP as transport. +    They need to be encrypted and metadata prepared before sending. +    """ + +    # The uuid of the local replica. +    # Any class inheriting from this one should provide a meaningful attribute +    # if the sync status event is meant to be used somewhere else. + +    uuid = 'undefined' +    userid = 'undefined' + +    @defer.inlineCallbacks +    def _send_docs(self, docs_by_generation, last_known_generation, +                   last_known_trans_id, sync_id): + +        if not docs_by_generation: +            defer.returnValue([None, None]) + +        # add remote replica metadata to the request +        body = RequestBody( +            last_known_generation=last_known_generation, +            last_known_trans_id=last_known_trans_id, +            sync_id=sync_id, +            ensure=self._ensure_callback is not None) +        result = yield self._send_batch(body, docs_by_generation) +        response_dict = json.loads(result)[0] +        gen_after_send = response_dict['new_generation'] +        trans_id_after_send = response_dict['new_transaction_id'] +        defer.returnValue([gen_after_send, trans_id_after_send]) + +    @defer.inlineCallbacks +    def _send_batch(self, body, docs): +        total, calls = len(docs), [] +        for i, entry in enumerate(docs): +            calls.append((self._prepare_one_doc, +                         entry, body, i + 1, total)) +        result = yield self._send_request(body, calls) +        _emit_send_status(self.uuid, body.consumed, total) + +        defer.returnValue(result) + +    def _send_request(self, body, calls): +        return self._http_request( +            self._url, +            method='POST', +            body=(body, calls), +            content_type='application/x-soledad-sync-put', +            body_producer=DocStreamProducer) + +    @defer.inlineCallbacks +    def _prepare_one_doc(self, entry, body, idx, total): +        get_doc_call, gen, trans_id = entry +        doc, content = yield self._encrypt_doc(get_doc_call) +        body.insert_info( +            id=doc.doc_id, rev=doc.rev, content=content, gen=gen, +            trans_id=trans_id, number_of_docs=total, +            doc_idx=idx) +        _emit_send_status(self.uuid, body.consumed, total) + +    @defer.inlineCallbacks +    def _encrypt_doc(self, get_doc_call): +        f, args, kwargs = get_doc_call +        doc = yield f(*args, **kwargs) +        if doc.is_tombstone(): +            defer.returnValue((doc, None)) +        else: +            content = yield self._crypto.encrypt_doc(doc) +            defer.returnValue((doc, content)) + + +def _emit_send_status(user_data, idx, total): +    content = {'sent': idx, 'total': total} +    emit_async(SOLEDAD_SYNC_SEND_STATUS, user_data, content) + +    msg = "%d/%d" % (idx, total) +    logger.debug("Sync send status: %s" % msg) diff --git a/src/leap/soledad/client/http_target/send_protocol.py b/src/leap/soledad/client/http_target/send_protocol.py new file mode 100644 index 00000000..4941aa34 --- /dev/null +++ b/src/leap/soledad/client/http_target/send_protocol.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# send_protocol.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +from zope.interface import implementer +from twisted.internet import defer +from twisted.internet import reactor +from twisted.web.iweb import IBodyProducer +from twisted.web.iweb import UNKNOWN_LENGTH + + +@implementer(IBodyProducer) +class DocStreamProducer(object): +    """ +    A producer that writes the body of a request to a consumer. +    """ + +    def __init__(self, producer): +        """ +        Initialize the string produer. + +        :param producer: A RequestBody instance and a list of producer calls +        :type producer: (.support.RequestBody, [(function, *args)]) +        """ +        self.body, self.producer = producer +        self.length = UNKNOWN_LENGTH +        self.pause = False +        self.stop = False + +    @defer.inlineCallbacks +    def startProducing(self, consumer): +        """ +        Write the body to the consumer. + +        :param consumer: Any IConsumer provider. +        :type consumer: twisted.internet.interfaces.IConsumer + +        :return: A Deferred that fires when production ends. +        :rtype: twisted.internet.defer.Deferred +        """ +        while self.producer and not self.stop: +            if self.pause: +                yield self.sleep(0.001) +                continue +            call = self.producer.pop(0) +            fun, args = call[0], call[1:] +            yield fun(*args) +            consumer.write(self.body.pop(1, leave_open=True)) +        consumer.write(self.body.pop(0))  # close stream + +    def sleep(self, secs): +        d = defer.Deferred() +        reactor.callLater(secs, d.callback, None) +        return d + +    def pauseProducing(self): +        self.pause = True + +    def stopProducing(self): +        self.stop = True + +    def resumeProducing(self): +        self.pause = False diff --git a/src/leap/soledad/client/http_target/support.py b/src/leap/soledad/client/http_target/support.py new file mode 100644 index 00000000..d8d8e420 --- /dev/null +++ b/src/leap/soledad/client/http_target/support.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# support.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import warnings +import json + +from twisted.internet import defer +from twisted.web.client import _ReadBodyProtocol +from twisted.web.client import PartialDownloadError +from twisted.web._newclient import ResponseDone +from twisted.web._newclient import PotentialDataLoss + +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.remote import http_errors + +# we want to make sure that HTTP errors will raise appropriate u1db errors, +# that is, fire errbacks with the appropriate failures, in the context of +# twisted. Because of that, we redefine the http body reader used by the HTTP +# client below. + + +class ReadBodyProtocol(_ReadBodyProtocol): +    """ +    From original Twisted implementation, focused on adding our error +    handling and ensuring that the proper u1db error is raised. +    """ + +    def __init__(self, response, deferred): +        """ +        Initialize the protocol, additionally storing the response headers. +        """ +        _ReadBodyProtocol.__init__( +            self, response.code, response.phrase, deferred) +        self.headers = response.headers + +    # ---8<--- snippet from u1db.remote.http_client, modified to use errbacks +    def _error(self, respdic): +        descr = respdic.get("error") +        exc_cls = errors.wire_description_to_exc.get(descr) +        if exc_cls is not None: +            message = respdic.get("message") +            self.deferred.errback(exc_cls(message)) +        else: +            self.deferred.errback( +                errors.HTTPError(self.status, respdic, self.headers)) +    # ---8<--- end of snippet from u1db.remote.http_client + +    def connectionLost(self, reason): +        """ +        Deliver the accumulated response bytes to the waiting L{Deferred}, if +        the response body has been completely received without error. +        """ +        if reason.check(ResponseDone): + +            body = b''.join(self.dataBuffer) + +            # ---8<--- snippet from u1db.remote.http_client +            if self.status in (200, 201): +                self.deferred.callback(body) +            elif self.status in http_errors.ERROR_STATUSES: +                try: +                    respdic = json.loads(body) +                except ValueError: +                    self.deferred.errback( +                        errors.HTTPError(self.status, body, self.headers)) +                else: +                    self._error(respdic) +            # special cases +            elif self.status == 503: +                self.deferred.errback(errors.Unavailable(body, self.headers)) +            else: +                self.deferred.errback( +                    errors.HTTPError(self.status, body, self.headers)) +            # ---8<--- end of snippet from u1db.remote.http_client + +        elif reason.check(PotentialDataLoss): +            self.deferred.errback( +                PartialDownloadError(self.status, self.message, +                                     b''.join(self.dataBuffer))) +        else: +            self.deferred.errback(reason) + + +def readBody(response, protocolClass=ReadBodyProtocol): +    """ +    Get the body of an L{IResponse} and return it as a byte string. + +    This is a helper function for clients that don't want to incrementally +    receive the body of an HTTP response. + +    @param response: The HTTP response for which the body will be read. +    @type response: L{IResponse} provider + +    @return: A L{Deferred} which will fire with the body of the response. +        Cancelling it will close the connection to the server immediately. +    """ +    def cancel(deferred): +        """ +        Cancel a L{readBody} call, close the connection to the HTTP server +        immediately, if it is still open. + +        @param deferred: The cancelled L{defer.Deferred}. +        """ +        abort = getAbort() +        if abort is not None: +            abort() + +    d = defer.Deferred(cancel) +    protocol = protocolClass(response, d) + +    def getAbort(): +        return getattr(protocol.transport, 'abortConnection', None) + +    response.deliverBody(protocol) + +    if protocol.transport is not None and getAbort() is None: +        warnings.warn( +            'Using readBody with a transport that does not have an ' +            'abortConnection method', +            category=DeprecationWarning, +            stacklevel=2) + +    return d + + +class RequestBody(object): +    """ +    This class is a helper to generate send and fetch requests. +    The expected format is something like: +    [ +    {headers}, +    {entry1}, +    {...}, +    {entryN}, +    ] +    """ + +    def __init__(self, **header_dict): +        """ +        Creates a new RequestBody holding header information. + +        :param header_dict: A dictionary with the headers. +        :type header_dict: dict +        """ +        self.headers = header_dict +        self.entries = [] +        self.consumed = 0 + +    def insert_info(self, **entry_dict): +        """ +        Dumps an entry into JSON format and add it to entries list. +        Adds 'content' key on a new line if it's present. + +        :param entry_dict: Entry as a dictionary +        :type entry_dict: dict +        """ +        content = '' +        if 'content' in entry_dict: +            content = ',\r\n' + (entry_dict['content'] or '') +        entry = json.dumps(entry_dict) + content +        self.entries.append(entry) + +    def pop(self, amount=10, leave_open=False): +        """ +        Removes entries and returns it formatted and ready +        to be sent. + +        :param amount: number of entries to pop and format +        :type amount: int + +        :param leave_open: flag to skip stream closing +        :type amount: bool + +        :return: formatted body ready to be sent +        :rtype: str +        """ +        start = self.consumed == 0 +        amount = min([len(self.entries), amount]) +        entries = [self.entries.pop(0) for i in xrange(amount)] +        self.consumed += amount +        end = len(self.entries) == 0 if not leave_open else False +        return self.entries_to_str(entries, start, end) + +    def __str__(self): +        return self.pop(len(self.entries)) + +    def __len__(self): +        return len(self.entries) + +    def entries_to_str(self, entries=None, start=True, end=True): +        """ +        Format a list of entries into the body format expected +        by the server. + +        :param entries: entries to format +        :type entries: list + +        :return: formatted body ready to be sent +        :rtype: str +        """ +        data = '' +        if start: +            data = '[\r\n' + json.dumps(self.headers) +        data += ''.join(',\r\n' + entry for entry in entries) +        if end: +            data += '\r\n]' +        return data diff --git a/src/leap/soledad/client/interfaces.py b/src/leap/soledad/client/interfaces.py new file mode 100644 index 00000000..0600449f --- /dev/null +++ b/src/leap/soledad/client/interfaces.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- +# interfaces.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Interfaces used by the Soledad Client. +""" +from zope.interface import Interface, Attribute + +# +# Plugins +# + + +class ISoledadPostSyncPlugin(Interface): +    """ +    I implement the minimal methods and attributes for a plugin that can be +    called after a soledad synchronization has ended. +    """ + +    def process_received_docs(self, doc_id_list): +        """ +        Do something with the passed list of doc_ids received after the last +        sync. + +        :param doc_id_list: a list of strings for the received doc_ids +        """ + +    watched_doc_types = Attribute(""" +        a tuple of the watched doc types for this plugin. So far, the +        `doc-types` convention is just the preffix of the doc_id, which is +        basically its first character, followed by a dash. So, for instance, +        `M-` is used for meta-docs in mail, and `F-` is used for flag-docs in +        mail. For now there's no central register of all the doc-types +        used.""") + + +# +# Soledad storage +# + +class ILocalStorage(Interface): +    """ +    I implement core methods for the u1db local storage of documents and +    indexes. +    """ +    local_db_path = Attribute( +        "The path for the local database replica") +    local_db_file_name = Attribute( +        "The name of the local SQLCipher U1DB database file") +    uuid = Attribute("The user uuid") +    default_prefix = Attribute( +        "Prefix for default values for path") + +    def put_doc(self, doc): +        """ +        Update a document in the local encrypted database. + +        :param doc: the document to update +        :type doc: Document + +        :return: +            a deferred that will fire with the new revision identifier for +            the document +        :rtype: Deferred +        """ + +    def delete_doc(self, doc): +        """ +        Delete a document from the local encrypted database. + +        :param doc: the document to delete +        :type doc: Document + +        :return: +            a deferred that will fire with ... +        :rtype: Deferred +        """ + +    def get_doc(self, doc_id, include_deleted=False): +        """ +        Retrieve a document from the local encrypted database. + +        :param doc_id: the unique document identifier +        :type doc_id: str +        :param include_deleted: +            if True, deleted documents will be returned with empty content; +            otherwise asking for a deleted document will return None +        :type include_deleted: bool + +        :return: +            A deferred that will fire with the document object, containing a +            Document, or None if it could not be found +        :rtype: Deferred +        """ + +    def get_docs(self, doc_ids, check_for_conflicts=True, +                 include_deleted=False): +        """ +        Get the content for many documents. + +        :param doc_ids: a list of document identifiers +        :type doc_ids: list +        :param check_for_conflicts: if set False, then the conflict check will +            be skipped, and 'None' will be returned instead of True/False +        :type check_for_conflicts: bool + +        :return: +            A deferred that will fire with an iterable giving the Document +            object for each document id in matching doc_ids order. +        :rtype: Deferred +        """ + +    def get_all_docs(self, include_deleted=False): +        """ +        Get the JSON content for all documents in the database. + +        :param include_deleted: If set to True, deleted documents will be +                                returned with empty content. Otherwise deleted +                                documents will not be included in the results. +        :return: +            A deferred that will fire with (generation, [Document]): that is, +            the current generation of the database, followed by a list of all +            the documents in the database. +        :rtype: Deferred +        """ + +    def create_doc(self, content, doc_id=None): +        """ +        Create a new document in the local encrypted database. + +        :param content: the contents of the new document +        :type content: dict +        :param doc_id: an optional identifier specifying the document id +        :type doc_id: str + +        :return: +            A deferred tht will fire with the new document (Document +            instance). +        :rtype: Deferred +        """ + +    def create_doc_from_json(self, json, doc_id=None): +        """ +        Create a new document. + +        You can optionally specify the document identifier, but the document +        must not already exist. See 'put_doc' if you want to override an +        existing document. +        If the database specifies a maximum document size and the document +        exceeds it, create will fail and raise a DocumentTooBig exception. + +        :param json: The JSON document string +        :type json: str +        :param doc_id: An optional identifier specifying the document id. +        :type doc_id: +        :return: +            A deferred that will fire with the new document (A Document +            instance) +        :rtype: Deferred +        """ + +    def create_index(self, index_name, *index_expressions): +        """ +        Create an named index, which can then be queried for future lookups. +        Creating an index which already exists is not an error, and is cheap. +        Creating an index which does not match the index_expressions of the +        existing index is an error. +        Creating an index will block until the expressions have been evaluated +        and the index generated. + +        :param index_name: A unique name which can be used as a key prefix +        :type index_name: str +        :param index_expressions: +            index expressions defining the index information. +        :type index_expressions: dict + +            Examples: + +            "fieldname", or "fieldname.subfieldname" to index alphabetically +            sorted on the contents of a field. + +            "number(fieldname, width)", "lower(fieldname)" +        """ + +    def delete_index(self, index_name): +        """ +        Remove a named index. + +        :param index_name: The name of the index we are removing +        :type index_name: str +        """ + +    def list_indexes(self): +        """ +        List the definitions of all known indexes. + +        :return: A list of [('index-name', ['field', 'field2'])] definitions. +        :rtype: Deferred +        """ + +    def get_from_index(self, index_name, *key_values): +        """ +        Return documents that match the keys supplied. + +        You must supply exactly the same number of values as have been defined +        in the index. It is possible to do a prefix match by using '*' to +        indicate a wildcard match. You can only supply '*' to trailing entries, +        (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) +        It is also possible to append a '*' to the last supplied value (eg +        'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +                           an index with 3 fields then you would have: +                           get_from_index(index_name, val1, val2, val3) +        :type key_values: tuple +        :return: List of [Document] +        :rtype: list +        """ + +    def get_count_from_index(self, index_name, *key_values): +        """ +        Return the count of the documents that match the keys and +        values supplied. + +        :param index_name: The index to query +        :type index_name: str +        :param key_values: values to match. eg, if you have +                           an index with 3 fields then you would have: +                           get_from_index(index_name, val1, val2, val3) +        :type key_values: tuple +        :return: count. +        :rtype: int +        """ + +    def get_range_from_index(self, index_name, start_value, end_value): +        """ +        Return documents that fall within the specified range. + +        Both ends of the range are inclusive. For both start_value and +        end_value, one must supply exactly the same number of values as have +        been defined in the index, or pass None. In case of a single column +        index, a string is accepted as an alternative for a tuple with a single +        value. It is possible to do a prefix match by using '*' to indicate +        a wildcard match. You can only supply '*' to trailing entries, (eg +        'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also +        possible to append a '*' to the last supplied value (eg 'val*', '*', +        '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + +        :param index_name: The index to query +        :type index_name: str +        :param start_values: tuples of values that define the lower bound of +            the range. eg, if you have an index with 3 fields then you would +            have: (val1, val2, val3) +        :type start_values: tuple +        :param end_values: tuples of values that define the upper bound of the +            range. eg, if you have an index with 3 fields then you would have: +            (val1, val2, val3) +        :type end_values: tuple +        :return: A deferred that will fire with a list of [Document] +        :rtype: Deferred +        """ + +    def get_index_keys(self, index_name): +        """ +        Return all keys under which documents are indexed in this index. + +        :param index_name: The index to query +        :type index_name: str +        :return: +            A deferred that will fire with a list of tuples of indexed keys. +        :rtype: Deferred +        """ + +    def get_doc_conflicts(self, doc_id): +        """ +        Get the list of conflicts for the given document. + +        :param doc_id: the document id +        :type doc_id: str + +        :return: +            A deferred that will fire with a list of the document entries that +            are conflicted. +        :rtype: Deferred +        """ + +    def resolve_doc(self, doc, conflicted_doc_revs): +        """ +        Mark a document as no longer conflicted. + +        :param doc: a document with the new content to be inserted. +        :type doc: Document +        :param conflicted_doc_revs: +            A deferred that will fire with a list of revisions that the new +            content supersedes. +        :type conflicted_doc_revs: list +        """ + + +class ISyncableStorage(Interface): +    """ +    I implement methods to synchronize with a remote replica. +    """ +    replica_uid = Attribute("The uid of the local replica") +    syncing = Attribute( +        "Property, True if the syncer is syncing.") +    token = Attribute("The authentication Token.") + +    def sync(self): +        """ +        Synchronize the local encrypted replica with a remote replica. + +        This method blocks until a syncing lock is acquired, so there are no +        attempts of concurrent syncs from the same client replica. + +        :param url: the url of the target replica to sync with +        :type url: str + +        :return: +            A deferred that will fire with the local generation before the +            synchronisation was performed. +        :rtype: str +        """ + +    def stop_sync(self): +        """ +        Stop the current syncing process. +        """ + + +class ISecretsStorage(Interface): +    """ +    I implement methods needed for initializing and accessing secrets, that are +    synced against the Shared Recovery Database. +    """ +    secrets_file_name = Attribute( +        "The name of the file where the storage secrets will be stored") + +    # XXX this used internally from secrets, so it might be good to preserve +    # as a public boundary with other components. + +    # We should also probably document its interface. +    secrets = Attribute("A SoledadSecrets object containing access to secrets") + +    def change_passphrase(self, new_passphrase): +        """ +        Change the passphrase that encrypts the storage secret. + +        :param new_passphrase: The new passphrase. +        :type new_passphrase: unicode + +        :raise NoStorageSecret: Raised if there's no storage secret available. +        """ diff --git a/src/leap/soledad/client/shared_db.py b/src/leap/soledad/client/shared_db.py new file mode 100644 index 00000000..4f70c74b --- /dev/null +++ b/src/leap/soledad/client/shared_db.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# shared_db.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +A shared database for storing/retrieving encrypted key material. +""" +from leap.soledad.common.l2db.remote.http_database import HTTPDatabase + +from leap.soledad.client.auth import TokenBasedAuth + + +# ---------------------------------------------------------------------------- +# Soledad shared database +# ---------------------------------------------------------------------------- + +# TODO could have a hierarchy of soledad exceptions. + + +class NoTokenForAuth(Exception): +    """ +    No token was found for token-based authentication. +    """ + + +class Unauthorized(Exception): +    """ +    User does not have authorization to perform task. +    """ + + +class ImproperlyConfiguredError(Exception): +    """ +    Wrong parameters in the database configuration. +    """ + + +class SoledadSharedDatabase(HTTPDatabase, TokenBasedAuth): +    """ +    This is a shared recovery database that enables users to store their +    encryption secrets in the server and retrieve them afterwards. +    """ +    # TODO: prevent client from messing with the shared DB. +    # TODO: define and document API. + +    # +    # Token auth methods. +    # + +    def set_token_credentials(self, uuid, token): +        """ +        Store given credentials so we can sign the request later. + +        :param uuid: The user's uuid. +        :type uuid: str +        :param token: The authentication token. +        :type token: str +        """ +        TokenBasedAuth.set_token_credentials(self, uuid, token) + +    def _sign_request(self, method, url_query, params): +        """ +        Return an authorization header to be included in the HTTP request. + +        :param method: The HTTP method. +        :type method: str +        :param url_query: The URL query string. +        :type url_query: str +        :param params: A list with encoded query parameters. +        :type param: list + +        :return: The Authorization header. +        :rtype: list of tuple +        """ +        return TokenBasedAuth._sign_request(self, method, url_query, params) + +    # +    # Modified HTTPDatabase methods. +    # + +    @staticmethod +    def open_database(url, creds=None): +        """ +        Open a Soledad shared database. + +        :param url: URL of the remote database. +        :type url: str +        :param creds: A tuple containing the authentication method and +            credentials. +        :type creds: tuple + +        :return: The shared database in the given url. +        :rtype: SoledadSharedDatabase +        """ +        db = SoledadSharedDatabase(url, creds=creds) +        return db + +    @staticmethod +    def delete_database(url): +        """ +        Dummy method that prevents from deleting shared database. + +        :raise: This will always raise an Unauthorized exception. + +        :param url: The database URL. +        :type url: str +        """ +        raise Unauthorized("Can't delete shared database.") + +    def __init__(self, url, document_factory=None, creds=None): +        """ +        Initialize database with auth token and encryption powers. + +        :param url: URL of the remote database. +        :type url: str +        :param document_factory: A factory for U1BD documents. +        :type document_factory: u1db.Document +        :param creds: A tuple containing the authentication method and +            credentials. +        :type creds: tuple +        """ +        HTTPDatabase.__init__(self, url, document_factory, creds) diff --git a/src/leap/soledad/client/sync.py b/src/leap/soledad/client/sync.py new file mode 100644 index 00000000..2a927189 --- /dev/null +++ b/src/leap/soledad/client/sync.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# sync.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Soledad synchronization utilities. +""" +import os + +from twisted.internet import defer + +from leap.soledad.common.log import getLogger +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.sync import Synchronizer +from leap.soledad.common.errors import BackendNotReadyError + + +logger = getLogger(__name__) + + +# we may want to collect statistics from the sync process +DO_STATS = False +if os.environ.get('SOLEDAD_STATS'): +    DO_STATS = True + + +class SoledadSynchronizer(Synchronizer): +    """ +    Collect the state around synchronizing 2 U1DB replicas. + +    Synchronization is bi-directional, in that new items in the source are sent +    to the target, and new items in the target are returned to the source. +    However, it still recognizes that one side is initiating the request. Also, +    at the moment, conflicts are only created in the source. + +    Also modified to allow for interrupting the synchronization process. +    """ +    received_docs = [] + +    def __init__(self, *args, **kwargs): +        Synchronizer.__init__(self, *args, **kwargs) +        if DO_STATS: +            self.sync_phase = [0] +            self.sync_exchange_phase = None + +    @defer.inlineCallbacks +    def sync(self): +        """ +        Synchronize documents between source and target. + +        :return: A deferred which will fire after the sync has finished with +                 the local generation before the synchronization was performed. +        :rtype: twisted.internet.defer.Deferred +        """ + +        sync_target = self.sync_target +        self.received_docs = [] + +        # ---------- phase 1: get sync info from server ---------------------- +        if DO_STATS: +            self.sync_phase[0] += 1 +            self.sync_exchange_phase = self.sync_target.sync_exchange_phase +        # -------------------------------------------------------------------- + +        # get target identifier, its current generation, +        # and its last-seen database generation for this source +        ensure_callback = None +        try: +            (self.target_replica_uid, target_gen, target_trans_id, +             target_my_gen, target_my_trans_id) = yield \ +                sync_target.get_sync_info(self.source._replica_uid) +        except (errors.DatabaseDoesNotExist, BackendNotReadyError) as e: +            logger.warn("Database isn't ready on server. Will be created.") +            logger.warn("Reason: %s" % e.__class__) +            self.target_replica_uid = None +            target_gen, target_trans_id = 0, '' +            target_my_gen, target_my_trans_id = 0, '' + +        logger.debug("target replica uid: %s" % self.target_replica_uid) +        logger.debug("target generation: %d" % target_gen) +        logger.debug("target trans id: %s" % target_trans_id) +        logger.debug("target my gen: %d" % target_my_gen) +        logger.debug("target my trans_id: %s" % target_my_trans_id) +        logger.debug("source replica_uid: %s" % self.source._replica_uid) + +        # make sure we'll have access to target replica uid once it exists +        if self.target_replica_uid is None: + +            def ensure_callback(replica_uid): +                self.target_replica_uid = replica_uid + +        # make sure we're not syncing one replica with itself +        if self.target_replica_uid == self.source._replica_uid: +            raise errors.InvalidReplicaUID + +        # validate the info the target has about the source replica +        self.source.validate_gen_and_trans_id( +            target_my_gen, target_my_trans_id) + +        # ---------- phase 2: what's changed --------------------------------- +        if DO_STATS: +            self.sync_phase[0] += 1 +        # -------------------------------------------------------------------- + +        # what's changed since that generation and this current gen +        my_gen, _, changes = self.source.whats_changed(target_my_gen) +        logger.debug("there are %d documents to send" % len(changes)) + +        # get source last-seen database generation for the target +        if self.target_replica_uid is None: +            target_last_known_gen, target_last_known_trans_id = 0, '' +        else: +            target_last_known_gen, target_last_known_trans_id = \ +                self.source._get_replica_gen_and_trans_id( +                    self.target_replica_uid) +            logger.debug( +                "last known target gen: %d" % target_last_known_gen) +            logger.debug( +                "last known target trans_id: %s" % target_last_known_trans_id) + +        # validate transaction ids +        if not changes and target_last_known_gen == target_gen: +            if target_trans_id != target_last_known_trans_id: +                raise errors.InvalidTransactionId +            defer.returnValue(my_gen) + +        # ---------- phase 3: sync exchange ---------------------------------- +        if DO_STATS: +            self.sync_phase[0] += 1 +        # -------------------------------------------------------------------- + +        docs_by_generation = self._docs_by_gen_from_changes(changes) + +        # exchange documents and try to insert the returned ones with +        # the target, return target synced-up-to gen. +        new_gen, new_trans_id = yield sync_target.sync_exchange( +            docs_by_generation, self.source._replica_uid, +            target_last_known_gen, target_last_known_trans_id, +            self._insert_doc_from_target, ensure_callback=ensure_callback) +        ids_sent = [doc_id for doc_id, _, _ in changes] +        logger.debug("target gen after sync: %d" % new_gen) +        logger.debug("target trans_id after sync: %s" % new_trans_id) +        if hasattr(self.source, 'commit'):  # sqlcipher backend speed up +            self.source.commit()  # insert it all in a single transaction +        info = { +            "target_replica_uid": self.target_replica_uid, +            "new_gen": new_gen, +            "new_trans_id": new_trans_id, +            "my_gen": my_gen +        } +        self._syncing_info = info + +        # ---------- phase 4: complete sync ---------------------------------- +        if DO_STATS: +            self.sync_phase[0] += 1 +        # -------------------------------------------------------------------- + +        yield self.complete_sync() + +        _, _, changes = self.source.whats_changed(target_my_gen) +        changed_doc_ids = [doc_id for doc_id, _, _ in changes] + +        just_received = list(set(changed_doc_ids) - set(ids_sent)) +        self.received_docs = just_received + +        # ---------- phase 5: sync is over ----------------------------------- +        if DO_STATS: +            self.sync_phase[0] += 1 +        # -------------------------------------------------------------------- + +        defer.returnValue(my_gen) + +    def _docs_by_gen_from_changes(self, changes): +        docs_by_generation = [] +        kwargs = {'include_deleted': True} +        for doc_id, gen, trans in changes: +            get_doc = (self.source.get_doc, (doc_id,), kwargs) +            docs_by_generation.append((get_doc, gen, trans)) +        return docs_by_generation + +    def complete_sync(self): +        """ +        Last stage of the synchronization: +            (a) record last known generation and transaction uid for the remote +            replica, and +            (b) make target aware of our current reached generation. + +        :return: A deferred which will fire when the sync has been completed. +        :rtype: twisted.internet.defer.Deferred +        """ +        logger.debug("completing deferred last step in sync...") + +        # record target synced-up-to generation including applying what we +        # sent +        info = self._syncing_info +        self.source._set_replica_gen_and_trans_id( +            info["target_replica_uid"], info["new_gen"], info["new_trans_id"]) + +        # if gapless record current reached generation with target +        return self._record_sync_info_with_the_target(info["my_gen"]) + +    def _record_sync_info_with_the_target(self, start_generation): +        """ +        Store local replica metadata in server. + +        :param start_generation: The local generation when the sync was +                                 started. +        :type start_generation: int + +        :return: A deferred which will fire when the operation has been +                 completed. +        :rtype: twisted.internet.defer.Deferred +        """ +        cur_gen, trans_id = self.source._get_generation_info() +        if (cur_gen == start_generation + self.num_inserted and +                self.num_inserted > 0): +            return self.sync_target.record_sync_info( +                self.source._replica_uid, cur_gen, trans_id) +        return defer.succeed(None) diff --git a/src/leap/soledad/common/README.txt b/src/leap/soledad/common/README.txt new file mode 100644 index 00000000..0a252650 --- /dev/null +++ b/src/leap/soledad/common/README.txt @@ -0,0 +1,70 @@ +Soledad common package +====================== + +This package contains Soledad bits used by both server and client. + +Couch L2DB Backend +------------------ + +L2DB backends rely on some atomic operations that modify documents contents +and metadata (conflicts, transaction ids and indexes). The only atomic +operation in Couch is a document put, so every u1db atomic operation has to be +mapped to a couch document put. + +The atomic operations in the U1DB SQLite reference backend implementation may +be identified by the use of a context manager to access the underlying +database. A listing of the methods involved in each atomic operation are +depiced below. The top-level elements correpond to the atomic operations that +have to be mapped, and items on deeper levels of the list have to be +implemented in a way that all changes will be pushed with just one operation. + +    * _set_replica_uid +    * put_doc: +        * _get_doc +        * _put_and_update_indexes +            * insert/update the document +            * insert into transaction log +    * delete_doc +        * _get_doc +        * _put_and_update_indexes +    * get_doc_conflicts +        * _get_conflicts +    * _set_replica_gen_and_trans_id +        * _do_set_replica_gen_and_trans_id +    * _put_doc_if_newer +        * _get_doc +        * _validate_source (**) +            * _get_replica_gen_and_trans_id +        * cases: +            * is newer: +                * _prune_conflicts (**) +                    * _has_conflicts +                    * _delete_conflicts +                * _put_and_update_indexes +            * same content as: +                * _put_and_update_indexes +            * conflicted: +                * _force_doc_sync_conflict +                    * _prune_conflicts +                    * _add_conflict +                    * _put_and_update_indexes +        * _do_set_replica_gen_and_trans_id +    * resolve_doc +        * _get_doc +        * cases: +            * doc is superseded +                * _put_and_update_indexes +            * else +                * _add_conflict +        * _delete_conflicts +    * delete_index +    * create_index + +Notes: + +  * Currently, the couch backend does not implement indexing, so what is +    depicted as `_put_and_update_indexes` above will be found as `_put_doc` in +    the backend. + +  * Conflict updates are part of document put using couch update functions, +    and as such are part of the same atomic operation as document put. diff --git a/src/leap/soledad/common/__init__.py b/src/leap/soledad/common/__init__.py new file mode 100644 index 00000000..4948ad20 --- /dev/null +++ b/src/leap/soledad/common/__init__.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from leap.common.check import leap_assert as soledad_assert +from leap.common.check import leap_assert_type as soledad_assert_type + +from ._version import get_versions + +""" +Soledad routines common to client and server. +""" + + +# +# Global constants +# + +SHARED_DB_NAME = 'shared' + + +# +# Global functions +# + +__version__ = get_versions()['version'] +del get_versions + + +__all__ = [ +    "soledad_assert", +    "soledad_assert_type", +    "__version__", +] diff --git a/src/leap/soledad/common/backend.py b/src/leap/soledad/common/backend.py new file mode 100644 index 00000000..4a29ca87 --- /dev/null +++ b/src/leap/soledad/common/backend.py @@ -0,0 +1,642 @@ +# -*- coding: utf-8 -*- +# backend.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +"""A L2DB generic backend.""" + +import functools + +from leap.soledad.common.document import ServerDocument +from leap.soledad.common.l2db import vectorclock +from leap.soledad.common.l2db.errors import ( +    RevisionConflict, +    InvalidDocId, +    ConflictedDoc, +    DocumentDoesNotExist, +    DocumentAlreadyDeleted, +) +from leap.soledad.common.l2db.backends import CommonBackend +from leap.soledad.common.l2db.backends import CommonSyncTarget + + +class SoledadBackend(CommonBackend): +    BATCH_SUPPORT = False + +    """ +    A L2DB backend implementation. +    """ + +    def __init__(self, database, replica_uid=None): +        """ +        Create a new backend. + +        :param database: the database implementation +        :type database: Database +        :param replica_uid: an optional unique replica identifier +        :type replica_uid: str +        """ +        # save params +        self._factory = ServerDocument +        self._real_replica_uid = None +        self._cache = None +        self._dbname = database._dbname +        self._database = database +        self.batching = False +        if replica_uid is not None: +            self._set_replica_uid(replica_uid) + +    def batch_start(self): +        if not self.BATCH_SUPPORT: +            return +        self.batching = True +        self.after_batch_callbacks = {} +        self._database.batch_start() +        if not self._cache: +            # batching needs cache +            self._cache = {} +        self._get_generation()  # warm up gen info + +    def batch_end(self): +        if not self.BATCH_SUPPORT: +            return +        self._database.batch_end() +        self.batching = False +        for name in self.after_batch_callbacks: +            self.after_batch_callbacks[name]() +        self.after_batch_callbacks = None + +    @property +    def cache(self): +        if self._cache is not None: +            return self._cache +        else: +            return {} + +    def init_caching(self, cache): +        """ +        Start using cache by setting internal _cache attribute. + +        :param cache: the cache instance, anything that behaves like a dict +        :type cache: dict +        """ +        self._cache = cache + +    def get_sync_target(self): +        """ +        Return a SyncTarget object, for another u1db to synchronize with. + +        :return: The sync target. +        :rtype: SoledadSyncTarget +        """ +        return SoledadSyncTarget(self) + +    def delete_database(self): +        """ +        Delete a U1DB database. +        """ +        self._database.delete_database() + +    def close(self): +        """ +        Release any resources associated with this database. + +        :return: True if db was succesfully closed. +        :rtype: bool +        """ +        self._database.close() +        return True + +    def __del__(self): +        """ +        Close the database upon garbage collection. +        """ +        self.close() + +    def _set_replica_uid(self, replica_uid): +        """ +        Force the replica uid to be set. + +        :param replica_uid: The new replica uid. +        :type replica_uid: str +        """ +        self._database.set_replica_uid(replica_uid) +        self._real_replica_uid = replica_uid +        self.cache['replica_uid'] = self._real_replica_uid + +    def _get_replica_uid(self): +        """ +        Get the replica uid. + +        :return: The replica uid. +        :rtype: str +        """ +        if self._real_replica_uid is not None: +            self.cache['replica_uid'] = self._real_replica_uid +            return self._real_replica_uid +        if 'replica_uid' in self.cache: +            return self.cache['replica_uid'] +        self._real_replica_uid = self._database.get_replica_uid() +        self._set_replica_uid(self._real_replica_uid) +        return self._real_replica_uid + +    _replica_uid = property(_get_replica_uid, _set_replica_uid) + +    replica_uid = property(_get_replica_uid) + +    def _get_generation(self): +        """ +        Return the current generation. + +        :return: The current generation. +        :rtype: int + +        :raise SoledadError: Raised by database on operation failure +        """ +        return self._get_generation_info()[0] + +    def _get_generation_info(self): +        """ +        Return the current generation. + +        :return: A tuple containing the current generation and transaction id. +        :rtype: (int, str) + +        :raise SoledadError: Raised by database on operation failure +        """ +        cur_gen, newest_trans_id = self._database.get_generation_info() +        return (cur_gen, newest_trans_id) + +    def _get_trans_id_for_gen(self, generation): +        """ +        Get the transaction id corresponding to a particular generation. + +        :param generation: The generation for which to get the transaction id. +        :type generation: int + +        :return: The transaction id for C{generation}. +        :rtype: str + +        :raise InvalidGeneration: Raised when the generation does not exist. + +        """ +        return self._database.get_trans_id_for_gen(generation) + +    def _get_transaction_log(self): +        """ +        This is only for the test suite, it is not part of the api. + +        :return: The complete transaction log. +        :rtype: [(str, str)] + +        """ +        return self._database.get_transaction_log() + +    def _get_doc(self, doc_id, check_for_conflicts=False): +        """ +        Extract the document from storage. + +        This can return None if the document doesn't exist. + +        :param doc_id: The unique document identifier +        :type doc_id: str +        :param check_for_conflicts: If set to False, then the conflict check +                                    will be skipped. +        :type check_for_conflicts: bool + +        :return: The document. +        :rtype: ServerDocument +        """ +        return self._database.get_doc(doc_id, check_for_conflicts) + +    def get_doc(self, doc_id, include_deleted=False): +        """ +        Get the JSON string for the given document. + +        :param doc_id: The unique document identifier +        :type doc_id: str +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise asking for a deleted +            document will return None. +        :type include_deleted: bool + +        :return: A document object. +        :rtype: ServerDocument. +        """ +        doc = self._get_doc(doc_id, check_for_conflicts=True) +        if doc is None: +            return None +        if doc.is_tombstone() and not include_deleted: +            return None +        return doc + +    def get_all_docs(self, include_deleted=False): +        """ +        Get the JSON content for all documents in the database. + +        :param include_deleted: If set to True, deleted documents will be +                                returned with empty content. Otherwise deleted +                                documents will not be included in the results. +        :type include_deleted: bool + +        :return: (generation, [ServerDocument]) +            The current generation of the database, followed by a list of all +            the documents in the database. +        :rtype: (int, [ServerDocument]) +        """ +        return self._database.get_all_docs(include_deleted) + +    def _put_doc(self, old_doc, doc): +        """ +        Put the document in the backend database. + +        Note that C{old_doc} must have been fetched with the parameter +        C{check_for_conflicts} equal to True, so we can properly update the +        new document using the conflict information from the old one. + +        :param old_doc: The old document version. +        :type old_doc: ServerDocument +        :param doc: The document to be put. +        :type doc: ServerDocument +        """ +        self._database.save_document(old_doc, doc, +                                     self._allocate_transaction_id()) + +    def put_doc(self, doc): +        """ +        Update a document. + +        If the document currently has conflicts, put will fail. +        If the database specifies a maximum document size and the document +        exceeds it, put will fail and raise a DocumentTooBig exception. + +        :param doc: A Document with new content. +        :return: new_doc_rev - The new revision identifier for the document. +            The Document object will also be updated. + +        :raise InvalidDocId: Raised if the document's id is invalid. +        :raise DocumentTooBig: Raised if the document size is too big. +        :raise ConflictedDoc: Raised if the document has conflicts. +        """ +        if doc.doc_id is None: +            raise InvalidDocId() +        self._check_doc_id(doc.doc_id) +        self._check_doc_size(doc) +        old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +        if old_doc and old_doc.has_conflicts: +            raise ConflictedDoc() +        if old_doc and doc.rev is None and old_doc.is_tombstone(): +            new_rev = self._allocate_doc_rev(old_doc.rev) +        else: +            if old_doc is not None: +                    if old_doc.rev != doc.rev: +                        raise RevisionConflict() +            else: +                if doc.rev is not None: +                    raise RevisionConflict() +            new_rev = self._allocate_doc_rev(doc.rev) +        doc.rev = new_rev +        self._put_doc(old_doc, doc) +        return new_rev + +    def whats_changed(self, old_generation=0): +        """ +        Return a list of documents that have changed since old_generation. + +        :param old_generation: The generation of the database in the old +                               state. +        :type old_generation: int + +        :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) +                 The current generation of the database, its associated +                 transaction id, and a list of of changed documents since +                 old_generation, represented by tuples with for each document +                 its doc_id and the generation and transaction id corresponding +                 to the last intervening change and sorted by generation (old +                 changes first) +        :rtype: (int, str, [(str, int, str)]) +        """ +        return self._database.whats_changed(old_generation) + +    def delete_doc(self, doc): +        """ +        Mark a document as deleted. + +        Will abort if the current revision doesn't match doc.rev. +        This will also set doc.content to None. + +        :param doc: The document to mark as deleted. +        :type doc: ServerDocument. + +        :raise DocumentDoesNotExist: Raised if the document does not +                                            exist. +        :raise RevisionConflict: Raised if the revisions do not match. +        :raise DocumentAlreadyDeleted: Raised if the document is +                                              already deleted. +        :raise ConflictedDoc: Raised if the doc has conflicts. +        """ +        old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +        if old_doc is None: +            raise DocumentDoesNotExist +        if old_doc.rev != doc.rev: +            raise RevisionConflict() +        if old_doc.is_tombstone(): +            raise DocumentAlreadyDeleted +        if old_doc.has_conflicts: +            raise ConflictedDoc() +        new_rev = self._allocate_doc_rev(doc.rev) +        doc.rev = new_rev +        doc.make_tombstone() +        self._put_doc(old_doc, doc) +        return new_rev + +    def get_doc_conflicts(self, doc_id): +        """ +        Get the conflicted versions of a document. + +        :param doc_id: The document id. +        :type doc_id: str + +        :return: A list of conflicted versions of the document. +        :rtype: list +        """ +        return self._database.get_doc_conflicts(doc_id) + +    def _get_replica_gen_and_trans_id(self, other_replica_uid): +        """ +        Return the last known generation and transaction id for the other db +        replica. + +        When you do a synchronization with another replica, the Database keeps +        track of what generation the other database replica was at, and what +        the associated transaction id was.  This is used to determine what data +        needs to be sent, and if two databases are claiming to be the same +        replica. + +        :param other_replica_uid: The identifier for the other replica. +        :type other_replica_uid: str + +        :return: A tuple containing the generation and transaction id we +                 encountered during synchronization. If we've never +                 synchronized with the replica, this is (0, ''). +        :rtype: (int, str) +        """ +        if other_replica_uid in self.cache: +            return self.cache[other_replica_uid] +        gen, trans_id = \ +            self._database.get_replica_gen_and_trans_id(other_replica_uid) +        self.cache[other_replica_uid] = (gen, trans_id) +        return (gen, trans_id) + +    def _set_replica_gen_and_trans_id(self, other_replica_uid, +                                      other_generation, other_transaction_id): +        """ +        Set the last-known generation and transaction id for the other +        database replica. + +        We have just performed some synchronization, and we want to track what +        generation the other replica was at. See also +        _get_replica_gen_and_trans_id. + +        :param other_replica_uid: The U1DB identifier for the other replica. +        :type other_replica_uid: str +        :param other_generation: The generation number for the other replica. +        :type other_generation: int +        :param other_transaction_id: The transaction id associated with the +            generation. +        :type other_transaction_id: str +        """ +        if other_replica_uid is not None and other_generation is not None: +            self.cache[other_replica_uid] = (other_generation, +                                             other_transaction_id) +            self._database.set_replica_gen_and_trans_id(other_replica_uid, +                                                        other_generation, +                                                        other_transaction_id) + +    def _do_set_replica_gen_and_trans_id( +            self, other_replica_uid, other_generation, other_transaction_id): +        """ +        _put_doc_if_newer from super class is calling it. So we declare this. + +        :param other_replica_uid: The U1DB identifier for the other replica. +        :type other_replica_uid: str +        :param other_generation: The generation number for the other replica. +        :type other_generation: int +        :param other_transaction_id: The transaction id associated with the +                                     generation. +        :type other_transaction_id: str +        """ +        args = [other_replica_uid, other_generation, other_transaction_id] +        callback = functools.partial(self._set_replica_gen_and_trans_id, *args) +        if self.batching: +            self.after_batch_callbacks['set_source_info'] = callback +        else: +            callback() + +    def _force_doc_sync_conflict(self, doc): +        """ +        Add a conflict and force a document put. + +        :param doc: The document to be put. +        :type doc: ServerDocument +        """ +        my_doc = self._get_doc(doc.doc_id) +        self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) +        doc.add_conflict(self._factory(doc.doc_id, my_doc.rev, +                                       my_doc.get_json())) +        doc.has_conflicts = True +        self._put_doc(my_doc, doc) + +    def resolve_doc(self, doc, conflicted_doc_revs): +        """ +        Mark a document as no longer conflicted. + +        We take the list of revisions that the client knows about that it is +        superseding. This may be a different list from the actual current +        conflicts, in which case only those are removed as conflicted.  This +        may fail if the conflict list is significantly different from the +        supplied information. (sync could have happened in the background from +        the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + +        :param doc: A Document with the new content to be inserted. +        :type doc: ServerDocument +        :param conflicted_doc_revs: A list of revisions that the new content +                                    supersedes. +        :type conflicted_doc_revs: [str] + +        :raise SoledadError: Raised by database on operation failure +        """ +        cur_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +        new_rev = self._ensure_maximal_rev(cur_doc.rev, +                                           conflicted_doc_revs) +        superseded_revs = set(conflicted_doc_revs) +        doc.rev = new_rev +        # this backend stores conflicts as properties of the documents, so we +        # have to copy these conflicts over to the document being updated. +        if cur_doc.rev in superseded_revs: +            # the newer doc version will supersede the one in the database, so +            # we copy conflicts before updating the backend. +            doc.set_conflicts(cur_doc.get_conflicts())  # copy conflicts over. +            doc.delete_conflicts(superseded_revs) +            self._put_doc(cur_doc, doc) +        else: +            # the newer doc version does not supersede the one in the +            # database, so we will add a conflict to the database and copy +            # those over to the document the user has in her hands. +            cur_doc.add_conflict(doc) +            cur_doc.delete_conflicts(superseded_revs) +            self._put_doc(cur_doc, cur_doc)  # just update conflicts +            # backend has been updated with current conflicts, now copy them +            # to the current document. +            doc.set_conflicts(cur_doc.get_conflicts()) + +    def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, +                          replica_trans_id='', number_of_docs=None, +                          doc_idx=None, sync_id=None): +        """ +        Insert/update document into the database with a given revision. + +        This api is used during synchronization operations. + +        If a document would conflict and save_conflict is set to True, the +        content will be selected as the 'current' content for doc.doc_id, +        even though doc.rev doesn't supersede the currently stored revision. +        The currently stored document will be added to the list of conflict +        alternatives for the given doc_id. + +        This forces the new content to be 'current' so that we get convergence +        after synchronizing, even if people don't resolve conflicts. Users can +        then notice that their content is out of date, update it, and +        synchronize again. (The alternative is that users could synchronize and +        think the data has propagated, but their local copy looks fine, and the +        remote copy is never updated again.) + +        :param doc: A document object +        :type doc: ServerDocument +        :param save_conflict: If this document is a conflict, do you want to +                              save it as a conflict, or just ignore it. +        :type save_conflict: bool +        :param replica_uid: A unique replica identifier. +        :type replica_uid: str +        :param replica_gen: The generation of the replica corresponding to the +                            this document. The replica arguments are optional, +                            but are used during synchronization. +        :type replica_gen: int +        :param replica_trans_id: The transaction_id associated with the +                                 generation. +        :type replica_trans_id: str +        :param number_of_docs: The total amount of documents sent on this sync +                               session. +        :type number_of_docs: int +        :param doc_idx: The index of the current document being sent. +        :type doc_idx: int +        :param sync_id: The id of the current sync session. +        :type sync_id: str + +        :return: (state, at_gen) -  If we don't have doc_id already, or if +                 doc_rev supersedes the existing document revision, then the +                 content will be inserted, and state is 'inserted'.  If +                 doc_rev is less than or equal to the existing revision, then +                 the put is ignored and state is respecitvely 'superseded' or +                 'converged'.  If doc_rev is not strictly superseded or +                 supersedes, then state is 'conflicted'. The document will not +                 be inserted if save_conflict is False.  For 'inserted' or +                 'converged', at_gen is the insertion/current generation. +        :rtype: (str, int) +        """ +        if not isinstance(doc, ServerDocument): +            doc = self._factory(doc.doc_id, doc.rev, doc.get_json()) +        my_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +        if my_doc: +            doc.set_conflicts(my_doc.get_conflicts()) +        return CommonBackend._put_doc_if_newer(self, doc, save_conflict, +                                               replica_uid, replica_gen, +                                               replica_trans_id) + +    def _put_and_update_indexes(self, cur_doc, doc): +        self._put_doc(cur_doc, doc) + +    def get_docs(self, doc_ids, check_for_conflicts=True, +                 include_deleted=False, read_content=True): +        """ +        Get the JSON content for many documents. + +        :param doc_ids: A list of document identifiers or None for all. +        :type doc_ids: list +        :param check_for_conflicts: If set to False, then the conflict check +                                    will be skipped, and 'None' will be +                                    returned instead of True/False. +        :type check_for_conflicts: bool +        :param include_deleted: If set to True, deleted documents will be +                                returned with empty content. Otherwise deleted +                                documents will not be included in the results. +        :return: iterable giving the Document object for each document id +                 in matching doc_ids order. +        :rtype: iterable +        """ +        return self._database.get_docs(doc_ids, check_for_conflicts, +                                       include_deleted, read_content) + +    def _prune_conflicts(self, doc, doc_vcr): +        """ +        Prune conflicts that are older then the current document's revision, or +        whose content match to the current document's content. +        Originally in u1db.CommonBackend + +        :param doc: The document to have conflicts pruned. +        :type doc: ServerDocument +        :param doc_vcr: A vector clock representing the current document's +                        revision. +        :type doc_vcr: u1db.vectorclock.VectorClock +        """ +        if doc.has_conflicts: +            autoresolved = False +            c_revs_to_prune = [] +            for c_doc in doc._conflicts: +                c_vcr = vectorclock.VectorClockRev(c_doc.rev) +                if doc_vcr.is_newer(c_vcr): +                    c_revs_to_prune.append(c_doc.rev) +                elif doc.same_content_as(c_doc): +                    c_revs_to_prune.append(c_doc.rev) +                    doc_vcr.maximize(c_vcr) +                    autoresolved = True +            if autoresolved: +                doc_vcr.increment(self._replica_uid) +                doc.rev = doc_vcr.as_str() +            doc.delete_conflicts(c_revs_to_prune) + + +class SoledadSyncTarget(CommonSyncTarget): + +    """ +    Functionality for using a SoledadBackend as a synchronization target. +    """ + +    def get_sync_info(self, source_replica_uid): +        source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( +            source_replica_uid) +        my_gen, my_trans_id = self._db._get_generation_info() +        return ( +            self._db._replica_uid, my_gen, my_trans_id, source_gen, +            source_trans_id) + +    def record_sync_info(self, source_replica_uid, source_replica_generation, +                         source_replica_transaction_id): +        if self._trace_hook: +            self._trace_hook('record_sync_info') +        self._db._set_replica_gen_and_trans_id( +            source_replica_uid, source_replica_generation, +            source_replica_transaction_id) diff --git a/src/leap/soledad/common/command.py b/src/leap/soledad/common/command.py new file mode 100644 index 00000000..66aa6b7a --- /dev/null +++ b/src/leap/soledad/common/command.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# command.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Utility to sanitize and run shell commands. +""" + + +import subprocess + + +def exec_validated_cmd(cmd, argument, validator=None): +    """ +    Executes cmd, validating argument with a validator function. + +    :param cmd: command. +    :type dbname: str +    :param argument: argument. +    :type argument: str +    :param validator: optional function to validate argument +    :type validator: function + +    :return: exit code and stdout or stderr (if code != 0) +    :rtype: (int, str) +    """ +    if validator and not validator(argument): +        return 1, "invalid argument" +    command = cmd.split(' ') +    command.append(argument) +    try: +        process = subprocess.Popen(command, stdout=subprocess.PIPE, +                                   stderr=subprocess.PIPE) +    except OSError as e: +        return 1, e +    (out, err) = process.communicate() +    code = process.wait() +    if code is not 0: +        return code, err +    else: +        return code, out diff --git a/src/leap/soledad/common/couch/__init__.py b/src/leap/soledad/common/couch/__init__.py new file mode 100644 index 00000000..2343e849 --- /dev/null +++ b/src/leap/soledad/common/couch/__init__.py @@ -0,0 +1,812 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +"""A U1DB backend that uses CouchDB as its persistence layer.""" + + +import json +import copy +import re +import uuid +import binascii + +from six import StringIO +from six.moves.urllib.parse import urljoin +from contextlib import contextmanager + + +from couchdb.client import Server, Database +from couchdb.http import ( +    ResourceConflict, +    ResourceNotFound, +    Session, +    urljoin as couch_urljoin, +    Resource, +) +from leap.soledad.common.l2db.errors import ( +    DatabaseDoesNotExist, +    InvalidGeneration, +    RevisionConflict, +) +from leap.soledad.common.l2db.remote import http_app + + +from .support import MultipartWriter +from leap.soledad.common.errors import InvalidURLError +from leap.soledad.common.document import ServerDocument +from leap.soledad.common.backend import SoledadBackend + + +COUCH_TIMEOUT = 120  # timeout for transfers between Soledad server and Couch + + +def list_users_dbs(couch_url): +    """ +    Retrieves a list with all databases that starts with 'user-' on CouchDB. +    Those databases belongs to users. So, the list will contain all the +    database names in the form of 'user-{uuid4}'. + +    :param couch_url: The couch url with needed credentials +    :type couch_url: str + +    :return: The list of all database names from users. +    :rtype: [str] +    """ +    with couch_server(couch_url) as server: +        users = [dbname for dbname in server if dbname.startswith('user-')] +    return users + + +# monkey-patch the u1db http app to use ServerDocument +http_app.Document = ServerDocument + + +@contextmanager +def couch_server(url): +    """ +    Provide a connection to a couch server and cleanup after use. + +    For database creation and deletion we use an ephemeral connection to the +    couch server. That connection has to be properly closed, so we provide it +    as a context manager. + +    :param url: The URL of the Couch server. +    :type url: str +    """ +    session = Session(timeout=COUCH_TIMEOUT) +    server = Server(url=url, full_commit=False, session=session) +    yield server + + +def _get_gen_doc_id(gen): +    return 'gen-%s' % str(gen).zfill(10) + + +GENERATION_KEY = 'gen' +TRANSACTION_ID_KEY = 'trans_id' +REPLICA_UID_KEY = 'replica_uid' +DOC_ID_KEY = 'doc_id' +SCHEMA_VERSION_KEY = 'schema_version' + +CONFIG_DOC_ID = '_local/config' +SYNC_DOC_ID_PREFIX = '_local/sync_' +SCHEMA_VERSION = 1 + + +class CouchDatabase(object): +    """ +    Holds CouchDB related code. +    This class gives methods to encapsulate database operations and hide +    CouchDB details from backend code. +    """ + +    @classmethod +    def open_database(cls, url, create, replica_uid=None, +                      database_security=None): +        """ +        Open a U1DB database using CouchDB as backend. + +        :param url: the url of the database replica +        :type url: str +        :param create: should the replica be created if it does not exist? +        :type create: bool +        :param replica_uid: an optional unique replica identifier +        :type replica_uid: str +        :param database_security: security rules as CouchDB security doc +        :type database_security: dict + +        :return: the database instance +        :rtype: SoledadBackend + +        :raise DatabaseDoesNotExist: Raised if database does not exist. +        """ +        # get database from url +        m = re.match('(^https?://[^/]+)/(.+)$', url) +        if not m: +            raise InvalidURLError +        url = m.group(1) +        dbname = m.group(2) +        with couch_server(url) as server: +            if dbname not in server: +                if create: +                    server.create(dbname) +                else: +                    raise DatabaseDoesNotExist() +        db = cls(url, dbname, ensure_security=create, +                 database_security=database_security) +        return SoledadBackend( +            db, replica_uid=replica_uid) + +    def __init__(self, url, dbname, ensure_security=False, +                 database_security=None): +        """ +        :param url: Couch server URL with necessary credentials +        :type url: string +        :param dbname: Couch database name +        :type dbname: string +        :param ensure_security: will PUT a _security ddoc if set +        :type ensure_security: bool +        :param database_security: security rules as CouchDB security doc +        :type database_security: dict +        """ +        self._session = Session(timeout=COUCH_TIMEOUT) +        self._url = url +        self._dbname = dbname +        self._database = self.get_couch_database(url, dbname) +        self.batching = False +        self.batch_generation = None +        self.batch_docs = {} +        if ensure_security: +            self.ensure_security_ddoc(database_security) + +    def batch_start(self): +        self.batching = True +        self.batch_generation = self.get_generation_info() +        ids = set(row.id for row in self._database.view('_all_docs')) +        self.batched_ids = ids + +    def batch_end(self): +        self.batching = False +        self.batch_generation = None +        self.__perform_batch() + +    def get_couch_database(self, url, dbname): +        """ +        Generate a couchdb.Database instance given a url and dbname. + +        :param url: CouchDB's server url with credentials +        :type url: str +        :param dbname: Database name +        :type dbname: str + +        :return: couch library database instance +        :rtype: couchdb.Database + +        :raise DatabaseDoesNotExist: Raised if database does not exist. +        """ +        try: +            return Database( +                urljoin(url, dbname), +                self._session) +        except ResourceNotFound: +            raise DatabaseDoesNotExist() + +    def ensure_security_ddoc(self, security_config=None): +        """ +        Make sure that only soledad user is able to access this database as +        an unprivileged member, meaning that administration access will +        be forbidden even inside an user database. +        The goal is to make sure that only the lowest access level is given +        to the unprivileged CouchDB user set on the server process. +        This is achieved by creating a _security design document, see: +        http://docs.couchdb.org/en/latest/api/database/security.html + +        :param security_config: security configuration parsed from conf file +        :type security_config: dict +        """ +        security_config = security_config or {} +        security = self._database.resource.get_json('_security')[2] +        security['members'] = {'names': [], 'roles': []} +        security['members']['names'] = security_config.get('members', +                                                           ['soledad']) +        security['members']['roles'] = security_config.get('members_roles', []) +        security['admins'] = {'names': [], 'roles': []} +        security['admins']['names'] = security_config.get('admins', []) +        security['admins']['roles'] = security_config.get('admins_roles', []) +        self._database.resource.put_json('_security', body=security) + +    def delete_database(self): +        """ +        Delete a U1DB CouchDB database. +        """ +        with couch_server(self._url) as server: +            del(server[self._dbname]) + +    def set_replica_uid(self, replica_uid): +        """ +        Force the replica uid to be set. + +        :param replica_uid: The new replica uid. +        :type replica_uid: str +        """ +        try: +            # set on existent config document +            doc = self._database[CONFIG_DOC_ID] +            doc[REPLICA_UID_KEY] = replica_uid +        except ResourceNotFound: +            # or create the config document +            doc = { +                '_id': CONFIG_DOC_ID, +                REPLICA_UID_KEY: replica_uid, +                SCHEMA_VERSION_KEY: SCHEMA_VERSION, +            } +        self._database.save(doc) + +    def get_replica_uid(self): +        """ +        Get the replica uid. + +        :return: The replica uid. +        :rtype: str +        """ +        try: +            # grab replica_uid from server +            doc = self._database[CONFIG_DOC_ID] +            replica_uid = doc[REPLICA_UID_KEY] +            return replica_uid +        except ResourceNotFound: +            # create a unique replica_uid +            replica_uid = uuid.uuid4().hex +            self.set_replica_uid(replica_uid) +            return replica_uid + +    def close(self): +        self._database = None + +    def get_all_docs(self, include_deleted=False): +        """ +        Get the JSON content for all documents in the database. + +        :param include_deleted: If set to True, deleted documents will be +                                returned with empty content. Otherwise deleted +                                documents will not be included in the results. +        :type include_deleted: bool + +        :return: (generation, [ServerDocument]) +            The current generation of the database, followed by a list of all +            the documents in the database. +        :rtype: (int, [ServerDocument]) +        """ + +        generation, _ = self.get_generation_info() +        results = list( +            self.get_docs(None, True, include_deleted)) +        return (generation, results) + +    def get_docs(self, doc_ids, check_for_conflicts=True, +                 include_deleted=False, read_content=True): +        """ +        Get the JSON content for many documents. + +        Use couch's `_all_docs` view to get the documents indicated in +        `doc_ids`, + +        :param doc_ids: A list of document identifiers or None for all. +        :type doc_ids: list +        :param check_for_conflicts: If set to False, then the conflict check +                                    will be skipped, and 'None' will be +                                    returned instead of True/False. +        :type check_for_conflicts: bool +        :param include_deleted: If set to True, deleted documents will be +                                returned with empty content. Otherwise deleted +                                documents will not be included in the results. + +        :return: iterable giving the Document object for each document id +                 in matching doc_ids order. +        :rtype: iterable +        """ +        params = {'include_docs': 'true', 'attachments': 'false'} +        if doc_ids is not None: +            params['keys'] = doc_ids +        view = self._database.view("_all_docs", **params) +        for row in view.rows: +            result = copy.deepcopy(row['doc']) +            for file_name in result.get('_attachments', {}).keys(): +                data = self._database.get_attachment(result, file_name) +                if data: +                    if read_content: +                        data = data.read() +                    result['_attachments'][file_name] = {'data': data} +            doc = self.__parse_doc_from_couch( +                result, result['_id'], +                check_for_conflicts=check_for_conflicts, decode=False) +            # filter out non-u1db or deleted documents +            if not doc or (not include_deleted and doc.is_tombstone()): +                continue +            yield doc + +    def get_doc(self, doc_id, check_for_conflicts=False): +        """ +        Extract the document from storage. + +        This can return None if the document doesn't exist. + +        :param doc_id: The unique document identifier +        :type doc_id: str +        :param check_for_conflicts: If set to False, then the conflict check +                                    will be skipped. +        :type check_for_conflicts: bool + +        :return: The document. +        :rtype: ServerDocument +        """ +        doc_from_batch = self.__check_batch_before_get(doc_id) +        if doc_from_batch: +            return doc_from_batch +        if self.batching and doc_id not in self.batched_ids: +            return None +        if doc_id not in self._database: +            return None +        # get document with all attachments (u1db content and eventual +        # conflicts) +        result = self.json_from_resource([doc_id], attachments=True) +        return self.__parse_doc_from_couch(result, doc_id, check_for_conflicts) + +    def __check_batch_before_get(self, doc_id): +        """ +        If doc_id is staged for batching, then we need to commit the batch +        before going ahead. This avoids consistency problems, like trying to +        get a document that isn't persisted and processing like it is missing. + +        :param doc_id: The unique document identifier +        :type doc_id: str +        """ +        if doc_id in self.batch_docs: +            couch_doc = self.batch_docs[doc_id] +            rev = self.__perform_batch(doc_id) +            couch_doc['_rev'] = rev +            self.batched_ids.add(doc_id) +            return self.__parse_doc_from_couch(couch_doc, doc_id, True) +        return None + +    def __perform_batch(self, doc_id=None): +        status = self._database.update(self.batch_docs.values()) +        rev = None +        for ok, stored_doc_id, rev_or_error in status: +            if not ok: +                error = rev_or_error +                if type(error) is ResourceConflict: +                    raise RevisionConflict +                raise error +            elif doc_id == stored_doc_id: +                rev = rev_or_error +        self.batch_docs.clear() +        return rev + +    def __parse_doc_from_couch(self, result, doc_id, +                               check_for_conflicts=False, decode=True): +        # restrict to u1db documents +        if 'u1db_rev' not in result: +            return None +        doc = ServerDocument(doc_id, result['u1db_rev']) +        # set contents or make tombstone +        if '_attachments' not in result \ +                or 'u1db_content' not in result['_attachments']: +            doc.make_tombstone() +        elif decode: +            doc.content = json.loads( +                binascii.a2b_base64( +                    result['_attachments']['u1db_content']['data'])) +        else: +            doc._json = result['_attachments']['u1db_content']['data'] +        # determine if there are conflicts +        if check_for_conflicts \ +                and '_attachments' in result \ +                and 'u1db_conflicts' in result['_attachments']: +            if decode: +                conflicts = binascii.a2b_base64( +                    result['_attachments']['u1db_conflicts']['data']) +            else: +                conflicts = result['_attachments']['u1db_conflicts']['data'] +            conflicts = json.loads(conflicts) +            doc.set_conflicts(self._build_conflicts(doc.doc_id, conflicts)) +        # store couch revision +        doc.couch_rev = result['_rev'] +        return doc + +    def _build_conflicts(self, doc_id, attached_conflicts): +        """ +        Build the conflicted documents list from the conflicts attachment +        fetched from a couch document. + +        :param attached_conflicts: The document's conflicts as fetched from a +                                   couch document attachment. +        :type attached_conflicts: dict +        """ +        conflicts = [] +        for doc_rev, content in attached_conflicts: +            doc = ServerDocument(doc_id, doc_rev) +            if content is None: +                doc.make_tombstone() +            else: +                doc.content = content +            conflicts.append(doc) +        return conflicts + +    def get_trans_id_for_gen(self, generation): +        """ +        Get the transaction id corresponding to a particular generation. + +        :param generation: The generation for which to get the transaction id. +        :type generation: int + +        :return: The transaction id for C{generation}. +        :rtype: str + +        :raise InvalidGeneration: Raised when the generation does not exist. +        """ +        if generation == 0: +            return '' +        log = self._get_transaction_log(start=generation, end=generation) +        if not log: +            raise InvalidGeneration +        _, _, trans_id = log[0] +        return trans_id + +    def get_replica_gen_and_trans_id(self, other_replica_uid): +        """ +        Return the last known generation and transaction id for the other db +        replica. + +        When you do a synchronization with another replica, the Database keeps +        track of what generation the other database replica was at, and what +        the associated transaction id was.  This is used to determine what data +        needs to be sent, and if two databases are claiming to be the same +        replica. + +        :param other_replica_uid: The identifier for the other replica. +        :type other_replica_uid: str + +        :return: A tuple containing the generation and transaction id we +                 encountered during synchronization. If we've never +                 synchronized with the replica, this is (0, ''). +        :rtype: (int, str) +        """ +        doc_id = '%s%s' % (SYNC_DOC_ID_PREFIX, other_replica_uid) +        try: +            doc = self._database[doc_id] +        except ResourceNotFound: +            doc = { +                '_id': doc_id, +                GENERATION_KEY: 0, +                REPLICA_UID_KEY: str(other_replica_uid), +                TRANSACTION_ID_KEY: '', +            } +            self._database.save(doc) +        gen, trans_id = doc[GENERATION_KEY], doc[TRANSACTION_ID_KEY] +        return gen, trans_id + +    def get_doc_conflicts(self, doc_id, couch_rev=None): +        """ +        Get the conflicted versions of a document. + +        If the C{couch_rev} parameter is not None, conflicts for a specific +        document's couch revision are returned. + +        :param couch_rev: The couch document revision. +        :type couch_rev: str + +        :return: A list of conflicted versions of the document. +        :rtype: list +        """ +        # request conflicts attachment from server +        params = {} +        conflicts = [] +        if couch_rev is not None: +            params['rev'] = couch_rev  # restric document's couch revision +        else: +            # TODO: move into resource logic! +            first_entry = self.get_doc(doc_id, check_for_conflicts=True) +            conflicts.append(first_entry) + +        try: +            response = self.json_from_resource([doc_id, 'u1db_conflicts'], +                                               **params) +            return conflicts + self._build_conflicts( +                doc_id, json.loads(response.read())) +        except ResourceNotFound: +            return [] + +    def set_replica_gen_and_trans_id( +            self, other_replica_uid, other_generation, other_transaction_id): +        """ +        Set the last-known generation and transaction id for the other +        database replica. + +        We have just performed some synchronization, and we want to track what +        generation the other replica was at. See also +        get_replica_gen_and_trans_id. + +        :param other_replica_uid: The U1DB identifier for the other replica. +        :type other_replica_uid: str +        :param other_generation: The generation number for the other replica. +        :type other_generation: int +        :param other_transaction_id: The transaction id associated with the +                                     generation. +        :type other_transaction_id: str +        """ +        doc_id = '%s%s' % (SYNC_DOC_ID_PREFIX, other_replica_uid) +        try: +            doc = self._database[doc_id] +        except ResourceNotFound: +            doc = {'_id': doc_id} +        doc[GENERATION_KEY] = other_generation +        doc[TRANSACTION_ID_KEY] = other_transaction_id +        self._database.save(doc) + +    def get_transaction_log(self): +        """ +        This is only for the test suite, it is not part of the api. + +        :return: The complete transaction log. +        :rtype: [(str, str)] +        """ +        log = self._get_transaction_log() +        return map(lambda i: (i[1], i[2]), log) + +    def _get_gen_docs( +            self, start=0, end=9999999999, descending=None, limit=None): +        params = {} +        if descending: +            params['descending'] = 'true' +            # honor couch way of traversing the view tree in reverse order +            start, end = end, start +        params['startkey'] = _get_gen_doc_id(start) +        params['endkey'] = _get_gen_doc_id(end) +        params['include_docs'] = 'true' +        if limit: +            params['limit'] = limit +        view = self._database.view("_all_docs", **params) +        return view.rows + +    def _get_transaction_log(self, start=0, end=9999999999): +        # get current gen and trans_id +        rows = self._get_gen_docs(start=start, end=end) +        log = [] +        for row in rows: +            doc = row['doc'] +            log.append(( +                doc[GENERATION_KEY], +                doc[DOC_ID_KEY], +                doc[TRANSACTION_ID_KEY])) +        return log + +    def whats_changed(self, old_generation=0): +        """ +        Return a list of documents that have changed since old_generation. + +        :param old_generation: The generation of the database in the old +                               state. +        :type old_generation: int + +        :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) +                 The current generation of the database, its associated +                 transaction id, and a list of of changed documents since +                 old_generation, represented by tuples with for each document +                 its doc_id and the generation and transaction id corresponding +                 to the last intervening change and sorted by generation (old +                 changes first) +        :rtype: (int, str, [(str, int, str)]) +        """ +        changes = [] +        cur_generation, last_trans_id = self.get_generation_info() +        relevant_tail = self._get_transaction_log(start=old_generation + 1) +        seen = set() +        for generation, doc_id, trans_id in reversed(relevant_tail): +            if doc_id not in seen: +                changes.append((doc_id, generation, trans_id)) +                seen.add(doc_id) +        changes.reverse() +        return (cur_generation, last_trans_id, changes) + +    def get_generation_info(self): +        """ +        Return the current generation. + +        :return: A tuple containing the current generation and transaction id. +        :rtype: (int, str) +        """ +        if self.batching and self.batch_generation: +            return self.batch_generation +        rows = self._get_gen_docs(descending=True, limit=1) +        if not rows: +            return 0, '' +        gen_doc = rows.pop()['doc'] +        return gen_doc[GENERATION_KEY], gen_doc[TRANSACTION_ID_KEY] + +    def json_from_resource(self, doc_path, **kwargs): +        """ +        Get a resource from it's path and gets a doc's JSON using provided +        parameters. + +        :param doc_path: The path to resource. +        :type doc_path: [str] + +        :return: The request's data parsed from JSON to a dict. +        :rtype: dict +        """ +        if doc_path is not None: +            resource = self._database.resource(*doc_path) +        else: +            resource = self._database.resource() +        _, _, data = resource.get_json(**kwargs) +        return data + +    def _allocate_new_generation(self, doc_id, transaction_id, save=True): +        """ +        Allocate a new generation number for a document modification. + +        We need to allocate a new generation to this document modification by +        creating a new gen doc. In order to avoid concurrent database updates +        from allocating the same new generation, we will try to create the +        document until we succeed, meaning that no other piece of code holds +        the same generation number as ours. + +        The loop below would only be executed more than once if: + +          1. there's more than one thread trying to modify the user's database, +             and + +          2. the execution of getting the current generation and saving the gen +             doc different threads get interleaved (one of them will succeed +             and the others will fail and try again). + +        Number 1 only happens when more than one user device is syncing at the +        same time. Number 2 depends on not-so-frequent coincidence of +        code execution. + +        Also, in the race between threads for a generation number there's +        always one thread that wins. so if there are N threads in the race, the +        expected number of repetitions of the loop for each thread would be +        N/2. If N is equal to the number of devices that the user has, the +        number of possible repetitions of the loop should always be low. +        """ +        while True: +            try: +                # add the gen document +                gen, _ = self.get_generation_info() +                new_gen = gen + 1 +                gen_doc = { +                    '_id': _get_gen_doc_id(new_gen), +                    GENERATION_KEY: new_gen, +                    DOC_ID_KEY: doc_id, +                    TRANSACTION_ID_KEY: transaction_id, +                } +                if save: +                    self._database.save(gen_doc) +                break  # succeeded allocating a new generation, proceed +            except ResourceConflict: +                pass  # try again! +        return gen_doc + +    def save_document(self, old_doc, doc, transaction_id): +        """ +        Put the document in the Couch backend database. + +        Note that C{old_doc} must have been fetched with the parameter +        C{check_for_conflicts} equal to True, so we can properly update the +        new document using the conflict information from the old one. + +        :param old_doc: The old document version. +        :type old_doc: ServerDocument +        :param doc: The document to be put. +        :type doc: ServerDocument + +        :raise RevisionConflict: Raised when trying to update a document but +                                 couch revisions mismatch. +        """ +        attachments = {}  # we save content and conflicts as attachments +        parts = []  # and we put it using couch's multipart PUT +        # save content as attachment +        if doc.is_tombstone() is False: +            content = doc.get_json() +            attachments['u1db_content'] = { +                'follows': True, +                'content_type': 'application/octet-stream', +                'length': len(content), +            } +            parts.append(content) + +        # save conflicts as attachment +        if doc.has_conflicts is True: +            conflicts = json.dumps( +                map(lambda cdoc: (cdoc.rev, cdoc.content), +                    doc.get_conflicts())) +            attachments['u1db_conflicts'] = { +                'follows': True, +                'content_type': 'application/octet-stream', +                'length': len(conflicts), +            } +            parts.append(conflicts) + +        # build the couch document +        couch_doc = { +            '_id': doc.doc_id, +            'u1db_rev': doc.rev, +            '_attachments': attachments, +        } +        # if we are updating a doc we have to add the couch doc revision +        if old_doc is not None and hasattr(old_doc, 'couch_rev'): +            couch_doc['_rev'] = old_doc.couch_rev +        # prepare the multipart PUT +        if not self.batching: +            buf = StringIO() +            envelope = MultipartWriter(buf) +            # the order in which attachments are described inside the +            # serialization of the couch document must match the order in +            # which they are actually written in the multipart structure. +            # Because of that, we use `sorted_keys=True` in the json +            # serialization (so "u1db_conflicts" comes before +            # "u1db_content" on the couch document attachments +            # description), and also reverse the order of the parts before +            # writing them, so the "conflict" part is written before the +            # "content" part. +            envelope.add( +                'application/json', +                json.dumps(couch_doc, sort_keys=True)) +            parts.reverse() +            for part in parts: +                envelope.add('application/octet-stream', part) +            envelope.close() +            # try to save and fail if there's a revision conflict +            try: +                resource = self._new_resource() +                resource.put_json( +                    doc.doc_id, body=str(buf.getvalue()), +                    headers=envelope.headers) +            except ResourceConflict: +                raise RevisionConflict() +            self._allocate_new_generation(doc.doc_id, transaction_id) +        else: +            for name, attachment in attachments.items(): +                del attachment['follows'] +                del attachment['length'] +                index = 0 if name is 'u1db_content' else 1 +                attachment['data'] = binascii.b2a_base64( +                    parts[index]).strip() +            couch_doc['_attachments'] = attachments +            gen_doc = self._allocate_new_generation( +                doc.doc_id, transaction_id, save=False) +            self.batch_docs[doc.doc_id] = couch_doc +            self.batch_docs[gen_doc['_id']] = gen_doc +            last_gen, last_trans_id = self.batch_generation +            self.batch_generation = (last_gen + 1, transaction_id) + +    def _new_resource(self, *path): +        """ +        Return a new resource for accessing a couch database. + +        :return: A resource for accessing a couch database. +        :rtype: couchdb.http.Resource +        """ +        # Workaround for: https://leap.se/code/issues/5448 +        url = couch_urljoin(self._database.resource.url, *path) +        resource = Resource(url, Session(timeout=COUCH_TIMEOUT)) +        resource.credentials = self._database.resource.credentials +        resource.headers = self._database.resource.headers.copy() +        return resource diff --git a/src/leap/soledad/common/couch/state.py b/src/leap/soledad/common/couch/state.py new file mode 100644 index 00000000..8cbe0934 --- /dev/null +++ b/src/leap/soledad/common/couch/state.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# state.py +# Copyright (C) 2015,2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Server state using CouchDatabase as backend. +""" +import couchdb +import re + +from six.moves.urllib.parse import urljoin + +from leap.soledad.common.log import getLogger +from leap.soledad.common.couch import CouchDatabase +from leap.soledad.common.couch import CONFIG_DOC_ID +from leap.soledad.common.couch import SCHEMA_VERSION +from leap.soledad.common.couch import SCHEMA_VERSION_KEY +from leap.soledad.common.command import exec_validated_cmd +from leap.soledad.common.l2db.remote.server_state import ServerState +from leap.soledad.common.l2db.errors import Unauthorized +from leap.soledad.common.errors import WrongCouchSchemaVersionError +from leap.soledad.common.errors import MissingCouchConfigDocumentError + + +logger = getLogger(__name__) + + +def is_db_name_valid(name): +    """ +    Validate a user database using a regular expression. + +    :param name: database name. +    :type name: str + +    :return: boolean for name vailidity +    :rtype: bool +    """ +    db_name_regex = "^user-[a-f0-9]+$" +    return re.match(db_name_regex, name) is not None + + +class CouchServerState(ServerState): + +    """ +    Inteface of the WSGI server with the CouchDB backend. +    """ + +    def __init__(self, couch_url, create_cmd=None, +                 check_schema_versions=False): +        """ +        Initialize the couch server state. + +        :param couch_url: The URL for the couch database. +        :type couch_url: str +        :param create_cmd: Command to be executed for user db creation. It will +                           receive a properly sanitized parameter with user db +                           name and should access CouchDB with necessary +                           privileges, which server lacks for security reasons. +        :type create_cmd: str +        :param check_schema_versions: Whether to check couch schema version of +                                      user dbs. Set to False as this is only +                                      intended to run once during start-up. +        :type check_schema_versions: bool +        """ +        self.couch_url = couch_url +        self.create_cmd = create_cmd +        if check_schema_versions: +            self._check_schema_versions() + +    def _check_schema_versions(self): +        """ +        Check that all user databases use the correct couch schema. +        """ +        server = couchdb.client.Server(self.couch_url) +        for dbname in server: +            if not dbname.startswith('user-'): +                continue +            db = server[dbname] + +            # if there are documents, ensure that a config doc exists +            config_doc = db.get(CONFIG_DOC_ID) +            if config_doc: +                if config_doc[SCHEMA_VERSION_KEY] != SCHEMA_VERSION: +                    logger.error( +                        "Unsupported database schema in database %s" % dbname) +                    raise WrongCouchSchemaVersionError(dbname) +            else: +                result = db.view('_all_docs', limit=1) +                if result.total_rows != 0: +                    logger.error( +                        "Missing couch config document in database %s" +                        % dbname) +                    raise MissingCouchConfigDocumentError(dbname) + +    def open_database(self, dbname): +        """ +        Open a couch database. + +        :param dbname: The name of the database to open. +        :type dbname: str + +        :return: The SoledadBackend object. +        :rtype: SoledadBackend +        """ +        url = urljoin(self.couch_url, dbname) +        db = CouchDatabase.open_database(url, create=False) +        return db + +    def ensure_database(self, dbname): +        """ +        Ensure couch database exists. + +        :param dbname: The name of the database to ensure. +        :type dbname: str + +        :raise Unauthorized: If disabled or other error was raised. + +        :return: The SoledadBackend object and its replica_uid. +        :rtype: (SoledadBackend, str) +        """ +        if not self.create_cmd: +            raise Unauthorized() +        else: +            code, out = exec_validated_cmd(self.create_cmd, dbname, +                                           validator=is_db_name_valid) +            if code is not 0: +                logger.error(""" +                    Error while creating database (%s) with (%s) command. +                    Output: %s +                    Exit code: %d +                    """ % (dbname, self.create_cmd, out, code)) +                raise Unauthorized() +        db = self.open_database(dbname) +        return db, db.replica_uid + +    def delete_database(self, dbname): +        """ +        Delete couch database. + +        :param dbname: The name of the database to delete. +        :type dbname: str + +        :raise Unauthorized: Always, because Soledad server is not allowed to +                             delete databases. +        """ +        raise Unauthorized() diff --git a/src/leap/soledad/common/couch/support.py b/src/leap/soledad/common/couch/support.py new file mode 100644 index 00000000..bfc4fef6 --- /dev/null +++ b/src/leap/soledad/common/couch/support.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# support.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import sys + + +""" +Monkey patches and temporary code that may be removed with version changes. +""" + + +# for bigcouch +# TODO: Remove if bigcouch support is dropped +class MultipartWriter(object): + +    """ +    A multipart writer adapted from python-couchdb's one so we can PUT +    documents using couch's multipart PUT. + +    This stripped down version does not allow for nested structures, and +    contains only the essential things we need to PUT SoledadDocuments to the +    couch backend. Also, please note that this is a patch. The couchdb lib has +    another implementation that works fine with CouchDB 1.6, but removing this +    now will break compatibility with bigcouch. +    """ + +    CRLF = '\r\n' + +    def __init__(self, fileobj, headers=None, boundary=None): +        """ +        Initialize the multipart writer. +        """ +        self.fileobj = fileobj +        if boundary is None: +            boundary = self._make_boundary() +        self._boundary = boundary +        self._build_headers('related', headers) + +    def add(self, mimetype, content, headers={}): +        """ +        Add a part to the multipart stream. +        """ +        self.fileobj.write('--') +        self.fileobj.write(self._boundary) +        self.fileobj.write(self.CRLF) +        headers['Content-Type'] = mimetype +        self._write_headers(headers) +        if content: +            # XXX: throw an exception if a boundary appears in the content?? +            self.fileobj.write(content) +            self.fileobj.write(self.CRLF) + +    def close(self): +        """ +        Close the multipart stream. +        """ +        self.fileobj.write('--') +        self.fileobj.write(self._boundary) +        # be careful not to have anything after '--', otherwise old couch +        # versions (including bigcouch) will fail. +        self.fileobj.write('--') + +    def _make_boundary(self): +        """ +        Create a boundary to discern multi parts. +        """ +        try: +            from uuid import uuid4 +            return '==' + uuid4().hex + '==' +        except ImportError: +            from random import randrange +            token = randrange(sys.maxint) +            format = '%%0%dd' % len(repr(sys.maxint - 1)) +            return '===============' + (format % token) + '==' + +    def _write_headers(self, headers): +        """ +        Write a part header in the buffer stream. +        """ +        if headers: +            for name in sorted(headers.keys()): +                value = headers[name] +                self.fileobj.write(name) +                self.fileobj.write(': ') +                self.fileobj.write(value) +                self.fileobj.write(self.CRLF) +        self.fileobj.write(self.CRLF) + +    def _build_headers(self, subtype, headers): +        """ +        Build the main headers of the multipart stream. + +        This is here so we can send headers separete from content using +        python-couchdb API. +        """ +        self.headers = {} +        self.headers['Content-Type'] = 'multipart/%s; boundary="%s"' % \ +                                       (subtype, self._boundary) +        if headers: +            for name in sorted(headers.keys()): +                value = headers[name] +                self.headers[name] = value diff --git a/src/leap/soledad/common/crypto.py b/src/leap/soledad/common/crypto.py new file mode 100644 index 00000000..c13c4aa7 --- /dev/null +++ b/src/leap/soledad/common/crypto.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# crypto.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Soledad common crypto bits. +""" + + +# +# Encryption schemes used for encryption. +# + +class EncryptionSchemes(object): + +    """ +    Representation of encryption schemes used to encrypt documents. +    """ + +    NONE = 'none' +    SYMKEY = 'symkey' +    PUBKEY = 'pubkey' + + +class UnknownEncryptionSchemeError(Exception): + +    """ +    Raised when trying to decrypt from unknown encryption schemes. +    """ +    pass + + +class EncryptionMethods(object): + +    """ +    Representation of encryption methods that can be used. +    """ + +    AES_256_CTR = 'aes-256-ctr' + + +class UnknownEncryptionMethodError(Exception): + +    """ +    Raised when trying to encrypt/decrypt with unknown method. +    """ +    pass + + +class MacMethods(object): + +    """ +    Representation of MAC methods used to authenticate document's contents. +    """ + +    HMAC = 'hmac' + + +class UnknownMacMethodError(Exception): + +    """ +    Raised when trying to authenticate document's content with unknown MAC +    mehtod. +    """ +    pass + + +class WrongMacError(Exception): + +    """ +    Raised when failing to authenticate document's contents based on MAC. +    """ + + +# +# Crypto utilities for a SoledadDocument. +# + +ENC_JSON_KEY = '_enc_json' +ENC_SCHEME_KEY = '_enc_scheme' +ENC_METHOD_KEY = '_enc_method' +ENC_IV_KEY = '_enc_iv' +MAC_KEY = '_mac' +MAC_METHOD_KEY = '_mac_method' diff --git a/src/leap/soledad/common/document.py b/src/leap/soledad/common/document.py new file mode 100644 index 00000000..6c26a29f --- /dev/null +++ b/src/leap/soledad/common/document.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +# document.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +A Soledad Document is an l2db.Document with lasers. +""" + + +from .l2db import Document + + +# +# SoledadDocument +# + +class SoledadDocument(Document): + +    """ +    Encryptable and syncable document. + +    LEAP Documents can be flagged as syncable or not, so the replicas +    might not sync every document. +    """ + +    def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False, +                 syncable=True): +        """ +        Container for handling an encryptable document. + +        @param doc_id: The unique document identifier. +        @type doc_id: str +        @param rev: The revision identifier of the document. +        @type rev: str +        @param json: The JSON string for this document. +        @type json: str +        @param has_conflicts: Boolean indicating if this document has conflicts +        @type has_conflicts: bool +        @param syncable: Should this document be synced with remote replicas? +        @type syncable: bool +        """ +        Document.__init__(self, doc_id, rev, json, has_conflicts) +        self._syncable = syncable + +    def _get_syncable(self): +        """ +        Return whether this document is syncable. + +        @return: Is this document syncable? +        @rtype: bool +        """ +        return self._syncable + +    def _set_syncable(self, syncable=True): +        """ +        Determine if this document should be synced with remote replicas. + +        @param syncable: Should this document be synced with remote replicas? +        @type syncable: bool +        """ +        self._syncable = syncable + +    syncable = property( +        _get_syncable, +        _set_syncable, +        doc="Determine if document should be synced with server." +    ) + +    def _get_rev(self): +        """ +        Get the document revision. + +        Returning the revision as string solves the following exception in +        Twisted web: +            exceptions.TypeError: Can only pass-through bytes on Python 2 + +        @return: The document revision. +        @rtype: str +        """ +        if self._rev is None: +            return None +        return str(self._rev) + +    def _set_rev(self, rev): +        """ +        Set document revision. + +        @param rev: The new document revision. +        @type rev: bytes +        """ +        self._rev = rev + +    rev = property( +        _get_rev, +        _set_rev, +        doc="Wrapper to ensure `doc.rev` is always returned as bytes.") + + +class ServerDocument(SoledadDocument): +    """ +    This is the document used by server to hold conflicts and transactions +    on a database. + +    The goal is to ensure an atomic and consistent update of the database. +    """ + +    def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): +        """ +        Container for handling a document that stored on server. + +        :param doc_id: The unique document identifier. +        :type doc_id: str +        :param rev: The revision identifier of the document. +        :type rev: str +        :param json: The JSON string for this document. +        :type json: str +        :param has_conflicts: Boolean indicating if this document has conflicts +        :type has_conflicts: bool +        """ +        SoledadDocument.__init__(self, doc_id, rev, json, has_conflicts) +        self._conflicts = None + +    def get_conflicts(self): +        """ +        Get the conflicted versions of the document. + +        :return: The conflicted versions of the document. +        :rtype: [ServerDocument] +        """ +        return self._conflicts or [] + +    def set_conflicts(self, conflicts): +        """ +        Set the conflicted versions of the document. + +        :param conflicts: The conflicted versions of the document. +        :type conflicts: list +        """ +        self._conflicts = conflicts +        self.has_conflicts = len(self._conflicts) > 0 + +    def add_conflict(self, doc): +        """ +        Add a conflict to this document. + +        :param doc: The conflicted version to be added. +        :type doc: Document +        """ +        if self._conflicts is None: +            raise Exception("Fetch conflicts first!") +        self._conflicts.append(doc) +        self.has_conflicts = len(self._conflicts) > 0 + +    def delete_conflicts(self, conflict_revs): +        """ +        Delete conflicted versions of this document. + +        :param conflict_revs: The conflicted revisions to be deleted. +        :type conflict_revs: [str] +        """ +        if self._conflicts is None: +            raise Exception("Fetch conflicts first!") +        self._conflicts = filter( +            lambda doc: doc.rev not in conflict_revs, +            self._conflicts) +        self.has_conflicts = len(self._conflicts) > 0 diff --git a/src/leap/soledad/common/errors.py b/src/leap/soledad/common/errors.py new file mode 100644 index 00000000..d543a3de --- /dev/null +++ b/src/leap/soledad/common/errors.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# errors.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Soledad errors. +""" + +from .l2db import errors +from .l2db.remote import http_errors + + +def register_exception(cls): +    """ +    A small decorator that registers exceptions in u1db maps. +    """ +    # update u1db "wire description to status" and "wire description to +    # exception" maps. +    http_errors.wire_description_to_status.update({ +        cls.wire_description: cls.status}) +    errors.wire_description_to_exc.update({ +        cls.wire_description: cls}) +    # do not modify the exception +    return cls + + +class SoledadError(errors.U1DBError): + +    """ +    Base Soledad HTTP errors. +    """ +    pass + + +# +# Authorization errors +# + + +class DatabaseAccessError(Exception): +    pass + + +@register_exception +class InvalidAuthTokenError(errors.Unauthorized): + +    """ +    Exception raised when failing to get authorization for some action because +    the provided token either does not exist in the tokens database, has a +    distinct structure from the expected one, or is associated with a user +    with a distinct uuid than the one provided by the client. +    """ + +    wire_descrition = "invalid auth token" +    status = 401 + + +# +# SoledadBackend errors +# u1db error statuses also have to be updated +http_errors.ERROR_STATUSES = set( +    http_errors.wire_description_to_status.values()) + + +class InvalidURLError(Exception): +    """ +    Exception raised when Soledad encounters a malformed URL. +    """ + + +class BackendNotReadyError(SoledadError): +    """ +    Generic exception raised when the backend is not ready to dispatch a client +    request. +    """ +    wire_description = "backend not ready" +    status = 500 + + +class WrongCouchSchemaVersionError(SoledadError): +    """ +    Raised in case there is a user database with wrong couch schema version. +    """ + + +class MissingCouchConfigDocumentError(SoledadError): +    """ +    Raised if a database has documents but lacks the couch config document. +    """ diff --git a/src/leap/soledad/common/l2db/__init__.py b/src/leap/soledad/common/l2db/__init__.py new file mode 100644 index 00000000..43d61b1d --- /dev/null +++ b/src/leap/soledad/common/l2db/__init__.py @@ -0,0 +1,694 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""L2DB""" + +import json + +from leap.soledad.common.l2db.errors import InvalidJSON, InvalidContent + +__version_info__ = (13, 9) +__version__ = '.'.join(map(lambda x: '%02d' % x, __version_info__)) + + +def open(path, create, document_factory=None): +    """Open a database at the given location. + +    Will raise u1db.errors.DatabaseDoesNotExist if create=False and the +    database does not already exist. + +    :param path: The filesystem path for the database to open. +    :param create: True/False, should the database be created if it doesn't +        already exist? +    :param document_factory: A function that will be called with the same +        parameters as Document.__init__. +    :return: An instance of Database. +    """ +    from leap.soledad.client._db import sqlite +    return sqlite.SQLiteDatabase.open_database( +        path, create=create, document_factory=document_factory) + + +# constraints on database names (relevant for remote access, as regex) +DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*" + +# constraints on doc ids (as regex) +# (no slashes, and no characters outside the ascii range) +DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+" + + +class Database(object): +    """A JSON Document data store. + +    This data store can be synchronized with other u1db.Database instances. +    """ + +    def set_document_factory(self, factory): +        """Set the document factory that will be used to create objects to be +        returned as documents by the database. + +        :param factory: A function that returns an object which at minimum must +            satisfy the same interface as does the class DocumentBase. +            Subclassing that class is the easiest way to create such +            a function. +        """ +        raise NotImplementedError(self.set_document_factory) + +    def set_document_size_limit(self, limit): +        """Set the maximum allowed document size for this database. + +        :param limit: Maximum allowed document size in bytes. +        """ +        raise NotImplementedError(self.set_document_size_limit) + +    def whats_changed(self, old_generation=0): +        """Return a list of documents that have changed since old_generation. +        This allows APPS to only store a db generation before going +        'offline', and then when coming back online they can use this +        data to update whatever extra data they are storing. + +        :param old_generation: The generation of the database in the old +            state. +        :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) +            The current generation of the database, its associated transaction +            id, and a list of of changed documents since old_generation, +            represented by tuples with for each document its doc_id and the +            generation and transaction id corresponding to the last intervening +            change and sorted by generation (old changes first) +        """ +        raise NotImplementedError(self.whats_changed) + +    def get_doc(self, doc_id, include_deleted=False): +        """Get the JSON string for the given document. + +        :param doc_id: The unique document identifier +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise asking for a deleted +            document will return None. +        :return: a Document object. +        """ +        raise NotImplementedError(self.get_doc) + +    def get_docs(self, doc_ids, check_for_conflicts=True, +                 include_deleted=False): +        """Get the JSON content for many documents. + +        :param doc_ids: A list of document identifiers. +        :param check_for_conflicts: If set to False, then the conflict check +            will be skipped, and 'None' will be returned instead of True/False. +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise deleted documents will not +            be included in the results. +        :return: iterable giving the Document object for each document id +            in matching doc_ids order. +        """ +        raise NotImplementedError(self.get_docs) + +    def get_all_docs(self, include_deleted=False): +        """Get the JSON content for all documents in the database. + +        :param include_deleted: If set to True, deleted documents will be +            returned with empty content. Otherwise deleted documents will not +            be included in the results. +        :return: (generation, [Document]) +            The current generation of the database, followed by a list of all +            the documents in the database. +        """ +        raise NotImplementedError(self.get_all_docs) + +    def create_doc(self, content, doc_id=None): +        """Create a new document. + +        You can optionally specify the document identifier, but the document +        must not already exist. See 'put_doc' if you want to override an +        existing document. +        If the database specifies a maximum document size and the document +        exceeds it, create will fail and raise a DocumentTooBig exception. + +        :param content: A Python dictionary. +        :param doc_id: An optional identifier specifying the document id. +        :return: Document +        """ +        raise NotImplementedError(self.create_doc) + +    def create_doc_from_json(self, json, doc_id=None): +        """Create a new document. + +        You can optionally specify the document identifier, but the document +        must not already exist. See 'put_doc' if you want to override an +        existing document. +        If the database specifies a maximum document size and the document +        exceeds it, create will fail and raise a DocumentTooBig exception. + +        :param json: The JSON document string +        :param doc_id: An optional identifier specifying the document id. +        :return: Document +        """ +        raise NotImplementedError(self.create_doc_from_json) + +    def put_doc(self, doc): +        """Update a document. +        If the document currently has conflicts, put will fail. +        If the database specifies a maximum document size and the document +        exceeds it, put will fail and raise a DocumentTooBig exception. + +        :param doc: A Document with new content. +        :return: new_doc_rev - The new revision identifier for the document. +            The Document object will also be updated. +        """ +        raise NotImplementedError(self.put_doc) + +    def delete_doc(self, doc): +        """Mark a document as deleted. +        Will abort if the current revision doesn't match doc.rev. +        This will also set doc.content to None. +        """ +        raise NotImplementedError(self.delete_doc) + +    def create_index(self, index_name, *index_expressions): +        """Create an named index, which can then be queried for future lookups. +        Creating an index which already exists is not an error, and is cheap. +        Creating an index which does not match the index_expressions of the +        existing index is an error. +        Creating an index will block until the expressions have been evaluated +        and the index generated. + +        :param index_name: A unique name which can be used as a key prefix +        :param index_expressions: index expressions defining the index +            information. + +            Examples: + +            "fieldname", or "fieldname.subfieldname" to index alphabetically +            sorted on the contents of a field. + +            "number(fieldname, width)", "lower(fieldname)" +        """ +        raise NotImplementedError(self.create_index) + +    def delete_index(self, index_name): +        """Remove a named index. + +        :param index_name: The name of the index we are removing +        """ +        raise NotImplementedError(self.delete_index) + +    def list_indexes(self): +        """List the definitions of all known indexes. + +        :return: A list of [('index-name', ['field', 'field2'])] definitions. +        """ +        raise NotImplementedError(self.list_indexes) + +    def get_from_index(self, index_name, *key_values): +        """Return documents that match the keys supplied. + +        You must supply exactly the same number of values as have been defined +        in the index. It is possible to do a prefix match by using '*' to +        indicate a wildcard match. You can only supply '*' to trailing entries, +        (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) +        It is also possible to append a '*' to the last supplied value (eg +        'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + +        :param index_name: The index to query +        :param key_values: values to match. eg, if you have +            an index with 3 fields then you would have: +            get_from_index(index_name, val1, val2, val3) +        :return: List of [Document] +        """ +        raise NotImplementedError(self.get_from_index) + +    def get_range_from_index(self, index_name, start_value, end_value): +        """Return documents that fall within the specified range. + +        Both ends of the range are inclusive. For both start_value and +        end_value, one must supply exactly the same number of values as have +        been defined in the index, or pass None. In case of a single column +        index, a string is accepted as an alternative for a tuple with a single +        value. It is possible to do a prefix match by using '*' to indicate +        a wildcard match. You can only supply '*' to trailing entries, (eg +        'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also +        possible to append a '*' to the last supplied value (eg 'val*', '*', +        '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + +        :param index_name: The index to query +        :param start_values: tuples of values that define the lower bound of +            the range. eg, if you have an index with 3 fields then you would +            have: (val1, val2, val3) +        :param end_values: tuples of values that define the upper bound of the +            range. eg, if you have an index with 3 fields then you would have: +            (val1, val2, val3) +        :return: List of [Document] +        """ +        raise NotImplementedError(self.get_range_from_index) + +    def get_index_keys(self, index_name): +        """Return all keys under which documents are indexed in this index. + +        :param index_name: The index to query +        :return: [] A list of tuples of indexed keys. +        """ +        raise NotImplementedError(self.get_index_keys) + +    def get_doc_conflicts(self, doc_id): +        """Get the list of conflicts for the given document. + +        The order of the conflicts is such that the first entry is the value +        that would be returned by "get_doc". + +        :return: [doc] A list of the Document entries that are conflicted. +        """ +        raise NotImplementedError(self.get_doc_conflicts) + +    def resolve_doc(self, doc, conflicted_doc_revs): +        """Mark a document as no longer conflicted. + +        We take the list of revisions that the client knows about that it is +        superseding. This may be a different list from the actual current +        conflicts, in which case only those are removed as conflicted.  This +        may fail if the conflict list is significantly different from the +        supplied information. (sync could have happened in the background from +        the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + +        :param doc: A Document with the new content to be inserted. +        :param conflicted_doc_revs: A list of revisions that the new content +            supersedes. +        """ +        raise NotImplementedError(self.resolve_doc) + +    def get_sync_target(self): +        """Return a SyncTarget object, for another u1db to synchronize with. + +        :return: An instance of SyncTarget. +        """ +        raise NotImplementedError(self.get_sync_target) + +    def close(self): +        """Release any resources associated with this database.""" +        raise NotImplementedError(self.close) + +    def sync(self, url, creds=None, autocreate=True): +        """Synchronize documents with remote replica exposed at url. + +        :param url: the url of the target replica to sync with. +        :param creds: optional dictionary giving credentials +            to authorize the operation with the server. For using OAuth +            the form of creds is: +                {'oauth': { +                 'consumer_key': ..., +                 'consumer_secret': ..., +                 'token_key': ..., +                 'token_secret': ... +                }} +        :param autocreate: ask the target to create the db if non-existent. +        :return: local_gen_before_sync The local generation before the +            synchronisation was performed. This is useful to pass into +            whatschanged, if an application wants to know which documents were +            affected by a synchronisation. +        """ +        from u1db.sync import Synchronizer +        from u1db.remote.http_target import HTTPSyncTarget +        return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync( +            autocreate=autocreate) + +    def _get_replica_gen_and_trans_id(self, other_replica_uid): +        """Return the last known generation and transaction id for the other db +        replica. + +        When you do a synchronization with another replica, the Database keeps +        track of what generation the other database replica was at, and what +        the associated transaction id was.  This is used to determine what data +        needs to be sent, and if two databases are claiming to be the same +        replica. + +        :param other_replica_uid: The identifier for the other replica. +        :return: (gen, trans_id) The generation and transaction id we +            encountered during synchronization. If we've never synchronized +            with the replica, this is (0, ''). +        """ +        raise NotImplementedError(self._get_replica_gen_and_trans_id) + +    def _set_replica_gen_and_trans_id(self, other_replica_uid, +                                      other_generation, other_transaction_id): +        """Set the last-known generation and transaction id for the other +        database replica. + +        We have just performed some synchronization, and we want to track what +        generation the other replica was at. See also +        _get_replica_gen_and_trans_id. +        :param other_replica_uid: The U1DB identifier for the other replica. +        :param other_generation: The generation number for the other replica. +        :param other_transaction_id: The transaction id associated with the +            generation. +        """ +        raise NotImplementedError(self._set_replica_gen_and_trans_id) + +    def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, +                          replica_trans_id=''): +        """Insert/update document into the database with a given revision. + +        This api is used during synchronization operations. + +        If a document would conflict and save_conflict is set to True, the +        content will be selected as the 'current' content for doc.doc_id, +        even though doc.rev doesn't supersede the currently stored revision. +        The currently stored document will be added to the list of conflict +        alternatives for the given doc_id. + +        This forces the new content to be 'current' so that we get convergence +        after synchronizing, even if people don't resolve conflicts. Users can +        then notice that their content is out of date, update it, and +        synchronize again. (The alternative is that users could synchronize and +        think the data has propagated, but their local copy looks fine, and the +        remote copy is never updated again.) + +        :param doc: A Document object +        :param save_conflict: If this document is a conflict, do you want to +            save it as a conflict, or just ignore it. +        :param replica_uid: A unique replica identifier. +        :param replica_gen: The generation of the replica corresponding to the +            this document. The replica arguments are optional, but are used +            during synchronization. +        :param replica_trans_id: The transaction_id associated with the +            generation. +        :return: (state, at_gen) -  If we don't have doc_id already, +            or if doc_rev supersedes the existing document revision, +            then the content will be inserted, and state is 'inserted'. +            If doc_rev is less than or equal to the existing revision, +            then the put is ignored and state is respecitvely 'superseded' +            or 'converged'. +            If doc_rev is not strictly superseded or supersedes, then +            state is 'conflicted'. The document will not be inserted if +            save_conflict is False. +            For 'inserted' or 'converged', at_gen is the insertion/current +            generation. +        """ +        raise NotImplementedError(self._put_doc_if_newer) + + +class DocumentBase(object): +    """Container for handling a single document. + +    :ivar doc_id: Unique identifier for this document. +    :ivar rev: The revision identifier of the document. +    :ivar json_string: The JSON string for this document. +    :ivar has_conflicts: Boolean indicating if this document has conflicts +    """ + +    def __init__(self, doc_id, rev, json_string, has_conflicts=False): +        self.doc_id = doc_id +        self.rev = rev +        if json_string is not None: +            try: +                value = json.loads(json_string) +            except ValueError: +                raise InvalidJSON +            if not isinstance(value, dict): +                raise InvalidJSON +        self._json = json_string +        self.has_conflicts = has_conflicts + +    def same_content_as(self, other): +        """Compare the content of two documents.""" +        if self._json: +            c1 = json.loads(self._json) +        else: +            c1 = None +        if other._json: +            c2 = json.loads(other._json) +        else: +            c2 = None +        return c1 == c2 + +    def __repr__(self): +        if self.has_conflicts: +            extra = ', conflicted' +        else: +            extra = '' +        return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, +                                     self.rev, extra, self.get_json()) + +    def __hash__(self): +        raise NotImplementedError(self.__hash__) + +    def __eq__(self, other): +        if not isinstance(other, Document): +            return NotImplemented +        return ( +            self.doc_id == other.doc_id and self.rev == other.rev and +            self.same_content_as(other) and self.has_conflicts == +            other.has_conflicts) + +    def __lt__(self, other): +        """This is meant for testing, not part of the official api. + +        It is implemented so that sorted([Document, Document]) can be used. +        It doesn't imply that users would want their documents to be sorted in +        this order. +        """ +        # Since this is just for testing, we don't worry about comparing +        # against things that aren't a Document. +        return ((self.doc_id, self.rev, self.get_json()) < +                (other.doc_id, other.rev, other.get_json())) + +    def get_json(self): +        """Get the json serialization of this document.""" +        if self._json is not None: +            return self._json +        return None + +    def get_size(self): +        """Calculate the total size of the document.""" +        size = 0 +        json = self.get_json() +        if json: +            size += len(json) +        if self.rev: +            size += len(self.rev) +        if self.doc_id: +            size += len(self.doc_id) +        return size + +    def set_json(self, json_string): +        """Set the json serialization of this document.""" +        if json_string is not None: +            try: +                value = json.loads(json_string) +            except ValueError: +                raise InvalidJSON +            if not isinstance(value, dict): +                raise InvalidJSON +        self._json = json_string + +    def make_tombstone(self): +        """Make this document into a tombstone.""" +        self._json = None + +    def is_tombstone(self): +        """Return True if the document is a tombstone, False otherwise.""" +        if self._json is not None: +            return False +        return True + + +class Document(DocumentBase): +    """Container for handling a single document. + +    :ivar doc_id: Unique identifier for this document. +    :ivar rev: The revision identifier of the document. +    :ivar json: The JSON string for this document. +    :ivar has_conflicts: Boolean indicating if this document has conflicts +    """ + +    # The following part of the API is optional: no implementation is forced to +    # have it but if the language supports dictionaries/hashtables, it makes +    # Documents a lot more user friendly. + +    def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): +        # TODO: We convert the json in the superclass to check its validity so +        # we might as well set _content here directly since the price is +        # already being paid. +        super(Document, self).__init__(doc_id, rev, json, has_conflicts) +        self._content = None + +    def same_content_as(self, other): +        """Compare the content of two documents.""" +        if self._json: +            c1 = json.loads(self._json) +        else: +            c1 = self._content +        if other._json: +            c2 = json.loads(other._json) +        else: +            c2 = other._content +        return c1 == c2 + +    def get_json(self): +        """Get the json serialization of this document.""" +        json_string = super(Document, self).get_json() +        if json_string is not None: +            return json_string +        if self._content is not None: +            return json.dumps(self._content) +        return None + +    def set_json(self, json): +        """Set the json serialization of this document.""" +        self._content = None +        super(Document, self).set_json(json) + +    def make_tombstone(self): +        """Make this document into a tombstone.""" +        self._content = None +        super(Document, self).make_tombstone() + +    def is_tombstone(self): +        """Return True if the document is a tombstone, False otherwise.""" +        if self._content is not None: +            return False +        return super(Document, self).is_tombstone() + +    def _get_content(self): +        """Get the dictionary representing this document.""" +        if self._json is not None: +            self._content = json.loads(self._json) +            self._json = None +        if self._content is not None: +            return self._content +        return None + +    def _set_content(self, content): +        """Set the dictionary representing this document.""" +        try: +            tmp = json.dumps(content) +        except TypeError: +            raise InvalidContent( +                "Can not be converted to JSON: %r" % (content,)) +        if not tmp.startswith('{'): +            raise InvalidContent( +                "Can not be converted to a JSON object: %r." % (content,)) +        # We might as well store the JSON at this point since we did the work +        # of encoding it, and it doesn't lose any information. +        self._json = tmp +        self._content = None + +    content = property( +        _get_content, _set_content, doc="Content of the Document.") + +    # End of optional part. + + +class SyncTarget(object): +    """Functionality for using a Database as a synchronization target.""" + +    def get_sync_info(self, source_replica_uid): +        """Return information about known state. + +        Return the replica_uid and the current database generation of this +        database, and the last-seen database generation for source_replica_uid + +        :param source_replica_uid: Another replica which we might have +            synchronized with in the past. +        :return: (target_replica_uid, target_replica_generation, +            target_trans_id, source_replica_last_known_generation, +            source_replica_last_known_transaction_id) +        """ +        raise NotImplementedError(self.get_sync_info) + +    def record_sync_info(self, source_replica_uid, source_replica_generation, +                         source_replica_transaction_id): +        """Record tip information for another replica. + +        After sync_exchange has been processed, the caller will have +        received new content from this replica. This call allows the +        source replica instigating the sync to inform us what their +        generation became after applying the documents we returned. + +        This is used to allow future sync operations to not need to repeat data +        that we just talked about. It also means that if this is called at the +        wrong time, there can be database records that will never be +        synchronized. + +        :param source_replica_uid: The identifier for the source replica. +        :param source_replica_generation: +            The database generation for the source replica. +        :param source_replica_transaction_id: The transaction id associated +            with the source replica generation. +        """ +        raise NotImplementedError(self.record_sync_info) + +    def sync_exchange(self, docs_by_generation, source_replica_uid, +                      last_known_generation, last_known_trans_id, +                      return_doc_cb, ensure_callback=None): +        """Incorporate the documents sent from the source replica. + +        This is not meant to be called by client code directly, but is used as +        part of sync(). + +        This adds docs to the local store, and determines documents that need +        to be returned to the source replica. + +        Documents must be supplied in docs_by_generation paired with +        the generation of their latest change in order from the oldest +        change to the newest, that means from the oldest generation to +        the newest. + +        Documents are also returned paired with the generation of +        their latest change in order from the oldest change to the +        newest. + +        :param docs_by_generation: A list of [(Document, generation, +            transaction_id)] tuples indicating documents which should be +            updated on this replica paired with the generation and transaction +            id of their latest change. +        :param source_replica_uid: The source replica's identifier +        :param last_known_generation: The last generation that the source +            replica knows about this target replica +        :param last_known_trans_id: The last transaction id that the source +            replica knows about this target replica +        :param: return_doc_cb(doc, gen): is a callback +            used to return documents to the source replica, it will +            be invoked in turn with Documents that have changed since +            last_known_generation together with the generation of +            their last change. +        :param: ensure_callback(replica_uid): if set the target may create +            the target db if not yet existent, the callback can then +            be used to inform of the created db replica uid. +        :return: new_generation - After applying docs_by_generation, this is +            the current generation for this replica +        """ +        raise NotImplementedError(self.sync_exchange) + +    def _set_trace_hook(self, cb): +        """Set a callback that will be invoked to trace database actions. + +        The callback will be passed a string indicating the current state, and +        the sync target object.  Implementations do not have to implement this +        api, it is used by the test suite. + +        :param cb: A callable that takes cb(state) +        """ +        raise NotImplementedError(self._set_trace_hook) + +    def _set_trace_hook_shallow(self, cb): +        """Set a callback that will be invoked to trace database actions. + +        Similar to _set_trace_hook, for implementations that don't offer +        state changes from the inner working of sync_exchange(). + +        :param cb: A callable that takes cb(state) +        """ +        self._set_trace_hook(cb) diff --git a/src/leap/soledad/common/l2db/backends/__init__.py b/src/leap/soledad/common/l2db/backends/__init__.py new file mode 100644 index 00000000..c731c3d3 --- /dev/null +++ b/src/leap/soledad/common/l2db/backends/__init__.py @@ -0,0 +1,204 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""Abstract classes and common implementations for the backends.""" + +import re +import json +import uuid + +from leap.soledad.common import l2db +from leap.soledad.common.l2db import sync as l2db_sync +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.vectorclock import VectorClockRev + + +check_doc_id_re = re.compile("^" + l2db.DOC_ID_CONSTRAINTS + "$", re.UNICODE) + + +class CommonSyncTarget(l2db_sync.LocalSyncTarget): +    pass + + +class CommonBackend(l2db.Database): + +    document_size_limit = 0 + +    def _allocate_doc_id(self): +        """Generate a unique identifier for this document.""" +        return 'D-' + uuid.uuid4().hex  # 'D-' stands for document + +    def _allocate_transaction_id(self): +        return 'T-' + uuid.uuid4().hex  # 'T-' stands for transaction + +    def _allocate_doc_rev(self, old_doc_rev): +        vcr = VectorClockRev(old_doc_rev) +        vcr.increment(self._replica_uid) +        return vcr.as_str() + +    def _check_doc_id(self, doc_id): +        if not check_doc_id_re.match(doc_id): +            raise errors.InvalidDocId() + +    def _check_doc_size(self, doc): +        if not self.document_size_limit: +            return +        if doc.get_size() > self.document_size_limit: +            raise errors.DocumentTooBig + +    def _get_generation(self): +        """Return the current generation. + +        """ +        raise NotImplementedError(self._get_generation) + +    def _get_generation_info(self): +        """Return the current generation and transaction id. + +        """ +        raise NotImplementedError(self._get_generation_info) + +    def _get_doc(self, doc_id, check_for_conflicts=False): +        """Extract the document from storage. + +        This can return None if the document doesn't exist. +        """ +        raise NotImplementedError(self._get_doc) + +    def _has_conflicts(self, doc_id): +        """Return True if the doc has conflicts, False otherwise.""" +        raise NotImplementedError(self._has_conflicts) + +    def create_doc(self, content, doc_id=None): +        if not isinstance(content, dict): +            raise errors.InvalidContent +        json_string = json.dumps(content) +        return self.create_doc_from_json(json_string, doc_id) + +    def create_doc_from_json(self, json, doc_id=None): +        if doc_id is None: +            doc_id = self._allocate_doc_id() +        doc = self._factory(doc_id, None, json) +        self.put_doc(doc) +        return doc + +    def _get_transaction_log(self): +        """This is only for the test suite, it is not part of the api.""" +        raise NotImplementedError(self._get_transaction_log) + +    def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): +        raise NotImplementedError(self._put_and_update_indexes) + +    def get_docs(self, doc_ids, check_for_conflicts=True, +                 include_deleted=False): +        for doc_id in doc_ids: +            doc = self._get_doc( +                doc_id, check_for_conflicts=check_for_conflicts) +            if doc.is_tombstone() and not include_deleted: +                continue +            yield doc + +    def _get_trans_id_for_gen(self, generation): +        """Get the transaction id corresponding to a particular generation. + +        Raises an InvalidGeneration when the generation does not exist. + +        """ +        raise NotImplementedError(self._get_trans_id_for_gen) + +    def validate_gen_and_trans_id(self, generation, trans_id): +        """Validate the generation and transaction id. + +        Raises an InvalidGeneration when the generation does not exist, and an +        InvalidTransactionId when it does but with a different transaction id. + +        """ +        if generation == 0: +            return +        known_trans_id = self._get_trans_id_for_gen(generation) +        if known_trans_id != trans_id: +            raise errors.InvalidTransactionId + +    def _validate_source(self, other_replica_uid, other_generation, +                         other_transaction_id): +        """Validate the new generation and transaction id. + +        other_generation must be greater than what we have stored for this +        replica, *or* it must be the same and the transaction_id must be the +        same as well. +        """ +        (old_generation, +         old_transaction_id) = self._get_replica_gen_and_trans_id( +             other_replica_uid) +        if other_generation < old_generation: +            raise errors.InvalidGeneration +        if other_generation > old_generation: +            return +        if other_transaction_id == old_transaction_id: +            return +        raise errors.InvalidTransactionId + +    def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, +                          replica_trans_id=''): +        cur_doc = self._get_doc(doc.doc_id) +        doc_vcr = VectorClockRev(doc.rev) +        if cur_doc is None: +            cur_vcr = VectorClockRev(None) +        else: +            cur_vcr = VectorClockRev(cur_doc.rev) +        self._validate_source(replica_uid, replica_gen, replica_trans_id) +        if doc_vcr.is_newer(cur_vcr): +            rev = doc.rev +            self._prune_conflicts(doc, doc_vcr) +            if doc.rev != rev: +                # conflicts have been autoresolved +                state = 'superseded' +            else: +                state = 'inserted' +            self._put_and_update_indexes(cur_doc, doc) +        elif doc.rev == cur_doc.rev: +            # magical convergence +            state = 'converged' +        elif cur_vcr.is_newer(doc_vcr): +            # Don't add this to seen_ids, because we have something newer, +            # so we should send it back, and we should not generate a +            # conflict +            state = 'superseded' +        elif cur_doc.same_content_as(doc): +            # the documents have been edited to the same thing at both ends +            doc_vcr.maximize(cur_vcr) +            doc_vcr.increment(self._replica_uid) +            doc.rev = doc_vcr.as_str() +            self._put_and_update_indexes(cur_doc, doc) +            state = 'superseded' +        else: +            state = 'conflicted' +            if save_conflict: +                self._force_doc_sync_conflict(doc) +        if replica_uid is not None and replica_gen is not None: +            self._do_set_replica_gen_and_trans_id( +                replica_uid, replica_gen, replica_trans_id) +        return state, self._get_generation() + +    def _ensure_maximal_rev(self, cur_rev, extra_revs): +        vcr = VectorClockRev(cur_rev) +        for rev in extra_revs: +            vcr.maximize(VectorClockRev(rev)) +        vcr.increment(self._replica_uid) +        return vcr.as_str() + +    def set_document_size_limit(self, limit): +        self.document_size_limit = limit diff --git a/src/leap/soledad/common/l2db/backends/inmemory.py b/src/leap/soledad/common/l2db/backends/inmemory.py new file mode 100644 index 00000000..6fd251af --- /dev/null +++ b/src/leap/soledad/common/l2db/backends/inmemory.py @@ -0,0 +1,466 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""The in-memory Database class for U1DB.""" + +import json + +from leap.soledad.common.l2db import ( +    Document, errors, +    query_parser, vectorclock) +from leap.soledad.common.l2db.backends import CommonBackend, CommonSyncTarget + + +def get_prefix(value): +    key_prefix = '\x01'.join(value) +    return key_prefix.rstrip('*') + + +class InMemoryDatabase(CommonBackend): +    """A database that only stores the data internally.""" + +    def __init__(self, replica_uid, document_factory=None): +        self._transaction_log = [] +        self._docs = {} +        # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner' +        self._conflicts = {} +        self._other_generations = {} +        self._indexes = {} +        self._replica_uid = replica_uid +        self._factory = document_factory or Document + +    def _set_replica_uid(self, replica_uid): +        """Force the replica_uid to be set.""" +        self._replica_uid = replica_uid + +    def set_document_factory(self, factory): +        self._factory = factory + +    def close(self): +        # This is a no-op, We don't want to free the data because one client +        # may be closing it, while another wants to inspect the results. +        pass + +    def _get_replica_gen_and_trans_id(self, other_replica_uid): +        return self._other_generations.get(other_replica_uid, (0, '')) + +    def _set_replica_gen_and_trans_id(self, other_replica_uid, +                                      other_generation, other_transaction_id): +        self._do_set_replica_gen_and_trans_id( +            other_replica_uid, other_generation, other_transaction_id) + +    def _do_set_replica_gen_and_trans_id(self, other_replica_uid, +                                         other_generation, +                                         other_transaction_id): +        # TODO: to handle race conditions, we may want to check if the current +        #       value is greater than this new value. +        self._other_generations[other_replica_uid] = (other_generation, +                                                      other_transaction_id) + +    def get_sync_target(self): +        return InMemorySyncTarget(self) + +    def _get_transaction_log(self): +        # snapshot! +        return self._transaction_log[:] + +    def _get_generation(self): +        return len(self._transaction_log) + +    def _get_generation_info(self): +        if not self._transaction_log: +            return 0, '' +        return len(self._transaction_log), self._transaction_log[-1][1] + +    def _get_trans_id_for_gen(self, generation): +        if generation == 0: +            return '' +        if generation > len(self._transaction_log): +            raise errors.InvalidGeneration +        return self._transaction_log[generation - 1][1] + +    def put_doc(self, doc): +        if doc.doc_id is None: +            raise errors.InvalidDocId() +        self._check_doc_id(doc.doc_id) +        self._check_doc_size(doc) +        old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) +        if old_doc and old_doc.has_conflicts: +            raise errors.ConflictedDoc() +        if old_doc and doc.rev is None and old_doc.is_tombstone(): +            new_rev = self._allocate_doc_rev(old_doc.rev) +        else: +            if old_doc is not None: +                if old_doc.rev != doc.rev: +                    raise errors.RevisionConflict() +            else: +                if doc.rev is not None: +                    raise errors.RevisionConflict() +            new_rev = self._allocate_doc_rev(doc.rev) +        doc.rev = new_rev +        self._put_and_update_indexes(old_doc, doc) +        return new_rev + +    def _put_and_update_indexes(self, old_doc, doc): +        for index in self._indexes.itervalues(): +            if old_doc is not None and not old_doc.is_tombstone(): +                index.remove_json(old_doc.doc_id, old_doc.get_json()) +            if not doc.is_tombstone(): +                index.add_json(doc.doc_id, doc.get_json()) +        trans_id = self._allocate_transaction_id() +        self._docs[doc.doc_id] = (doc.rev, doc.get_json()) +        self._transaction_log.append((doc.doc_id, trans_id)) + +    def _get_doc(self, doc_id, check_for_conflicts=False): +        try: +            doc_rev, content = self._docs[doc_id] +        except KeyError: +            return None +        doc = self._factory(doc_id, doc_rev, content) +        if check_for_conflicts: +            doc.has_conflicts = (doc.doc_id in self._conflicts) +        return doc + +    def _has_conflicts(self, doc_id): +        return doc_id in self._conflicts + +    def get_doc(self, doc_id, include_deleted=False): +        doc = self._get_doc(doc_id, check_for_conflicts=True) +        if doc is None: +            return None +        if doc.is_tombstone() and not include_deleted: +            return None +        return doc + +    def get_all_docs(self, include_deleted=False): +        """Return all documents in the database.""" +        generation = self._get_generation() +        results = [] +        for doc_id, (doc_rev, content) in self._docs.items(): +            if content is None and not include_deleted: +                continue +            doc = self._factory(doc_id, doc_rev, content) +            doc.has_conflicts = self._has_conflicts(doc_id) +            results.append(doc) +        return (generation, results) + +    def get_doc_conflicts(self, doc_id): +        if doc_id not in self._conflicts: +            return [] +        result = [self._get_doc(doc_id)] +        result[0].has_conflicts = True +        result.extend([self._factory(doc_id, rev, content) +                       for rev, content in self._conflicts[doc_id]]) +        return result + +    def _replace_conflicts(self, doc, conflicts): +        if not conflicts: +            del self._conflicts[doc.doc_id] +        else: +            self._conflicts[doc.doc_id] = conflicts +        doc.has_conflicts = bool(conflicts) + +    def _prune_conflicts(self, doc, doc_vcr): +        if self._has_conflicts(doc.doc_id): +            autoresolved = False +            remaining_conflicts = [] +            cur_conflicts = self._conflicts[doc.doc_id] +            for c_rev, c_doc in cur_conflicts: +                c_vcr = vectorclock.VectorClockRev(c_rev) +                if doc_vcr.is_newer(c_vcr): +                    continue +                if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)): +                    doc_vcr.maximize(c_vcr) +                    autoresolved = True +                    continue +                remaining_conflicts.append((c_rev, c_doc)) +            if autoresolved: +                doc_vcr.increment(self._replica_uid) +                doc.rev = doc_vcr.as_str() +            self._replace_conflicts(doc, remaining_conflicts) + +    def resolve_doc(self, doc, conflicted_doc_revs): +        cur_doc = self._get_doc(doc.doc_id) +        if cur_doc is None: +            cur_rev = None +        else: +            cur_rev = cur_doc.rev +        new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs) +        superseded_revs = set(conflicted_doc_revs) +        remaining_conflicts = [] +        cur_conflicts = self._conflicts[doc.doc_id] +        for c_rev, c_doc in cur_conflicts: +            if c_rev in superseded_revs: +                continue +            remaining_conflicts.append((c_rev, c_doc)) +        doc.rev = new_rev +        if cur_rev in superseded_revs: +            self._put_and_update_indexes(cur_doc, doc) +        else: +            remaining_conflicts.append((new_rev, doc.get_json())) +        self._replace_conflicts(doc, remaining_conflicts) + +    def delete_doc(self, doc): +        if doc.doc_id not in self._docs: +            raise errors.DocumentDoesNotExist +        if self._docs[doc.doc_id][1] in ('null', None): +            raise errors.DocumentAlreadyDeleted +        doc.make_tombstone() +        self.put_doc(doc) + +    def create_index(self, index_name, *index_expressions): +        if index_name in self._indexes: +            if self._indexes[index_name]._definition == list( +                    index_expressions): +                return +            raise errors.IndexNameTakenError +        index = InMemoryIndex(index_name, list(index_expressions)) +        for doc_id, (doc_rev, doc) in self._docs.iteritems(): +            if doc is not None: +                index.add_json(doc_id, doc) +        self._indexes[index_name] = index + +    def delete_index(self, index_name): +        try: +            del self._indexes[index_name] +        except KeyError: +            pass + +    def list_indexes(self): +        definitions = [] +        for idx in self._indexes.itervalues(): +            definitions.append((idx._name, idx._definition)) +        return definitions + +    def get_from_index(self, index_name, *key_values): +        try: +            index = self._indexes[index_name] +        except KeyError: +            raise errors.IndexDoesNotExist +        doc_ids = index.lookup(key_values) +        result = [] +        for doc_id in doc_ids: +            result.append(self._get_doc(doc_id, check_for_conflicts=True)) +        return result + +    def get_range_from_index(self, index_name, start_value=None, +                             end_value=None): +        """Return all documents with key values in the specified range.""" +        try: +            index = self._indexes[index_name] +        except KeyError: +            raise errors.IndexDoesNotExist +        if isinstance(start_value, basestring): +            start_value = (start_value,) +        if isinstance(end_value, basestring): +            end_value = (end_value,) +        doc_ids = index.lookup_range(start_value, end_value) +        result = [] +        for doc_id in doc_ids: +            result.append(self._get_doc(doc_id, check_for_conflicts=True)) +        return result + +    def get_index_keys(self, index_name): +        try: +            index = self._indexes[index_name] +        except KeyError: +            raise errors.IndexDoesNotExist +        keys = index.keys() +        # XXX inefficiency warning +        return list(set([tuple(key.split('\x01')) for key in keys])) + +    def whats_changed(self, old_generation=0): +        changes = [] +        relevant_tail = self._transaction_log[old_generation:] +        # We don't use len(self._transaction_log) because _transaction_log may +        # get mutated by a concurrent operation. +        cur_generation = old_generation + len(relevant_tail) +        last_trans_id = '' +        if relevant_tail: +            last_trans_id = relevant_tail[-1][1] +        elif self._transaction_log: +            last_trans_id = self._transaction_log[-1][1] +        seen = set() +        generation = cur_generation +        for doc_id, trans_id in reversed(relevant_tail): +            if doc_id not in seen: +                changes.append((doc_id, generation, trans_id)) +                seen.add(doc_id) +            generation -= 1 +        changes.reverse() +        return (cur_generation, last_trans_id, changes) + +    def _force_doc_sync_conflict(self, doc): +        my_doc = self._get_doc(doc.doc_id) +        self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) +        self._conflicts.setdefault(doc.doc_id, []).append( +            (my_doc.rev, my_doc.get_json())) +        doc.has_conflicts = True +        self._put_and_update_indexes(my_doc, doc) + + +class InMemoryIndex(object): +    """Interface for managing an Index.""" + +    def __init__(self, index_name, index_definition): +        self._name = index_name +        self._definition = index_definition +        self._values = {} +        parser = query_parser.Parser() +        self._getters = parser.parse_all(self._definition) + +    def evaluate_json(self, doc): +        """Determine the 'key' after applying this index to the doc.""" +        raw = json.loads(doc) +        return self.evaluate(raw) + +    def evaluate(self, obj): +        """Evaluate a dict object, applying this definition.""" +        all_rows = [[]] +        for getter in self._getters: +            new_rows = [] +            keys = getter.get(obj) +            if not keys: +                return [] +            for key in keys: +                new_rows.extend([row + [key] for row in all_rows]) +            all_rows = new_rows +        all_rows = ['\x01'.join(row) for row in all_rows] +        return all_rows + +    def add_json(self, doc_id, doc): +        """Add this json doc to the index.""" +        keys = self.evaluate_json(doc) +        if not keys: +            return +        for key in keys: +            self._values.setdefault(key, []).append(doc_id) + +    def remove_json(self, doc_id, doc): +        """Remove this json doc from the index.""" +        keys = self.evaluate_json(doc) +        if keys: +            for key in keys: +                doc_ids = self._values[key] +                doc_ids.remove(doc_id) +                if not doc_ids: +                    del self._values[key] + +    def _find_non_wildcards(self, values): +        """Check if this should be a wildcard match. + +        Further, this will raise an exception if the syntax is improperly +        defined. + +        :return: The offset of the last value we need to match against. +        """ +        if len(values) != len(self._definition): +            raise errors.InvalidValueForIndex() +        is_wildcard = False +        last = 0 +        for idx, val in enumerate(values): +            if val.endswith('*'): +                if val != '*': +                    # We have an 'x*' style wildcard +                    if is_wildcard: +                        # We were already in wildcard mode, so this is invalid +                        raise errors.InvalidGlobbing +                    last = idx + 1 +                is_wildcard = True +            else: +                if is_wildcard: +                    # We were in wildcard mode, we can't follow that with +                    # non-wildcard +                    raise errors.InvalidGlobbing +                last = idx + 1 +        if not is_wildcard: +            return -1 +        return last + +    def lookup(self, values): +        """Find docs that match the values.""" +        last = self._find_non_wildcards(values) +        if last == -1: +            return self._lookup_exact(values) +        else: +            return self._lookup_prefix(values[:last]) + +    def lookup_range(self, start_values, end_values): +        """Find docs within the range.""" +        # TODO: Wildly inefficient, which is unlikely to be a problem for the +        # inmemory implementation. +        if start_values: +            self._find_non_wildcards(start_values) +            start_values = get_prefix(start_values) +        if end_values: +            if self._find_non_wildcards(end_values) == -1: +                exact = True +            else: +                exact = False +            end_values = get_prefix(end_values) +        found = [] +        for key, doc_ids in sorted(self._values.iteritems()): +            if start_values and start_values > key: +                continue +            if end_values and end_values < key: +                if exact: +                    break +                else: +                    if not key.startswith(end_values): +                        break +            found.extend(doc_ids) +        return found + +    def keys(self): +        """Find the indexed keys.""" +        return self._values.keys() + +    def _lookup_prefix(self, value): +        """Find docs that match the prefix string in values.""" +        # TODO: We need a different data structure to make prefix style fast, +        #       some sort of sorted list would work, but a plain dict doesn't. +        key_prefix = get_prefix(value) +        all_doc_ids = [] +        for key, doc_ids in sorted(self._values.iteritems()): +            if key.startswith(key_prefix): +                all_doc_ids.extend(doc_ids) +        return all_doc_ids + +    def _lookup_exact(self, value): +        """Find docs that match exactly.""" +        key = '\x01'.join(value) +        if key in self._values: +            return self._values[key] +        return () + + +class InMemorySyncTarget(CommonSyncTarget): + +    def get_sync_info(self, source_replica_uid): +        source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( +            source_replica_uid) +        my_gen, my_trans_id = self._db._get_generation_info() +        return ( +            self._db._replica_uid, my_gen, my_trans_id, source_gen, +            source_trans_id) + +    def record_sync_info(self, source_replica_uid, source_replica_generation, +                         source_transaction_id): +        if self._trace_hook: +            self._trace_hook('record_sync_info') +        self._db._set_replica_gen_and_trans_id( +            source_replica_uid, source_replica_generation, +            source_transaction_id) diff --git a/src/leap/soledad/common/l2db/errors.py b/src/leap/soledad/common/l2db/errors.py new file mode 100644 index 00000000..b502fc2d --- /dev/null +++ b/src/leap/soledad/common/l2db/errors.py @@ -0,0 +1,194 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""A list of errors that u1db can raise.""" + + +class U1DBError(Exception): +    """Generic base class for U1DB errors.""" + +    # description/tag for identifying the error during transmission (http,...) +    wire_description = "error" + +    def __init__(self, message=None): +        self.message = message + + +class RevisionConflict(U1DBError): +    """The document revisions supplied does not match the current version.""" + +    wire_description = "revision conflict" + + +class InvalidJSON(U1DBError): +    """Content was not valid json.""" + + +class InvalidContent(U1DBError): +    """Content was not a python dictionary.""" + + +class InvalidDocId(U1DBError): +    """A document was requested with an invalid document identifier.""" + +    wire_description = "invalid document id" + + +class MissingDocIds(U1DBError): +    """Needs document ids.""" + +    wire_description = "missing document ids" + + +class DocumentTooBig(U1DBError): +    """Document exceeds the maximum document size for this database.""" + +    wire_description = "document too big" + + +class UserQuotaExceeded(U1DBError): +    """Document exceeds the maximum document size for this database.""" + +    wire_description = "user quota exceeded" + + +class SubscriptionNeeded(U1DBError): +    """User needs a subscription to be able to use this replica..""" + +    wire_description = "user needs subscription" + + +class InvalidTransactionId(U1DBError): +    """Invalid transaction for generation.""" + +    wire_description = "invalid transaction id" + + +class InvalidGeneration(U1DBError): +    """Generation was previously synced with a different transaction id.""" + +    wire_description = "invalid generation" + + +class InvalidReplicaUID(U1DBError): +    """Attempting to sync a database with itself.""" + +    wire_description = "invalid replica uid" + + +class ConflictedDoc(U1DBError): +    """The document is conflicted, you must call resolve before put()""" + + +class InvalidValueForIndex(U1DBError): +    """The values supplied does not match the index definition.""" + + +class InvalidGlobbing(U1DBError): +    """Raised if wildcard matches are not strictly at the tail of the request. +    """ + + +class DocumentDoesNotExist(U1DBError): +    """The document does not exist.""" + +    wire_description = "document does not exist" + + +class DocumentAlreadyDeleted(U1DBError): +    """The document was already deleted.""" + +    wire_description = "document already deleted" + + +class DatabaseDoesNotExist(U1DBError): +    """The database does not exist.""" + +    wire_description = "database does not exist" + + +class IndexNameTakenError(U1DBError): +    """The given index name is already taken.""" + + +class IndexDefinitionParseError(U1DBError): +    """The index definition cannot be parsed.""" + + +class IndexDoesNotExist(U1DBError): +    """No index of that name exists.""" + + +class Unauthorized(U1DBError): +    """Request wasn't authorized properly.""" + +    wire_description = "unauthorized" + + +class HTTPError(U1DBError): +    """Unspecific HTTP errror.""" + +    wire_description = None + +    def __init__(self, status, message=None, headers={}): +        self.status = status +        self.message = message +        self.headers = headers + +    def __str__(self): +        if not self.message: +            return "HTTPError(%d)" % self.status +        else: +            return "HTTPError(%d, %r)" % (self.status, self.message) + + +class Unavailable(HTTPError): +    """Server not available not serve request.""" + +    wire_description = "unavailable" + +    def __init__(self, message=None, headers={}): +        super(Unavailable, self).__init__(503, message, headers) + +    def __str__(self): +        if not self.message: +            return "Unavailable()" +        else: +            return "Unavailable(%r)" % self.message + + +class BrokenSyncStream(U1DBError): +    """Unterminated or otherwise broken sync exchange stream.""" + +    wire_description = None + + +class UnknownAuthMethod(U1DBError): +    """Unknown auhorization method.""" + +    wire_description = None + + +# mapping wire (transimission) descriptions/tags for errors to the exceptions +wire_description_to_exc = dict( +    (x.wire_description, x) for x in globals().values() +    if getattr(x, 'wire_description', None) not in (None, "error")) +wire_description_to_exc["error"] = U1DBError + + +# +# wire error descriptions not corresponding to an exception +DOCUMENT_DELETED = "document deleted" diff --git a/src/leap/soledad/common/l2db/query_parser.py b/src/leap/soledad/common/l2db/query_parser.py new file mode 100644 index 00000000..15a9ac80 --- /dev/null +++ b/src/leap/soledad/common/l2db/query_parser.py @@ -0,0 +1,371 @@ +# Copyright 2011 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. +""" +Code for parsing Index definitions. +""" + +import re + +from leap.soledad.common.l2db import errors + + +class Getter(object): +    """Get values from a document based on a specification.""" + +    def get(self, raw_doc): +        """Get a value from the document. + +        :param raw_doc: a python dictionary to get the value from. +        :return: A list of values that match the description. +        """ +        raise NotImplementedError(self.get) + + +class StaticGetter(Getter): +    """A getter that returns a defined value (independent of the doc).""" + +    def __init__(self, value): +        """Create a StaticGetter. + +        :param value: the value to return when get is called. +        """ +        if value is None: +            self.value = [] +        elif isinstance(value, list): +            self.value = value +        else: +            self.value = [value] + +    def get(self, raw_doc): +        return self.value + + +def extract_field(raw_doc, subfields, index=0): +    if not isinstance(raw_doc, dict): +        return [] +    val = raw_doc.get(subfields[index]) +    if val is None: +        return [] +    if index < len(subfields) - 1: +        if isinstance(val, list): +            results = [] +            for item in val: +                results.extend(extract_field(item, subfields, index + 1)) +            return results +        if isinstance(val, dict): +            return extract_field(val, subfields, index + 1) +        return [] +    if isinstance(val, dict): +        return [] +    if isinstance(val, list): +        # Strip anything in the list that isn't a simple type +        return [v for v in val if not isinstance(v, (dict, list))] +    return [val] + + +class ExtractField(Getter): +    """Extract a field from the document.""" + +    def __init__(self, field): +        """Create an ExtractField object. + +        When a document is passed to get() this will return a value +        from the document based on the field specifier passed to +        the constructor. + +        None will be returned if the field is nonexistant, or refers to an +        object, rather than a simple type or list of simple types. + +        :param field: a specifier for the field to return. +            This is either a field name, or a dotted field name. +        """ +        self.field = field.split('.') + +    def get(self, raw_doc): +        return extract_field(raw_doc, self.field) + + +class Transformation(Getter): +    """A transformation on a value from another Getter.""" + +    name = None +    arity = 1 +    args = ['expression'] + +    def __init__(self, inner): +        """Create a transformation. + +        :param inner: the argument(s) to the transformation. +        """ +        self.inner = inner + +    def get(self, raw_doc): +        inner_values = self.inner.get(raw_doc) +        assert isinstance(inner_values, list),\ +            'get() should always return a list' +        return self.transform(inner_values) + +    def transform(self, values): +        """Transform the values. + +        This should be implemented by subclasses to transform the +        value when get() is called. + +        :param values: the values from the other Getter +        :return: the transformed values. +        """ +        raise NotImplementedError(self.transform) + + +class Lower(Transformation): +    """Lowercase a string. + +    This transformation will return None for non-string inputs. However, +    it will lowercase any strings in a list, dropping any elements +    that are not strings. +    """ + +    name = "lower" + +    def _can_transform(self, val): +        return isinstance(val, basestring) + +    def transform(self, values): +        if not values: +            return [] +        return [val.lower() for val in values if self._can_transform(val)] + + +class Number(Transformation): +    """Convert an integer to a zero padded string. + +    This transformation will return None for non-integer inputs. However, it +    will transform any integers in a list, dropping any elements that are not +    integers. +    """ + +    name = 'number' +    arity = 2 +    args = ['expression', int] + +    def __init__(self, inner, number): +        super(Number, self).__init__(inner) +        self.padding = "%%0%sd" % number + +    def _can_transform(self, val): +        return isinstance(val, int) and not isinstance(val, bool) + +    def transform(self, values): +        """Transform any integers in values into zero padded strings.""" +        if not values: +            return [] +        return [self.padding % (v,) for v in values if self._can_transform(v)] + + +class Bool(Transformation): +    """Convert bool to string.""" + +    name = "bool" +    args = ['expression'] + +    def _can_transform(self, val): +        return isinstance(val, bool) + +    def transform(self, values): +        """Transform any booleans in values into strings.""" +        if not values: +            return [] +        return [('1' if v else '0') for v in values if self._can_transform(v)] + + +class SplitWords(Transformation): +    """Split a string on whitespace. + +    This Getter will return [] for non-string inputs. It will however +    split any strings in an input list, discarding any elements that +    are not strings. +    """ + +    name = "split_words" + +    def _can_transform(self, val): +        return isinstance(val, basestring) + +    def transform(self, values): +        if not values: +            return [] +        result = set() +        for value in values: +            if self._can_transform(value): +                for word in value.split(): +                    result.add(word) +        return list(result) + + +class Combine(Transformation): +    """Combine multiple expressions into a single index.""" + +    name = "combine" +    # variable number of args +    arity = -1 + +    def __init__(self, *inner): +        super(Combine, self).__init__(inner) + +    def get(self, raw_doc): +        inner_values = [] +        for inner in self.inner: +            inner_values.extend(inner.get(raw_doc)) +        return self.transform(inner_values) + +    def transform(self, values): +        return values + + +class IsNull(Transformation): +    """Indicate whether the input is None. + +    This Getter returns a bool indicating whether the input is nil. +    """ + +    name = "is_null" + +    def transform(self, values): +        return [len(values) == 0] + + +def check_fieldname(fieldname): +    if fieldname.endswith('.'): +        raise errors.IndexDefinitionParseError( +            "Fieldname cannot end in '.':%s^" % (fieldname,)) + + +class Parser(object): +    """Parse an index expression into a sequence of transformations.""" + +    _transformations = {} +    _delimiters = re.compile("\(|\)|,") + +    def __init__(self): +        self._tokens = [] + +    def _set_expression(self, expression): +        self._open_parens = 0 +        self._tokens = [] +        expression = expression.strip() +        while expression: +            delimiter = self._delimiters.search(expression) +            if delimiter: +                idx = delimiter.start() +                if idx == 0: +                    result, expression = (expression[:1], expression[1:]) +                    self._tokens.append(result) +                else: +                    result, expression = (expression[:idx], expression[idx:]) +                    result = result.strip() +                    if result: +                        self._tokens.append(result) +            else: +                expression = expression.strip() +                if expression: +                    self._tokens.append(expression) +                expression = None + +    def _get_token(self): +        if self._tokens: +            return self._tokens.pop(0) + +    def _peek_token(self): +        if self._tokens: +            return self._tokens[0] + +    @staticmethod +    def _to_getter(term): +        if isinstance(term, Getter): +            return term +        check_fieldname(term) +        return ExtractField(term) + +    def _parse_op(self, op_name): +        self._get_token()  # '(' +        op = self._transformations.get(op_name, None) +        if op is None: +            raise errors.IndexDefinitionParseError( +                "Unknown operation: %s" % op_name) +        args = [] +        while True: +            args.append(self._parse_term()) +            sep = self._get_token() +            if sep == ')': +                break +            if sep != ',': +                raise errors.IndexDefinitionParseError( +                    "Unexpected token '%s' in parentheses." % (sep,)) +        parsed = [] +        for i, arg in enumerate(args): +            arg_type = op.args[i % len(op.args)] +            if arg_type == 'expression': +                inner = self._to_getter(arg) +            else: +                try: +                    inner = arg_type(arg) +                except ValueError as e: +                    raise errors.IndexDefinitionParseError( +                        "Invalid value %r for argument type %r " +                        "(%r)." % (arg, arg_type, e)) +            parsed.append(inner) +        return op(*parsed) + +    def _parse_term(self): +        term = self._get_token() +        if term is None: +            raise errors.IndexDefinitionParseError( +                "Unexpected end of index definition.") +        if term in (',', ')', '('): +            raise errors.IndexDefinitionParseError( +                "Unexpected token '%s' at start of expression." % (term,)) +        next_token = self._peek_token() +        if next_token == '(': +            return self._parse_op(term) +        return term + +    def parse(self, expression): +        self._set_expression(expression) +        term = self._to_getter(self._parse_term()) +        if self._peek_token(): +            raise errors.IndexDefinitionParseError( +                "Unexpected token '%s' after end of expression." +                % (self._peek_token(),)) +        return term + +    def parse_all(self, fields): +        return [self.parse(field) for field in fields] + +    @classmethod +    def register_transormation(cls, transform): +        assert transform.name not in cls._transformations, ( +            "Transform %s already registered for %s" +            % (transform.name, cls._transformations[transform.name])) +        cls._transformations[transform.name] = transform + + +Parser.register_transormation(SplitWords) +Parser.register_transormation(Lower) +Parser.register_transormation(Number) +Parser.register_transormation(Bool) +Parser.register_transormation(IsNull) +Parser.register_transormation(Combine) diff --git a/src/leap/soledad/common/l2db/remote/__init__.py b/src/leap/soledad/common/l2db/remote/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. diff --git a/src/leap/soledad/common/l2db/remote/http_app.py b/src/leap/soledad/common/l2db/remote/http_app.py new file mode 100644 index 00000000..a4eddb36 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_app.py @@ -0,0 +1,660 @@ +# Copyright 2011-2012 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +""" +HTTP Application exposing U1DB. +""" +# TODO -- deprecate, use twisted/txaio. + +import functools +import six.moves.http_client as httplib +import inspect +import json +import sys +import six.moves.urllib.parse as urlparse + +import routes.mapper + +from leap.soledad.common.l2db import ( +    __version__ as _u1db_version, +    DBNAME_CONSTRAINTS, Document, +    errors, sync) +from leap.soledad.common.l2db.remote import http_errors, utils + + +def parse_bool(expression): +    """Parse boolean querystring parameter.""" +    if expression == 'true': +        return True +    return False + + +def parse_list(expression): +    if not expression: +        return [] +    return [t.strip() for t in expression.split(',')] + + +def none_or_str(expression): +    if expression is None: +        return None +    return str(expression) + + +class BadRequest(Exception): +    """Bad request.""" + + +class _FencedReader(object): +    """Read and get lines from a file but not past a given length.""" + +    MAXCHUNK = 8192 + +    def __init__(self, rfile, total, max_entry_size): +        self.rfile = rfile +        self.remaining = total +        self.max_entry_size = max_entry_size +        self._kept = None + +    def read_chunk(self, atmost): +        if self._kept is not None: +            # ignore atmost, kept data should be a subchunk anyway +            kept, self._kept = self._kept, None +            return kept +        if self.remaining == 0: +            return '' +        data = self.rfile.read(min(self.remaining, atmost)) +        self.remaining -= len(data) +        return data + +    def getline(self): +        line_parts = [] +        size = 0 +        while True: +            chunk = self.read_chunk(self.MAXCHUNK) +            if chunk == '': +                break +            nl = chunk.find("\n") +            if nl != -1: +                size += nl + 1 +                if size > self.max_entry_size: +                    raise BadRequest +                line_parts.append(chunk[:nl + 1]) +                rest = chunk[nl + 1:] +                self._kept = rest or None +                break +            else: +                size += len(chunk) +                if size > self.max_entry_size: +                    raise BadRequest +                line_parts.append(chunk) +        return ''.join(line_parts) + + +def http_method(**control): +    """Decoration for handling of query arguments and content for a HTTP +       method. + +       args and content here are the query arguments and body of the incoming +       HTTP requests. + +       Match query arguments to python method arguments: +           w = http_method()(f) +           w(self, args, content) => args["content"]=content; +                                     f(self, **args) + +       JSON deserialize content to arguments: +           w = http_method(content_as_args=True,...)(f) +           w(self, args, content) => args.update(json.loads(content)); +                                     f(self, **args) + +       Support conversions (e.g int): +           w = http_method(Arg=Conv,...)(f) +           w(self, args, content) => args["Arg"]=Conv(args["Arg"]); +                                     f(self, **args) + +       Enforce no use of query arguments: +           w = http_method(no_query=True,...)(f) +           w(self, args, content) raises BadRequest if args is not empty + +       Argument mismatches, deserialisation failures produce BadRequest. +    """ +    content_as_args = control.pop('content_as_args', False) +    no_query = control.pop('no_query', False) +    conversions = control.items() + +    def wrap(f): +        argspec = inspect.getargspec(f) +        assert argspec.args[0] == "self" +        nargs = len(argspec.args) +        ndefaults = len(argspec.defaults or ()) +        required_args = set(argspec.args[1:nargs - ndefaults]) +        all_args = set(argspec.args) + +        @functools.wraps(f) +        def wrapper(self, args, content): +            if no_query and args: +                raise BadRequest() +            if content is not None: +                if content_as_args: +                    try: +                        args.update(json.loads(content)) +                    except ValueError: +                        raise BadRequest() +                else: +                    args["content"] = content +            if not (required_args <= set(args) <= all_args): +                raise BadRequest("Missing required arguments.") +            for name, conv in conversions: +                if name not in args: +                    continue +                try: +                    args[name] = conv(args[name]) +                except ValueError: +                    raise BadRequest() +            return f(self, **args) + +        return wrapper + +    return wrap + + +class URLToResource(object): +    """Mappings from URLs to resources.""" + +    def __init__(self): +        self._map = routes.mapper.Mapper(controller_scan=None) + +    def register(self, resource_cls): +        # register +        self._map.connect(None, resource_cls.url_pattern, +                          resource_cls=resource_cls, +                          requirements={"dbname": DBNAME_CONSTRAINTS}) +        self._map.create_regs() +        return resource_cls + +    def match(self, path): +        params = self._map.match(path) +        if params is None: +            return None, None +        resource_cls = params.pop('resource_cls') +        return resource_cls, params + + +url_to_resource = URLToResource() + + +@url_to_resource.register +class GlobalResource(object): +    """Global (root) resource.""" + +    url_pattern = "/" + +    def __init__(self, state, responder): +        self.state = state +        self.responder = responder + +    @http_method() +    def get(self): +        info = self.state.global_info() +        info['version'] = _u1db_version +        self.responder.send_response_json(**info) + + +@url_to_resource.register +class DatabaseResource(object): +    """Database resource.""" + +    url_pattern = "/{dbname}" + +    def __init__(self, dbname, state, responder): +        self.dbname = dbname +        self.state = state +        self.responder = responder + +    @http_method() +    def get(self): +        self.state.check_database(self.dbname) +        self.responder.send_response_json(200) + +    @http_method(content_as_args=True) +    def put(self): +        self.state.ensure_database(self.dbname) +        self.responder.send_response_json(200, ok=True) + +    @http_method() +    def delete(self): +        self.state.delete_database(self.dbname) +        self.responder.send_response_json(200, ok=True) + + +@url_to_resource.register +class DocsResource(object): +    """Documents resource.""" + +    url_pattern = "/{dbname}/docs" + +    def __init__(self, dbname, state, responder): +        self.responder = responder +        self.db = state.open_database(dbname) + +    @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool, +                 include_deleted=parse_bool) +    def get(self, doc_ids=None, check_for_conflicts=True, +            include_deleted=False): +        if doc_ids is None: +            raise errors.MissingDocIds +        docs = self.db.get_docs(doc_ids, include_deleted=include_deleted) +        self.responder.content_type = 'application/json' +        self.responder.start_response(200) +        self.responder.start_stream(), +        for doc in docs: +            entry = dict( +                doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), +                has_conflicts=doc.has_conflicts) +            self.responder.stream_entry(entry) +        self.responder.end_stream() +        self.responder.finish_response() + + +@url_to_resource.register +class AllDocsResource(object): +    """All Documents resource.""" + +    url_pattern = "/{dbname}/all-docs" + +    def __init__(self, dbname, state, responder): +        self.responder = responder +        self.db = state.open_database(dbname) + +    @http_method(include_deleted=parse_bool) +    def get(self, include_deleted=False): +        gen, docs = self.db.get_all_docs(include_deleted=include_deleted) +        self.responder.content_type = 'application/json' +        # returning a x-u1db-generation header is optional +        # HTTPDatabase will fallback to return -1 if it's missing +        self.responder.start_response(200, +                                      headers={'x-u1db-generation': str(gen)}) +        self.responder.start_stream(), +        for doc in docs: +            entry = dict( +                doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), +                has_conflicts=doc.has_conflicts) +            self.responder.stream_entry(entry) +        self.responder.end_stream() +        self.responder.finish_response() + + +@url_to_resource.register +class DocResource(object): +    """Document resource.""" + +    url_pattern = "/{dbname}/doc/{id:.*}" + +    def __init__(self, dbname, id, state, responder): +        self.id = id +        self.responder = responder +        self.db = state.open_database(dbname) + +    @http_method(old_rev=str) +    def put(self, content, old_rev=None): +        doc = Document(self.id, old_rev, content) +        doc_rev = self.db.put_doc(doc) +        if old_rev is None: +            status = 201  # created +        else: +            status = 200 +        self.responder.send_response_json(status, rev=doc_rev) + +    @http_method(old_rev=str) +    def delete(self, old_rev=None): +        doc = Document(self.id, old_rev, None) +        self.db.delete_doc(doc) +        self.responder.send_response_json(200, rev=doc.rev) + +    @http_method(include_deleted=parse_bool) +    def get(self, include_deleted=False): +        doc = self.db.get_doc(self.id, include_deleted=include_deleted) +        if doc is None: +            wire_descr = errors.DocumentDoesNotExist.wire_description +            self.responder.send_response_json( +                http_errors.wire_description_to_status[wire_descr], +                error=wire_descr, +                headers={ +                    'x-u1db-rev': '', +                    'x-u1db-has-conflicts': 'false' +                }) +            return +        headers = { +            'x-u1db-rev': doc.rev, +            'x-u1db-has-conflicts': json.dumps(doc.has_conflicts) +        } +        if doc.is_tombstone(): +            self.responder.send_response_json( +                http_errors.wire_description_to_status[ +                    errors.DOCUMENT_DELETED], +                error=errors.DOCUMENT_DELETED, +                headers=headers) +        else: +            self.responder.send_response_content( +                doc.get_json(), headers=headers) + + +@url_to_resource.register +class SyncResource(object): +    """Sync endpoint resource.""" + +    # maximum allowed request body size +    max_request_size = 15 * 1024 * 1024  # 15Mb +    # maximum allowed entry/line size in request body +    max_entry_size = 10 * 1024 * 1024    # 10Mb + +    url_pattern = "/{dbname}/sync-from/{source_replica_uid}" + +    # pluggable +    sync_exchange_class = sync.SyncExchange + +    def __init__(self, dbname, source_replica_uid, state, responder): +        self.source_replica_uid = source_replica_uid +        self.responder = responder +        self.state = state +        self.dbname = dbname +        self.replica_uid = None + +    def get_target(self): +        return self.state.open_database(self.dbname).get_sync_target() + +    @http_method() +    def get(self): +        result = self.get_target().get_sync_info(self.source_replica_uid) +        self.responder.send_response_json( +            target_replica_uid=result[0], target_replica_generation=result[1], +            target_replica_transaction_id=result[2], +            source_replica_uid=self.source_replica_uid, +            source_replica_generation=result[3], +            source_transaction_id=result[4]) + +    @http_method(generation=int, +                 content_as_args=True, no_query=True) +    def put(self, generation, transaction_id): +        self.get_target().record_sync_info(self.source_replica_uid, +                                           generation, +                                           transaction_id) +        self.responder.send_response_json(ok=True) + +    # Implements the same logic as LocalSyncTarget.sync_exchange + +    @http_method(last_known_generation=int, last_known_trans_id=none_or_str, +                 content_as_args=True) +    def post_args(self, last_known_generation, last_known_trans_id=None, +                  ensure=False): +        if ensure: +            db, self.replica_uid = self.state.ensure_database(self.dbname) +        else: +            db = self.state.open_database(self.dbname) +        db.validate_gen_and_trans_id( +            last_known_generation, last_known_trans_id) +        self.sync_exch = self.sync_exchange_class( +            db, self.source_replica_uid, last_known_generation) + +    @http_method(content_as_args=True) +    def post_stream_entry(self, id, rev, content, gen, trans_id): +        doc = Document(id, rev, content) +        self.sync_exch.insert_doc_from_source(doc, gen, trans_id) + +    def post_end(self): + +        def send_doc(doc, gen, trans_id): +            entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), +                         gen=gen, trans_id=trans_id) +            self.responder.stream_entry(entry) + +        new_gen = self.sync_exch.find_changes_to_return() +        self.responder.content_type = 'application/x-u1db-sync-stream' +        self.responder.start_response(200) +        self.responder.start_stream(), +        header = {"new_generation": new_gen, +                  "new_transaction_id": self.sync_exch.new_trans_id} +        if self.replica_uid is not None: +            header['replica_uid'] = self.replica_uid +        self.responder.stream_entry(header) +        self.sync_exch.return_docs(send_doc) +        self.responder.end_stream() +        self.responder.finish_response() + + +class HTTPResponder(object): +    """Encode responses from the server back to the client.""" + +    # a multi document response will put args and documents +    # each on one line of the response body + +    def __init__(self, start_response): +        self._started = False +        self._stream_state = -1 +        self._no_initial_obj = True +        self.sent_response = False +        self._start_response = start_response +        self._write = None +        self.content_type = 'application/json' +        self.content = [] + +    def start_response(self, status, obj_dic=None, headers={}): +        """start sending response with optional first json object.""" +        if self._started: +            return +        self._started = True +        status_text = httplib.responses[status] +        self._write = self._start_response( +            '%d %s' % (status, status_text), +            [('content-type', self.content_type), +             ('cache-control', 'no-cache')] + +            headers.items()) +        # xxx version in headers +        if obj_dic is not None: +            self._no_initial_obj = False +            self._write(json.dumps(obj_dic) + "\r\n") + +    def finish_response(self): +        """finish sending response.""" +        self.sent_response = True + +    def send_response_json(self, status=200, headers={}, **kwargs): +        """send and finish response with json object body from keyword args.""" +        content = json.dumps(kwargs) + "\r\n" +        self.send_response_content(content, headers=headers, status=status) + +    def send_response_content(self, content, status=200, headers={}): +        """send and finish response with content""" +        headers['content-length'] = str(len(content)) +        self.start_response(status, headers=headers) +        if self._stream_state == 1: +            self.content = [',\r\n', content] +        else: +            self.content = [content] +        self.finish_response() + +    def start_stream(self): +        "start stream (array) as part of the response." +        assert self._started and self._no_initial_obj +        self._stream_state = 0 +        self._write("[") + +    def stream_entry(self, entry): +        "send stream entry as part of the response." +        assert self._stream_state != -1 +        if self._stream_state == 0: +            self._stream_state = 1 +            self._write('\r\n') +        else: +            self._write(',\r\n') +        if type(entry) == dict: +            entry = json.dumps(entry) +        self._write(entry) + +    def end_stream(self): +        "end stream (array)." +        assert self._stream_state != -1 +        self._write("\r\n]\r\n") + + +class HTTPInvocationByMethodWithBody(object): +    """Invoke methods on a resource.""" + +    def __init__(self, resource, environ, parameters): +        self.resource = resource +        self.environ = environ +        self.max_request_size = getattr( +            resource, 'max_request_size', parameters.max_request_size) +        self.max_entry_size = getattr( +            resource, 'max_entry_size', parameters.max_entry_size) + +    def _lookup(self, method): +        try: +            return getattr(self.resource, method) +        except AttributeError: +            raise BadRequest() + +    def __call__(self): +        args = urlparse.parse_qsl(self.environ['QUERY_STRING'], +                                  strict_parsing=False) +        try: +            args = dict( +                (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) +        except ValueError: +            raise BadRequest() +        method = self.environ['REQUEST_METHOD'].lower() +        if method in ('get', 'delete'): +            meth = self._lookup(method) +            return meth(args, None) +        else: +            # we expect content-length > 0, reconsider if we move +            # to support chunked enconding +            try: +                content_length = int(self.environ['CONTENT_LENGTH']) +            except (ValueError, KeyError): +                raise BadRequest +            if content_length <= 0: +                raise BadRequest +            if content_length > self.max_request_size: +                raise BadRequest +            reader = _FencedReader(self.environ['wsgi.input'], content_length, +                                   self.max_entry_size) +            content_type = self.environ.get('CONTENT_TYPE', '') +            content_type = content_type.split(';', 1)[0].strip() +            if content_type == 'application/json': +                meth = self._lookup(method) +                body = reader.read_chunk(sys.maxint) +                return meth(args, body) +            elif content_type == 'application/x-u1db-sync-stream': +                meth_args = self._lookup('%s_args' % method) +                meth_entry = self._lookup('%s_stream_entry' % method) +                meth_end = self._lookup('%s_end' % method) +                body_getline = reader.getline +                if body_getline().strip() != '[': +                    raise BadRequest() +                line = body_getline() +                line, comma = utils.check_and_strip_comma(line.strip()) +                meth_args(args, line) +                while True: +                    line = body_getline() +                    entry = line.strip() +                    if entry == ']': +                        break +                    if not entry or not comma:  # empty or no prec comma +                        raise BadRequest +                    entry, comma = utils.check_and_strip_comma(entry) +                    meth_entry({}, entry) +                if comma or body_getline():  # extra comma or data +                    raise BadRequest +                return meth_end() +            else: +                raise BadRequest() + + +class HTTPApp(object): + +    # maximum allowed request body size +    max_request_size = 15 * 1024 * 1024  # 15Mb +    # maximum allowed entry/line size in request body +    max_entry_size = 10 * 1024 * 1024    # 10Mb + +    def __init__(self, state): +        self.state = state + +    def _lookup_resource(self, environ, responder): +        resource_cls, params = url_to_resource.match(environ['PATH_INFO']) +        if resource_cls is None: +            raise BadRequest  # 404 instead? +        resource = resource_cls( +            state=self.state, responder=responder, **params) +        return resource + +    def __call__(self, environ, start_response): +        responder = HTTPResponder(start_response) +        self.request_begin(environ) +        try: +            resource = self._lookup_resource(environ, responder) +            HTTPInvocationByMethodWithBody(resource, environ, self)() +        except errors.U1DBError as e: +            self.request_u1db_error(environ, e) +            status = http_errors.wire_description_to_status.get( +                e.wire_description, 500) +            responder.send_response_json(status, error=e.wire_description) +        except BadRequest: +            self.request_bad_request(environ) +            responder.send_response_json(400, error="bad request") +        except KeyboardInterrupt: +            raise +        except: +            self.request_failed(environ) +            raise +        else: +            self.request_done(environ) +        return responder.content + +    # hooks for tracing requests + +    def request_begin(self, environ): +        """Hook called at the beginning of processing a request.""" +        pass + +    def request_done(self, environ): +        """Hook called when done processing a request.""" +        pass + +    def request_u1db_error(self, environ, exc): +        """Hook called when processing a request resulted in a U1DBError. + +        U1DBError passed as exc. +        """ +        pass + +    def request_bad_request(self, environ): +        """Hook called when processing a bad request. + +        No actual processing was done. +        """ +        pass + +    def request_failed(self, environ): +        """Hook called when processing a request failed unexpectedly. + +        Invoked from an except block, so there's interpreter exception +        information available. +        """ +        pass diff --git a/src/leap/soledad/common/l2db/remote/http_client.py b/src/leap/soledad/common/l2db/remote/http_client.py new file mode 100644 index 00000000..1124b038 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_client.py @@ -0,0 +1,178 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""Base class to make requests to a remote HTTP server.""" + +import json +import socket +import ssl +import sys +import urllib +import six.moves.urllib.parse as urlparse +import six.moves.http_client as httplib +from time import sleep +from leap.soledad.common.l2db import errors +from leap.soledad.common.l2db.remote import http_errors + +from leap.soledad.common.l2db.remote.ssl_match_hostname import match_hostname + +# Ubuntu/debian +# XXX other... +CA_CERTS = "/etc/ssl/certs/ca-certificates.crt" + + +def _encode_query_parameter(value): +    """Encode query parameter.""" +    if isinstance(value, bool): +        if value: +            value = 'true' +        else: +            value = 'false' +    return unicode(value).encode('utf-8') + + +class _VerifiedHTTPSConnection(httplib.HTTPSConnection): +    """HTTPSConnection verifying server side certificates.""" +    # derived from httplib.py + +    def connect(self): +        "Connect to a host on a given (SSL) port." + +        sock = socket.create_connection((self.host, self.port), +                                        self.timeout, self.source_address) +        if self._tunnel_host: +            self.sock = sock +            self._tunnel() +        if sys.platform.startswith('linux'): +            cert_opts = { +                'cert_reqs': ssl.CERT_REQUIRED, +                'ca_certs': CA_CERTS +            } +        else: +            # XXX no cert verification implemented elsewhere for now +            cert_opts = {} +        self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, +                                    ssl_version=ssl.PROTOCOL_SSLv3, +                                    **cert_opts +                                    ) +        if cert_opts: +            match_hostname(self.sock.getpeercert(), self.host) + + +class HTTPClientBase(object): +    """Base class to make requests to a remote HTTP server.""" + +    # Will use these delays to retry on 503 befor finally giving up. The final +    # 0 is there to not wait after the final try fails. +    _delays = (1, 1, 2, 4, 0) + +    def __init__(self, url, creds=None): +        self._url = urlparse.urlsplit(url) +        self._conn = None +        self._creds = {} +        if creds is not None: +            if len(creds) != 1: +                raise errors.UnknownAuthMethod() +            auth_meth, credentials = creds.items()[0] +            try: +                set_creds = getattr(self, 'set_%s_credentials' % auth_meth) +            except AttributeError: +                raise errors.UnknownAuthMethod(auth_meth) +            set_creds(**credentials) + +    def _ensure_connection(self): +        if self._conn is not None: +            return +        if self._url.scheme == 'https': +            connClass = _VerifiedHTTPSConnection +        else: +            connClass = httplib.HTTPConnection +        self._conn = connClass(self._url.hostname, self._url.port) + +    def close(self): +        if self._conn: +            self._conn.close() +            self._conn = None + +    # xxx retry mechanism? + +    def _error(self, respdic): +        descr = respdic.get("error") +        exc_cls = errors.wire_description_to_exc.get(descr) +        if exc_cls is not None: +            message = respdic.get("message") +            raise exc_cls(message) + +    def _response(self): +        resp = self._conn.getresponse() +        body = resp.read() +        headers = dict(resp.getheaders()) +        if resp.status in (200, 201): +            return body, headers +        elif resp.status in http_errors.ERROR_STATUSES: +            try: +                respdic = json.loads(body) +            except ValueError: +                pass +            else: +                self._error(respdic) +        # special case +        if resp.status == 503: +            raise errors.Unavailable(body, headers) +        raise errors.HTTPError(resp.status, body, headers) + +    def _sign_request(self, method, url_query, params): +        raise NotImplementedError + +    def _request(self, method, url_parts, params=None, body=None, +                 content_type=None): +        self._ensure_connection() +        unquoted_url = url_query = self._url.path +        if url_parts: +            if not url_query.endswith('/'): +                url_query += '/' +                unquoted_url = url_query +            url_query += '/'.join(urllib.quote(part, safe='') +                                  for part in url_parts) +            # oauth performs its own quoting +            unquoted_url += '/'.join(url_parts) +        encoded_params = {} +        if params: +            for key, value in params.items(): +                key = unicode(key).encode('utf-8') +                encoded_params[key] = _encode_query_parameter(value) +            url_query += ('?' + urllib.urlencode(encoded_params)) +        if body is not None and not isinstance(body, basestring): +            body = json.dumps(body) +            content_type = 'application/json' +        headers = {} +        if content_type: +            headers['content-type'] = content_type +        headers.update( +            self._sign_request(method, unquoted_url, encoded_params)) +        for delay in self._delays: +            try: +                self._conn.request(method, url_query, body, headers) +                return self._response() +            except errors.Unavailable as e: +                sleep(delay) +        raise e + +    def _request_json(self, method, url_parts, params=None, body=None, +                      content_type=None): +        res, headers = self._request(method, url_parts, params, body, +                                     content_type) +        return json.loads(res), headers diff --git a/src/leap/soledad/common/l2db/remote/http_database.py b/src/leap/soledad/common/l2db/remote/http_database.py new file mode 100644 index 00000000..7e61e5a4 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_database.py @@ -0,0 +1,158 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""HTTPDatabase to access a remote db over the HTTP API.""" + +import json +import uuid + +from leap.soledad.common.l2db import ( +    Database, +    Document, +    errors) +from leap.soledad.common.l2db.remote import ( +    http_client, +    http_errors, +    http_target) + + +DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[ +    errors.DOCUMENT_DELETED] + + +class HTTPDatabase(http_client.HTTPClientBase, Database): +    """Implement the Database API to a remote HTTP server.""" + +    def __init__(self, url, document_factory=None, creds=None): +        super(HTTPDatabase, self).__init__(url, creds=creds) +        self._factory = document_factory or Document + +    def set_document_factory(self, factory): +        self._factory = factory + +    @staticmethod +    def open_database(url, create): +        db = HTTPDatabase(url) +        db.open(create) +        return db + +    @staticmethod +    def delete_database(url): +        db = HTTPDatabase(url) +        db._delete() +        db.close() + +    def open(self, create): +        if create: +            self._ensure() +        else: +            self._check() + +    def _check(self): +        return self._request_json('GET', [])[0] + +    def _ensure(self): +        self._request_json('PUT', [], {}, {}) + +    def _delete(self): +        self._request_json('DELETE', [], {}, {}) + +    def put_doc(self, doc): +        if doc.doc_id is None: +            raise errors.InvalidDocId() +        params = {} +        if doc.rev is not None: +            params['old_rev'] = doc.rev +        res, headers = self._request_json('PUT', ['doc', doc.doc_id], params, +                                          doc.get_json(), 'application/json') +        doc.rev = res['rev'] +        return res['rev'] + +    def get_doc(self, doc_id, include_deleted=False): +        try: +            res, headers = self._request( +                'GET', ['doc', doc_id], {"include_deleted": include_deleted}) +        except errors.DocumentDoesNotExist: +            return None +        except errors.HTTPError as e: +            if (e.status == DOCUMENT_DELETED_STATUS and +                    'x-u1db-rev' in e.headers): +                        res = None +                        headers = e.headers +            else: +                raise +        doc_rev = headers['x-u1db-rev'] +        has_conflicts = json.loads(headers['x-u1db-has-conflicts']) +        doc = self._factory(doc_id, doc_rev, res) +        doc.has_conflicts = has_conflicts +        return doc + +    def _build_docs(self, res): +        for doc_dict in json.loads(res): +            doc = self._factory( +                doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content']) +            doc.has_conflicts = doc_dict['has_conflicts'] +            yield doc + +    def get_docs(self, doc_ids, check_for_conflicts=True, +                 include_deleted=False): +        if not doc_ids: +            return [] +        doc_ids = ','.join(doc_ids) +        res, headers = self._request( +            'GET', ['docs'], { +                "doc_ids": doc_ids, "include_deleted": include_deleted, +                "check_for_conflicts": check_for_conflicts}) +        return self._build_docs(res) + +    def get_all_docs(self, include_deleted=False): +        res, headers = self._request( +            'GET', ['all-docs'], {"include_deleted": include_deleted}) +        gen = -1 +        if 'x-u1db-generation' in headers: +            gen = int(headers['x-u1db-generation']) +        return gen, list(self._build_docs(res)) + +    def _allocate_doc_id(self): +        return 'D-%s' % (uuid.uuid4().hex,) + +    def create_doc(self, content, doc_id=None): +        if not isinstance(content, dict): +            raise errors.InvalidContent +        json_string = json.dumps(content) +        return self.create_doc_from_json(json_string, doc_id) + +    def create_doc_from_json(self, content, doc_id=None): +        if doc_id is None: +            doc_id = self._allocate_doc_id() +        res, headers = self._request_json('PUT', ['doc', doc_id], {}, +                                          content, 'application/json') +        new_doc = self._factory(doc_id, res['rev'], content) +        return new_doc + +    def delete_doc(self, doc): +        if doc.doc_id is None: +            raise errors.InvalidDocId() +        params = {'old_rev': doc.rev} +        res, headers = self._request_json( +            'DELETE', ['doc', doc.doc_id], params) +        doc.make_tombstone() +        doc.rev = res['rev'] + +    def get_sync_target(self): +        st = http_target.HTTPSyncTarget(self._url.geturl()) +        st._creds = self._creds +        return st diff --git a/src/leap/soledad/common/l2db/remote/http_errors.py b/src/leap/soledad/common/l2db/remote/http_errors.py new file mode 100644 index 00000000..ee4cfefa --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_errors.py @@ -0,0 +1,48 @@ +# Copyright 2011-2012 Canonical Ltd. +# Copyright 2016 LEAP Encryption Access Project +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +""" +Information about the encoding of errors over HTTP. +""" + +from leap.soledad.common.l2db import errors + + +# error wire descriptions mapping to HTTP status codes +wire_description_to_status = dict([ +    (errors.InvalidDocId.wire_description, 400), +    (errors.MissingDocIds.wire_description, 400), +    (errors.Unauthorized.wire_description, 401), +    (errors.DocumentTooBig.wire_description, 403), +    (errors.UserQuotaExceeded.wire_description, 403), +    (errors.SubscriptionNeeded.wire_description, 403), +    (errors.DatabaseDoesNotExist.wire_description, 404), +    (errors.DocumentDoesNotExist.wire_description, 404), +    (errors.DocumentAlreadyDeleted.wire_description, 404), +    (errors.RevisionConflict.wire_description, 409), +    (errors.InvalidGeneration.wire_description, 409), +    (errors.InvalidReplicaUID.wire_description, 409), +    (errors.InvalidTransactionId.wire_description, 409), +    (errors.Unavailable.wire_description, 503), +    # without matching exception +    (errors.DOCUMENT_DELETED, 404) +]) + + +ERROR_STATUSES = set(wire_description_to_status.values()) +# 400 included explicitly for tests +ERROR_STATUSES.add(400) diff --git a/src/leap/soledad/common/l2db/remote/http_target.py b/src/leap/soledad/common/l2db/remote/http_target.py new file mode 100644 index 00000000..38804f01 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/http_target.py @@ -0,0 +1,125 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""SyncTarget API implementation to a remote HTTP server.""" + +import json + +from leap.soledad.common.l2db import Document, SyncTarget +from leap.soledad.common.l2db.errors import BrokenSyncStream +from leap.soledad.common.l2db.remote import ( +    http_client, utils) + + +class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget): +    """Implement the SyncTarget api to a remote HTTP server.""" + +    @staticmethod +    def connect(url): +        return HTTPSyncTarget(url) + +    def get_sync_info(self, source_replica_uid): +        self._ensure_connection() +        res, _ = self._request_json('GET', ['sync-from', source_replica_uid]) +        return (res['target_replica_uid'], res['target_replica_generation'], +                res['target_replica_transaction_id'], +                res['source_replica_generation'], res['source_transaction_id']) + +    def record_sync_info(self, source_replica_uid, source_replica_generation, +                         source_transaction_id): +        self._ensure_connection() +        if self._trace_hook:  # for tests +            self._trace_hook('record_sync_info') +        self._request_json('PUT', ['sync-from', source_replica_uid], {}, +                           {'generation': source_replica_generation, +                               'transaction_id': source_transaction_id}) + +    def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): +        parts = data.splitlines()  # one at a time +        if not parts or parts[0] != '[': +            raise BrokenSyncStream +        data = parts[1:-1] +        comma = False +        if data: +            line, comma = utils.check_and_strip_comma(data[0]) +            res = json.loads(line) +            if ensure_callback and 'replica_uid' in res: +                ensure_callback(res['replica_uid']) +            for entry in data[1:]: +                if not comma:  # missing in between comma +                    raise BrokenSyncStream +                line, comma = utils.check_and_strip_comma(entry) +                entry = json.loads(line) +                doc = Document(entry['id'], entry['rev'], entry['content']) +                return_doc_cb(doc, entry['gen'], entry['trans_id']) +        if parts[-1] != ']': +            try: +                partdic = json.loads(parts[-1]) +            except ValueError: +                pass +            else: +                if isinstance(partdic, dict): +                    self._error(partdic) +            raise BrokenSyncStream +        if not data or comma:  # no entries or bad extra comma +            raise BrokenSyncStream +        return res + +    def sync_exchange(self, docs_by_generations, source_replica_uid, +                      last_known_generation, last_known_trans_id, +                      return_doc_cb, ensure_callback=None): +        self._ensure_connection() +        if self._trace_hook:  # for tests +            self._trace_hook('sync_exchange') +        url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) +        self._conn.putrequest('POST', url) +        self._conn.putheader('content-type', 'application/x-u1db-sync-stream') +        for header_name, header_value in self._sign_request('POST', url, {}): +            self._conn.putheader(header_name, header_value) +        entries = ['['] +        size = 1 + +        def prepare(**dic): +            entry = comma + '\r\n' + json.dumps(dic) +            entries.append(entry) +            return len(entry) + +        comma = '' +        size += prepare( +            last_known_generation=last_known_generation, +            last_known_trans_id=last_known_trans_id, +            ensure=ensure_callback is not None) +        comma = ',' +        for doc, gen, trans_id in docs_by_generations: +            size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), +                            gen=gen, trans_id=trans_id) +        entries.append('\r\n]') +        size += len(entries[-1]) +        self._conn.putheader('content-length', str(size)) +        self._conn.endheaders() +        for entry in entries: +            self._conn.send(entry) +        entries = None +        data, _ = self._response() +        res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) +        data = None +        return res['new_generation'], res['new_transaction_id'] + +    # for tests +    _trace_hook = None + +    def _set_trace_hook_shallow(self, cb): +        self._trace_hook = cb diff --git a/src/leap/soledad/common/l2db/remote/server_state.py b/src/leap/soledad/common/l2db/remote/server_state.py new file mode 100644 index 00000000..d4c3c45f --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/server_state.py @@ -0,0 +1,68 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""State for servers exposing a set of U1DB databases.""" + + +class ServerState(object): +    """Passed to a Request when it is instantiated. + +    This is used to track server-side state, such as working-directory, open +    databases, etc. +    """ + +    def __init__(self): +        self._workingdir = None + +    def set_workingdir(self, path): +        self._workingdir = path + +    def global_info(self): +        """Return global information about the server.""" +        return {} + +    def _relpath(self, relpath): +        # Note: We don't want to allow absolute paths here, because we +        #       don't want to expose the filesystem. We should also check that +        #       relpath doesn't have '..' in it, etc. +        return self._workingdir + '/' + relpath + +    def open_database(self, path): +        """Open a database at the given location.""" +        from leap.soledad.client._db import sqlite +        full_path = self._relpath(path) +        return sqlite.SQLiteDatabase.open_database(full_path, create=False) + +    def check_database(self, path): +        """Check if the database at the given location exists. + +        Simply returns if it does or raises DatabaseDoesNotExist. +        """ +        db = self.open_database(path) +        db.close() + +    def ensure_database(self, path): +        """Ensure database at the given location.""" +        from leap.soledad.client._db import sqlite +        full_path = self._relpath(path) +        db = sqlite.SQLiteDatabase.open_database(full_path, create=True) +        return db, db._replica_uid + +    def delete_database(self, path): +        """Delete database at the given location.""" +        from leap.soledad.client._db import sqlite +        full_path = self._relpath(path) +        sqlite.SQLiteDatabase.delete_database(full_path) diff --git a/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py b/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py new file mode 100644 index 00000000..ce82f1b2 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/ssl_match_hostname.py @@ -0,0 +1,65 @@ +"""The match_hostname() function from Python 3.2, essential when using SSL.""" +# XXX put it here until it's packaged + +import re + +__version__ = '3.2a3' + + +class CertificateError(ValueError): +    pass + + +def _dnsname_to_pat(dn): +    pats = [] +    for frag in dn.split(r'.'): +        if frag == '*': +            # When '*' is a fragment by itself, it matches a non-empty dotless +            # fragment. +            pats.append('[^.]+') +        else: +            # Otherwise, '*' matches any dotless fragment. +            frag = re.escape(frag) +            pats.append(frag.replace(r'\*', '[^.]*')) +    return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): +    """Verify that *cert* (in decoded format as returned by +    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules +    are mostly followed, but IP addresses are not accepted for *hostname*. + +    CertificateError is raised on failure. On success, the function +    returns nothing. +    """ +    if not cert: +        raise ValueError("empty or no certificate") +    dnsnames = [] +    san = cert.get('subjectAltName', ()) +    for key, value in san: +        if key == 'DNS': +            if _dnsname_to_pat(value).match(hostname): +                return +            dnsnames.append(value) +    if not san: +        # The subject is only checked when subjectAltName is empty +        for sub in cert.get('subject', ()): +            for key, value in sub: +                # XXX according to RFC 2818, the most specific Common Name +                # must be used. +                if key == 'commonName': +                    if _dnsname_to_pat(value).match(hostname): +                        return +                    dnsnames.append(value) +    if len(dnsnames) > 1: +        raise CertificateError( +            "hostname %r doesn't match either of %s" +            % (hostname, ', '.join(map(repr, dnsnames)))) +    elif len(dnsnames) == 1: +        raise CertificateError( +            "hostname %r doesn't match %r" +            % (hostname, dnsnames[0])) +    else: +        raise CertificateError( +            "no appropriate commonName or " +            "subjectAltName fields were found") diff --git a/src/leap/soledad/common/l2db/remote/utils.py b/src/leap/soledad/common/l2db/remote/utils.py new file mode 100644 index 00000000..14cedea9 --- /dev/null +++ b/src/leap/soledad/common/l2db/remote/utils.py @@ -0,0 +1,23 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""Utilities for details of the procotol.""" + + +def check_and_strip_comma(line): +    if line and line[-1] == ',': +        return line[:-1], True +    return line, False diff --git a/src/leap/soledad/common/l2db/sync.py b/src/leap/soledad/common/l2db/sync.py new file mode 100644 index 00000000..32281f30 --- /dev/null +++ b/src/leap/soledad/common/l2db/sync.py @@ -0,0 +1,311 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""The synchronization utilities for U1DB.""" +from six.moves import zip as izip + +from leap.soledad.common import l2db +from leap.soledad.common.l2db import errors + + +class Synchronizer(object): +    """Collect the state around synchronizing 2 U1DB replicas. + +    Synchronization is bi-directional, in that new items in the source are sent +    to the target, and new items in the target are returned to the source. +    However, it still recognizes that one side is initiating the request. Also, +    at the moment, conflicts are only created in the source. +    """ + +    def __init__(self, source, sync_target): +        """Create a new Synchronization object. + +        :param source: A Database +        :param sync_target: A SyncTarget +        """ +        self.source = source +        self.sync_target = sync_target +        self.target_replica_uid = None +        self.num_inserted = 0 + +    def _insert_doc_from_target(self, doc, replica_gen, trans_id): +        """Try to insert synced document from target. + +        Implements TAKE OTHER semantics: any document from the target +        that is in conflict will be taken as the new official value, +        while the current conflicting value will be stored alongside +        as a conflict. In the process indexes will be updated etc. + +        :return: None +        """ +        # Increases self.num_inserted depending whether the document +        # was effectively inserted. +        state, _ = self.source._put_doc_if_newer( +            doc, save_conflict=True, +            replica_uid=self.target_replica_uid, replica_gen=replica_gen, +            replica_trans_id=trans_id) +        if state == 'inserted': +            self.num_inserted += 1 +        elif state == 'converged': +            # magical convergence +            pass +        elif state == 'superseded': +            # we have something newer, will be taken care of at the next sync +            pass +        else: +            assert state == 'conflicted' +            # The doc was saved as a conflict, so the database was updated +            self.num_inserted += 1 + +    def _record_sync_info_with_the_target(self, start_generation): +        """Record our new after sync generation with the target if gapless. + +        Any documents received from the target will cause the local +        database to increment its generation. We do not want to send +        them back to the target in a future sync. However, there could +        also be concurrent updates from another process doing eg +        'put_doc' while the sync was running. And we do want to +        synchronize those documents.  We can tell if there was a +        concurrent update by comparing our new generation number +        versus the generation we started, and how many documents we +        inserted from the target. If it matches exactly, then we can +        record with the target that they are fully up to date with our +        new generation. +        """ +        cur_gen, trans_id = self.source._get_generation_info() +        last_gen = start_generation + self.num_inserted +        if (cur_gen == last_gen and self.num_inserted > 0): +                self.sync_target.record_sync_info( +                    self.source._replica_uid, cur_gen, trans_id) + +    def sync(self, callback=None, autocreate=False): +        """Synchronize documents between source and target.""" +        sync_target = self.sync_target +        # get target identifier, its current generation, +        # and its last-seen database generation for this source +        try: +            (self.target_replica_uid, target_gen, target_trans_id, +             target_my_gen, target_my_trans_id) = sync_target.get_sync_info( +                self.source._replica_uid) +        except errors.DatabaseDoesNotExist: +            if not autocreate: +                raise +            # will try to ask sync_exchange() to create the db +            self.target_replica_uid = None +            target_gen, target_trans_id = 0, '' +            target_my_gen, target_my_trans_id = 0, '' + +            def ensure_callback(replica_uid): +                self.target_replica_uid = replica_uid + +        else: +            ensure_callback = None +        if self.target_replica_uid == self.source._replica_uid: +            raise errors.InvalidReplicaUID +        # validate the generation and transaction id the target knows about us +        self.source.validate_gen_and_trans_id( +            target_my_gen, target_my_trans_id) +        # what's changed since that generation and this current gen +        my_gen, _, changes = self.source.whats_changed(target_my_gen) + +        # this source last-seen database generation for the target +        if self.target_replica_uid is None: +            target_last_known_gen, target_last_known_trans_id = 0, '' +        else: +            target_last_known_gen, target_last_known_trans_id = ( +                self.source._get_replica_gen_and_trans_id(  # nopep8 +                    self.target_replica_uid)) +        if not changes and target_last_known_gen == target_gen: +            if target_trans_id != target_last_known_trans_id: +                raise errors.InvalidTransactionId +            return my_gen +        changed_doc_ids = [doc_id for doc_id, _, _ in changes] +        # prepare to send all the changed docs +        docs_to_send = self.source.get_docs( +            changed_doc_ids, +            check_for_conflicts=False, include_deleted=True) +        # TODO: there must be a way to not iterate twice +        docs_by_generation = zip( +            docs_to_send, (gen for _, gen, _ in changes), +            (trans for _, _, trans in changes)) + +        # exchange documents and try to insert the returned ones with +        # the target, return target synced-up-to gen +        new_gen, new_trans_id = sync_target.sync_exchange( +            docs_by_generation, self.source._replica_uid, +            target_last_known_gen, target_last_known_trans_id, +            self._insert_doc_from_target, ensure_callback=ensure_callback) +        # record target synced-up-to generation including applying what we sent +        self.source._set_replica_gen_and_trans_id( +            self.target_replica_uid, new_gen, new_trans_id) + +        # if gapless record current reached generation with target +        self._record_sync_info_with_the_target(my_gen) + +        return my_gen + + +class SyncExchange(object): +    """Steps and state for carrying through a sync exchange on a target.""" + +    def __init__(self, db, source_replica_uid, last_known_generation): +        self._db = db +        self.source_replica_uid = source_replica_uid +        self.source_last_known_generation = last_known_generation +        self.seen_ids = {}  # incoming ids not superseded +        self.changes_to_return = None +        self.new_gen = None +        self.new_trans_id = None +        # for tests +        self._incoming_trace = [] +        self._trace_hook = None +        self._db._last_exchange_log = { +            'receive': {'docs': self._incoming_trace}, +            'return': None +        } + +    def _set_trace_hook(self, cb): +        self._trace_hook = cb + +    def _trace(self, state): +        if not self._trace_hook: +            return +        self._trace_hook(state) + +    def insert_doc_from_source(self, doc, source_gen, trans_id): +        """Try to insert synced document from source. + +        Conflicting documents are not inserted but will be sent over +        to the sync source. + +        It keeps track of progress by storing the document source +        generation as well. + +        The 1st step of a sync exchange is to call this repeatedly to +        try insert all incoming documents from the source. + +        :param doc: A Document object. +        :param source_gen: The source generation of doc. +        :return: None +        """ +        state, at_gen = self._db._put_doc_if_newer( +            doc, save_conflict=False, +            replica_uid=self.source_replica_uid, replica_gen=source_gen, +            replica_trans_id=trans_id) +        if state == 'inserted': +            self.seen_ids[doc.doc_id] = at_gen +        elif state == 'converged': +            # magical convergence +            self.seen_ids[doc.doc_id] = at_gen +        elif state == 'superseded': +            # we have something newer that we will return +            pass +        else: +            # conflict that we will returne +            assert state == 'conflicted' +        # for tests +        self._incoming_trace.append((doc.doc_id, doc.rev)) +        self._db._last_exchange_log['receive'].update({ +            'source_uid': self.source_replica_uid, +            'source_gen': source_gen +        }) + +    def find_changes_to_return(self): +        """Find changes to return. + +        Find changes since last_known_generation in db generation +        order using whats_changed. It excludes documents ids that have +        already been considered (superseded by the sender, etc). + +        :return: new_generation - the generation of this database +            which the caller can consider themselves to be synchronized after +            processing the returned documents. +        """ +        self._db._last_exchange_log['receive'].update({  # for tests +            'last_known_gen': self.source_last_known_generation +        }) +        self._trace('before whats_changed') +        gen, trans_id, changes = self._db.whats_changed( +            self.source_last_known_generation) +        self._trace('after whats_changed') +        self.new_gen = gen +        self.new_trans_id = trans_id +        seen_ids = self.seen_ids +        # changed docs that weren't superseded by or converged with +        self.changes_to_return = [ +            (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes if +            # there was a subsequent update +            doc_id not in seen_ids or seen_ids.get(doc_id) < gen] +        return self.new_gen + +    def return_docs(self, return_doc_cb): +        """Return the changed documents and their last change generation +        repeatedly invoking the callback return_doc_cb. + +        The final step of a sync exchange. + +        :param: return_doc_cb(doc, gen, trans_id): is a callback +                used to return the documents with their last change generation +                to the target replica. +        :return: None +        """ +        changes_to_return = self.changes_to_return +        # return docs, including conflicts +        changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] +        self._trace('before get_docs') +        docs = self._db.get_docs( +            changed_doc_ids, check_for_conflicts=False, include_deleted=True) + +        docs_by_gen = izip( +            docs, (gen for _, gen, _ in changes_to_return), +            (trans_id for _, _, trans_id in changes_to_return)) +        _outgoing_trace = []  # for tests +        for doc, gen, trans_id in docs_by_gen: +            return_doc_cb(doc, gen, trans_id) +            _outgoing_trace.append((doc.doc_id, doc.rev)) +        # for tests +        self._db._last_exchange_log['return'] = { +            'docs': _outgoing_trace, +            'last_gen': self.new_gen} + + +class LocalSyncTarget(l2db.SyncTarget): +    """Common sync target implementation logic for all local sync targets.""" + +    def __init__(self, db): +        self._db = db +        self._trace_hook = None + +    def sync_exchange(self, docs_by_generations, source_replica_uid, +                      last_known_generation, last_known_trans_id, +                      return_doc_cb, ensure_callback=None): +        self._db.validate_gen_and_trans_id( +            last_known_generation, last_known_trans_id) +        sync_exch = SyncExchange( +            self._db, source_replica_uid, last_known_generation) +        if self._trace_hook: +            sync_exch._set_trace_hook(self._trace_hook) +        # 1st step: try to insert incoming docs and record progress +        for doc, doc_gen, trans_id in docs_by_generations: +            sync_exch.insert_doc_from_source(doc, doc_gen, trans_id) +        # 2nd step: find changed documents (including conflicts) to return +        new_gen = sync_exch.find_changes_to_return() +        # final step: return docs and record source replica sync point +        sync_exch.return_docs(return_doc_cb) +        return new_gen, sync_exch.new_trans_id + +    def _set_trace_hook(self, cb): +        self._trace_hook = cb diff --git a/src/leap/soledad/common/l2db/vectorclock.py b/src/leap/soledad/common/l2db/vectorclock.py new file mode 100644 index 00000000..42bceaa8 --- /dev/null +++ b/src/leap/soledad/common/l2db/vectorclock.py @@ -0,0 +1,89 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db.  If not, see <http://www.gnu.org/licenses/>. + +"""VectorClockRev helper class.""" + + +class VectorClockRev(object): +    """Track vector clocks for multiple replica ids. + +    This allows simple comparison to determine if one VectorClockRev is +    newer/older/in-conflict-with another VectorClockRev without having to +    examine history. Every replica has a strictly increasing revision. When +    creating a new revision, they include all revisions for all other replicas +    which the new revision dominates, and increment their own revision to +    something greater than the current value. +    """ + +    def __init__(self, value): +        self._values = self._expand(value) + +    def __repr__(self): +        s = self.as_str() +        return '%s(%s)' % (self.__class__.__name__, s) + +    def as_str(self): +        s = '|'.join(['%s:%d' % (m, r) for m, r +                      in sorted(self._values.items())]) +        return s + +    def _expand(self, value): +        result = {} +        if value is None: +            return result +        for replica_info in value.split('|'): +            replica_uid, counter = replica_info.split(':') +            counter = int(counter) +            result[replica_uid] = counter +        return result + +    def is_newer(self, other): +        """Is this VectorClockRev strictly newer than other. +        """ +        if not self._values: +            return False +        if not other._values: +            return True +        this_is_newer = False +        other_expand = dict(other._values) +        for key, value in self._values.iteritems(): +            if key in other_expand: +                other_value = other_expand.pop(key) +                if other_value > value: +                    return False +                elif other_value < value: +                    this_is_newer = True +            else: +                this_is_newer = True +        if other_expand: +            return False +        return this_is_newer + +    def increment(self, replica_uid): +        """Increase the 'replica_uid' section of this vector clock. + +        :return: A string representing the new vector clock value +        """ +        self._values[replica_uid] = self._values.get(replica_uid, 0) + 1 + +    def maximize(self, other_vcr): +        for replica_uid, counter in other_vcr._values.iteritems(): +            if replica_uid not in self._values: +                self._values[replica_uid] = counter +            else: +                this_counter = self._values[replica_uid] +                if this_counter < counter: +                    self._values[replica_uid] = counter diff --git a/src/leap/soledad/common/log.py b/src/leap/soledad/common/log.py new file mode 100644 index 00000000..59a47726 --- /dev/null +++ b/src/leap/soledad/common/log.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# log.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +This module centralizes logging facilities and allows for different behaviours, +as using the python logging module instead of twisted logger, and to print logs +to stdout, mainly for development purposes. +""" + + +import os +import sys +import time + +from twisted.logger import Logger +from twisted.logger import textFileLogObserver +from twisted.logger import LogLevel +from twisted.logger import InvalidLogLevelError +from twisted.python.failure import Failure + + +# What follows is a patched class to correctly log namespace and level when +# using the default formatter and --syslog option in twistd. This seems to be a +# known bug but it has not been reported to upstream yet. + +class SyslogLogger(Logger): + +    def emit(self, level, format=None, **kwargs): +        if level not in LogLevel.iterconstants(): +            self.failure( +                "Got invalid log level {invalidLevel!r} in {logger}.emit().", +                Failure(InvalidLogLevelError(level)), +                invalidLevel=level, +                logger=self, +            ) +            return + +        event = kwargs +        event.update( +            log_logger=self, log_level=level, log_namespace=self.namespace, +            log_source=self.source, log_format=format, log_time=time.time(), +        ) + +        # ---------------------------------8<--------------------------------- +        # this is a workaround for the mess between twisted's legacy log system +        # and twistd's --syslog option. +        event["system"] = "%s#%s" % (self.namespace, level.name) +        # ---------------------------------8<--------------------------------- + +        if "log_trace" in event: +            event["log_trace"].append((self, self.observer)) + +        self.observer(event) + + +def getLogger(*args, **kwargs): + +    if os.environ.get('SOLEDAD_USE_PYTHON_LOGGING'): +        import logging +        return logging.getLogger(__name__) + +    if os.environ.get('SOLEDAD_LOG_TO_STDOUT'): +        kwargs({'observer': textFileLogObserver(sys.stdout)}) + +    return SyslogLogger(*args, **kwargs) + + +__all__ = ['getLogger'] diff --git a/src/leap/soledad/common/tests/__init__.py b/src/leap/soledad/common/tests/__init__.py new file mode 100644 index 00000000..acebb77b --- /dev/null +++ b/src/leap/soledad/common/tests/__init__.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2013, 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +Tests to make sure Soledad provides U1DB functionality and more. +""" + + +import os + + +def load_tests(): +    """ +    Build a test suite that includes all tests in leap.soledad.common.tests +    but does not include tests in the u1db_tests/ subfolder. The reason for +    not including those tests are: + +        1. they by themselves only test u1db functionality in the u1db module +           (despite we use them as basis for testing soledad functionalities). + +        2. they would fail because we monkey patch u1db's remote http server +           to add soledad functionality we need. +    """ +    import unittest +    import glob +    import imp +    tests_prefix = os.path.join( +        '.', 'src', 'leap', 'soledad', 'common', 'tests') +    suite = unittest.TestSuite() +    for testcase in glob.glob(os.path.join(tests_prefix, 'test_*.py')): +        modname = os.path.basename(os.path.splitext(testcase)[0]) +        f, pathname, description = imp.find_module(modname, [tests_prefix]) +        module = imp.load_module(modname, f, pathname, description) +        suite.addTest(unittest.TestLoader().loadTestsFromModule(module)) +    return suite diff --git a/src/leap/soledad/common/tests/test_command.py b/src/leap/soledad/common/tests/test_command.py new file mode 100644 index 00000000..2136bb8f --- /dev/null +++ b/src/leap/soledad/common/tests/test_command.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# test_command.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for command execution using a validator function for arguments. +""" +from twisted.trial import unittest +from leap.soledad.common.command import exec_validated_cmd + + +def validator(arg): +    return True if arg is 'valid' else False + + +class ExecuteValidatedCommandTest(unittest.TestCase): + +    def test_argument_validation(self): +        status, out = exec_validated_cmd("command", "invalid arg", validator) +        self.assertEquals(status, 1) +        self.assertEquals(out, "invalid argument") +        status, out = exec_validated_cmd("echo", "valid", validator) +        self.assertEquals(status, 0) +        self.assertEquals(out, "valid\n") + +    def test_return_status_code_success(self): +        status, out = exec_validated_cmd("echo", "arg") +        self.assertEquals(status, 0) +        self.assertEquals(out, "arg\n") + +    def test_handle_command_with_spaces(self): +        status, out = exec_validated_cmd("echo I am", "an argument") +        self.assertEquals(status, 0, out) +        self.assertEquals(out, "I am an argument\n") + +    def test_handle_oserror_on_invalid_command(self): +        status, out = exec_validated_cmd("inexistent command with", "args") +        self.assertEquals(status, 1) +        self.assertIn("No such file or directory", out) + +    def test_return_status_code_number_on_failure(self): +        status, out = exec_validated_cmd("ls", "user-bebacafe") +        self.assertNotEquals(status, 0) +        self.assertIn('No such file or directory\n', out) diff --git a/src/leap/soledad/server/__init__.py b/src/leap/soledad/server/__init__.py new file mode 100644 index 00000000..a4080f13 --- /dev/null +++ b/src/leap/soledad/server/__init__.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +# server.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +""" +The Soledad Server allows for recovery document storage and database +synchronization. +""" + +import six.moves.urllib.parse as urlparse +import sys + +from leap.soledad.common.l2db.remote import http_app, utils +from leap.soledad.common import SHARED_DB_NAME + +from .sync import SyncResource +from .sync import MAX_REQUEST_SIZE +from .sync import MAX_ENTRY_SIZE + +from ._version import get_versions +from ._config import get_config + + +__all__ = [ +    'SoledadApp', +    'get_config', +    '__version__', +] + + +# ---------------------------------------------------------------------------- +# Soledad WSGI application +# ---------------------------------------------------------------------------- + + +class SoledadApp(http_app.HTTPApp): +    """ +    Soledad WSGI application +    """ + +    SHARED_DB_NAME = SHARED_DB_NAME +    """ +    The name of the shared database that holds user's encrypted secrets. +    """ + +    max_request_size = MAX_REQUEST_SIZE * 1024 * 1024 +    max_entry_size = MAX_ENTRY_SIZE * 1024 * 1024 + +    def __call__(self, environ, start_response): +        """ +        Handle a WSGI call to the Soledad application. + +        @param environ: Dictionary containing CGI variables. +        @type environ: dict +        @param start_response: Callable of the form start_response(status, +            response_headers, exc_info=None). +        @type start_response: callable + +        @return: HTTP application results. +        @rtype: list +        """ +        return http_app.HTTPApp.__call__(self, environ, start_response) + + +# ---------------------------------------------------------------------------- +# WSGI resources registration +# ---------------------------------------------------------------------------- + +# monkey patch u1db with a new resource map +http_app.url_to_resource = http_app.URLToResource() + +# register u1db unmodified resources +http_app.url_to_resource.register(http_app.GlobalResource) +http_app.url_to_resource.register(http_app.DatabaseResource) +http_app.url_to_resource.register(http_app.DocsResource) +http_app.url_to_resource.register(http_app.DocResource) + +# register Soledad's new or modified resources +http_app.url_to_resource.register(SyncResource) + + +# ---------------------------------------------------------------------------- +# Modified HTTP method invocation (to account for splitted sync) +# ---------------------------------------------------------------------------- + +class HTTPInvocationByMethodWithBody( +        http_app.HTTPInvocationByMethodWithBody): +    """ +    Invoke methods on a resource. +    """ + +    def __call__(self): +        """ +        Call an HTTP method of a resource. + +        This method was rewritten to allow for a sync flow which uses one POST +        request for each transferred document (back and forth). + +        Usual U1DB sync process transfers all documents from client to server +        and back in only one POST request. This is inconvenient for some +        reasons, as lack of possibility of gracefully interrupting the sync +        process, and possible timeouts for when dealing with large documents +        that have to be retrieved and encrypted/decrypted. Because of those, +        we split the sync process into many POST requests. +        """ +        args = urlparse.parse_qsl(self.environ['QUERY_STRING'], +                                  strict_parsing=False) +        try: +            args = dict( +                (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) +        except ValueError: +            raise http_app.BadRequest() +        method = self.environ['REQUEST_METHOD'].lower() +        if method in ('get', 'delete'): +            meth = self._lookup(method) +            return meth(args, None) +        else: +            # we expect content-length > 0, reconsider if we move +            # to support chunked enconding +            try: +                content_length = int(self.environ['CONTENT_LENGTH']) +            except (ValueError, KeyError): +                # raise http_app.BadRequest +                content_length = self.max_request_size +            if content_length <= 0: +                raise http_app.BadRequest +            if content_length > self.max_request_size: +                raise http_app.BadRequest +            reader = http_app._FencedReader( +                self.environ['wsgi.input'], content_length, +                self.max_entry_size) +            content_type = self.environ.get('CONTENT_TYPE') +            if content_type == 'application/json': +                meth = self._lookup(method) +                body = reader.read_chunk(sys.maxint) +                return meth(args, body) +            elif content_type.startswith('application/x-soledad-sync'): +                # read one line and validate it +                body_getline = reader.getline +                if body_getline().strip() != '[': +                    raise http_app.BadRequest() +                line = body_getline() +                line, comma = utils.check_and_strip_comma(line.strip()) +                meth_args = self._lookup('%s_args' % method) +                meth_args(args, line) +                # handle incoming documents +                if content_type == 'application/x-soledad-sync-put': +                    meth_put = self._lookup('%s_put' % method) +                    meth_end = self._lookup('%s_end' % method) +                    while True: +                        entry = body_getline().strip() +                        if entry == ']':  # end of incoming document stream +                            break +                        if not entry or not comma:  # empty or no prec comma +                            raise http_app.BadRequest +                        entry, comma = utils.check_and_strip_comma(entry) +                        content = body_getline().strip() +                        content, comma = utils.check_and_strip_comma(content) +                        meth_put({'content': content or None}, entry) +                    if comma or body_getline():  # extra comma or data +                        raise http_app.BadRequest +                    return meth_end() +                # handle outgoing documents +                elif content_type == 'application/x-soledad-sync-get': +                    meth_get = self._lookup('%s_get' % method) +                    return meth_get() +                else: +                    raise http_app.BadRequest() +            else: +                raise http_app.BadRequest() + + +# monkey patch server with new http invocation +http_app.HTTPInvocationByMethodWithBody = HTTPInvocationByMethodWithBody + + +__version__ = get_versions()['version'] +del get_versions diff --git a/src/leap/soledad/server/_blobs.py b/src/leap/soledad/server/_blobs.py new file mode 100644 index 00000000..10678360 --- /dev/null +++ b/src/leap/soledad/server/_blobs.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +# _blobs.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +""" +Blobs Server implementation. + +This is a very simplistic implementation for the time being. +Clients should be able to opt-in util the feature is complete. + +A more performant BlobsBackend can (and should) be implemented for production +environments. +""" +import os +import base64 +import json +import re + +from twisted.logger import Logger +from twisted.web import static +from twisted.web import resource +from twisted.web.client import FileBodyProducer +from twisted.web.server import NOT_DONE_YET +from twisted.internet import utils, defer + +from zope.interface import implementer + +from leap.common.files import mkdir_p +from leap.soledad.server import interfaces + + +__all__ = ['BlobsResource'] + + +logger = Logger() + +# Used for sanitizers, we accept only letters, numbers, '-' and '_' +VALID_STRINGS = re.compile('^[a-zA-Z0-9_-]+$') + + +# for the future: +# [ ] isolate user avatar in a safer way +# [ ] catch timeout in the server (and delete incomplete upload) +# [ ] chunking (should we do it on the client or on the server?) + + +@implementer(interfaces.IBlobsBackend) +class FilesystemBlobsBackend(object): + +    def __init__(self, blobs_path='/tmp/blobs/', quota=200 * 1024): +        self.quota = quota +        if not os.path.isdir(blobs_path): +            os.makedirs(blobs_path) +        self.path = blobs_path + +    def read_blob(self, user, blob_id, request): +        logger.info('reading blob: %s - %s' % (user, blob_id)) +        path = self._get_path(user, blob_id) +        logger.debug('blob path: %s' % path) +        _file = static.File(path, defaultType='application/octet-stream') +        return _file.render_GET(request) + +    @defer.inlineCallbacks +    def write_blob(self, user, blob_id, request): +        path = self._get_path(user, blob_id) +        try: +            mkdir_p(os.path.split(path)[0]) +        except OSError: +            pass +        if os.path.isfile(path): +            # 409 - Conflict +            request.setResponseCode(409) +            request.write("Blob already exists: %s" % blob_id) +            defer.returnValue(None) +        used = yield self.get_total_storage(user) +        if used > self.quota: +            logger.error("Error 507: Quota exceeded for user: %s" % user) +            request.setResponseCode(507) +            request.write('Quota Exceeded!') +            defer.returnValue(None) +        logger.info('writing blob: %s - %s' % (user, blob_id)) +        fbp = FileBodyProducer(request.content) +        yield fbp.startProducing(open(path, 'wb')) + +    def delete_blob(self, user, blob_id): +        blob_path = self._get_path(user, blob_id) +        os.unlink(blob_path) + +    def get_blob_size(user, blob_id): +        raise NotImplementedError + +    def list_blobs(self, user, request): +        blob_ids = [] +        base_path = self._get_path(user) +        for _, _, filenames in os.walk(base_path): +            blob_ids += filenames +        return json.dumps(blob_ids) + +    def get_total_storage(self, user): +        return self._get_disk_usage(self._get_path(user)) + +    def add_tag_header(self, user, blob_id, request): +        with open(self._get_path(user, blob_id)) as doc_file: +            doc_file.seek(-16, 2) +            tag = base64.urlsafe_b64encode(doc_file.read()) +            request.responseHeaders.setRawHeaders('Tag', [tag]) + +    @defer.inlineCallbacks +    def _get_disk_usage(self, start_path): +        if not os.path.isdir(start_path): +            defer.returnValue(0) +        cmd = ['/usr/bin/du', '-s', '-c', start_path] +        output = yield utils.getProcessOutput(cmd[0], cmd[1:]) +        size = output.split()[0] +        defer.returnValue(int(size)) + +    def _validate_path(self, desired_path, user, blob_id): +        if not VALID_STRINGS.match(user): +            raise Exception("Invalid characters on user: %s" % user) +        if blob_id and not VALID_STRINGS.match(blob_id): +            raise Exception("Invalid characters on blob_id: %s" % blob_id) +        desired_path = os.path.realpath(desired_path)  # expand path references +        root = os.path.realpath(self.path) +        if not desired_path.startswith(root + os.sep + user): +            err = "User %s tried accessing a invalid path: %s" % (user, +                                                                  desired_path) +            raise Exception(err) +        return desired_path + +    def _get_path(self, user, blob_id=False): +        parts = [user] +        if blob_id: +            parts += [blob_id[0], blob_id[0:3], blob_id[0:6]] +            parts += [blob_id] +        path = os.path.join(self.path, *parts) +        return self._validate_path(path, user, blob_id) + + +class ImproperlyConfiguredException(Exception): +    pass + + +class BlobsResource(resource.Resource): + +    isLeaf = True + +    # Allowed backend classes are defined here +    handlers = {"filesystem": FilesystemBlobsBackend} + +    def __init__(self, backend, blobs_path, **backend_kwargs): +        resource.Resource.__init__(self) +        self._blobs_path = blobs_path +        backend_kwargs.update({'blobs_path': blobs_path}) +        if backend not in self.handlers: +            raise ImproperlyConfiguredException("No such backend: %s", backend) +        self._handler = self.handlers[backend](**backend_kwargs) +        assert interfaces.IBlobsBackend.providedBy(self._handler) + +    # TODO double check credentials, we can have then +    # under request. + +    def render_GET(self, request): +        logger.info("http get: %s" % request.path) +        user, blob_id = self._validate(request) +        if not blob_id: +            return self._handler.list_blobs(user, request) +        self._handler.add_tag_header(user, blob_id, request) +        return self._handler.read_blob(user, blob_id, request) + +    def render_DELETE(self, request): +        logger.info("http put: %s" % request.path) +        user, blob_id = self._validate(request) +        self._handler.delete_blob(user, blob_id) +        return '' + +    def render_PUT(self, request): +        logger.info("http put: %s" % request.path) +        user, blob_id = self._validate(request) +        d = self._handler.write_blob(user, blob_id, request) +        d.addCallback(lambda _: request.finish()) +        d.addErrback(self._error, request) +        return NOT_DONE_YET + +    def _error(self, e, request): +        logger.error('Error processing request: %s' % e.getErrorMessage()) +        request.setResponseCode(500) +        request.finish() + +    def _validate(self, request): +        for arg in request.postpath: +            if arg and not VALID_STRINGS.match(arg): +                raise Exception('Invalid blob resource argument: %s' % arg) +        return request.postpath + + +if __name__ == '__main__': +    # A dummy blob server +    # curl -X PUT --data-binary @/tmp/book.pdf localhost:9000/user/someid +    # curl -X GET -o /dev/null localhost:9000/user/somerandomstring +    from twisted.python import log +    import sys +    log.startLogging(sys.stdout) + +    from twisted.web.server import Site +    from twisted.internet import reactor + +    # parse command line arguments +    import argparse + +    parser = argparse.ArgumentParser() +    parser.add_argument('--port', default=9000, type=int) +    parser.add_argument('--path', default='/tmp/blobs/user') +    args = parser.parse_args() + +    root = BlobsResource("filesystem", args.path) +    # I picture somethink like +    # BlobsResource(backend="filesystem", backend_opts={'path': '/tmp/blobs'}) + +    factory = Site(root) +    reactor.listenTCP(args.port, factory) +    reactor.run() diff --git a/src/leap/soledad/server/_config.py b/src/leap/soledad/server/_config.py new file mode 100644 index 00000000..e89e70d6 --- /dev/null +++ b/src/leap/soledad/server/_config.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# config.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +import configparser + + +__all__ = ['get_config'] + + +CONFIG_DEFAULTS = { +    'soledad-server': { +        'couch_url': 'http://localhost:5984', +        'create_cmd': None, +        'admin_netrc': '/etc/couchdb/couchdb-admin.netrc', +        'batching': True, +        'blobs': False, +        'blobs_path': '/srv/leap/soledad/blobs', +    }, +    'database-security': { +        'members': ['soledad'], +        'members_roles': [], +        'admins': [], +        'admins_roles': [] +    } +} + + +_config = None + + +def get_config(section='soledad-server'): +    global _config +    if not _config: +        _config = _load_config('/etc/soledad/soledad-server.conf') +    return _config[section] + + +def _load_config(file_path): +    """ +    Load server configuration from file. + +    @param file_path: The path to the configuration file. +    @type file_path: str + +    @return: A dictionary with the configuration. +    @rtype: dict +    """ +    conf = dict(CONFIG_DEFAULTS) +    config = configparser.SafeConfigParser() +    config.read(file_path) +    for section in conf: +        if not config.has_section(section): +            continue +        for key, value in conf[section].items(): +            if not config.has_option(section, key): +                continue +            elif type(value) == bool: +                conf[section][key] = config.getboolean(section, key) +            elif type(value) == list: +                values = config.get(section, key).split(',') +                values = [v.strip() for v in values] +                conf[section][key] = values +            else: +                conf[section][key] = config.get(section, key) +    # TODO: implement basic parsing/sanitization of options comming from +    # config file. +    return conf diff --git a/src/leap/soledad/server/_resource.py b/src/leap/soledad/server/_resource.py new file mode 100644 index 00000000..49c4b742 --- /dev/null +++ b/src/leap/soledad/server/_resource.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# resource.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +A twisted resource that serves the Soledad Server. +""" +from twisted.web.resource import Resource + +from ._server_info import ServerInfo +from ._wsgi import get_sync_resource + + +__all__ = ['SoledadResource', 'SoledadAnonResource'] + + +class _Robots(Resource): +    def render_GET(self, request): +        return ( +            'User-agent: *\n' +            'Disallow: /\n' +            '# you are not a robot, are you???') + + +class SoledadAnonResource(Resource): + +    """ +    The parts of Soledad Server that unauthenticated users can see. +    This is nice because this means that a non-authenticated user will get 404 +    for anything that is not in this minimal resource tree. +    """ + +    def __init__(self, enable_blobs=False): +        Resource.__init__(self) +        server_info = ServerInfo(enable_blobs) +        self.putChild('', server_info) +        self.putChild('robots.txt', _Robots()) + + +class SoledadResource(Resource): +    """ +    This is a dummy twisted resource, used only to allow different entry points +    for the Soledad Server. +    """ + +    def __init__(self, blobs_resource=None, sync_pool=None): +        """ +        Initialize the Soledad resource. + +        :param blobs_resource: a resource to serve blobs, if enabled. +        :type blobs_resource: _blobs.BlobsResource + +        :param sync_pool: A pool to pass to the WSGI sync resource. +        :type sync_pool: twisted.python.threadpool.ThreadPool +        """ +        Resource.__init__(self) + +        # requests to / return server information +        server_info = ServerInfo(bool(blobs_resource)) +        self.putChild('', server_info) + +        # requests to /blobs will serve blobs if enabled +        if blobs_resource: +            self.putChild('blobs', blobs_resource) + +        # other requests are routed to legacy sync resource +        self._sync_resource = get_sync_resource(sync_pool) + +    def getChild(self, path, request): +        """ +        Route requests to legacy WSGI sync resource dynamically. +        """ +        request.postpath.insert(0, request.prepath.pop()) +        return self._sync_resource diff --git a/src/leap/soledad/server/_server_info.py b/src/leap/soledad/server/_server_info.py new file mode 100644 index 00000000..50659338 --- /dev/null +++ b/src/leap/soledad/server/_server_info.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# _server_info.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Resource that announces information about the server. +""" +import json + +from twisted.web.resource import Resource + +from leap.soledad.server import __version__ + + +__all__ = ['ServerInfo'] + + +class ServerInfo(Resource): +    """ +    Return information about the server. +    """ + +    isLeaf = True + +    def __init__(self, blobs_enabled): +        self._info = { +            "blobs": blobs_enabled, +            "version": __version__ +        } + +    def render_GET(self, request): +        return json.dumps(self._info) diff --git a/src/leap/soledad/server/_wsgi.py b/src/leap/soledad/server/_wsgi.py new file mode 100644 index 00000000..f6ff6b26 --- /dev/null +++ b/src/leap/soledad/server/_wsgi.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# application.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +A WSGI application that serves Soledad synchronization. +""" +from twisted.internet import reactor +from twisted.web.wsgi import WSGIResource + +from leap.soledad.server import SoledadApp +from leap.soledad.server.gzip_middleware import GzipMiddleware +from leap.soledad.common.backend import SoledadBackend +from leap.soledad.common.couch.state import CouchServerState +from leap.soledad.common.log import getLogger + +from twisted.logger import Logger +log = Logger() + +__all__ = ['init_couch_state', 'get_sync_resource'] + + +def _get_couch_state(conf): +    state = CouchServerState(conf['couch_url'], create_cmd=conf['create_cmd'], +                             check_schema_versions=True) +    SoledadBackend.BATCH_SUPPORT = conf.get('batching', False) +    return state + + +_app = SoledadApp(None)  # delay state init +wsgi_application = GzipMiddleware(_app) + + +# During its initialization, the couch state verifies if all user databases +# contain a config document with the correct couch schema version stored, and +# will log an error and raise an exception if that is not the case. +# +# If this verification made too early (i.e.  before the reactor has started and +# the twistd web logging facilities have been setup), the logging will not +# work.  Because of that, we delay couch state initialization until the reactor +# is running. + +def init_couch_state(conf): +    try: +        _app.state = _get_couch_state(conf) +    except Exception as e: +        logger = getLogger() +        logger.error(str(e)) +        reactor.stop() + + +def get_sync_resource(pool): +    return WSGIResource(reactor, pool, wsgi_application) diff --git a/src/leap/soledad/server/auth.py b/src/leap/soledad/server/auth.py new file mode 100644 index 00000000..1357b289 --- /dev/null +++ b/src/leap/soledad/server/auth.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# auth.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Twisted http token auth. +""" +import binascii +import time + +from hashlib import sha512 +from zope.interface import implementer + +from twisted.cred import error +from twisted.cred.checkers import ICredentialsChecker +from twisted.cred.credentials import IUsernamePassword +from twisted.cred.credentials import IAnonymous +from twisted.cred.credentials import Anonymous +from twisted.cred.credentials import UsernamePassword +from twisted.cred.portal import IRealm +from twisted.cred.portal import Portal +from twisted.internet import defer +from twisted.logger import Logger +from twisted.web.iweb import ICredentialFactory +from twisted.web.resource import IResource + +from leap.soledad.common.couch import couch_server + +from ._resource import SoledadResource, SoledadAnonResource +from ._blobs import BlobsResource +from ._config import get_config + + +log = Logger() + + +@implementer(IRealm) +class SoledadRealm(object): + +    def __init__(self, sync_pool, conf=None): +        assert sync_pool is not None +        if conf is None: +            conf = get_config() +        blobs = conf['blobs'] +        blobs_resource = BlobsResource("filesystem", +                                       conf['blobs_path']) if blobs else None +        self.anon_resource = SoledadAnonResource( +            enable_blobs=blobs) +        self.auth_resource = SoledadResource( +            blobs_resource=blobs_resource, +            sync_pool=sync_pool) + +    def requestAvatar(self, avatarId, mind, *interfaces): + +        # Anonymous access +        if IAnonymous.providedBy(avatarId): +            return (IResource, self.anon_resource, +                    lambda: None) + +        # Authenticated access +        else: +            if IResource in interfaces: +                return (IResource, self.auth_resource, +                        lambda: None) +        raise NotImplementedError() + + +@implementer(ICredentialsChecker) +class TokenChecker(object): + +    credentialInterfaces = [IUsernamePassword, IAnonymous] + +    TOKENS_DB_PREFIX = "tokens_" +    TOKENS_DB_EXPIRE = 30 * 24 * 3600  # 30 days in seconds +    TOKENS_TYPE_KEY = "type" +    TOKENS_TYPE_DEF = "Token" +    TOKENS_USER_ID_KEY = "user_id" + +    def __init__(self): +        self._couch_url = get_config().get('couch_url') + +    def _get_server(self): +        return couch_server(self._couch_url) + +    def _tokens_dbname(self): +        # the tokens db rotates every 30 days, and the current db name is +        # "tokens_NNN", where NNN is the number of seconds since epoch +        # divide dby the rotate period in seconds. When rotating, old and +        # new tokens db coexist during a certain window of time and valid +        # tokens are replicated from the old db to the new one. See: +        # https://leap.se/code/issues/6785 +        dbname = self.TOKENS_DB_PREFIX + \ +            str(int(time.time() / self.TOKENS_DB_EXPIRE)) +        return dbname + +    def _tokens_db(self): +        dbname = self._tokens_dbname() + +        # TODO -- leaking abstraction here: this module shouldn't need +        # to known anything about the context manager. hide that in the couch +        # module +        with self._get_server() as server: +            db = server[dbname] +        return db + +    def requestAvatarId(self, credentials): +        if IAnonymous.providedBy(credentials): +            return defer.succeed(Anonymous()) + +        uuid = credentials.username +        token = credentials.password + +        # lookup key is a hash of the token to prevent timing attacks. +        # TODO cache the tokens already! + +        db = self._tokens_db() +        token = db.get(sha512(token).hexdigest()) +        if token is None: +            return defer.fail(error.UnauthorizedLogin()) + +        # TODO -- use cryptography constant time builtin comparison. +        # we compare uuid hashes to avoid possible timing attacks that +        # might exploit python's builtin comparison operator behaviour, +        # which fails immediatelly when non-matching bytes are found. +        couch_uuid_hash = sha512(token[self.TOKENS_USER_ID_KEY]).digest() +        req_uuid_hash = sha512(uuid).digest() +        if token[self.TOKENS_TYPE_KEY] != self.TOKENS_TYPE_DEF \ +                or couch_uuid_hash != req_uuid_hash: +            return defer.fail(error.UnauthorizedLogin()) + +        return defer.succeed(uuid) + + +@implementer(ICredentialFactory) +class TokenCredentialFactory(object): + +    scheme = 'token' + +    def getChallenge(self, request): +        return {} + +    def decode(self, response, request): +        try: +            creds = binascii.a2b_base64(response + b'===') +        except binascii.Error: +            raise error.LoginFailed('Invalid credentials') + +        creds = creds.split(b':', 1) +        if len(creds) == 2: +            return UsernamePassword(*creds) +        else: +            raise error.LoginFailed('Invalid credentials') + + +def portalFactory(sync_pool): +    realm = SoledadRealm(sync_pool=sync_pool) +    checker = TokenChecker() +    return Portal(realm, [checker]) + + +credentialFactory = TokenCredentialFactory() diff --git a/src/leap/soledad/server/caching.py b/src/leap/soledad/server/caching.py new file mode 100644 index 00000000..9a049a39 --- /dev/null +++ b/src/leap/soledad/server/caching.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# caching.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Server side caching. Using beaker for now. +""" +from beaker.cache import CacheManager + + +def setup_caching(): +    _cache_manager = CacheManager(type='memory') +    return _cache_manager + + +_cache_manager = setup_caching() + + +def get_cache_for(key, expire=3600): +    return _cache_manager.get_cache(key, expire=expire) diff --git a/src/leap/soledad/server/entrypoint.py b/src/leap/soledad/server/entrypoint.py new file mode 100644 index 00000000..c06b740e --- /dev/null +++ b/src/leap/soledad/server/entrypoint.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# entrypoint.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +The entrypoint for Soledad server. + +This is the entrypoint for the application that is loaded from the initscript +or the systemd script. +""" + +from twisted.internet import reactor +from twisted.python import threadpool + +from .auth import portalFactory +from .session import SoledadSession +from ._config import get_config +from ._wsgi import init_couch_state + + +# load configuration from file +conf = get_config() + + +class SoledadEntrypoint(SoledadSession): + +    def __init__(self): +        pool = threadpool.ThreadPool(name='wsgi') +        reactor.callWhenRunning(pool.start) +        reactor.addSystemEventTrigger('after', 'shutdown', pool.stop) +        portal = portalFactory(pool) +        SoledadSession.__init__(self, portal) + + +# see the comments in application.py recarding why couch state has to be +# initialized when the reactor is running + +reactor.callWhenRunning(init_couch_state, conf) diff --git a/src/leap/soledad/server/gzip_middleware.py b/src/leap/soledad/server/gzip_middleware.py new file mode 100644 index 00000000..c77f9f67 --- /dev/null +++ b/src/leap/soledad/server/gzip_middleware.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# gzip_middleware.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Gzip middleware for WSGI apps. +""" +from six import StringIO +from gzip import GzipFile + + +class GzipMiddleware(object): +    """ +    GzipMiddleware class for WSGI. +    """ +    def __init__(self, app, compresslevel=9): +        self.app = app +        self.compresslevel = compresslevel + +    def __call__(self, environ, start_response): +        if 'gzip' not in environ.get('HTTP_ACCEPT_ENCODING', ''): +            return self.app(environ, start_response) + +        buffer = StringIO.StringIO() +        output = GzipFile( +            mode='wb', +            compresslevel=self.compresslevel, +            fileobj=buffer +        ) + +        start_response_args = [] + +        def dummy_start_response(status, headers, exc_info=None): +            start_response_args.append(status) +            start_response_args.append(headers) +            start_response_args.append(exc_info) +            return output.write + +        app_iter = self.app(environ, dummy_start_response) +        for line in app_iter: +            output.write(line) +        if hasattr(app_iter, 'close'): +            app_iter.close() +        output.close() +        buffer.seek(0) +        result = buffer.getvalue() +        headers = [] +        for name, value in start_response_args[1]: +            if name.lower() != 'content-length': +                headers.append((name, value)) +        headers.append(('Content-Length', str(len(result)))) +        headers.append(('Content-Encoding', 'gzip')) +        start_response(start_response_args[0], headers, start_response_args[2]) +        buffer.close() +        return [result] diff --git a/src/leap/soledad/server/interfaces.py b/src/leap/soledad/server/interfaces.py new file mode 100644 index 00000000..67b04bc3 --- /dev/null +++ b/src/leap/soledad/server/interfaces.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# interfaces.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +from zope.interface import Interface + + +class IBlobsBackend(Interface): + +    """ +    An interface for a BlobsBackend. +    """ + +    def read_blob(user, blob_id, request): +        """ +        Read blob with a given blob_id, and write it to the passed request. + +        :returns: a deferred that fires upon finishing. +        """ + +    def write_blob(user, blob_id, request): +        """ +        Write blob to the storage, reading it from the passed request. + +        :returns: a deferred that fires upon finishing. +        """ + +    def delete_blob(user, blob_id): +        """ +        Delete the given blob_id. +        """ + +    def get_blob_size(user, blob_id): +        """ +        Get the size of the given blob id. +        """ + +    def list_blobs(user, request): +        """ +        Returns a json-encoded list of ids from user's blob. + +        :returns: a deferred that fires upon finishing. +        """ + +    def get_total_storage(user): +        """ +        Get the size used by a given user as the sum of all the blobs stored +        unders its namespace. +        """ + +    def add_tag_header(user, blob_id, request): +        """ +        Adds a header 'Tag' to the passed request object, containing the last +        16 bytes of the encoded blob, which according to the spec contains the +        tag. + +        :returns: a deferred that fires upon finishing. +        """ diff --git a/src/leap/soledad/server/session.py b/src/leap/soledad/server/session.py new file mode 100644 index 00000000..1c1b5345 --- /dev/null +++ b/src/leap/soledad/server/session.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# session.py +# Copyright (C) 2017 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Twisted resource containing an authenticated Soledad session. +""" +from zope.interface import implementer + +from twisted.cred.credentials import Anonymous +from twisted.cred import error +from twisted.python import log +from twisted.web import util +from twisted.web._auth import wrapper +from twisted.web.guard import HTTPAuthSessionWrapper +from twisted.web.resource import ErrorPage +from twisted.web.resource import IResource + +from leap.soledad.server.auth import credentialFactory +from leap.soledad.server.url_mapper import URLMapper + + +@implementer(IResource) +class UnauthorizedResource(wrapper.UnauthorizedResource): +    isLeaf = True + +    def __init__(self): +        pass + +    def render(self, request): +        request.setResponseCode(401) +        if request.method == b'HEAD': +            return b'' +        return b'Unauthorized' + +    def getChildWithDefault(self, path, request): +        return self + + +@implementer(IResource) +class SoledadSession(HTTPAuthSessionWrapper): + +    def __init__(self, portal): +        self._mapper = URLMapper() +        self._portal = portal +        self._credentialFactory = credentialFactory +        # expected by the contract of the parent class +        self._credentialFactories = [credentialFactory] + +    def _matchPath(self, request): +        match = self._mapper.match(request.path, request.method) +        return match + +    def _parseHeader(self, header): +        elements = header.split(b' ') +        scheme = elements[0].lower() +        if scheme == self._credentialFactory.scheme: +            return (b' '.join(elements[1:])) +        return None + +    def _authorizedResource(self, request): +        # check whether the path of the request exists in the app +        match = self._matchPath(request) +        if not match: +            return UnauthorizedResource() + +        # get authorization header or fail +        header = request.getHeader(b'authorization') +        if not header: +            return util.DeferredResource(self._login(Anonymous())) + +        # parse the authorization header +        auth_data = self._parseHeader(header) +        if not auth_data: +            return UnauthorizedResource() + +        # decode the credentials from the parsed header +        try: +            credentials = self._credentialFactory.decode(auth_data, request) +        except error.LoginFailed: +            return UnauthorizedResource() +        except: +            # If you port this to the newer log facility, be aware that +            # the tests rely on the error to be logged. +            log.err(None, "Unexpected failure from credentials factory") +            return ErrorPage(500, None, None) + +        # make sure the uuid given in path corresponds to the one given in +        # the credentials +        request_uuid = match.get('uuid') +        if request_uuid and request_uuid != credentials.username: +            return ErrorPage(500, None, None) + +        # if all checks pass, try to login with credentials +        return util.DeferredResource(self._login(credentials)) diff --git a/src/leap/soledad/server/state.py b/src/leap/soledad/server/state.py new file mode 100644 index 00000000..f269b77e --- /dev/null +++ b/src/leap/soledad/server/state.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# state.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Server side synchronization infrastructure. +""" +from leap.soledad.server import caching + + +class ServerSyncState(object): +    """ +    The state of one sync session, as stored on backend server. + +    On server side, the ongoing syncs metadata is maintained in +    a caching layer. +    """ + +    def __init__(self, source_replica_uid, sync_id): +        """ +        Initialize the sync state object. + +        :param sync_id: The id of current sync +        :type sync_id: str +        :param source_replica_uid: The source replica uid +        :type source_replica_uid: str +        """ +        self._source_replica_uid = source_replica_uid +        self._sync_id = sync_id +        caching_key = source_replica_uid + sync_id +        self._storage = caching.get_cache_for(caching_key) + +    def _put_dict_info(self, key, value): +        """ +        Put some information about the sync state. + +        :param key: The key for the info to be put. +        :type key: str +        :param value: The value for the info to be put. +        :type value: str +        """ +        if key not in self._storage: +            self._storage[key] = [] +        info_list = self._storage.get(key) +        info_list.append(value) +        self._storage[key] = info_list + +    def put_seen_id(self, seen_id, gen): +        """ +        Put one seen id on the sync state. + +        :param seen_id: The doc_id of a document seen during sync. +        :type seen_id: str +        :param gen: The corresponding db generation. +        :type gen: int +        """ +        self._put_dict_info( +            'seen_id', +            (seen_id, gen)) + +    def seen_ids(self): +        """ +        Return all document ids seen during the sync. + +        :return: A dict with doc ids seen during the sync. +        :rtype: dict +        """ +        if 'seen_id' in self._storage: +            seen_ids = self._storage.get('seen_id') +        else: +            seen_ids = [] +        return dict(seen_ids) + +    def put_changes_to_return(self, gen, trans_id, changes_to_return): +        """ +        Put the calculated changes to return in the backend sync state. + +        :param gen: The target database generation that will be synced. +        :type gen: int +        :param trans_id: The target database transaction id that will be +                         synced. +        :type trans_id: str +        :param changes_to_return: A list of tuples with the changes to be +                                  returned during the sync process. +        :type changes_to_return: list +        """ +        self._put_dict_info( +            'changes_to_return', +            { +                'gen': gen, +                'trans_id': trans_id, +                'changes_to_return': changes_to_return, +            } +        ) + +    def sync_info(self): +        """ +        Return information about the current sync state. + +        :return: The generation and transaction id of the target database +                 which will be synced, and the number of documents to return, +                 or a tuple of Nones if those have not already been sent to +                 server. +        :rtype: tuple +        """ +        gen = trans_id = number_of_changes = None +        if 'changes_to_return' in self._storage: +            info = self._storage.get('changes_to_return')[0] +            gen = info['gen'] +            trans_id = info['trans_id'] +            number_of_changes = len(info['changes_to_return']) +        return gen, trans_id, number_of_changes + +    def next_change_to_return(self, received): +        """ +        Return the next change to be returned to the source syncing replica. + +        :param received: How many documents the source replica has already +                         received during the current sync process. +        :type received: int +        """ +        gen = trans_id = next_change_to_return = None +        if 'changes_to_return' in self._storage: +            info = self._storage.get('changes_to_return')[0] +            gen = info['gen'] +            trans_id = info['trans_id'] +            if received < len(info['changes_to_return']): +                next_change_to_return = (info['changes_to_return'][received]) +        return gen, trans_id, next_change_to_return diff --git a/src/leap/soledad/server/sync.py b/src/leap/soledad/server/sync.py new file mode 100644 index 00000000..6791c06c --- /dev/null +++ b/src/leap/soledad/server/sync.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +# sync.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Server side synchronization infrastructure. +""" +import time +from six.moves import zip as izip + +from leap.soledad.common.l2db import sync +from leap.soledad.common.l2db.remote import http_app +from leap.soledad.server.caching import get_cache_for +from leap.soledad.server.state import ServerSyncState +from leap.soledad.common.document import ServerDocument + + +MAX_REQUEST_SIZE = float('inf')  # It's a stream. +MAX_ENTRY_SIZE = 200  # in Mb +ENTRY_CACHE_SIZE = 8192 * 1024 + + +class SyncExchange(sync.SyncExchange): + +    def __init__(self, db, source_replica_uid, last_known_generation, sync_id): +        """ +        :param db: The target syncing database. +        :type db: SoledadBackend +        :param source_replica_uid: The uid of the source syncing replica. +        :type source_replica_uid: str +        :param last_known_generation: The last target replica generation the +                                      source replica knows about. +        :type last_known_generation: int +        :param sync_id: The id of the current sync session. +        :type sync_id: str +        """ +        self._db = db +        self.source_replica_uid = source_replica_uid +        self.source_last_known_generation = last_known_generation +        self.sync_id = sync_id +        self.new_gen = None +        self.new_trans_id = None +        self._trace_hook = None +        # recover sync state +        self._sync_state = ServerSyncState(self.source_replica_uid, sync_id) + +    def find_changes_to_return(self): +        """ +        Find changes to return. + +        Find changes since last_known_generation in db generation +        order using whats_changed. It excludes documents ids that have +        already been considered (superseded by the sender, etc). + +        :return: the generation of this database, which the caller can +                 consider themselves to be synchronized after processing +                 allreturned documents, and the amount of documents to be sent +                 to the source syncing replica. +        :rtype: int +        """ +        # check if changes to return have already been calculated +        new_gen, new_trans_id, number_of_changes = self._sync_state.sync_info() +        if number_of_changes is None: +            self._trace('before whats_changed') +            new_gen, new_trans_id, changes = self._db.whats_changed( +                self.source_last_known_generation) +            self._trace('after whats_changed') +            seen_ids = self._sync_state.seen_ids() +            # changed docs that weren't superseded by or converged with +            self.changes_to_return = [ +                (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes +                # there was a subsequent update +                if doc_id not in seen_ids or seen_ids.get(doc_id) < gen] +            self._sync_state.put_changes_to_return( +                new_gen, new_trans_id, self.changes_to_return) +            number_of_changes = len(self.changes_to_return) +        self.new_gen = new_gen +        self.new_trans_id = new_trans_id +        return self.new_gen, number_of_changes + +    def return_docs(self, return_doc_cb): +        """Return the changed documents and their last change generation +        repeatedly invoking the callback return_doc_cb. + +        The final step of a sync exchange. + +        :param: return_doc_cb(doc, gen, trans_id): is a callback +                used to return the documents with their last change generation +                to the target replica. +        :return: None +        """ +        changes_to_return = self.changes_to_return +        # return docs, including conflicts. +        # content as a file-object (will be read when writing) +        changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] +        docs = self._db.get_docs( +            changed_doc_ids, check_for_conflicts=False, +            include_deleted=True, read_content=False) + +        docs_by_gen = izip( +            docs, (gen for _, gen, _ in changes_to_return), +            (trans_id for _, _, trans_id in changes_to_return)) +        for doc, gen, trans_id in docs_by_gen: +            return_doc_cb(doc, gen, trans_id) + +    def batched_insert_from_source(self, entries, sync_id): +        if not entries: +            return +        self._db.batch_start() +        for entry in entries: +            doc, gen, trans_id, number_of_docs, doc_idx = entry +            self.insert_doc_from_source(doc, gen, trans_id, number_of_docs, +                                        doc_idx, sync_id) +        self._db.batch_end() + +    def insert_doc_from_source( +            self, doc, source_gen, trans_id, +            number_of_docs=None, doc_idx=None, sync_id=None): +        """Try to insert synced document from source. + +        Conflicting documents are not inserted but will be sent over +        to the sync source. + +        It keeps track of progress by storing the document source +        generation as well. + +        The 1st step of a sync exchange is to call this repeatedly to +        try insert all incoming documents from the source. + +        :param doc: A Document object. +        :type doc: Document +        :param source_gen: The source generation of doc. +        :type source_gen: int +        :param trans_id: The transaction id of that document change. +        :type trans_id: str +        :param number_of_docs: The total amount of documents sent on this sync +                               session. +        :type number_of_docs: int +        :param doc_idx: The index of the current document. +        :type doc_idx: int +        :param sync_id: The id of the current sync session. +        :type sync_id: str +        """ +        state, at_gen = self._db._put_doc_if_newer( +            doc, save_conflict=False, replica_uid=self.source_replica_uid, +            replica_gen=source_gen, replica_trans_id=trans_id, +            number_of_docs=number_of_docs, doc_idx=doc_idx, sync_id=sync_id) +        if state == 'inserted': +            self._sync_state.put_seen_id(doc.doc_id, at_gen) +        elif state == 'converged': +            # magical convergence +            self._sync_state.put_seen_id(doc.doc_id, at_gen) +        elif state == 'superseded': +            # we have something newer that we will return +            pass +        else: +            # conflict that we will returne +            assert state == 'conflicted' + + +class SyncResource(http_app.SyncResource): + +    max_request_size = MAX_REQUEST_SIZE * 1024 * 1024 +    max_entry_size = MAX_ENTRY_SIZE * 1024 * 1024 + +    sync_exchange_class = SyncExchange + +    @http_app.http_method( +        last_known_generation=int, last_known_trans_id=http_app.none_or_str, +        sync_id=http_app.none_or_str, content_as_args=True) +    def post_args(self, last_known_generation, last_known_trans_id=None, +                  sync_id=None, ensure=False): +        """ +        Handle the initial arguments for the sync POST request from client. + +        :param last_known_generation: The last server replica generation the +                                      client knows about. +        :type last_known_generation: int +        :param last_known_trans_id: The last server replica transaction_id the +                                    client knows about. +        :type last_known_trans_id: str +        :param sync_id: The id of the current sync session. +        :type sync_id: str +        :param ensure: Whether the server replica should be created if it does +                       not already exist. +        :type ensure: bool +        """ +        # create or open the database +        cache = get_cache_for('db-' + sync_id + self.dbname, expire=120) +        if ensure: +            db, self.replica_uid = self.state.ensure_database(self.dbname) +        else: +            db = self.state.open_database(self.dbname) +        db.init_caching(cache) +        # validate the information the client has about server replica +        db.validate_gen_and_trans_id( +            last_known_generation, last_known_trans_id) +        # get a sync exchange object +        self.sync_exch = self.sync_exchange_class( +            db, self.source_replica_uid, last_known_generation, sync_id) +        self._sync_id = sync_id +        self._staging = [] +        self._staging_size = 0 + +    @http_app.http_method(content_as_args=True) +    def post_put( +            self, id, rev, content, gen, +            trans_id, number_of_docs, doc_idx): +        """ +        Put one incoming document into the server replica. + +        :param id: The id of the incoming document. +        :type id: str +        :param rev: The revision of the incoming document. +        :type rev: str +        :param content: The content of the incoming document. +        :type content: dict +        :param gen: The source replica generation corresponding to the +                    revision of the incoming document. +        :type gen: int +        :param trans_id: The source replica transaction id corresponding to +                         the revision of the incoming document. +        :type trans_id: str +        :param number_of_docs: The total amount of documents sent on this sync +                               session. +        :type number_of_docs: int +        :param doc_idx: The index of the current document. +        :type doc_idx: int +        """ +        doc = ServerDocument(id, rev, json=content) +        self._staging_size += len(content or '') +        self._staging.append((doc, gen, trans_id, number_of_docs, doc_idx)) +        if self._staging_size > ENTRY_CACHE_SIZE or doc_idx == number_of_docs: +            self.sync_exch.batched_insert_from_source(self._staging, +                                                      self._sync_id) +            self._staging = [] +            self._staging_size = 0 + +    def post_get(self): +        """ +        Return syncing documents to the client. +        """ +        def send_doc(doc, gen, trans_id): +            entry = dict(id=doc.doc_id, rev=doc.rev, +                         gen=gen, trans_id=trans_id) +            self.responder.stream_entry(entry) +            content_reader = doc.get_json() +            if content_reader: +                content = content_reader.read() +                self.responder.stream_entry(content) +                content_reader.close() +                # throttle at 5mb/s +                # FIXME: twistd cant control througput +                # we need to either use gunicorn or go async +                time.sleep(len(content) / (5.0 * 1024 * 1024)) +            else: +                self.responder.stream_entry('') + +        new_gen, number_of_changes = \ +            self.sync_exch.find_changes_to_return() +        self.responder.content_type = 'application/x-u1db-sync-response' +        self.responder.start_response(200) +        self.responder.start_stream(), +        header = { +            "new_generation": new_gen, +            "new_transaction_id": self.sync_exch.new_trans_id, +            "number_of_changes": number_of_changes, +        } +        if self.replica_uid is not None: +            header['replica_uid'] = self.replica_uid +        self.responder.stream_entry(header) +        self.sync_exch.return_docs(send_doc) +        self.responder.end_stream() +        self.responder.finish_response() + +    def post_end(self): +        """ +        Return the current generation and transaction_id after inserting one +        incoming document. +        """ +        self.responder.content_type = 'application/x-soledad-sync-response' +        self.responder.start_response(200) +        self.responder.start_stream(), +        new_gen, new_trans_id = self.sync_exch._db._get_generation_info() +        header = { +            "new_generation": new_gen, +            "new_transaction_id": new_trans_id, +        } +        if self.replica_uid is not None: +            header['replica_uid'] = self.replica_uid +        self.responder.stream_entry(header) +        self.responder.end_stream() +        self.responder.finish_response() diff --git a/src/leap/soledad/server/url_mapper.py b/src/leap/soledad/server/url_mapper.py new file mode 100644 index 00000000..b50a81cd --- /dev/null +++ b/src/leap/soledad/server/url_mapper.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# url_mapper.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +An URL mapper that represents authorized paths. +""" +from routes.mapper import Mapper + +from leap.soledad.common import SHARED_DB_NAME +from leap.soledad.common.l2db import DBNAME_CONSTRAINTS + + +class URLMapper(object): +    """ +    Maps the URLs users can access. +    """ + +    def __init__(self): +        self._map = Mapper(controller_scan=None) +        self._connect_urls() +        self._map.create_regs() + +    def match(self, path, method): +        environ = {'PATH_INFO': path, 'REQUEST_METHOD': method} +        return self._map.match(environ=environ) + +    def _connect(self, pattern, http_methods): +        self._map.connect( +            None, pattern, http_methods=http_methods, +            conditions=dict(method=http_methods), +            requirements={'dbname': DBNAME_CONSTRAINTS}) + +    def _connect_urls(self): +        """ +        Register the authorization info in the mapper using C{SHARED_DB_NAME} +        as the user's database name. + +        This method sets up the following authorization rules: + +            URL path                        | Authorized actions +            ---------------------------------------------------- +            /                               | GET +            /robots.txt                     | GET +            /shared-db                      | GET +            /shared-db/doc/{any_id}         | GET, PUT, DELETE +            /user-{uuid}/sync-from/{source} | GET, PUT, POST +            /blobs/{uuid}/{blob_id}         | GET, PUT, POST +            /blobs/{uuid}                   | GET +        """ +        # auth info for global resource +        self._connect('/', ['GET']) +        # robots +        self._connect('/robots.txt', ['GET']) +        # auth info for shared-db database resource +        self._connect('/%s' % SHARED_DB_NAME, ['GET']) +        # auth info for shared-db doc resource +        self._connect('/%s/doc/{id:.*}' % SHARED_DB_NAME, +                      ['GET', 'PUT', 'DELETE']) +        # auth info for user-db sync resource +        self._connect('/user-{uuid}/sync-from/{source_replica_uid}', +                      ['GET', 'PUT', 'POST']) +        # auth info for blobs resource +        self._connect('/blobs/{uuid}/{blob_id}', ['GET', 'PUT']) +        self._connect('/blobs/{uuid}', ['GET'])  | 
