diff options
Diffstat (limited to 'src')
162 files changed, 30117 insertions, 1283 deletions
diff --git a/src/eip_client.egg-info/PKG-INFO b/src/eip_client.egg-info/PKG-INFO deleted file mode 100644 index e4bc754e..00000000 --- a/src/eip_client.egg-info/PKG-INFO +++ /dev/null @@ -1,11 +0,0 @@ -Metadata-Version: 1.0 -Name: eip-client -Version: 0.1dev -Summary: the internet encryption toolkit -Home-page: http://leap.se -Author: leap project -Author-email: info@leap.se -License: GPL -Description: UNKNOWN -Keywords: leap,client,qt,encryption -Platform: all diff --git a/src/eip_client.egg-info/SOURCES.txt b/src/eip_client.egg-info/SOURCES.txt deleted file mode 100644 index 05688ff1..00000000 --- a/src/eip_client.egg-info/SOURCES.txt +++ /dev/null @@ -1,28 +0,0 @@ -MANIFEST.in -README.txt -setup.cfg -setup.py -docs/LICENSE.txt -docs/leap.1 -setup/linux/polkit/net.openvpn.gui.leap.policy -setup/scripts/leap -src/eip_client.egg-info/PKG-INFO -src/eip_client.egg-info/SOURCES.txt -src/eip_client.egg-info/dependency_links.txt -src/eip_client.egg-info/entry_points.txt -src/eip_client.egg-info/not-zip-safe -src/eip_client.egg-info/top_level.txt -src/leap/__init__.py -src/leap/app.py -src/leap/baseapp/__init__.py -src/leap/baseapp/config.py -src/leap/baseapp/mainwindow.py -src/leap/eip/__init__.py -src/leap/eip/conductor.py -src/leap/eip/vpnmanager.py -src/leap/eip/vpnwatcher.py -src/leap/gui/__init__.py -src/leap/gui/mainwindow_rc.py -src/leap/utils/__init__.py -src/leap/utils/coroutines.py -src/leap/utils/leap_argparse.py
\ No newline at end of file diff --git a/src/eip_client.egg-info/dependency_links.txt b/src/eip_client.egg-info/dependency_links.txt deleted file mode 100644 index 8b137891..00000000 --- a/src/eip_client.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/eip_client.egg-info/entry_points.txt b/src/eip_client.egg-info/entry_points.txt deleted file mode 100644 index a184cd05..00000000 --- a/src/eip_client.egg-info/entry_points.txt +++ /dev/null @@ -1,3 +0,0 @@ - - # -*- Entry points: -*- -
\ No newline at end of file diff --git a/src/eip_client.egg-info/not-zip-safe b/src/eip_client.egg-info/not-zip-safe deleted file mode 100644 index 8b137891..00000000 --- a/src/eip_client.egg-info/not-zip-safe +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/eip_client.egg-info/top_level.txt b/src/eip_client.egg-info/top_level.txt deleted file mode 100644 index 2905ed7d..00000000 --- a/src/eip_client.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -leap diff --git a/src/leap/__init__.py b/src/leap/__init__.py index e69de29b..5e003931 100644 --- a/src/leap/__init__.py +++ b/src/leap/__init__.py @@ -0,0 +1,35 @@ +""" +LEAP Encryption Access Project +website: U{https://leap.se/} +""" + +from leap import eip +from leap import baseapp +from leap import util + +__all__ = [eip, baseapp, util] + +__version__ = "unknown" +try: + from ._version import get_versions + __version__ = get_versions()['version'] + del get_versions +except ImportError: + #running on a tree that has not run + #the setup.py setver + pass + +__appname__ = "unknown" +try: + from leap._appname import __appname__ +except ImportError: + #running on a tree that has not run + #the setup.py setver + pass + +__full_version__ = __appname__ + '/' + str(__version__) + +try: + from leap._branding import BRANDING as __branding +except ImportError: + __branding = {} diff --git a/src/leap/_version.py b/src/leap/_version.py new file mode 100644 index 00000000..c33430ea --- /dev/null +++ b/src/leap/_version.py @@ -0,0 +1,197 @@ + +IN_LONG_VERSION_PY = True +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (build by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.7+ (https://github.com/warner/python-versioneer) + +# these strings will be replaced by git during git-archive +git_refnames = "$Format:%d$" +git_full = "$Format:%H$" + + +import subprocess +import sys + +def run_command(args, cwd=None, verbose=False): + try: + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen(args, stdout=subprocess.PIPE, cwd=cwd) + except EnvironmentError: + e = sys.exc_info()[1] + if verbose: + print("unable to run %s" % args[0]) + print(e) + return None + stdout = p.communicate()[0].strip() + if sys.version >= '3': + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % args[0]) + return None + return stdout + + +import sys +import re +import os.path + +def get_expanded_variables(versionfile_source): + # the code embedded in _version.py can just fetch the value of these + # variables. When used from setup.py, we don't want to import + # _version.py, so we do it with a regexp instead. This function is not + # used from _version.py. + variables = {} + try: + for line in open(versionfile_source,"r").readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + variables["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + variables["full"] = mo.group(1) + except EnvironmentError: + pass + return variables + +def versions_from_expanded_variables(variables, tag_prefix, verbose=False): + refnames = variables["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("variables are unexpanded, not using") + return {} # unexpanded, so not in an unpacked git-archive tarball + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + for ref in list(refs): + if not re.search(r'\d', ref): + if verbose: + print("discarding '%s', no digits" % ref) + refs.discard(ref) + # Assume all version tags have a digit. git's %d expansion + # behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us + # distinguish between branches and tags. By ignoring refnames + # without digits, we filter out many common branch names like + # "release" and "stabilization", as well as "HEAD" and "master". + if verbose: + print("remaining refs: %s" % ",".join(sorted(refs))) + for ref in sorted(refs): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return { "version": r, + "full": variables["full"].strip() } + # no suitable tags, so we use the full revision id + if verbose: + print("no suitable tags, using full revision id") + return { "version": variables["full"].strip(), + "full": variables["full"].strip() } + +def versions_from_vcs(tag_prefix, versionfile_source, verbose=False): + # this runs 'git' from the root of the source tree. That either means + # someone ran a setup.py command (and this code is in versioneer.py, so + # IN_LONG_VERSION_PY=False, thus the containing directory is the root of + # the source tree), or someone ran a project-specific entry point (and + # this code is in _version.py, so IN_LONG_VERSION_PY=True, thus the + # containing directory is somewhere deeper in the source tree). This only + # gets called if the git-archive 'subst' variables were *not* expanded, + # and _version.py hasn't already been rewritten with a short version + # string, meaning we're inside a checked out source tree. + + try: + here = os.path.abspath(__file__) + except NameError: + # some py2exe/bbfreeze/non-CPython implementations don't do __file__ + return {} # not always correct + + # versionfile_source is the relative path from the top of the source tree + # (where the .git directory might live) to this file. Invert this to find + # the root from __file__. + root = here + if IN_LONG_VERSION_PY: + for i in range(len(versionfile_source.split("/"))): + root = os.path.dirname(root) + else: + root = os.path.dirname(here) + if not os.path.exists(os.path.join(root, ".git")): + if verbose: + print("no .git in %s" % root) + return {} + + GIT = "git" + if sys.platform == "win32": + GIT = "git.cmd" + stdout = run_command([GIT, "describe", "--tags", "--dirty", "--always"], + cwd=root) + if stdout is None: + return {} + if not stdout.startswith(tag_prefix): + if verbose: + print("tag '%s' doesn't start with prefix '%s'" % (stdout, tag_prefix)) + return {} + tag = stdout[len(tag_prefix):] + stdout = run_command([GIT, "rev-parse", "HEAD"], cwd=root) + if stdout is None: + return {} + full = stdout.strip() + if tag.endswith("-dirty"): + full += "-dirty" + return {"version": tag, "full": full} + + +def versions_from_parentdir(parentdir_prefix, versionfile_source, verbose=False): + if IN_LONG_VERSION_PY: + # We're running from _version.py. If it's from a source tree + # (execute-in-place), we can work upwards to find the root of the + # tree, and then check the parent directory for a version string. If + # it's in an installed application, there's no hope. + try: + here = os.path.abspath(__file__) + except NameError: + # py2exe/bbfreeze/non-CPython don't have __file__ + return {} # without __file__, we have no hope + # versionfile_source is the relative path from the top of the source + # tree to _version.py. Invert this to find the root from __file__. + root = here + for i in range(len(versionfile_source.split("/"))): + root = os.path.dirname(root) + else: + # we're running from versioneer.py, which means we're running from + # the setup.py in a source tree. sys.argv[0] is setup.py in the root. + here = os.path.abspath(sys.argv[0]) + root = os.path.dirname(here) + + # Source tarballs conventionally unpack into a directory that includes + # both the project name and a version string. + dirname = os.path.basename(root) + if not dirname.startswith(parentdir_prefix): + if verbose: + print("guessing rootdir is '%s', but '%s' doesn't start with prefix '%s'" % + (root, dirname, parentdir_prefix)) + return None + return {"version": dirname[len(parentdir_prefix):], "full": ""} + +tag_prefix = "" +parentdir_prefix = "leap_client-" +versionfile_source = "src/leap/_version.py" + +def get_versions(default={"version": "unknown", "full": ""}, verbose=False): + variables = { "refnames": git_refnames, "full": git_full } + ver = versions_from_expanded_variables(variables, tag_prefix, verbose) + if not ver: + ver = versions_from_vcs(tag_prefix, versionfile_source, verbose) + if not ver: + ver = versions_from_parentdir(parentdir_prefix, versionfile_source, + verbose) + if not ver: + ver = default + return ver + diff --git a/src/leap/app.py b/src/leap/app.py index 0a61fd4f..d594c7cd 100644 --- a/src/leap/app.py +++ b/src/leap/app.py @@ -1,12 +1,24 @@ +# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4 +from functools import partial import logging +import signal + # This is only needed for Python v2 but is harmless for Python v3. import sip sip.setapi('QVariant', 2) +sip.setapi('QString', 2) from PyQt4.QtGui import (QApplication, QSystemTrayIcon, QMessageBox) +from PyQt4.QtCore import QTimer +from leap import __version__ as VERSION from leap.baseapp.mainwindow import LeapWindow -logger = logging.getLogger(name=__name__) + +def sigint_handler(*args, **kwargs): + logger = kwargs.get('logger', None) + logger.debug('SIGINT catched. shutting down...') + mainwindow = args[0] + mainwindow.shutdownSignal.emit() def main(): @@ -15,26 +27,77 @@ def main(): long live to the (hidden) leap window! """ import sys - from leap.utils import leap_argparse + from leap.util import leap_argparse parser, opts = leap_argparse.init_leapc_args() debug = getattr(opts, 'debug', False) - #XXX get debug level and set logger accordingly + # XXX get severity from command line args if debug: - logger.debug('args: ', opts) + level = logging.DEBUG + else: + level = logging.WARNING + + logger = logging.getLogger(name='leap') + logger.setLevel(level) + console = logging.StreamHandler() + console.setLevel(level) + formatter = logging.Formatter( + '%(asctime)s ' + '- %(name)s - %(levelname)s - %(message)s') + console.setFormatter(formatter) + logger.addHandler(console) + #logger.debug(opts) + logger.info('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') + logger.info('LEAP client version %s', VERSION) + logger.info('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') + logfile = getattr(opts, 'log_file', False) + if logfile: + logger.debug('setting logfile to %s ', logfile) + fileh = logging.FileHandler(logfile) + fileh.setLevel(logging.DEBUG) + fileh.setFormatter(formatter) + logger.addHandler(fileh) + + logger.info('Starting app') app = QApplication(sys.argv) + # needed for initializing qsettings + # it will write .config/leap/leap.conf + # top level app settings + # in a platform independent way + app.setOrganizationName("leap") + app.setApplicationName("leap") + app.setOrganizationDomain("leap.se") + if not QSystemTrayIcon.isSystemTrayAvailable(): QMessageBox.critical(None, "Systray", - "I couldn't detect any \ -system tray on this system.") + "I couldn't detect" + "any system tray on this system.") sys.exit(1) if not debug: QApplication.setQuitOnLastWindowClosed(False) window = LeapWindow(opts) - window.show() + + # this dummy timer ensures that + # control is given to the outside loop, so we + # can hook our sigint handler. + timer = QTimer() + timer.start(500) + timer.timeout.connect(lambda: None) + + sigint_window = partial(sigint_handler, window, logger=logger) + signal.signal(signal.SIGINT, sigint_window) + + if debug: + # we only show the main window + # if debug mode active. + # if not, it will be set visible + # from the systray menu. + window.show() + + # run main loop sys.exit(app.exec_()) if __name__ == "__main__": diff --git a/src/leap/utils/__init__.py b/src/leap/base/__init__.py index e69de29b..e69de29b 100644 --- a/src/leap/utils/__init__.py +++ b/src/leap/base/__init__.py diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py new file mode 100644 index 00000000..50533278 --- /dev/null +++ b/src/leap/base/auth.py @@ -0,0 +1,376 @@ +import binascii +import json +import logging +#import urlparse + +import requests +import srp + +from PyQt4 import QtCore + +from leap.base import constants as baseconstants +from leap.crypto import leapkeyring +from leap.util.web import get_https_domain_and_port + +logger = logging.getLogger(__name__) + +SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5) + +""" +Registration and authentication classes for the +SRP auth mechanism used in the leap platform. + +We're using the srp library which uses a c-based implementation +of the protocol if the c extension is available, and a python-based +one if not. +""" + + +class ImproperlyConfigured(Exception): + """ + """ + + +class SRPAuthenticationError(Exception): + """ + exception raised + for authentication errors + """ + + +def null_check(value, value_name): + try: + assert value is not None + except AssertionError: + raise ImproperlyConfigured( + "%s parameter cannot be None" % value_name) + + +safe_unhexlify = lambda x: binascii.unhexlify(x) \ + if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) + + +class LeapSRPRegister(object): + + def __init__(self, + schema="https", + provider=None, + port=None, + verify=True, + register_path="1/users.json", + method="POST", + fetcher=requests, + srp=srp, + hashfun=srp.SHA256, + ng_constant=srp.NG_1024): + + null_check(provider, provider) + + self.schema = schema + + # XXX FIXME + self.provider = provider + self.port = port + # XXX splitting server,port + # deprecate port call. + domain, port = get_https_domain_and_port(provider) + self.provider = domain + self.port = port + + self.verify = verify + self.register_path = register_path + self.method = method + self.fetcher = fetcher + self.srp = srp + self.HASHFUN = hashfun + self.NG = ng_constant + + self.init_session() + + def init_session(self): + self.session = self.fetcher.session() + + def get_registration_uri(self): + # XXX assert is https! + # use urlparse + if self.port: + uri = "%s://%s:%s/%s" % ( + self.schema, + self.provider, + self.port, + self.register_path) + else: + uri = "%s://%s/%s" % ( + self.schema, + self.provider, + self.register_path) + + return uri + + def register_user(self, username, password, keep=False): + """ + @rtype: tuple + @rparam: (ok, request) + """ + salt, vkey = self.srp.create_salted_verification_key( + username, + password, + self.HASHFUN, + self.NG) + + user_data = { + 'user[login]': username, + 'user[password_verifier]': binascii.hexlify(vkey), + 'user[password_salt]': binascii.hexlify(salt)} + + uri = self.get_registration_uri() + logger.debug('post to uri: %s' % uri) + + # XXX get self.method + req = self.session.post( + uri, data=user_data, + timeout=SIGNUP_TIMEOUT, + verify=self.verify) + logger.debug(req) + logger.debug('user_data: %s', user_data) + #logger.debug('response: %s', req.text) + # we catch it in the form + #req.raise_for_status() + return (req.ok, req) + + +class SRPAuth(requests.auth.AuthBase): + + def __init__(self, username, password, server=None, verify=None): + # sanity check + null_check(server, 'server') + self.username = username + self.password = password + self.server = server + self.verify = verify + + self.init_data = None + self.session = requests.session() + + self.init_srp() + + def get_json_data(self, response): + return json.loads(response.content) + + def init_srp(self): + usr = srp.User( + self.username, + self.password, + srp.SHA256, + srp.NG_1024) + uname, A = usr.start_authentication() + + self.srp_usr = usr + self.A = A + + def get_auth_data(self): + return { + 'login': self.username, + 'A': binascii.hexlify(self.A) + } + + def get_init_data(self): + try: + init_session = self.session.post( + self.server + '/1/sessions.json/', + data=self.get_auth_data(), + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (salt).") + if init_session.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (salt).") + + # XXX should get auth_result.json instead + self.init_data = self.get_json_data(init_session) + return self.init_data + + def get_server_proof_data(self): + try: + auth_result = self.session.put( + #self.server + '/1/sessions.json/' + self.username, + self.server + '/1/sessions/' + self.username, + data={'client_auth': binascii.hexlify(self.M)}, + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (HAMK).") + + if auth_result.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (HAMK).") + + # XXX should get auth_result.json instead + try: + self.auth_data = self.get_json_data(auth_result) + except ValueError: + raise SRPAuthenticationError( + "No valid data sent (HAMK)") + + return self.auth_data + + def authenticate(self): + logger.debug('start authentication...') + + init_data = self.get_init_data() + salt = init_data.get('salt', None) + B = init_data.get('B', None) + + # XXX refactor this function + # move checks and un-hex + # to routines + + if not salt or not B: + raise SRPAuthenticationError( + "Server did not send initial data.") + + try: + unhex_salt = safe_unhexlify(salt) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (salt)") + try: + unhex_B = safe_unhexlify(B) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (B)") + + self.M = self.srp_usr.process_challenge( + unhex_salt, + unhex_B + ) + + proof_data = self.get_server_proof_data() + + HAMK = proof_data.get("M2", None) + if not HAMK: + errors = proof_data.get('errors', None) + if errors: + logger.error(errors) + raise SRPAuthenticationError("Server did not send HAMK.") + + try: + unhex_HAMK = safe_unhexlify(HAMK) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (HAMK)") + + self.srp_usr.verify_session( + unhex_HAMK) + + try: + assert self.srp_usr.authenticated() + logger.debug('user is authenticated!') + except (AssertionError): + raise SRPAuthenticationError( + "Auth verification failed.") + + def __call__(self, req): + self.authenticate() + req.session = self.session + return req + + +def srpauth_protected(user=None, passwd=None, server=None, verify=True): + """ + decorator factory that accepts + user and password keyword arguments + and add those to the decorated request + """ + def srpauth(fn): + def wrapper(*args, **kwargs): + if user and passwd: + auth = SRPAuth(user, passwd, server, verify) + kwargs['auth'] = auth + kwargs['verify'] = verify + return fn(*args, **kwargs) + return wrapper + return srpauth + + +def get_leap_credentials(): + settings = QtCore.QSettings() + full_username = settings.value('eip_username') + username, domain = full_username.split('@') + seed = settings.value('%s_seed' % domain, None) + password = leapkeyring.leap_get_password(full_username, seed=seed) + return (username, password) + + +# XXX TODO +# Pass verify as single argument, +# in srpauth_protected style + +def magick_srpauth(fn): + """ + decorator that gets user and password + from the config file and adds those to + the decorated request + """ + logger.debug('magick srp auth decorator called') + + def wrapper(*args, **kwargs): + #uri = args[0] + # XXX Ugh! + # Problem with this approach. + # This won't work when we're using + # api.foo.bar + # Unless we keep a table with the + # equivalencies... + user, passwd = get_leap_credentials() + + # XXX pass verify and server too + # (pop) + auth = SRPAuth(user, passwd) + kwargs['auth'] = auth + return fn(*args, **kwargs) + return wrapper + + +if __name__ == "__main__": + """ + To test against test_provider (twisted version) + Register an user: (will be valid during the session) + >>> python auth.py add test password + + Test login with that user: + >>> python auth.py login test password + """ + + import sys + + if len(sys.argv) not in (4, 5): + print 'Usage: auth <add|login> <user> <pass> [server]' + sys.exit(0) + + action = sys.argv[1] + user = sys.argv[2] + passwd = sys.argv[3] + + if len(sys.argv) == 5: + SERVER = sys.argv[4] + else: + SERVER = "https://localhost:8443" + + if action == "login": + + @srpauth_protected( + user=user, passwd=passwd, server=SERVER, verify=False) + def test_srp_protected_get(*args, **kwargs): + req = requests.get(*args, **kwargs) + req.raise_for_status + return req + + req = test_srp_protected_get('https://localhost:8443/1/cert') + print 'cert :', req.content[:200] + "..." + sys.exit(0) + + if action == "add": + auth = LeapSRPRegister(provider=SERVER, verify=False) + auth.register_user(user, passwd) diff --git a/src/leap/base/authentication.py b/src/leap/base/authentication.py new file mode 100644 index 00000000..09ff1d07 --- /dev/null +++ b/src/leap/base/authentication.py @@ -0,0 +1,11 @@ +""" +Authentication Base Class +""" + + +class Authentication(object): + """ + I have no idea how Authentication (certs,?) + will be done, but stub it here. + """ + pass diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py new file mode 100644 index 00000000..23446f4a --- /dev/null +++ b/src/leap/base/checks.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +import logging +import platform +import socket + +import netifaces +import ping +import requests + +from leap.base import constants +from leap.base import exceptions + +logger = logging.getLogger(name=__name__) + + +class LeapNetworkChecker(object): + """ + all network related checks + """ + def __init__(self, *args, **kwargs): + provider_gw = kwargs.pop('provider_gw', None) + self.provider_gateway = provider_gw + + def run_all(self, checker=None): + if not checker: + checker = self + #self.error = None # ? + + # for MVS + checker.check_tunnel_default_interface() + checker.check_internet_connection() + checker.is_internet_up() + + if self.provider_gateway: + checker.ping_gateway(self.provider_gateway) + + def check_internet_connection(self): + try: + # XXX remove this hardcoded random ip + # ping leap.se or eip provider instead...? + requests.get('http://216.172.161.165') + + except (requests.HTTPError, requests.RequestException) as e: + raise exceptions.NoInternetConnection(e.message) + except requests.ConnectionError as e: + error = "Unidentified Connection Error" + if e.message == "[Errno 113] No route to host": + if not self.is_internet_up(): + error = "No valid internet connection found." + else: + error = "Provider server appears to be down." + logger.error(error) + raise exceptions.NoInternetConnection(error) + logger.debug('Network appears to be up.') + + def is_internet_up(self): + iface, gateway = self.get_default_interface_gateway() + self.ping_gateway(self.provider_gateway) + + def check_tunnel_default_interface(self): + """ + Raises an TunnelNotDefaultRouteError + (including when no routes are present) + """ + if not platform.system() == "Linux": + raise NotImplementedError + + f = open("/proc/net/route") + route_table = f.readlines() + f.close() + #toss out header + route_table.pop(0) + + if not route_table: + raise exceptions.TunnelNotDefaultRouteError() + + line = route_table.pop(0) + iface, destination = line.split('\t')[0:2] + if not destination == '00000000' or not iface == 'tun0': + raise exceptions.TunnelNotDefaultRouteError() + + def get_default_interface_gateway(self): + """only impletemented for linux so far.""" + if not platform.system() == "Linux": + raise NotImplementedError + + # XXX use psutil + f = open("/proc/net/route") + route_table = f.readlines() + f.close() + #toss out header + route_table.pop(0) + + default_iface = None + gateway = None + while route_table: + line = route_table.pop(0) + iface, destination, gateway = line.split('\t')[0:3] + if destination == '00000000': + default_iface = iface + break + + if not default_iface: + raise exceptions.NoDefaultInterfaceFoundError + + if default_iface not in netifaces.interfaces(): + raise exceptions.InterfaceNotFoundError + + return default_iface, gateway + + def ping_gateway(self, gateway): + # TODO: Discuss how much packet loss (%) is acceptable. + + # XXX -- validate gateway + # -- is it a valid ip? (there's something in util) + # -- is it a domain? + # -- can we resolve? -- raise NoDNSError if not. + packet_loss = ping.quiet_ping(gateway)[0] + if packet_loss > constants.MAX_ICMP_PACKET_LOSS: + raise exceptions.NoConnectionToGateway + + def check_name_resolution(self, domain_name): + try: + socket.gethostbyname(domain_name) + return True + except socket.gaierror: + raise exceptions.CannotResolveDomainError diff --git a/src/leap/base/config.py b/src/leap/base/config.py new file mode 100644 index 00000000..0255fbab --- /dev/null +++ b/src/leap/base/config.py @@ -0,0 +1,279 @@ +""" +Configuration Base Class +""" +import grp +import json +import logging +import socket +import tempfile +import os + +logger = logging.getLogger(name=__name__) + +import requests + +from leap.base import exceptions +from leap.base import constants +from leap.base.pluggableconfig import PluggableConfig +from leap.util.fileutil import (mkdir_p) + +# move to base! +from leap.eip import exceptions as eipexceptions + + +class BaseLeapConfig(object): + slug = None + + # XXX we have to enforce that every derived class + # has a slug (via interface) + # get property getter that raises NI.. + + def save(self): + raise NotImplementedError("abstract base class") + + def load(self): + raise NotImplementedError("abstract base class") + + def get_config(self, *kwargs): + raise NotImplementedError("abstract base class") + + @property + def config(self): + return self.get_config() + + def get_value(self, *kwargs): + raise NotImplementedError("abstract base class") + + +class MetaConfigWithSpec(type): + """ + metaclass for JSONLeapConfig classes. + It creates a configuration spec out of + the `spec` dictionary. The `properties` attribute + of the spec dict is turn into the `schema` attribute + of the new class (which will be used to validate against). + """ + # XXX in the near future, this is the + # place where we want to enforce + # singletons, read-only and similar stuff. + + def __new__(meta, classname, bases, classDict): + schema_obj = classDict.get('spec', None) + + # not quite happy with this workaround. + # I want to raise if missing spec dict, but only + # for grand-children of this metaclass. + # maybe should use abc module for this. + abcderived = ("JSONLeapConfig",) + if schema_obj is None and classname not in abcderived: + raise exceptions.ImproperlyConfigured( + "missing spec dict on your derived class (%s)" % classname) + + # we create a configuration spec attribute + # from the spec dict + config_class = type( + classname + "Spec", + (PluggableConfig, object), + {'options': schema_obj}) + classDict['spec'] = config_class + + return type.__new__(meta, classname, bases, classDict) + +########################################################## +# some hacking still in progress: + +# Configs have: + +# - a slug (from where a filename/folder is derived) +# - a spec (for validation and defaults). +# this spec is conformant to the json-schema. +# basically a dict that will be used +# for type casting and validation, and defaults settings. + +# all config objects, since they are derived from BaseConfig, implement basic +# useful methods: +# - save +# - load + +########################################################## + + +class JSONLeapConfig(BaseLeapConfig): + + __metaclass__ = MetaConfigWithSpec + + def __init__(self, *args, **kwargs): + # sanity check + try: + assert self.slug is not None + except AssertionError: + raise exceptions.ImproperlyConfigured( + "missing slug on JSONLeapConfig" + " derived class") + try: + assert self.spec is not None + except AssertionError: + raise exceptions.ImproperlyConfigured( + "missing spec on JSONLeapConfig" + " derived class") + assert issubclass(self.spec, PluggableConfig) + + self.domain = kwargs.pop('domain', None) + self._config = self.spec(format="json") + self._config.load() + self.fetcher = kwargs.pop('fetcher', requests) + + # mandatory baseconfig interface + + def save(self, to=None): + if to is None: + to = self.filename + folder, filename = os.path.split(to) + if folder and not os.path.isdir(folder): + mkdir_p(folder) + self._config.serialize(to) + + def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False): + if from_uri is not None: + fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify) + if fetched: + return + if fromfile is None: + fromfile = self.filename + if os.path.isfile(fromfile): + self._config.load(fromfile=fromfile) + else: + logger.error('tried to load config from non-existent path') + logger.error('Not Found: %s', fromfile) + + def fetch(self, uri, fetcher=None, verify=True): + if not fetcher: + fetcher = self.fetcher + logger.debug('verify: %s', verify) + logger.debug('uri: %s', uri) + request = fetcher.get(uri, verify=verify) + # XXX should send a if-modified-since header + + # XXX get 404, ... + # and raise a UnableToFetch... + request.raise_for_status() + fd, fname = tempfile.mkstemp(suffix=".json") + + if request.json: + self._config.load(json.dumps(request.json)) + + else: + # not request.json + # might be server did not announce content properly, + # let's try deserializing all the same. + try: + self._config.load(request.content) + except ValueError: + raise eipexceptions.LeapBadConfigFetchedError + + return True + + def get_config(self): + return self._config.config + + # public methods + + def get_filename(self): + return self._slug_to_filename() + + @property + def filename(self): + return self.get_filename() + + def validate(self, data): + logger.debug('validating schema') + self._config.validate(data) + return True + + # private + + def _slug_to_filename(self): + # is this going to work in winland if slug is "foo/bar" ? + folder, filename = os.path.split(self.slug) + config_file = get_config_file(filename, folder) + return config_file + + def exists(self): + return os.path.isfile(self.filename) + + +# +# utility functions +# +# (might be moved to some class as we see fit, but +# let's remain functional for a while) +# maybe base.config.util ?? +# + + +def get_config_dir(): + """ + get the base dir for all leap config + @rparam: config path + @rtype: string + """ + # TODO + # check for $XDG_CONFIG_HOME var? + # get a more sensible path for win/mac + # kclair: opinion? ^^ + + return os.path.expanduser( + os.path.join('~', + '.config', + 'leap')) + + +def get_config_file(filename, folder=None): + """ + concatenates the given filename + with leap config dir. + @param filename: name of the file + @type filename: string + @rparam: full path to config file + """ + path = [] + path.append(get_config_dir()) + if folder is not None: + path.append(folder) + path.append(filename) + return os.path.join(*path) + + +def get_default_provider_path(): + default_subpath = os.path.join("providers", + constants.DEFAULT_PROVIDER) + default_provider_path = get_config_file( + '', + folder=default_subpath) + return default_provider_path + + +def get_provider_path(domain): + # XXX if not domain, return get_default_provider_path + default_subpath = os.path.join("providers", domain) + provider_path = get_config_file( + '', + folder=default_subpath) + return provider_path + + +def validate_ip(ip_str): + """ + raises exception if the ip_str is + not a valid representation of an ip + """ + socket.inet_aton(ip_str) + + +def get_username(): + return os.getlogin() + + +def get_groupname(): + gid = os.getgroups()[-1] + return grp.getgrgid(gid).gr_name diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py new file mode 100644 index 00000000..41d13935 --- /dev/null +++ b/src/leap/base/connection.py @@ -0,0 +1,115 @@ +""" +Base Connection Classs +""" +from __future__ import (division, unicode_literals, print_function) + +import logging + +from leap.base.authentication import Authentication + +logger = logging.getLogger(name=__name__) + + +class Connection(Authentication): + # JSONLeapConfig + #spec = {} + + def __init__(self, *args, **kwargs): + self.connection_state = None + self.desired_connection_state = None + #XXX FIXME diamond inheritance gotcha.. + #If you inherit from >1 class, + #super is only initializing one + #of the bases..!! + # I think we better pass config as a constructor + # parameter -- kali 2012-08-30 04:33 + super(Connection, self).__init__(*args, **kwargs) + + def connect(self): + """ + entry point for connection process + """ + pass + + def disconnect(self): + """ + disconnects client + """ + pass + + #def shutdown(self): + #""" + #shutdown and quit + #""" + #self.desired_con_state = self.status.DISCONNECTED + + def connection_state(self): + """ + returns the current connection state + """ + return self.status.current + + def desired_connection_state(self): + """ + returns the desired_connection state + """ + return self.desired_connection_state + + def get_icon_name(self): + """ + get icon name from status object + """ + return self.status.get_state_icon() + + # + # private methods + # + + def _disconnect(self): + """ + private method for disconnecting + """ + if self.subp is not None: + self.subp.terminate() + self.subp = None + # XXX signal state changes! :) + + def _is_alive(self): + """ + don't know yet + """ + pass + + def _connect(self): + """ + entry point for connection cascade methods. + """ + #conn_result = ConState.DISCONNECTED + try: + conn_result = self._try_connection() + except UnrecoverableError as except_msg: + logger.error("FATAL: %s" % unicode(except_msg)) + conn_result = self.status.UNRECOVERABLE + except Exception as except_msg: + self.error_queue.append(except_msg) + logger.error("Failed Connection: %s" % + unicode(except_msg)) + return conn_result + + +class ConnectionError(Exception): + """ + generic connection error + """ + def __str__(self): + if len(self.args) >= 1: + return repr(self.args[0]) + else: + raise self() + + +class UnrecoverableError(ConnectionError): + """ + we cannot do anything about it, sorry + """ + pass diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py new file mode 100644 index 00000000..f7be8d98 --- /dev/null +++ b/src/leap/base/constants.py @@ -0,0 +1,32 @@ +"""constants to be used in base module""" +from leap import __branding +APP_NAME = __branding.get("short_name", "leap") + +# default provider placeholder +# using `example.org` we make sure that this +# is not going to be resolved during the tests phases +# (we expect testers to add it to their /etc/hosts + +DEFAULT_PROVIDER = __branding.get( + "provider_domain", + "testprovider.example.org") + +DEFINITION_EXPECTED_PATH = "provider.json" + +DEFAULT_PROVIDER_DEFINITION = { + u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER, + u'api_version': u'0.1.0', + u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e', + u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER, + u'description': {u'en': u'This is a test provider'}, + u'display_name': {u'en': u'Test Provider'}, + u'domain': u'%s' % DEFAULT_PROVIDER, + u'enrollment_policy': u'open', + u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21', + u'serial': 1, + u'services': [u'eip'], + u'version': u'0.1.0'} + +MAX_ICMP_PACKET_LOSS = 10 + +ROUTE_CHECK_INTERVAL = 10 diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py new file mode 100644 index 00000000..227da953 --- /dev/null +++ b/src/leap/base/exceptions.py @@ -0,0 +1,77 @@ +""" +Exception attributes and their meaning/uses +------------------------------------------- + +* critical: if True, will abort execution prematurely, + after attempting any cleaning + action. + +* failfirst: breaks any error_check loop that is examining + the error queue. + +* message: the message that will be used in the __repr__ of the exception. + +* usermessage: the message that will be passed to user in ErrorDialogs + in Qt-land. +""" + + +class LeapException(Exception): + """ + base LeapClient exception + sets some parameters that we will check + during error checking routines + """ + critical = False + failfirst = False + warning = False + + +class CriticalError(LeapException): + """ + we cannot do anything about it + """ + critical = True + failfirst = True + + +# In use ??? +# don't thing so. purge if not... + +class MissingConfigFileError(Exception): + pass + + +class ImproperlyConfigured(Exception): + pass + + +class NoDefaultInterfaceFoundError(LeapException): + message = "no default interface found" + usermessage = "Looks like your computer is not connected to the internet" + + +class InterfaceNotFoundError(LeapException): + # XXX should take iface arg on init maybe? + message = "interface not found" + + +class NoConnectionToGateway(CriticalError): + message = "no connection to gateway" + usermessage = "Looks like there are problems with your internet connection" + + +class NoInternetConnection(CriticalError): + message = "No Internet connection found" + usermessage = "It looks like there is no internet connection." + # and now we try to connect to our web to troubleshoot LOL :P + + +class CannotResolveDomainError(LeapException): + message = "Cannot resolve domain" + usermessage = "Domain cannot be found" + + +class TunnelNotDefaultRouteError(CriticalError): + message = "Tunnel connection dissapeared. VPN down?" + usermessage = "The Encrypted Connection was lost. Shutting down..." diff --git a/src/leap/base/network.py b/src/leap/base/network.py new file mode 100644 index 00000000..3aba3f61 --- /dev/null +++ b/src/leap/base/network.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +from __future__ import (print_function) +import logging +import threading + +from leap.eip.config import get_eip_gateway +from leap.base.checks import LeapNetworkChecker +from leap.base.constants import ROUTE_CHECK_INTERVAL +from leap.base.exceptions import TunnelNotDefaultRouteError +from leap.util.coroutines import (launch_thread, process_events) + +from time import sleep + +logger = logging.getLogger(name=__name__) + + +class NetworkCheckerThread(object): + """ + Manages network checking thread that makes sure we have a working network + connection. + """ + def __init__(self, *args, **kwargs): + self.status_signals = kwargs.pop('status_signals', None) + #self.watcher_cb = kwargs.pop('status_signals', None) + self.error_cb = kwargs.pop( + 'error_cb', + lambda exc: logger.error("%s", exc.message)) + self.shutdown = threading.Event() + + # XXX get provider_gateway and pass it to checker + # see in eip.config for function + # #718 + self.checker = LeapNetworkChecker( + provider_gw=get_eip_gateway()) + + def start(self): + self.process_handle = self._launch_recurrent_network_checks( + (self.error_cb,)) + + def stop(self): + self.shutdown.set() + logger.debug("network checked stopped.") + + def run_checks(self): + pass + + #private methods + + #here all the observers in fail_callbacks expect one positional argument, + #which is exception so we can try by passing a lambda with logger to + #check it works. + def _network_checks_thread(self, fail_callbacks): + #TODO: replace this with waiting for a signal from openvpn + while True: + try: + self.checker.check_tunnel_default_interface() + break + except TunnelNotDefaultRouteError: + # XXX ??? why do we sleep here??? + # aa: If the openvpn isn't up and running yet, + # let's give it a moment to breath. + sleep(1) + + fail_observer_dict = dict((( + observer, + process_events(observer)) for observer in fail_callbacks)) + while not self.shutdown.is_set(): + try: + self.checker.check_tunnel_default_interface() + self.checker.check_internet_connection() + sleep(ROUTE_CHECK_INTERVAL) + except Exception as exc: + for obs in fail_observer_dict: + fail_observer_dict[obs].send(exc) + sleep(ROUTE_CHECK_INTERVAL) + #reset event + self.shutdown.clear() + + def _launch_recurrent_network_checks(self, fail_callbacks): + #we need to wrap the fail callback in a tuple + watcher = launch_thread( + self._network_checks_thread, + (fail_callbacks,)) + return watcher diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py new file mode 100644 index 00000000..b8615ad8 --- /dev/null +++ b/src/leap/base/pluggableconfig.py @@ -0,0 +1,421 @@ +""" +generic configuration handlers +""" +import copy +import json +import logging +import os +import time +import urlparse + +import jsonschema + +logger = logging.getLogger(__name__) + + +__all__ = ['PluggableConfig', + 'adaptors', + 'types', + 'UnknownOptionException', + 'MissingValueException', + 'ConfigurationProviderException', + 'TypeCastException'] + +# exceptions + + +class UnknownOptionException(Exception): + """exception raised when a non-configuration + value is present in the configuration""" + + +class MissingValueException(Exception): + """exception raised when a required value is missing""" + + +class ConfigurationProviderException(Exception): + """exception raised when a configuration provider is missing, etc""" + + +class TypeCastException(Exception): + """exception raised when a + configuration item cannot be coerced to a type""" + + +class ConfigAdaptor(object): + """ + abstract base class for config adaotors for + serialization/deserialization and custom validation + and type casting. + """ + def read(self, filename): + raise NotImplementedError("abstract base class") + + def write(self, config, filename): + with open(filename, 'w') as f: + self._write(f, config) + + def _write(self, fp, config): + raise NotImplementedError("abstract base class") + + def validate(self, config, schema): + raise NotImplementedError("abstract base class") + + +adaptors = {} + + +class JSONSchemaEncoder(json.JSONEncoder): + """ + custom default encoder that + casts python objects to json objects for + the schema validation + """ + def default(self, obj): + if obj is str: + return 'string' + if obj is unicode: + return 'string' + if obj is int: + return 'integer' + if obj is list: + return 'array' + if obj is dict: + return 'object' + if obj is bool: + return 'boolean' + + +class JSONAdaptor(ConfigAdaptor): + indent = 2 + extensions = ['json'] + + def read(self, _from): + if isinstance(_from, file): + _from_string = _from.read() + if isinstance(_from, str): + _from_string = _from + return json.loads(_from_string) + + def _write(self, fp, config): + fp.write(json.dumps(config, + indent=self.indent, + sort_keys=True)) + + def validate(self, config, schema_obj): + schema_json = JSONSchemaEncoder().encode(schema_obj) + schema = json.loads(schema_json) + jsonschema.validate(config, schema) + + +adaptors['json'] = JSONAdaptor() + +# +# Adaptors +# +# Allow to apply a predefined set of types to the +# specs, so it checks the validity of formats and cast it +# to proper python types. + +# TODO: +# - multilingual object. +# - HTTPS uri + + +class DateType(object): + fmt = '%Y-%m-%d' + + def to_python(self, data): + return time.strptime(data, self.fmt) + + def get_prep_value(self, data): + return time.strftime(self.fmt, data) + + +class URIType(object): + + def to_python(self, data): + parsed = urlparse.urlparse(data) + if not parsed.scheme: + raise TypeCastException("uri %s has no schema" % data) + return parsed + + def get_prep_value(self, data): + return data.geturl() + + +class HTTPSURIType(object): + + def to_python(self, data): + parsed = urlparse.urlparse(data) + if not parsed.scheme: + raise TypeCastException("uri %s has no schema" % data) + if parsed.scheme != "https": + raise TypeCastException( + "uri %s does not has " + "https schema" % data) + return parsed + + def get_prep_value(self, data): + return data.geturl() + + +types = { + 'date': DateType(), + 'uri': URIType(), + 'https-uri': HTTPSURIType(), +} + + +class PluggableConfig(object): + + options = {} + + def __init__(self, + adaptors=adaptors, + types=types, + format=None): + + self.config = {} + self.adaptors = adaptors + self.types = types + self._format = format + + @property + def option_dict(self): + if hasattr(self, 'options') and isinstance(self.options, dict): + return self.options.get('properties', None) + + def items(self): + """ + act like an iterator + """ + if isinstance(self.option_dict, dict): + return self.option_dict.items() + return self.options + + def validate(self, config, format=None): + """ + validate config + """ + schema = self.options + if format is None: + format = self._format + + if format: + adaptor = self.get_adaptor(self._format) + adaptor.validate(config, schema) + else: + # we really should make format mandatory... + logger.error('no format passed to validate') + + # first round of validation is ok. + # now we proceed to cast types if any specified. + self.to_python(config) + + def to_python(self, config): + """ + cast types following first type and then format indications. + """ + unseen_options = [i for i in config if i not in self.option_dict] + if unseen_options: + raise UnknownOptionException( + "Unknown options: %s" % ', '.join(unseen_options)) + + for key, value in config.items(): + _type = self.option_dict[key].get('type') + if _type is None and 'default' in self.option_dict[key]: + _type = type(self.option_dict[key]['default']) + if _type is not None: + tocast = True + if not callable(_type) and isinstance(value, _type): + tocast = False + if tocast: + try: + config[key] = _type(value) + except BaseException, e: + raise TypeCastException( + "Could not coerce %s, %s, " + "to type %s: %s" % (key, value, _type.__name__, e)) + _format = self.option_dict[key].get('format', None) + _ftype = self.types.get(_format, None) + if _ftype: + try: + config[key] = _ftype.to_python(value) + except BaseException, e: + raise TypeCastException( + "Could not coerce %s, %s, " + "to format %s: %s" % (key, value, + _ftype.__class__.__name__, + e)) + + return config + + def prep_value(self, config): + """ + the inverse of to_python method, + called just before serialization + """ + for key, value in config.items(): + _format = self.option_dict[key].get('format', None) + _ftype = self.types.get(_format, None) + if _ftype and hasattr(_ftype, 'get_prep_value'): + try: + config[key] = _ftype.get_prep_value(value) + except BaseException, e: + raise TypeCastException( + "Could not serialize %s, %s, " + "by format %s: %s" % (key, value, + _ftype.__class__.__name__, + e)) + else: + config[key] = value + return config + + # methods for adding configuration + + def get_default_values(self): + """ + return a config options from configuration defaults + """ + defaults = {} + for key, value in self.items(): + if 'default' in value: + defaults[key] = value['default'] + return copy.deepcopy(defaults) + + def get_adaptor(self, format): + """ + get specified format adaptor or + guess for a given filename + """ + adaptor = self.adaptors.get(format, None) + if adaptor: + return adaptor + + # not registered in adaptors dict, let's try all + for adaptor in self.adaptors.values(): + if format in adaptor.extensions: + return adaptor + + def filename2format(self, filename): + extension = os.path.splitext(filename)[-1] + return extension.lstrip('.') or None + + def serialize(self, filename, format=None, full=False): + if not format: + format = self._format + if not format: + format = self.filename2format(filename) + if not format: + raise Exception('Please specify a format') + # TODO: more specific exception type + + adaptor = self.get_adaptor(format) + if not adaptor: + raise Exception("Adaptor not found for format: %s" % format) + + config = copy.deepcopy(self.config) + serializable = self.prep_value(config) + adaptor.write(serializable, filename) + + def deserialize(self, string=None, fromfile=None, format=None): + """ + load configuration from a file or string + """ + + def _try_deserialize(): + if fromfile: + with open(fromfile, 'r') as f: + content = adaptor.read(f) + elif string: + content = adaptor.read(string) + return content + + # XXX cleanup this! + + if fromfile: + assert os.path.exists(fromfile) + if not format: + format = self.filename2format(fromfile) + + if not format: + format = self._format + if format: + adaptor = self.get_adaptor(format) + else: + adaptor = None + + if adaptor: + content = _try_deserialize() + return content + + # no adaptor, let's try rest of adaptors + + adaptors = self.adaptors[:] + + if format: + adaptors.sort( + key=lambda x: int( + format in x.extensions), + reverse=True) + + for adaptor in adaptors: + content = _try_deserialize() + return content + + def load(self, *args, **kwargs): + """ + load from string or file + if no string of fromfile option is given, + it will attempt to load from defaults + defined in the schema. + """ + string = args[0] if args else None + fromfile = kwargs.get("fromfile", None) + content = None + + # start with defaults, so we can + # have partial values applied. + content = self.get_default_values() + if string and isinstance(string, str): + content = self.deserialize(string) + + if not string and fromfile is not None: + #import ipdb;ipdb.set_trace() + content = self.deserialize(fromfile=fromfile) + + if not content: + logger.error('no content could be loaded') + # XXX raise! + return + + # lazy evaluation until first level of nesting + # to allow lambdas with context-dependant info + # like os.path.expanduser + for k, v in content.iteritems(): + if callable(v): + content[k] = v() + + self.validate(content) + self.config = content + return True + + +def testmain(): + from tests import test_validation as t + import pprint + + config = PluggableConfig(_format="json") + properties = copy.deepcopy(t.sample_spec) + + config.options = properties + config.load(fromfile='data.json') + + print 'config' + pprint.pprint(config.config) + + config.serialize('/tmp/testserial.json') + +if __name__ == "__main__": + testmain() diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py new file mode 100644 index 00000000..d41f3695 --- /dev/null +++ b/src/leap/base/providers.py @@ -0,0 +1,29 @@ +"""all dealing with leap-providers: definition files, updating""" +from leap.base import config as baseconfig +from leap.base import specs + + +class LeapProviderDefinition(baseconfig.JSONLeapConfig): + spec = specs.leap_provider_spec + + def _get_slug(self): + domain = getattr(self, 'domain', None) + if domain: + path = baseconfig.get_provider_path(domain) + else: + path = baseconfig.get_default_provider_path() + + return baseconfig.get_config_file( + 'provider.json', folder=path) + + def _set_slug(self, *args, **kwargs): + raise AttributeError("you cannot set slug") + + slug = property(_get_slug, _set_slug) + + +class LeapProviderSet(object): + # we gather them from the filesystem + # TODO: (MVS+) + def __init__(self): + self.count = 0 diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py new file mode 100644 index 00000000..b4bb8dcf --- /dev/null +++ b/src/leap/base/specs.py @@ -0,0 +1,59 @@ +leap_provider_spec = { + 'description': 'provider definition', + 'type': 'object', + 'properties': { + 'serial': { + 'type': int, + 'default': 1, + 'required': True, + }, + 'version': { + 'type': unicode, + 'default': '0.1.0' + #'required': True + }, + 'domain': { + 'type': unicode, # XXX define uri type + 'default': 'testprovider.example.org' + #'required': True, + }, + 'display_name': { + 'type': dict, # XXX multilingual object? + 'default': {u'en': u'Test Provider'} + #'required': True + }, + 'description': { + 'type': dict, + 'default': {u'en': u'Test provider'} + }, + 'enrollment_policy': { + 'type': unicode, # oneof ?? + 'default': 'open' + }, + 'services': { + 'type': list, # oneof ?? + 'default': ['eip'] + }, + 'api_version': { + 'type': unicode, + 'default': '0.1.0' # version regexp + }, + 'api_uri': { + 'type': unicode # uri + }, + 'public_key': { + 'type': unicode # fingerprint + }, + 'ca_cert_fingerprint': { + 'type': unicode, + }, + 'ca_cert_uri': { + 'type': unicode, + 'format': 'https-uri' + }, + 'languages': { + 'type': list, + 'default': ['en'] + } + } +} diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/base/tests/__init__.py diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py new file mode 100644 index 00000000..8d573b1e --- /dev/null +++ b/src/leap/base/tests/test_checks.py @@ -0,0 +1,124 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +from mock import (patch, Mock) +from StringIO import StringIO + +import ping +import requests + +from leap.base import checks +from leap.base import exceptions +from leap.testing.basetest import BaseLeapTest + +_uid = os.getuid() + + +class LeapNetworkCheckTest(BaseLeapTest): + __name__ = "leap_network_check_tests" + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_checker_should_implement_check_methods(self): + checker = checks.LeapNetworkChecker() + + self.assertTrue(hasattr(checker, "check_internet_connection"), + "missing meth") + self.assertTrue(hasattr(checker, "check_tunnel_default_interface"), + "missing meth") + self.assertTrue(hasattr(checker, "is_internet_up"), + "missing meth") + self.assertTrue(hasattr(checker, "ping_gateway"), + "missing meth") + + def test_checker_should_actually_call_all_tests(self): + checker = checks.LeapNetworkChecker() + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + + # ping gateway only called if we pass provider_gw + checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0") + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.ping_gateway.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + + def test_get_default_interface_no_interface(self): + checker = checks.LeapNetworkChecker() + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.NoDefaultInterfaceFoundError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT") + checker.get_default_interface_gateway() + + def test_check_tunnel_default_interface(self): + checker = checks.LeapNetworkChecker() + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.TunnelNotDefaultRouteError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT") + checker.check_tunnel_default_interface() + + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.TunnelNotDefaultRouteError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT\n" + "wlan0\t00000000\t0102A8C0\t" + "0003\t0\t0\t0\t00000000\t0\t0\t0") + checker.check_tunnel_default_interface() + + with patch('leap.base.checks.open', create=True) as mock_open: + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT\n" + "tun0\t00000000\t01002A0A\t0003\t0\t0\t0\t00000080\t0\t0\t0") + checker.check_tunnel_default_interface() + + def test_ping_gateway_fail(self): + checker = checks.LeapNetworkChecker() + with patch.object(ping, "quiet_ping") as mocked_ping: + with self.assertRaises(exceptions.NoConnectionToGateway): + mocked_ping.return_value = [11, "", ""] + checker.ping_gateway("4.2.2.2") + + def test_check_internet_connection_failures(self): + checker = checks.LeapNetworkChecker() + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.HTTPError + with self.assertRaises(exceptions.NoInternetConnection): + checker.check_internet_connection() + + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.RequestException + with self.assertRaises(exceptions.NoInternetConnection): + checker.check_internet_connection() + + #TODO: Mock possible errors that can be raised by is_internet_up + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.ConnectionError + with self.assertRaises(exceptions.NoInternetConnection): + checker.check_internet_connection() + + @unittest.skipUnless(_uid == 0, "root only") + def test_ping_gateway(self): + checker = checks.LeapNetworkChecker() + checker.ping_gateway("4.2.2.2") diff --git a/src/leap/base/tests/test_config.py b/src/leap/base/tests/test_config.py new file mode 100644 index 00000000..d03149b2 --- /dev/null +++ b/src/leap/base/tests/test_config.py @@ -0,0 +1,247 @@ +import json +import os +import platform +import socket +#import tempfile + +import mock +import requests + +from leap.base import config +from leap.base import constants +from leap.base import exceptions +from leap.eip import constants as eipconstants +from leap.util.fileutil import mkdir_p +from leap.testing.basetest import BaseLeapTest + + +try: + import unittest2 as unittest +except ImportError: + import unittest + +_system = platform.system() + + +class JSONLeapConfigTest(BaseLeapTest): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_metaclass(self): + with self.assertRaises(exceptions.ImproperlyConfigured) as exc: + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + exc.startswith("missing spec dict") + + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + spec = {'properties': {}} + with self.assertRaises(exceptions.ImproperlyConfigured) as exc: + DummyTestConfig() + exc.startswith("missing slug") + + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + spec = {'properties': {}} + slug = "foo" + DummyTestConfig() + +######################################3 +# +# provider fetch tests block +# + + +class ProviderTest(BaseLeapTest): + # override per test fixtures + + def setUp(self): + pass + + def tearDown(self): + pass + + +# XXX depreacated. similar test in eip.checks + +#class BareHomeTestCase(ProviderTest): +# + #__name__ = "provider_config_tests_bare_home" +# + #def test_should_raise_if_missing_eip_json(self): + #with self.assertRaises(exceptions.MissingConfigFileError): + #config.get_config_json(os.path.join(self.home, 'eip.json')) + + +class ProviderDefinitionTestCase(ProviderTest): + # XXX MOVE TO eip.test_checks + # -- kali 2012-08-24 00:38 + + __name__ = "provider_config_tests" + + def setUp(self): + # dump a sample eip file + # XXX Move to Use EIP Spec Instead!!! + # XXX tests to be moved to eip.checks and eip.providers + # XXX can use eipconfig.dump_default_eipconfig + + path = os.path.join(self.home, '.config', 'leap') + mkdir_p(path) + with open(os.path.join(path, 'eip.json'), 'w') as fp: + json.dump(eipconstants.EIP_SAMPLE_JSON, fp) + + +# these tests below should move to +# eip.checks +# config.Configuration has been deprecated + +# TODO: +# - We're instantiating a ProviderTest because we're doing the home wipeoff +# on setUpClass instead of the setUp (for speedup of the general cases). + +# We really should be testing all of them in the same testCase, and +# doing an extra wipe of the tempdir... but be careful!!!! do not mess with +# os.environ home more than needed... that could potentially bite! + +# XXX actually, another thing to fix here is separating tests: +# - test that requests has been called. +# - check deeper for error types/msgs + +# we SHOULD inject requests dep in the constructor +# (so we can pass mock easily). + + +#class ProviderFetchConError(ProviderTest): + #def test_connection_error(self): + #with mock.patch.object(requests, "get") as mock_method: + #mock_method.side_effect = requests.ConnectionError + #cf = config.Configuration() + #self.assertIsInstance(cf.error, str) +# +# +#class ProviderFetchHttpError(ProviderTest): + #def test_file_not_found(self): + #with mock.patch.object(requests, "get") as mock_method: + #mock_method.side_effect = requests.HTTPError + #cf = config.Configuration() + #self.assertIsInstance(cf.error, str) +# +# +#class ProviderFetchInvalidUrl(ProviderTest): + #def test_invalid_url(self): + #cf = config.Configuration("ht") + #self.assertTrue(cf.error) + + +# end provider fetch tests +########################################### + + +class ConfigHelperFunctions(BaseLeapTest): + + __name__ = "config_helper_tests" + + def setUp(self): + pass + + def tearDown(self): + pass + + # tests + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_lin_get_config_file(self): + """ + config file path where expected? (linux) + """ + self.assertEqual( + config.get_config_file( + 'test', folder="foo/bar"), + os.path.expanduser( + '~/.config/leap/foo/bar/test') + ) + + @unittest.skipUnless(_system == "Darwin", "mac only") + def test_mac_get_config_file(self): + """ + config file path where expected? (mac) + """ + self._missing_test_for_plat(do_raise=True) + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_get_config_file(self): + """ + config file path where expected? + """ + self._missing_test_for_plat(do_raise=True) + + # + # XXX hey, I'm raising exceptions here + # on purpose. just wanted to make sure + # that the skip stuff is doing it right. + # If you're working on win/macos tests, + # feel free to remove tests that you see + # are too redundant. + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_lin_get_config_dir(self): + """ + nice config dir? (linux) + """ + self.assertEqual( + config.get_config_dir(), + os.path.expanduser('~/.config/leap')) + + @unittest.skipUnless(_system == "Darwin", "mac only") + def test_mac_get_config_dir(self): + """ + nice config dir? (mac) + """ + self._missing_test_for_plat(do_raise=True) + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_get_config_dir(self): + """ + nice config dir? (win) + """ + self._missing_test_for_plat(do_raise=True) + + # provider paths + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_get_default_provider_path(self): + """ + is default provider path ok? + """ + self.assertEqual( + config.get_default_provider_path(), + os.path.expanduser( + '~/.config/leap/providers/%s/' % + constants.DEFAULT_PROVIDER) + ) + + # validate ip + + def test_validate_ip(self): + """ + check our ip validation + """ + config.validate_ip('3.3.3.3') + with self.assertRaises(socket.error): + config.validate_ip('255.255.255.256') + with self.assertRaises(socket.error): + config.validate_ip('foobar') + + @unittest.skip + def test_validate_domain(self): + """ + code to be written yet + """ + raise NotImplementedError + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py new file mode 100644 index 00000000..15c4ed58 --- /dev/null +++ b/src/leap/base/tests/test_providers.py @@ -0,0 +1,143 @@ +import copy +import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +import jsonschema + +from leap import __branding as BRANDING +from leap.testing.basetest import BaseLeapTest +from leap.base import providers + + +EXPECTED_DEFAULT_CONFIG = { + u"api_version": u"0.1.0", + u"description": {u'en': u"Test provider"}, + u"display_name": {u'en': u"Test Provider"}, + u"domain": u"testprovider.example.org", + u"enrollment_policy": u"open", + u"serial": 1, + u"services": [ + u"eip" + ], + u"languages": [u"en"], + u"version": u"0.1.0" +} + + +class TestLeapProviderDefinition(BaseLeapTest): + def setUp(self): + self.domain = "testprovider.example.org" + self.definition = providers.LeapProviderDefinition( + domain=self.domain) + self.definition.save() + self.definition.load() + self.config = self.definition.config + + def tearDown(self): + if hasattr(self, 'testfile') and os.path.isfile(self.testfile): + os.remove(self.testfile) + + # tests + + # XXX most of these tests can be made more abstract + # and moved to test_baseconfig *triangulate!* + + def test_provider_slug_property(self): + slug = self.definition.slug + self.assertEquals( + slug, + os.path.join( + self.home, + '.config', 'leap', 'providers', + '%s' % self.domain, + 'provider.json')) + with self.assertRaises(AttributeError): + self.definition.slug = 23 + + def test_provider_dump(self): + # check a good provider definition is dumped to disk + self.testfile = self.get_tempfile('test.json') + self.definition.save(to=self.testfile) + deserialized = json.load(open(self.testfile, 'rb')) + self.maxDiff = None + self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) + + def test_provider_dump_to_slug(self): + # same as above, but we test the ability to save to a + # file generated from the slug. + # XXX THIS TEST SHOULD MOVE TO test_baseconfig + self.definition.save() + filename = self.definition.filename + self.assertTrue(os.path.isfile(filename)) + deserialized = json.load(open(filename, 'rb')) + self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) + + def test_provider_load(self): + # check loading provider from disk file + self.testfile = self.get_tempfile('test_load.json') + with open(self.testfile, 'w') as wf: + wf.write(json.dumps(EXPECTED_DEFAULT_CONFIG)) + self.definition.load(fromfile=self.testfile) + self.assertDictEqual(self.config, + EXPECTED_DEFAULT_CONFIG) + + def test_provider_validation(self): + self.definition.validate(self.config) + _config = copy.deepcopy(self.config) + _config['serial'] = 'aaa' + with self.assertRaises(jsonschema.ValidationError): + self.definition.validate(_config) + + @unittest.skip + def test_load_malformed_json_definition(self): + raise NotImplementedError + + @unittest.skip + def test_type_validation(self): + # check various type validation + # type cast + raise NotImplementedError + + +class TestLeapProviderSet(BaseLeapTest): + + def setUp(self): + self.providers = providers.LeapProviderSet() + + def tearDown(self): + pass + ### + + def test_get_zero_count(self): + self.assertEqual(self.providers.count, 0) + + @unittest.skip + def test_count_defined_providers(self): + # check the method used for making + # the list of providers + raise NotImplementedError + + @unittest.skip + def test_get_default_provider(self): + raise NotImplementedError + + @unittest.skip + def test_should_be_at_least_one_provider_after_init(self): + # when we init an empty environment, + # there should be at least one provider, + # that will be a dump of the default provider definition + # somehow a high level test + raise NotImplementedError + + @unittest.skip + def test_get_eip_remote_from_default_provider(self): + # from: default provider + # expect: remote eip domain + raise NotImplementedError + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/base/tests/test_validation.py b/src/leap/base/tests/test_validation.py new file mode 100644 index 00000000..87e99648 --- /dev/null +++ b/src/leap/base/tests/test_validation.py @@ -0,0 +1,92 @@ +import copy +import datetime +#import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +import jsonschema + +from leap.base.config import JSONLeapConfig +from leap.base import pluggableconfig +from leap.testing.basetest import BaseLeapTest + +SAMPLE_CONFIG_DICT = { + 'prop_one': 1, + 'prop_uri': "http://example.org", + 'prop_date': '2012-12-12', +} + +EXPECTED_CONFIG = { + 'prop_one': 1, + 'prop_uri': "http://example.org", + 'prop_date': datetime.datetime(2012, 12, 12) +} + +sample_spec = { + 'description': 'sample schema definition', + 'type': 'object', + 'properties': { + 'prop_one': { + 'type': int, + 'default': 1, + 'required': True + }, + 'prop_uri': { + 'type': str, + 'default': 'http://example.org', + 'required': True, + 'format': 'uri' + }, + 'prop_date': { + 'type': str, + 'default': '2012-12-12', + 'format': 'date' + } + } +} + + +class SampleConfig(JSONLeapConfig): + spec = sample_spec + + @property + def slug(self): + return os.path.expanduser('~/sampleconfig.json') + + +class TestJSONLeapConfigValidation(BaseLeapTest): + def setUp(self): + self.sampleconfig = SampleConfig() + self.sampleconfig.save() + self.sampleconfig.load() + self.config = self.sampleconfig.config + + def tearDown(self): + if hasattr(self, 'testfile') and os.path.isfile(self.testfile): + os.remove(self.testfile) + + # tests + + def test_good_validation(self): + self.sampleconfig.validate(SAMPLE_CONFIG_DICT) + + def test_broken_int(self): + _config = copy.deepcopy(SAMPLE_CONFIG_DICT) + _config['prop_one'] = '1' + with self.assertRaises(jsonschema.ValidationError): + self.sampleconfig.validate(_config) + + def test_format_property(self): + # JsonSchema Validator does not check the format property. + # We should have to extend the Configuration class + blah = copy.deepcopy(SAMPLE_CONFIG_DICT) + blah['prop_uri'] = 'xxx' + with self.assertRaises(pluggableconfig.TypeCastException): + self.sampleconfig.validate(blah) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/baseapp/config.py b/src/leap/baseapp/config.py deleted file mode 100644 index efdb4726..00000000 --- a/src/leap/baseapp/config.py +++ /dev/null @@ -1,40 +0,0 @@ -import ConfigParser -import os - - -def get_config(config_file=None): - """ - temporary method for getting configs, - mainly for early stage development process. - in the future we will get preferences - from the storage api - """ - config = ConfigParser.ConfigParser() - #config.readfp(open('defaults.cfg')) - #XXX does this work on win / mac also??? - conf_path_list = ['eip.cfg', # XXX build a - # proper path with platform-specific places - # XXX make .config/foo - os.path.expanduser('~/.eip.cfg')] - if config_file: - config.readfp(config_file) - else: - config.read(conf_path_list) - return config - - -# XXX wrapper around config? to get default values - -def get_with_defaults(config, section, option): - if config.has_option(section, option): - return config.get(section, option) - else: - # XXX lookup in defaults dict??? - pass - - -def get_vpn_stdout_mockup(): - command = "python" - args = ["-u", "-c", "from eip_client import fakeclient;\ -fakeclient.write_output()"] - return command, args diff --git a/src/leap/baseapp/constants.py b/src/leap/baseapp/constants.py new file mode 100644 index 00000000..e312be21 --- /dev/null +++ b/src/leap/baseapp/constants.py @@ -0,0 +1,6 @@ +# This timer used for polling vpn manager state. + +# XXX what is an optimum polling interval? +# too little will be overkill, too much will +# miss transition states. +TIMER_MILLISECONDS = 250.0 diff --git a/src/leap/baseapp/dialogs.py b/src/leap/baseapp/dialogs.py new file mode 100644 index 00000000..3cb539cf --- /dev/null +++ b/src/leap/baseapp/dialogs.py @@ -0,0 +1,58 @@ +# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4 +import logging + +from PyQt4.QtGui import (QDialog, QFrame, QPushButton, QLabel, QMessageBox) + +logger = logging.getLogger(name=__name__) + + +class ErrorDialog(QDialog): + def __init__(self, parent=None, errtype=None, msg=None, label=None): + super(ErrorDialog, self).__init__(parent) + frameStyle = QFrame.Sunken | QFrame.Panel + self.warningLabel = QLabel() + self.warningLabel.setFrameStyle(frameStyle) + self.warningButton = QPushButton("QMessageBox.&warning()") + + if msg is not None: + self.msg = msg + if label is not None: + self.label = label + if errtype == "critical": + self.criticalMessage(self.msg, self.label) + + def warningMessage(self, msg, label): + msgBox = QMessageBox(QMessageBox.Warning, + "QMessageBox.warning()", msg, + QMessageBox.NoButton, self) + msgBox.addButton("&Ok", QMessageBox.AcceptRole) + if msgBox.exec_() == QMessageBox.AcceptRole: + pass + # do whatever we want to do after + # closing the dialog. we can pass that + # in the constructor + + def criticalMessage(self, msg, label): + msgBox = QMessageBox(QMessageBox.Critical, + "QMessageBox.critical()", msg, + QMessageBox.NoButton, self) + msgBox.addButton("&Ok", QMessageBox.AcceptRole) + msgBox.exec_() + + # It's critical, so we exit. + # We should better emit a signal and connect it + # with the proper shutdownAndQuit method, but + # this suffices for now. + logger.info('Quitting') + import sys + sys.exit() + + def confirmMessage(self, msg, label, action): + msgBox = QMessageBox(QMessageBox.Critical, + "QMessageBox.critical()", msg, + QMessageBox.NoButton, self) + msgBox.addButton("&Ok", QMessageBox.AcceptRole) + msgBox.addButton("&Cancel", QMessageBox.RejectRole) + + if msgBox.exec_() == QMessageBox.AcceptRole: + action() diff --git a/src/leap/baseapp/eip.py b/src/leap/baseapp/eip.py new file mode 100644 index 00000000..54acbc0e --- /dev/null +++ b/src/leap/baseapp/eip.py @@ -0,0 +1,217 @@ +from __future__ import print_function +import logging +import time +#import sys + +from PyQt4 import QtCore + +from leap.baseapp.dialogs import ErrorDialog +from leap.baseapp import constants +from leap.eip import exceptions as eip_exceptions +from leap.eip.eipconnection import EIPConnection + +logger = logging.getLogger(name=__name__) + + +class EIPConductorAppMixin(object): + """ + initializes an instance of EIPConnection, + gathers errors, and passes status-change signals + from Qt land along to the conductor. + Connects the eip connect/disconnect logic + to the switches in the app (buttons/menu items). + """ + + def __init__(self, *args, **kwargs): + opts = kwargs.pop('opts') + config_file = getattr(opts, 'config_file', None) + provider = kwargs.pop('provider') + + self.eip_service_started = False + + # conductor (eip connection) is in charge of all + # vpn-related configuration / monitoring. + # we pass a tuple of signals that will be + # triggered when status changes. + + self.conductor = EIPConnection( + watcher_cb=self.newLogLine.emit, + config_file=config_file, + checker_signals=(self.eipStatusChange.emit, ), + status_signals=(self.openvpnStatusChange.emit, ), + debug=self.debugmode, + ovpn_verbosity=opts.openvpn_verb, + provider=provider) + + self.skip_download = opts.no_provider_checks + self.skip_verify = opts.no_ca_verify + + def run_eip_checks(self): + """ + runs eip checks and + the error checking loop + """ + logger.debug('running EIP CHECKS') + self.conductor.run_checks( + skip_download=self.skip_download, + skip_verify=self.skip_verify) + self.error_check() + + self.start_eipconnection.emit() + + def error_check(self): + """ + consumes the conductor error queue. + pops errors, and acts accordingly (launching user dialogs). + """ + logger.debug('error check') + + errq = self.conductor.error_queue + while errq.qsize() != 0: + logger.debug('%s errors left in conductor queue', errq.qsize()) + # we get exception and original traceback from queue + error, tb = errq.get() + + # redundant log, debugging the loop. + logger.error('%s: %s', error.__class__.__name__, error.message) + + if issubclass(error.__class__, eip_exceptions.EIPClientError): + self.triggerEIPError.emit(error) + + else: + # deprecated form of raising exception. + raise error, None, tb + + if error.failfirst is True: + break + + @QtCore.pyqtSlot(object) + def onEIPError(self, error): + """ + check severity and launches + dialogs informing user about the errors. + in the future we plan to derive errors to + our log viewer. + """ + + if getattr(error, 'usermessage', None): + message = error.usermessage + else: + message = error.message + + # XXX + # check headless = False before + # launching dialog. + # (so Qt tests can assert stuff) + + if error.critical: + logger.critical(error.message) + #critical error (non recoverable), + #we give user some info and quit. + #(critical error dialog will exit app) + ErrorDialog(errtype="critical", + msg=message, + label="critical error") + elif error.warning: + logger.warning(error.message) + + else: + dialog = ErrorDialog() + dialog.warningMessage(message, 'error') + + @QtCore.pyqtSlot() + def statusUpdate(self): + """ + polls status and updates ui with real time + info about transferred bytes / connection state. + right now is triggered by a timer tick + (timer controlled by StatusAwareTrayIcon class) + """ + # TODO I guess it's too expensive to poll + # continously. move to signal events instead. + # (i.e., subscribe to connection status changes + # from openvpn manager) + + if not self.eip_service_started: + # there is a race condition + # going on here. Depending on how long we take + # to init the qt app, the management socket + # is not ready yet. + return + + #if self.conductor.with_errors: + #XXX how to wait on pkexec??? + #something better that this workaround, plz!! + #I removed the pkexec pass authentication at all. + #time.sleep(5) + #logger.debug('timeout') + #logger.error('errors. disconnect') + #self.start_or_stopVPN() # is stop + + state = self.conductor.poll_connection_state() + if not state: + return + + ts, con_status, ok, ip, remote = state + self.set_statusbarMessage(con_status) + self.setIconToolTip() + + ts = time.strftime("%a %b %d %X", ts) + if self.debugmode: + self.updateTS.setText(ts) + self.status_label.setText(con_status) + self.ip_label.setText(ip) + self.remote_label.setText(remote) + + # status i/o + + status = self.conductor.get_status_io() + if status and self.debugmode: + #XXX move this to systray menu indicators + ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) = status + ts = time.strftime("%a %b %d %X", ts) + self.updateTS.setText(ts) + self.tun_read_bytes.setText(tun_read) + self.tun_write_bytes.setText(tun_write) + + @QtCore.pyqtSlot() + def start_or_stopVPN(self): + """ + stub for running child process with vpn + """ + if self.conductor.has_errors(): + logger.debug('not starting vpn; conductor has errors') + + if self.eip_service_started is False: + try: + self.conductor.connect() + + except eip_exceptions.EIPNoCommandError as exc: + self.triggerEIPError.emit(exc) + + except Exception as err: + # raise generic exception (Bad Thing Happened?) + logger.exception(err) + else: + # no errors, so go on. + if self.debugmode: + self.startStopButton.setText('&Disconnect') + self.eip_service_started = True + self.toggleEIPAct() + + # XXX decouple! (timer is init by icons class). + # we could bring Timer Init to this Mixin + # or to its own Mixin. + self.timer.start(constants.TIMER_MILLISECONDS) + self.network_checker.start() + return + + if self.eip_service_started is True: + self.network_checker.stop() + self.conductor.disconnect() + if self.debugmode: + self.startStopButton.setText('&Connect') + self.eip_service_started = False + self.toggleEIPAct() + self.timer.stop() + return diff --git a/src/leap/baseapp/leap_app.py b/src/leap/baseapp/leap_app.py new file mode 100644 index 00000000..4b63dd2f --- /dev/null +++ b/src/leap/baseapp/leap_app.py @@ -0,0 +1,153 @@ +import logging + +import sip +sip.setapi('QVariant', 2) + +from PyQt4 import QtCore +from PyQt4 import QtGui + +from leap.gui import mainwindow_rc + +logger = logging.getLogger(name=__name__) + + +APP_LOGO = ':/images/leap-color-small.png' + + +class MainWindowMixin(object): + """ + create the main window + for leap app + """ + + def __init__(self, *args, **kwargs): + # XXX set initial visibility + # debug = no visible + + widget = QtGui.QWidget() + self.setCentralWidget(widget) + + mainLayout = QtGui.QVBoxLayout() + # add widgets to layout + #self.createWindowHeader() + #mainLayout.addWidget(self.headerBox) + + # created in systray + mainLayout.addWidget(self.statusIconBox) + if self.debugmode: + mainLayout.addWidget(self.statusBox) + mainLayout.addWidget(self.loggerBox) + widget.setLayout(mainLayout) + + self.createMainActions() + self.createMainMenus() + + self.setWindowTitle("LEAP Client") + self.set_app_icon() + self.set_statusbarMessage('ready') + + def createMainActions(self): + #self.openAct = QtGui.QAction("&Open...", self, shortcut="Ctrl+O", + #triggered=self.open) + + self.firstRunWizardAct = QtGui.QAction( + "&First run wizard...", self, + triggered=self.stop_connection_and_launch_first_run_wizard) + self.aboutAct = QtGui.QAction("&About", self, triggered=self.about) + + #self.aboutQtAct = QtGui.QAction("About &Qt", self, + #triggered=QtGui.qApp.aboutQt) + + def createMainMenus(self): + self.connMenu = QtGui.QMenu("&Connections", self) + #self.viewMenu.addSeparator() + self.connMenu.addAction(self.quitAction) + + self.settingsMenu = QtGui.QMenu("&Settings", self) + self.settingsMenu.addAction(self.firstRunWizardAct) + + self.helpMenu = QtGui.QMenu("&Help", self) + self.helpMenu.addAction(self.aboutAct) + #self.helpMenu.addAction(self.aboutQtAct) + + self.menuBar().addMenu(self.connMenu) + self.menuBar().addMenu(self.settingsMenu) + self.menuBar().addMenu(self.helpMenu) + + def stop_connection_and_launch_first_run_wizard(self): + settings = QtCore.QSettings() + settings.setValue('FirstRunWizardDone', False) + logger.debug('should run first run wizard again...') + + status = self.conductor.get_icon_name() + if status != "disconnected": + self.start_or_stopVPN() + + self.launch_first_run_wizard() + #from leap.gui.firstrunwizard import FirstRunWizard + #wizard = FirstRunWizard( + #parent=self, + #success_cb=self.initReady.emit) + #wizard.show() + + def set_app_icon(self): + icon = QtGui.QIcon(APP_LOGO) + self.setWindowIcon(icon) + + #def createWindowHeader(self): + #""" + #description lines for main window + #""" + #self.headerBox = QtGui.QGroupBox() + #self.headerLabel = QtGui.QLabel( + #"<font size=40>LEAP Encryption Access Project</font>") + #self.headerLabelSub = QtGui.QLabel( + #"<br><i>your internet encryption toolkit</i>") +# + #pixmap = QtGui.QPixmap(APP_LOGO) + #leap_lbl = QtGui.QLabel() + #leap_lbl.setPixmap(pixmap) +# + #headerLayout = QtGui.QHBoxLayout() + #headerLayout.addWidget(leap_lbl) + #headerLayout.addWidget(self.headerLabel) + #headerLayout.addWidget(self.headerLabelSub) + #headerLayout.addStretch() + #self.headerBox.setLayout(headerLayout) + + def set_statusbarMessage(self, msg): + self.statusBar().showMessage(msg) + + def closeEvent(self, event): + """ + redefines close event (persistent window behaviour) + """ + if self.trayIcon.isVisible() and not self.debugmode: + QtGui.QMessageBox.information( + self, "Systray", + "The program will keep running " + "in the system tray. To " + "terminate the program, choose " + "<b>Quit</b> in the " + "context menu of the system tray entry.") + self.hide() + event.ignore() + return + self.cleanupAndQuit() + + def cleanupAndQuit(self): + """ + cleans state before shutting down app. + """ + # save geometry for restoring + settings = QtCore.QSettings() + geom_key = "DebugGeometry" if self.debugmode else "Geometry" + settings.setValue(geom_key, self.saveGeometry()) + + # TODO:make sure to shutdown all child process / threads + # in conductor + # XXX send signal instead? + logger.info('Shutting down') + self.conductor.cleanup() + logger.info('Exiting. Bye.') + QtGui.qApp.quit() diff --git a/src/leap/baseapp/log.py b/src/leap/baseapp/log.py new file mode 100644 index 00000000..8a7f81c3 --- /dev/null +++ b/src/leap/baseapp/log.py @@ -0,0 +1,65 @@ +import logging + +from PyQt4 import QtGui +from PyQt4 import QtCore + +vpnlogger = logging.getLogger('leap.openvpn') + + +class LogPaneMixin(object): + """ + a simple log pane + that writes new lines as they come + """ + + def createLogBrowser(self): + """ + creates Browser widget for displaying logs + (in debug mode only). + """ + self.loggerBox = QtGui.QGroupBox() + logging_layout = QtGui.QVBoxLayout() + self.logbrowser = QtGui.QTextBrowser() + + startStopButton = QtGui.QPushButton("&Connect") + self.startStopButton = startStopButton + + logging_layout.addWidget(self.logbrowser) + logging_layout.addWidget(self.startStopButton) + self.loggerBox.setLayout(logging_layout) + + # status box + + self.statusBox = QtGui.QGroupBox() + grid = QtGui.QGridLayout() + + self.updateTS = QtGui.QLabel('') + self.status_label = QtGui.QLabel('Disconnected') + self.ip_label = QtGui.QLabel('') + self.remote_label = QtGui.QLabel('') + + tun_read_label = QtGui.QLabel("tun read") + self.tun_read_bytes = QtGui.QLabel("0") + tun_write_label = QtGui.QLabel("tun write") + self.tun_write_bytes = QtGui.QLabel("0") + + grid.addWidget(self.updateTS, 0, 0) + grid.addWidget(self.status_label, 0, 1) + grid.addWidget(self.ip_label, 1, 0) + grid.addWidget(self.remote_label, 1, 1) + grid.addWidget(tun_read_label, 2, 0) + grid.addWidget(self.tun_read_bytes, 2, 1) + grid.addWidget(tun_write_label, 3, 0) + grid.addWidget(self.tun_write_bytes, 3, 1) + + self.statusBox.setLayout(grid) + + @QtCore.pyqtSlot(str) + def onLoggerNewLine(self, line): + """ + simple slot: writes new line to logger Pane. + """ + msg = line[:-1] + if self.debugmode: + self.logbrowser.append(msg) + vpnlogger.info(msg) diff --git a/src/leap/baseapp/mainwindow.py b/src/leap/baseapp/mainwindow.py index 68b6de8f..8d61bf5c 100644 --- a/src/leap/baseapp/mainwindow.py +++ b/src/leap/baseapp/mainwindow.py @@ -1,398 +1,170 @@ # vim: set fileencoding=utf-8 : #!/usr/bin/env python import logging -import time -logger = logging.getLogger(name=__name__) -from PyQt4.QtGui import (QMainWindow, QWidget, QVBoxLayout, QMessageBox, - QSystemTrayIcon, QGroupBox, QLabel, QPixmap, - QHBoxLayout, QIcon, - QPushButton, QGridLayout, QAction, QMenu, - QTextBrowser, qApp) -from PyQt4.QtCore import (pyqtSlot, pyqtSignal, QTimer) +import sip +sip.setapi('QString', 2) +sip.setapi('QVariant', 2) + +from PyQt4 import QtCore +from PyQt4 import QtGui -from leap.gui import mainwindow_rc -from leap.eip.conductor import EIPConductor +from leap.baseapp.eip import EIPConductorAppMixin +from leap.baseapp.log import LogPaneMixin +from leap.baseapp.systray import StatusAwareTrayIconMixin +from leap.baseapp.network import NetworkCheckerAppMixin +from leap.baseapp.leap_app import MainWindowMixin +from leap.eip.checks import ProviderCertChecker +from leap.gui.threads import FunThread +logger = logging.getLogger(name=__name__) -class LeapWindow(QMainWindow): - #XXX tbd: refactor into model / view / controller - #and put in its own modules... - newLogLine = pyqtSignal([str]) - statusChange = pyqtSignal([object]) +class LeapWindow(QtGui.QMainWindow, + MainWindowMixin, EIPConductorAppMixin, + StatusAwareTrayIconMixin, + NetworkCheckerAppMixin, + LogPaneMixin): + """ + main window for the leap app. + Initializes all of its base classes + We keep here some signal initialization + that gets tricky otherwise. + """ + + # signals + + newLogLine = QtCore.pyqtSignal([str]) + mainappReady = QtCore.pyqtSignal([]) + initReady = QtCore.pyqtSignal([]) + networkError = QtCore.pyqtSignal([object]) + triggerEIPError = QtCore.pyqtSignal([object]) + start_eipconnection = QtCore.pyqtSignal([]) + shutdownSignal = QtCore.pyqtSignal([]) + + # this is status change got from openvpn management + openvpnStatusChange = QtCore.pyqtSignal([object]) + # this is global eip status + eipStatusChange = QtCore.pyqtSignal([str]) def __init__(self, opts): - super(LeapWindow, self).__init__() + logger.debug('init leap window') self.debugmode = getattr(opts, 'debug', False) - - self.vpn_service_started = False - - self.createWindowHeader() - self.createIconGroupBox() - - self.createActions() - self.createTrayIcon() + super(LeapWindow, self).__init__() if self.debugmode: self.createLogBrowser() - # create timer - self.timer = QTimer() - - # bind signals - - self.trayIcon.activated.connect(self.iconActivated) - self.newLogLine.connect(self.onLoggerNewLine) - self.statusChange.connect(self.onStatusChange) - self.timer.timeout.connect(self.onTimerTick) - - widget = QWidget() - self.setCentralWidget(widget) + settings = QtCore.QSettings() + self.provider_domain = settings.value("provider_domain", None) + self.eip_username = settings.value("eip_username", None) - # add widgets to layout - mainLayout = QVBoxLayout() - mainLayout.addWidget(self.headerBox) - mainLayout.addWidget(self.statusIconBox) - if self.debugmode: - mainLayout.addWidget(self.statusBox) - mainLayout.addWidget(self.loggerBox) - widget.setLayout(mainLayout) + logger.debug('provider: %s', self.provider_domain) + logger.debug('eip_username: %s', self.eip_username) - # - # conductor is in charge of all - # vpn-related configuration / monitoring. - # we pass a tuple of signals that will be - # triggered when status changes. - # - config_file = getattr(opts, 'config_file', None) - self.conductor = EIPConductor( - watcher_cb=self.newLogLine.emit, - config_file=config_file, - status_signals=(self.statusChange.emit, )) + EIPConductorAppMixin.__init__( + self, opts=opts, provider=self.provider_domain) + StatusAwareTrayIconMixin.__init__(self) + NetworkCheckerAppMixin.__init__(self) + MainWindowMixin.__init__(self) - self.trayIcon.show() + geom_key = "DebugGeometry" if self.debugmode else "Geometry" + geom = settings.value(geom_key) + if geom: + self.restoreGeometry(geom) - self.setWindowTitle("Leap") - self.resize(400, 300) + # XXX check for wizard + self.wizard_done = settings.value("FirstRunWizardDone") - self.set_statusbarMessage('ready') + self.initchecks = FunThread(self.run_eip_checks) - if self.conductor.autostart: - self.start_or_stopVPN() + # bind signals + self.initchecks.finished.connect( + lambda: logger.debug('Initial checks thread finished')) + self.trayIcon.activated.connect(self.iconActivated) + self.newLogLine.connect( + lambda line: self.onLoggerNewLine(line)) + self.timer.timeout.connect( + lambda: self.onTimerTick()) + self.networkError.connect( + lambda exc: self.onNetworkError(exc)) + self.triggerEIPError.connect( + lambda exc: self.onEIPError(exc)) - def closeEvent(self, event): - """ - redefines close event (persistent window behaviour) - """ - if self.trayIcon.isVisible() and not self.debugmode: - QMessageBox.information(self, "Systray", - "The program will keep running " - "in the system tray. To " - "terminate the program, choose " - "<b>Quit</b> in the " - "context menu of the system tray entry.") - self.hide() - event.ignore() if self.debugmode: + self.startStopButton.clicked.connect( + lambda: self.start_or_stopVPN()) + self.start_eipconnection.connect( + lambda: self.start_or_stopVPN()) + self.shutdownSignal.connect( + self.cleanupAndQuit) + + # status change. + # TODO unify + self.openvpnStatusChange.connect( + lambda status: self.onOpenVPNStatusChange(status)) + self.eipStatusChange.connect( + lambda newstatus: self.onEIPConnStatusChange(newstatus)) + self.eipStatusChange.connect( + lambda newstatus: self.toggleEIPAct()) + + # do first run wizard and init signals + self.mainappReady.connect(self.do_first_run_wizard_check) + self.initReady.connect(self.runchecks_and_eipconnect) + + # ... all ready. go! + # connected to do_first_run_wizard_check + self.mainappReady.emit() + + def do_first_run_wizard_check(self): + """ + checks whether first run wizard needs to be run + launches it if needed + and emits initReady signal if not. + """ + + logger.debug('first run wizard check...') + need_wizard = False + + # do checks (can overlap if wizard was interrupted) + if not self.wizard_done: + need_wizard = True + + if not self.provider_domain: + need_wizard = True + else: + pcertchecker = ProviderCertChecker(domain=self.provider_domain) + if not pcertchecker.is_cert_valid(do_raise=False): + logger.warning('missing valid client cert. need wizard') + need_wizard = True + + # launch wizard if needed + if need_wizard: + self.launch_first_run_wizard() + else: # no wizard needed + logger.debug('running first run wizard') + self.initReady.emit() + + def launch_first_run_wizard(self): + """ + launches wizard and blocks + """ + from leap.gui.firstrun.wizard import FirstRunWizard + wizard = FirstRunWizard( + self.conductor, + parent=self, + eip_username=self.eip_username, + start_eipconnection_signal=self.start_eipconnection, + eip_statuschange_signal=self.eipStatusChange, + quitcallback=self.onWizardCancel) + wizard.show() + + def onWizardCancel(self): + if not self.wizard_done: + logger.debug( + 'clicked on Cancel during first ' + 'run wizard. shutting down') self.cleanupAndQuit() - def setIcon(self, name): - icon = self.Icons.get(name) - self.trayIcon.setIcon(icon) - self.setWindowIcon(icon) - - def setToolTip(self): - """ - get readable status and place it on systray tooltip - """ - status = self.conductor.status.get_readable_status() - self.trayIcon.setToolTip(status) - - def iconActivated(self, reason): - """ - handles left click, left double click - showing the trayicon menu - """ - #XXX there's a bug here! - #menu shows on (0,0) corner first time, - #until double clicked at least once. - if reason in (QSystemTrayIcon.Trigger, - QSystemTrayIcon.DoubleClick): - self.trayIconMenu.show() - - def createWindowHeader(self): - """ - description lines for main window - """ - #XXX good candidate to refactor out! :) - self.headerBox = QGroupBox() - self.headerLabel = QLabel("<font size=40><b>E</b>ncryption \ -<b>I</b>nternet <b>P</b>roxy</font>") - self.headerLabelSub = QLabel("<i>trust your \ -technolust</i>") - - pixmap = QPixmap(':/images/leapfrog.jpg') - frog_lbl = QLabel() - frog_lbl.setPixmap(pixmap) - - headerLayout = QHBoxLayout() - headerLayout.addWidget(frog_lbl) - headerLayout.addWidget(self.headerLabel) - headerLayout.addWidget(self.headerLabelSub) - headerLayout.addStretch() - self.headerBox.setLayout(headerLayout) - - def getIcon(self, icon_name): - # XXX get from connection dict - icons = {'disconnected': 0, - 'connecting': 1, - 'connected': 2} - return icons.get(icon_name, None) - - def createIconGroupBox(self): - """ - dummy icongroupbox - (to be removed from here -- reference only) - """ - icons = { - 'disconnected': ':/images/conn_error.png', - 'connecting': ':/images/conn_connecting.png', - 'connected': ':/images/conn_connected.png' - } - con_widgets = { - 'disconnected': QLabel(), - 'connecting': QLabel(), - 'connected': QLabel(), - } - con_widgets['disconnected'].setPixmap( - QPixmap(icons['disconnected'])) - con_widgets['connecting'].setPixmap( - QPixmap(icons['connecting'])) - con_widgets['connected'].setPixmap( - QPixmap(icons['connected'])), - self.ConnectionWidgets = con_widgets - - con_icons = { - 'disconnected': QIcon(icons['disconnected']), - 'connecting': QIcon(icons['connecting']), - 'connected': QIcon(icons['connected']) - } - self.Icons = con_icons - - self.statusIconBox = QGroupBox("Connection Status") - statusIconLayout = QHBoxLayout() - statusIconLayout.addWidget(self.ConnectionWidgets['disconnected']) - statusIconLayout.addWidget(self.ConnectionWidgets['connecting']) - statusIconLayout.addWidget(self.ConnectionWidgets['connected']) - statusIconLayout.itemAt(1).widget().hide() - statusIconLayout.itemAt(2).widget().hide() - self.statusIconBox.setLayout(statusIconLayout) - - def createActions(self): - """ - creates actions to be binded to tray icon - """ - self.connectVPNAction = QAction("Connect to &VPN", self, - triggered=self.hide) - # XXX change action name on (dis)connect - self.dis_connectAction = QAction("&(Dis)connect", self, - triggered=self.start_or_stopVPN) - self.minimizeAction = QAction("Mi&nimize", self, - triggered=self.hide) - self.maximizeAction = QAction("Ma&ximize", self, - triggered=self.showMaximized) - self.restoreAction = QAction("&Restore", self, - triggered=self.showNormal) - self.quitAction = QAction("&Quit", self, - triggered=self.cleanupAndQuit) - - def createTrayIcon(self): - """ - creates the tray icon - """ - self.trayIconMenu = QMenu(self) - - self.trayIconMenu.addAction(self.connectVPNAction) - self.trayIconMenu.addAction(self.dis_connectAction) - self.trayIconMenu.addSeparator() - self.trayIconMenu.addAction(self.minimizeAction) - self.trayIconMenu.addAction(self.maximizeAction) - self.trayIconMenu.addAction(self.restoreAction) - self.trayIconMenu.addSeparator() - self.trayIconMenu.addAction(self.quitAction) - - self.trayIcon = QSystemTrayIcon(self) - self.trayIcon.setContextMenu(self.trayIconMenu) - - def createLogBrowser(self): - """ - creates Browser widget for displaying logs - (in debug mode only). - """ - self.loggerBox = QGroupBox() - logging_layout = QVBoxLayout() - self.logbrowser = QTextBrowser() - - startStopButton = QPushButton("&Connect") - startStopButton.clicked.connect(self.start_or_stopVPN) - self.startStopButton = startStopButton - - logging_layout.addWidget(self.logbrowser) - logging_layout.addWidget(self.startStopButton) - self.loggerBox.setLayout(logging_layout) - - # status box - - self.statusBox = QGroupBox() - grid = QGridLayout() - - self.updateTS = QLabel('') - self.status_label = QLabel('Disconnected') - self.ip_label = QLabel('') - self.remote_label = QLabel('') - - tun_read_label = QLabel("tun read") - self.tun_read_bytes = QLabel("0") - tun_write_label = QLabel("tun write") - self.tun_write_bytes = QLabel("0") - - grid.addWidget(self.updateTS, 0, 0) - grid.addWidget(self.status_label, 0, 1) - grid.addWidget(self.ip_label, 1, 0) - grid.addWidget(self.remote_label, 1, 1) - grid.addWidget(tun_read_label, 2, 0) - grid.addWidget(self.tun_read_bytes, 2, 1) - grid.addWidget(tun_write_label, 3, 0) - grid.addWidget(self.tun_write_bytes, 3, 1) - - self.statusBox.setLayout(grid) - - @pyqtSlot(str) - def onLoggerNewLine(self, line): - """ - simple slot: writes new line to logger Pane. - """ - if self.debugmode: - self.logbrowser.append(line[:-1]) - - def set_statusbarMessage(self, msg): - self.statusBar().showMessage(msg) - - @pyqtSlot(object) - def onStatusChange(self, status): - """ - slot for status changes. triggers new signals for - updating icon, status bar, etc. - """ - - print('STATUS CHANGED! (on Qt-land)') - print('%s -> %s' % (status.previous, status.current)) - icon_name = self.conductor.get_icon_name() - self.setIcon(icon_name) - print 'icon = ', icon_name - - # change connection pixmap widget - self.setConnWidget(icon_name) - - def setConnWidget(self, icon_name): - #print 'changing icon to %s' % icon_name - oldlayout = self.statusIconBox.layout() - - # XXX reuse with icons - # XXX move states to StateWidget - states = {"disconnected": 0, - "connecting": 1, - "connected": 2} - - for i in range(3): - oldlayout.itemAt(i).widget().hide() - new = states[icon_name] - oldlayout.itemAt(new).widget().show() - - @pyqtSlot() - def start_or_stopVPN(self): - """ - stub for running child process with vpn - """ - if self.vpn_service_started is False: - self.conductor.connect() - if self.debugmode: - self.startStopButton.setText('&Disconnect') - self.vpn_service_started = True - - # XXX what is optimum polling interval? - # too little is overkill, too much - # will miss transition states.. - - self.timer.start(250.0) - return - if self.vpn_service_started is True: - self.conductor.disconnect() - # FIXME this should trigger also - # statuschange event. why isn't working?? - if self.debugmode: - self.startStopButton.setText('&Connect') - self.vpn_service_started = False - self.timer.stop() - return - - @pyqtSlot() - def onTimerTick(self): - self.statusUpdate() - - @pyqtSlot() - def statusUpdate(self): - """ - called on timer tick - polls status and updates ui with real time - info about transferred bytes / connection state. - """ - # XXX it's too expensive to poll - # continously. move to signal events instead. - - if not self.vpn_service_started: - return - - # XXX remove all access to manager layer - # from here. - if self.conductor.manager.with_errors: - #XXX how to wait on pkexec??? - #something better that this workaround, plz!! - time.sleep(10) - print('errors. disconnect.') - self.start_or_stopVPN() # is stop - - state = self.conductor.poll_connection_state() - if not state: - return - - ts, con_status, ok, ip, remote = state - self.set_statusbarMessage(con_status) - self.setToolTip() - - ts = time.strftime("%a %b %d %X", ts) - if self.debugmode: - self.updateTS.setText(ts) - self.status_label.setText(con_status) - self.ip_label.setText(ip) - self.remote_label.setText(remote) - - # status i/o - - status = self.conductor.manager.get_status_io() - if status and self.debugmode: - #XXX move this to systray menu indicators - ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) = status - ts = time.strftime("%a %b %d %X", ts) - self.updateTS.setText(ts) - self.tun_read_bytes.setText(tun_read) - self.tun_write_bytes.setText(tun_write) - - def cleanupAndQuit(self): - """ - cleans state before shutting down app. - """ - # TODO:make sure to shutdown all child process / threads - # in conductor - self.conductor.cleanup() - qApp.quit() + def runchecks_and_eipconnect(self): + self.show_systray_icon() + self.initchecks.begin() diff --git a/src/leap/baseapp/network.py b/src/leap/baseapp/network.py new file mode 100644 index 00000000..077d5164 --- /dev/null +++ b/src/leap/baseapp/network.py @@ -0,0 +1,40 @@ +from __future__ import print_function + +import logging + +logger = logging.getLogger(name=__name__) + +from PyQt4 import QtCore + +from leap.baseapp.dialogs import ErrorDialog +from leap.base.network import NetworkCheckerThread + + +class NetworkCheckerAppMixin(object): + """ + initialize an instance of the Network Checker, + which gathers error and passes them on. + """ + + def __init__(self, *args, **kwargs): + self.network_checker = NetworkCheckerThread( + error_cb=self.networkError.emit, + debug=self.debugmode) + + # XXX move run_checks to slot + self.network_checker.run_checks() + + @QtCore.pyqtSlot(object) + def onNetworkError(self, exc): + """ + slot that receives a network exceptions + and raises a user error message + """ + logger.debug('handling network exception') + logger.error(exc.message) + dialog = ErrorDialog(parent=self) + + if exc.critical: + dialog.criticalMessage(exc.usermessage, "network error") + else: + dialog.warningMessage(exc.usermessage, "network error") diff --git a/src/leap/baseapp/permcheck.py b/src/leap/baseapp/permcheck.py new file mode 100644 index 00000000..6b74cb6e --- /dev/null +++ b/src/leap/baseapp/permcheck.py @@ -0,0 +1,17 @@ +import commands +import os + +from leap.util.fileutil import which + + +def is_pkexec_in_system(): + pkexec_path = which('pkexec') + if not pkexec_path: + return False + return os.access(pkexec_path, os.X_OK) + + +def is_auth_agent_running(): + return bool( + commands.getoutput( + 'ps aux | grep polkit-[g]nome-authentication-agent-1')) diff --git a/src/leap/baseapp/systray.py b/src/leap/baseapp/systray.py new file mode 100644 index 00000000..49f044aa --- /dev/null +++ b/src/leap/baseapp/systray.py @@ -0,0 +1,245 @@ +import logging +import sip +sip.setapi('QString', 2) +sip.setapi('QVariant', 2) + +from PyQt4 import QtCore +from PyQt4 import QtGui + +from leap import __branding as BRANDING +from leap import __version__ as VERSION + +from leap.gui import mainwindow_rc + +logger = logging.getLogger(__name__) + + +class StatusAwareTrayIconMixin(object): + """ + a mix of several functions needed + to create a systray and make it + get updated from conductor status + polling. + """ + states = { + "disconnected": 0, + "connecting": 1, + "connected": 2} + + iconpath = { + "disconnected": ':/images/conn_error.png', + "connecting": ':/images/conn_connecting.png', + "connected": ':/images/conn_connected.png'} + + Icons = { + 'disconnected': lambda self: QtGui.QIcon( + self.iconpath['disconnected']), + 'connecting': lambda self: QtGui.QIcon( + self.iconpath['connecting']), + 'connected': lambda self: QtGui.QIcon( + self.iconpath['connected']) + } + + def __init__(self, *args, **kwargs): + self.createIconGroupBox() + self.createActions() + self.createTrayIcon() + + # not sure if this really belongs here, but... + self.timer = QtCore.QTimer() + + def show_systray_icon(self): + #logger.debug('showing tray icon................') + self.trayIcon.show() + + def createIconGroupBox(self): + """ + dummy icongroupbox + (to be removed from here -- reference only) + """ + con_widgets = { + 'disconnected': QtGui.QLabel(), + 'connecting': QtGui.QLabel(), + 'connected': QtGui.QLabel(), + } + con_widgets['disconnected'].setPixmap( + QtGui.QPixmap( + self.iconpath['disconnected'])) + con_widgets['connecting'].setPixmap( + QtGui.QPixmap( + self.iconpath['connecting'])) + con_widgets['connected'].setPixmap( + QtGui.QPixmap( + self.iconpath['connected'])), + self.ConnectionWidgets = con_widgets + + self.statusIconBox = QtGui.QGroupBox("EIP Connection Status") + statusIconLayout = QtGui.QHBoxLayout() + statusIconLayout.addWidget(self.ConnectionWidgets['disconnected']) + statusIconLayout.addWidget(self.ConnectionWidgets['connecting']) + statusIconLayout.addWidget(self.ConnectionWidgets['connected']) + statusIconLayout.itemAt(1).widget().hide() + statusIconLayout.itemAt(2).widget().hide() + + self.leapConnStatus = QtGui.QLabel("<b>disconnected</b>") + statusIconLayout.addWidget(self.leapConnStatus) + + self.statusIconBox.setLayout(statusIconLayout) + + def createTrayIcon(self): + """ + creates the tray icon + """ + self.trayIconMenu = QtGui.QMenu(self) + + self.trayIconMenu.addAction(self.connAct) + self.trayIconMenu.addSeparator() + self.trayIconMenu.addAction(self.detailsAct) + self.trayIconMenu.addSeparator() + self.trayIconMenu.addAction(self.aboutAct) + # we should get this hidden inside the "about" dialog + # (as a little button maybe) + #self.trayIconMenu.addAction(self.aboutQtAct) + self.trayIconMenu.addSeparator() + self.trayIconMenu.addAction(self.quitAction) + + self.trayIcon = QtGui.QSystemTrayIcon(self) + self.setIcon('disconnected') + self.trayIcon.setContextMenu(self.trayIconMenu) + + #self.trayIconMenu.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) + #self.trayIconMenu.customContextMenuRequested.connect( + #self.on_context_menu) + + def bad(self): + logger.error('this should not be called') + + def createActions(self): + """ + creates actions to be binded to tray icon + """ + # XXX change action name on (dis)connect + self.connAct = QtGui.QAction("Encryption ON turn &off", self, + triggered=lambda: self.start_or_stopVPN()) + + self.detailsAct = QtGui.QAction("&Details...", + self, + triggered=self.detailsWin) + self.aboutAct = QtGui.QAction("&About", self, + triggered=self.about) + self.aboutQtAct = QtGui.QAction("About Q&t", self, + triggered=QtGui.qApp.aboutQt) + self.quitAction = QtGui.QAction("&Quit", self, + triggered=self.cleanupAndQuit) + + def toggleEIPAct(self): + # this is too simple by now. + # XXX get STATUS CONSTANTS INSTEAD + + icon_status = self.conductor.get_icon_name() + if icon_status == "connected": + self.connAct.setEnabled(True) + self.connAct.setText('Encryption ON turn o&ff') + return + if icon_status == "disconnected": + self.connAct.setEnabled(True) + self.connAct.setText('Encryption OFF turn &on') + return + if icon_status == "connecting": + self.connAct.setDisabled(True) + self.connAct.setText('connecting...') + return + + def detailsWin(self): + visible = self.isVisible() + if visible: + self.hide() + else: + self.show() + + def about(self): + # move to widget + flavor = BRANDING.get('short_name', None) + content = ("LEAP client<br>" + "(version <b>%s</b>)<br>" % VERSION) + if flavor: + content = content + ('<br>Flavor: <i>%s</i><br>' % flavor) + content = content + ( + "<br><a href='https://leap.se/'>" + "https://leap.se</a>") + QtGui.QMessageBox.about(self, "About", content) + + def setConnWidget(self, icon_name): + oldlayout = self.statusIconBox.layout() + + for i in range(3): + oldlayout.itemAt(i).widget().hide() + new = self.states[icon_name] + oldlayout.itemAt(new).widget().show() + + def setIcon(self, name): + icon_fun = self.Icons.get(name) + if icon_fun and callable(icon_fun): + icon = icon_fun(self) + self.trayIcon.setIcon(icon) + + def getIcon(self, icon_name): + return self.states.get(icon_name, None) + + def setIconToolTip(self): + """ + get readable status and place it on systray tooltip + """ + status = self.conductor.status.get_readable_status() + self.trayIcon.setToolTip(status) + + def iconActivated(self, reason): + """ + handles left click, left double click + showing the trayicon menu + """ + if reason in (QtGui.QSystemTrayIcon.Trigger, + QtGui.QSystemTrayIcon.DoubleClick): + context_menu = self.trayIcon.contextMenu() + # for some reason, context_menu.show() + # is failing in a way beyond my understanding. + # (not working the first time it's clicked). + # this works however. + context_menu.exec_(self.trayIcon.geometry().center()) + + @QtCore.pyqtSlot() + def onTimerTick(self): + self.statusUpdate() + + @QtCore.pyqtSlot(object) + def onOpenVPNStatusChange(self, status): + """ + updates icon, according to the openvpn status change. + """ + icon_name = self.conductor.get_icon_name() + + # XXX refactor. Use QStateMachine + + if icon_name in ("disconnected", "connected"): + self.eipStatusChange.emit(icon_name) + + if icon_name in ("connecting"): + # let's see how it matches + leap_status_name = self.conductor.get_leap_status() + self.eipStatusChange.emit(leap_status_name) + + self.setIcon(icon_name) + # change connection pixmap widget + self.setConnWidget(icon_name) + + @QtCore.pyqtSlot(str) + def onEIPConnStatusChange(self, newstatus): + """ + slot for EIP status changes + not to be confused with onOpenVPNStatusChange. + this only updates the non-debug LEAP Status line + next to the connection icon. + """ + # XXX move bold to style sheet + self.leapConnStatus.setText( + "<b>%s</b>" % newstatus) diff --git a/src/leap/certs/__init__.py b/src/leap/certs/__init__.py new file mode 100644 index 00000000..c4d009b1 --- /dev/null +++ b/src/leap/certs/__init__.py @@ -0,0 +1,7 @@ +import os + +_where = os.path.split(__file__)[0] + + +def where(filename): + return os.path.join(_where, filename) diff --git a/src/leap/crypto/__init__.py b/src/leap/crypto/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/crypto/__init__.py diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py new file mode 100644 index 00000000..8908865d --- /dev/null +++ b/src/leap/crypto/certs.py @@ -0,0 +1,71 @@ +import ctypes +import socket + +import gnutls.connection +import gnutls.crypto +import gnutls.library + + +def get_https_cert_from_domain(domain): + """ + @param domain: a domain name to get a certificate from. + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + cred = gnutls.connection.X509Credentials() + + session = gnutls.connection.ClientSession(sock, cred) + session.connect((domain, 443)) + session.handshake() + cert = session.peer_certificate + return cert + + +def get_cert_from_file(filepath): + with open(filepath) as f: + cert = gnutls.crypto.X509Certificate(f.read()) + return cert + + +def get_cert_fingerprint(domain=None, filepath=None, + hash_type="SHA256", sep=":"): + """ + @param domain: a domain name to get a fingerprint from + @type domain: str + @param filepath: path to a file containing a PEM file + @type filepath: str + @param hash_type: the hash function to be used in the fingerprint. + must be one of SHA1, SHA224, SHA256, SHA384, SHA512 + @type hash_type: str + @rparam: hex_fpr, a hexadecimal representation of a bytestring + containing the fingerprint. + @rtype: string + """ + if domain: + cert = get_https_cert_from_domain(domain) + if filepath: + cert = get_cert_from_file(filepath) + + _buffer = ctypes.create_string_buffer(64) + buffer_length = ctypes.c_size_t(64) + + SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512") + if hash_type in SUPPORTED_DIGEST_FUN: + digestfunction = getattr( + gnutls.library.constants, + "GNUTLS_DIG_%s" % hash_type) + else: + # XXX improperlyconfigured or something + raise Exception("digest function not supported") + + gnutls.library.functions.gnutls_x509_crt_get_fingerprint( + cert._c_object, digestfunction, + ctypes.byref(_buffer), ctypes.byref(buffer_length)) + + # deinit + #server_cert._X509Certificate__deinit(server_cert._c_object) + # needed? is segfaulting + + fpr = ctypes.string_at(_buffer, buffer_length.value) + hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr) + + return hex_fpr diff --git a/src/leap/crypto/leapkeyring.py b/src/leap/crypto/leapkeyring.py new file mode 100644 index 00000000..d4be7bf9 --- /dev/null +++ b/src/leap/crypto/leapkeyring.py @@ -0,0 +1,69 @@ +import keyring + +from leap.base.config import get_config_file + +############# +# Disclaimer +############# +# This currently is not a keyring, it's more like a joke. +# No, seriously. +# We're affected by this **bug** + +# https://bitbucket.org/kang/python-keyring-lib/ +# issue/65/dbusexception-method-opensession-with + +# so using the gnome keyring does not seem feasible right now. +# I thought this was the next best option to store secrets in plain sight. + +# in the future we should move to use the gnome/kde/macosx/win keyrings. + + +class LeapCryptedFileKeyring(keyring.backend.CryptedFileKeyring): + + filename = ".secrets" + + @property + def file_path(self): + return get_config_file(self.filename) + + def __init__(self, seed=None): + self.seed = seed + + def _get_new_password(self): + # XXX every time this method is called, + # $deity kills a kitten. + return "secret%s" % self.seed + + def _init_file(self): + self.keyring_key = self._get_new_password() + self.set_password('keyring_setting', 'pass_ref', 'pass_ref_value') + + def _unlock(self): + self.keyring_key = self._get_new_password() + print 'keyring key ', self.keyring_key + try: + ref_pw = self.get_password( + 'keyring_setting', + 'pass_ref') + print 'ref pw ', ref_pw + assert ref_pw == "pass_ref_value" + except AssertionError: + self._lock() + raise ValueError('Incorrect password') + + +def leap_set_password(key, value, seed="xxx"): + keyring.set_keyring(LeapCryptedFileKeyring(seed=seed)) + keyring.set_password('leap', key, value) + + +def leap_get_password(key, seed="xxx"): + keyring.set_keyring(LeapCryptedFileKeyring(seed=seed)) + #import ipdb;ipdb.set_trace() + return keyring.get_password('leap', key) + + +if __name__ == "__main__": + leap_set_password('test', 'bar') + passwd = leap_get_password('test') + assert passwd == 'bar' diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py new file mode 100644 index 00000000..116c535e --- /dev/null +++ b/src/leap/eip/checks.py @@ -0,0 +1,518 @@ +import logging +import ssl +#import platform +import time +import os + +import gnutls.crypto +#import netifaces +#import ping +import requests + +from leap import __branding as BRANDING +from leap import certs as leapcerts +from leap.base.auth import srpauth_protected, magick_srpauth +from leap.base import config as baseconfig +from leap.base import constants as baseconstants +from leap.base import providers +from leap.crypto import certs +from leap.eip import config as eipconfig +from leap.eip import constants as eipconstants +from leap.eip import exceptions as eipexceptions +from leap.eip import specs as eipspecs +from leap.util.fileutil import mkdir_p + +logger = logging.getLogger(name=__name__) + +""" +ProviderCertChecker +------------------- +Checks on certificates. To be moved to base. +docs TBD + +EIPConfigChecker +---------- +It is used from the eip conductor (a instance of EIPConnection that is +managed from the QtApp), running `run_all` method before trying to call +`connect` or any other of the state-changing methods. + +It checks that the needed files are provided or can be discovered over the +net. Much of these tests are not specific to EIP module, and can be splitted +into base.tests to be invoked by the base leap init routines. +However, I'm testing them alltogether for the sake of having the whole unit +reachable and testable as a whole. + +""" + + +def get_branding_ca_cert(domain): + # XXX deprecated + ca_file = BRANDING.get('provider_ca_file') + if ca_file: + return leapcerts.where(ca_file) + + +class ProviderCertChecker(object): + """ + Several checks needed for getting + client certs and checking tls connection + with provider. + """ + def __init__(self, fetcher=requests, + domain=None): + + self.fetcher = fetcher + self.domain = domain + self.cacert = eipspecs.provider_ca_path(domain) + + def run_all( + self, checker=None, + skip_download=False, skip_verify=False): + + if not checker: + checker = self + + do_verify = not skip_verify + logger.debug('do_verify: %s', do_verify) + # checker.download_ca_cert() + + # For MVS+ + # checker.download_ca_signature() + # checker.get_ca_signatures() + # checker.is_there_trust_path() + + # For MVS + checker.is_there_provider_ca() + + # XXX FAKE IT!!! + checker.is_https_working(verify=do_verify, autocacert=True) + checker.check_new_cert_needed(verify=do_verify) + + def download_ca_cert(self, uri=None, verify=True): + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + + # should check domain exists + capath = self._get_ca_cert_path(self.domain) + with open(capath, 'w') as f: + f.write(req.content) + + def check_ca_cert_fingerprint( + self, hash_type="SHA256", + fingerprint=None): + """ + compares the fingerprint in + the ca cert with a string + we are passed + returns True if they are equal, False if not. + @param hash_type: digest function + @type hash_type: str + @param fingerprint: the fingerprint to compare with. + @type fingerprint: str (with : separator) + @rtype bool + """ + ca_cert_path = self.ca_cert_path + ca_cert_fpr = certs.get_cert_fingerprint( + filepath=ca_cert_path) + return ca_cert_fpr == fingerprint + + def verify_api_https(self, uri): + assert uri.startswith('https://') + cacert = self.ca_cert_path + verify = cacert and cacert or True + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + return True + + def download_ca_signature(self): + # MVS+ + raise NotImplementedError + + def get_ca_signatures(self): + # MVS+ + raise NotImplementedError + + def is_there_trust_path(self): + # MVS+ + raise NotImplementedError + + def is_there_provider_ca(self): + if not self.cacert: + return False + cacert_exists = os.path.isfile(self.cacert) + if cacert_exists: + logger.debug('True') + return True + logger.debug('False!') + return False + + def is_https_working( + self, uri=None, verify=True, + autocacert=False): + if uri is None: + uri = self._get_root_uri() + # XXX raise InsecureURI or something better + try: + assert uri.startswith('https') + except AssertionError: + raise AssertionError( + "uri passed should start with https") + if autocacert and verify is True and self.cacert is not None: + logger.debug('verify cert: %s', self.cacert) + verify = self.cacert + #import pdb4qt; pdb4qt.set_trace() + logger.debug('is https working?') + logger.debug('uri: %s (verify:%s)', uri, verify) + try: + self.fetcher.get(uri, verify=verify) + + except requests.exceptions.SSLError as exc: + logger.error("SSLError") + # XXX RAISE! See #638 + #raise eipexceptions.HttpsBadCertError + logger.warning('BUG #638 CERT VERIFICATION FAILED! ' + '(this should be CRITICAL)') + logger.warning('SSLError: %s', exc.message) + + except requests.exceptions.ConnectionError: + logger.error('ConnectionError') + raise eipexceptions.HttpsNotSupported + + else: + logger.debug('True') + return True + + def check_new_cert_needed(self, skip_download=False, verify=True): + logger.debug('is new cert needed?') + if not self.is_cert_valid(do_raise=False): + logger.debug('True') + self.download_new_client_cert( + skip_download=skip_download, + verify=verify) + return True + logger.debug('False') + return False + + def download_new_client_cert(self, uri=None, verify=True, + skip_download=False, + credentials=None): + logger.debug('download new client cert') + if skip_download: + return True + if uri is None: + uri = self._get_client_cert_uri() + # XXX raise InsecureURI or something better + assert uri.startswith('https') + + if verify is True and self.cacert is not None: + verify = self.cacert + + fgetfn = self.fetcher.get + + if credentials: + user, passwd = credentials + + logger.debug('domain = %s', self.domain) + + @srpauth_protected(user, passwd, + server="https://%s" % self.domain, + verify=verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) + + else: + # XXX FIXME fix decorated args + @magick_srpauth(verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) + try: + + # XXX FIXME!!!! + # verify=verify + # Workaround for #638. return to verification + # when That's done!!! + #req = self.fetcher.get(uri, verify=False) + req = getfn(uri, verify=False) + req.raise_for_status() + + except requests.exceptions.SSLError: + logger.warning('SSLError while fetching cert. ' + 'Look below for stack trace.') + # XXX raise better exception + raise + try: + pemfile_content = req.content + self.is_valid_pemfile(pemfile_content) + cert_path = self._get_client_cert_path() + self.write_cert(pemfile_content, to=cert_path) + except: + logger.warning('Error while validating cert') + raise + return True + + def is_cert_valid(self, cert_path=None, do_raise=True): + exists = lambda: self.is_certificate_exists() + valid_pemfile = lambda: self.is_valid_pemfile() + not_expired = lambda: self.is_cert_not_expired() + + valid = exists() and valid_pemfile() and not_expired() + if not valid: + if do_raise: + raise Exception('missing valid cert') + else: + return False + return True + + def is_certificate_exists(self, certfile=None): + if certfile is None: + certfile = self._get_client_cert_path() + return os.path.isfile(certfile) + + def is_cert_not_expired(self, certfile=None, now=time.gmtime): + if certfile is None: + certfile = self._get_client_cert_path() + with open(certfile) as cf: + cert_s = cf.read() + cert = gnutls.crypto.X509Certificate(cert_s) + from_ = time.gmtime(cert.activation_time) + to_ = time.gmtime(cert.expiration_time) + return from_ < now() < to_ + + def is_valid_pemfile(self, cert_s=None): + """ + checks that the passed string + is a valid pem certificate + @param cert_s: string containing pem content + @type cert_s: string + @rtype: bool + """ + if cert_s is None: + certfile = self._get_client_cert_path() + with open(certfile) as cf: + cert_s = cf.read() + try: + # XXX get a real cert validation + # so far this is only checking begin/end + # delimiters :) + # XXX use gnutls for get proper + # validation. + # crypto.X509Certificate(cert_s) + sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5 + # we might have private key and cert in the same file + certparts = cert_s.split(sep) + if len(certparts) > 1: + cert_s = sep + certparts[1] + ssl.PEM_cert_to_DER_cert(cert_s) + except: + # XXX raise proper exception + raise + return True + + @property + def ca_cert_path(self): + return self._get_ca_cert_path(self.domain) + + def _get_root_uri(self): + return u"https://%s/" % self.domain + + def _get_client_cert_uri(self): + # XXX get the whole thing from constants + return "https://%s/1/cert" % self.domain + + def _get_client_cert_path(self): + return eipspecs.client_cert_path(domain=self.domain) + + def _get_ca_cert_path(self, domain): + # XXX this folder path will be broken for win + # and this should be moved to eipspecs.ca_path + + # XXX use baseconfig.get_provider_path(folder=Foo) + # !!! + + capath = baseconfig.get_config_file( + 'cacert.pem', + folder='providers/%s/keys/ca' % domain) + folder, fname = os.path.split(capath) + if not os.path.isdir(folder): + mkdir_p(folder) + return capath + + def write_cert(self, pemfile_content, to=None): + folder, filename = os.path.split(to) + if not os.path.isdir(folder): + mkdir_p(folder) + with open(to, 'w') as cert_f: + cert_f.write(pemfile_content) + + +class EIPConfigChecker(object): + """ + Several checks needed + to ensure a EIPConnection + can be sucessfully established. + use run_all to run all checks. + """ + + def __init__(self, fetcher=requests, domain=None): + # we do not want to accept too many + # argument on init. + # we want tests + # to be explicitely run. + + self.fetcher = fetcher + + # if not domain, get from config + self.domain = domain + + self.eipconfig = eipconfig.EIPConfig(domain=domain) + self.defaultprovider = providers.LeapProviderDefinition(domain=domain) + self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain) + + def run_all(self, checker=None, skip_download=False): + """ + runs all checks in a row. + will raise if some error encountered. + catching those exceptions is not + our responsibility at this moment + """ + if not checker: + checker = self + + # let's call all tests + # needed for a sane eip session. + + # TODO: get rid of check_default. + # check_complete should + # be enough. but here to make early tests easier. + checker.check_default_eipconfig() + + checker.check_is_there_default_provider() + checker.fetch_definition(skip_download=skip_download) + checker.fetch_eip_service_config(skip_download=skip_download) + checker.check_complete_eip_config() + #checker.ping_gateway() + + # public checks + + def check_default_eipconfig(self): + """ + checks if default eipconfig exists, + and dumps a default file if not + """ + # XXX ONLY a transient check + # because some old function still checks + # for eip config at the beginning. + + # it *really* does not make sense to + # dump it right now, we can get an in-memory + # config object and dump it to disk in a + # later moment + logger.debug('checking default eip config') + if not self._is_there_default_eipconfig(): + self._dump_default_eipconfig() + + def check_is_there_default_provider(self, config=None): + """ + raises EIPMissingDefaultProvider if no + default provider found on eip config. + This is catched by ui and runs FirstRunWizard (MVS+) + """ + if config is None: + config = self.eipconfig.config + logger.debug('checking default provider') + provider = config.get('provider', None) + if provider is None: + raise eipexceptions.EIPMissingDefaultProvider + # XXX raise also if malformed ProviderDefinition? + return True + + def fetch_definition(self, skip_download=False, + config=None, uri=None, + domain=None): + """ + fetches a definition file from server + """ + # TODO: + # - Implement diff + # - overwrite only if different. + # (attend to serial field different, for instance) + + logger.debug('fetching definition') + + if skip_download: + logger.debug('(fetching def skipped)') + return True + if config is None: + config = self.defaultprovider.config + if uri is None: + if not domain: + domain = config.get('provider', None) + uri = self._get_provider_definition_uri(domain=domain) + + # FIXME! Pass ca path verify!!! + # BUG #638 + # FIXME FIXME FIXME + self.defaultprovider.load( + from_uri=uri, + fetcher=self.fetcher, + verify=False) + self.defaultprovider.save() + + def fetch_eip_service_config(self, skip_download=False, + config=None, uri=None, domain=None): + if skip_download: + return True + if config is None: + config = self.eipserviceconfig.config + if uri is None: + if not domain: + domain = self.domain or config.get('provider', None) + uri = self._get_eip_service_uri(domain=domain) + + self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher) + self.eipserviceconfig.save() + + def check_complete_eip_config(self, config=None): + # TODO check for gateway + if config is None: + config = self.eipconfig.config + try: + 'trying assertions' + assert 'provider' in config + assert config['provider'] is not None + # XXX assert there is gateway !! + except AssertionError: + raise eipexceptions.EIPConfigurationError + + # XXX TODO: + # We should WRITE eip config if missing or + # incomplete at this point + #self.eipconfig.save() + + # + # private helpers + # + + def _is_there_default_eipconfig(self): + return self.eipconfig.exists() + + def _dump_default_eipconfig(self): + self.eipconfig.save() + + def _get_provider_definition_uri(self, domain=None, path=None): + if domain is None: + domain = self.domain or baseconstants.DEFAULT_PROVIDER + if path is None: + path = baseconstants.DEFINITION_EXPECTED_PATH + uri = u"https://%s/%s" % (domain, path) + logger.debug('getting provider definition from %s' % uri) + return uri + + def _get_eip_service_uri(self, domain=None, path=None): + if domain is None: + domain = self.domain or baseconstants.DEFAULT_PROVIDER + if path is None: + path = eipconstants.EIP_SERVICE_EXPECTED_PATH + uri = "https://%s/%s" % (domain, path) + logger.debug('getting eip service file from %s', uri) + return uri diff --git a/src/leap/eip/conductor.py b/src/leap/eip/conductor.py deleted file mode 100644 index e3adadc4..00000000 --- a/src/leap/eip/conductor.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -stablishes a vpn connection and monitors its state -""" -from __future__ import (division, unicode_literals, print_function) -#import threading -from functools import partial -import logging - -from leap.utils.coroutines import spawn_and_watch_process -from leap.baseapp.config import get_config, get_vpn_stdout_mockup -from leap.eip.vpnwatcher import EIPConnectionStatus, status_watcher -from leap.eip.vpnmanager import OpenVPNManager, ConnectionRefusedError - -logger = logging.getLogger(name=__name__) - - -# TODO Move exceptions to their own module - - -class ConnectionError(Exception): - """ - generic connection error - """ - pass - - -class EIPClientError(Exception): - """ - base EIPClient exception - """ - def __str__(self): - if len(self.args) >= 1: - return repr(self.args[0]) - else: - return ConnectionError - - -class UnrecoverableError(EIPClientError): - """ - we cannot do anything about it, sorry - """ - pass - - -class OpenVPNConnection(object): - """ - All related to invocation - of the openvpn binary - """ - # Connection Methods - - def __init__(self, config_file=None, watcher_cb=None): - #XXX FIXME - #change watcher_cb to line_observer - """ - :param config_file: configuration file to read from - :param watcher_cb: callback to be \ -called for each line in watched stdout - :param signal_map: dictionary of signal names and callables \ -to be triggered for each one of them. - :type config_file: str - :type watcher_cb: function - :type signal_map: dict - """ - # XXX get host/port from config - self.manager = OpenVPNManager() - - self.config_file = config_file - self.watcher_cb = watcher_cb - #self.signal_maps = signal_maps - - self.subp = None - self.watcher = None - - self.server = None - self.port = None - self.proto = None - - self.autostart = True - - self._get_config() - - def _set_command_mockup(self): - """ - sets command and args for a command mockup - that just mimics the output from the real thing - """ - command, args = get_vpn_stdout_mockup() - self.command, self.args = command, args - - def _get_config(self): - """ - retrieves the config options from defaults or - home file, or config file passed in command line. - """ - config = get_config(config_file=self.config_file) - self.config = config - - if config.has_option('openvpn', 'command'): - commandline = config.get('openvpn', 'command') - if commandline == "mockup": - self._set_command_mockup() - return - command_split = commandline.split(' ') - command = command_split[0] - if len(command_split) > 1: - args = command_split[1:] - else: - args = [] - self.command = command - #print("debug: command = %s" % command) - self.args = args - else: - self._set_command_mockup() - - if config.has_option('openvpn', 'autostart'): - autostart = config.get('openvpn', 'autostart') - self.autostart = autostart - - def _launch_openvpn(self): - """ - invocation of openvpn binaries in a subprocess. - """ - #XXX TODO: - #deprecate watcher_cb, - #use _only_ signal_maps instead - - if self.watcher_cb is not None: - linewrite_callback = self.watcher_cb - else: - #XXX get logger instead - linewrite_callback = lambda line: print('watcher: %s' % line) - - observers = (linewrite_callback, - partial(status_watcher, self.status)) - subp, watcher = spawn_and_watch_process( - self.command, - self.args, - observers=observers) - self.subp = subp - self.watcher = watcher - - conn_result = self.status.CONNECTED - return conn_result - - def _try_connection(self): - """ - attempts to connect - """ - if self.subp is not None: - print('cowardly refusing to launch subprocess again') - return - self._launch_openvpn() - - def cleanup(self): - """ - terminates child subprocess - """ - if self.subp: - self.subp.terminate() - - -class EIPConductor(OpenVPNConnection): - """ - Manages the execution of the OpenVPN process, auto starts, monitors the - network connection, handles configuration, fixes leaky hosts, handles - errors, etc. - Preferences will be stored via the Storage API. (TBD) - Status updates (connected, bandwidth, etc) are signaled to the GUI. - """ - - def __init__(self, *args, **kwargs): - self.settingsfile = kwargs.get('settingsfile', None) - self.logfile = kwargs.get('logfile', None) - self.error_queue = [] - self.desired_con_state = None # ??? - - status_signals = kwargs.pop('status_signals', None) - self.status = EIPConnectionStatus(callbacks=status_signals) - - super(EIPConductor, self).__init__(*args, **kwargs) - - def connect(self): - """ - entry point for connection process - """ - self.manager.forget_errors() - self._try_connection() - # XXX should capture errors? - - def disconnect(self): - """ - disconnects client - """ - self._disconnect() - self.status.change_to(self.status.DISCONNECTED) - pass - - def shutdown(self): - """ - shutdown and quit - """ - self.desired_con_state = self.status.DISCONNECTED - - def connection_state(self): - """ - returns the current connection state - """ - return self.status.current - - def desired_connection_state(self): - """ - returns the desired_connection state - """ - return self.desired_con_state - - def poll_connection_state(self): - """ - """ - try: - state = self.manager.get_connection_state() - except ConnectionRefusedError: - # connection refused. might be not ready yet. - return - if not state: - return - (ts, status_step, - ok, ip, remote) = state - self.status.set_vpn_state(status_step) - status_step = self.status.get_readable_status() - return (ts, status_step, ok, ip, remote) - - def get_icon_name(self): - """ - get icon name from status object - """ - return self.status.get_state_icon() - - # - # private methods - # - - def _disconnect(self): - """ - private method for disconnecting - """ - if self.subp is not None: - self.subp.terminate() - self.subp = None - # XXX signal state changes! :) - - def _is_alive(self): - """ - don't know yet - """ - pass - - def _connect(self): - """ - entry point for connection cascade methods. - """ - #conn_result = ConState.DISCONNECTED - try: - conn_result = self._try_connection() - except UnrecoverableError as except_msg: - logger.error("FATAL: %s" % unicode(except_msg)) - conn_result = self.status.UNRECOVERABLE - except Exception as except_msg: - self.error_queue.append(except_msg) - logger.error("Failed Connection: %s" % - unicode(except_msg)) - return conn_result diff --git a/src/leap/eip/config.py b/src/leap/eip/config.py new file mode 100644 index 00000000..42c00380 --- /dev/null +++ b/src/leap/eip/config.py @@ -0,0 +1,303 @@ +import logging +import os +import platform +import tempfile + +from leap import __branding as BRANDING +from leap import certs +from leap.util.fileutil import (which, mkdir_p, check_and_fix_urw_only) + +from leap.base import config as baseconfig +from leap.baseapp.permcheck import (is_pkexec_in_system, + is_auth_agent_running) +from leap.eip import exceptions as eip_exceptions +from leap.eip import specs as eipspecs + +logger = logging.getLogger(name=__name__) +provider_ca_file = BRANDING.get('provider_ca_file', None) + + +class EIPConfig(baseconfig.JSONLeapConfig): + spec = eipspecs.eipconfig_spec + + def _get_slug(self): + eipjsonpath = baseconfig.get_config_file( + 'eip.json') + return eipjsonpath + + def _set_slug(self, *args, **kwargs): + raise AttributeError("you cannot set slug") + + slug = property(_get_slug, _set_slug) + + +class EIPServiceConfig(baseconfig.JSONLeapConfig): + spec = eipspecs.eipservice_config_spec + + def _get_slug(self): + domain = getattr(self, 'domain', None) + if domain: + path = baseconfig.get_provider_path(domain) + else: + path = baseconfig.get_default_provider_path() + return baseconfig.get_config_file( + 'eip-service.json', folder=path) + + def _set_slug(self): + raise AttributeError("you cannot set slug") + + slug = property(_get_slug, _set_slug) + + +def get_socket_path(): + socket_path = os.path.join( + tempfile.mkdtemp(prefix="leap-tmp"), + 'openvpn.socket') + logger.debug('socket path: %s', socket_path) + return socket_path + + +def get_eip_gateway(provider=None): + """ + return the first host in eip service config + that matches the name defined in the eip.json config + file. + """ + placeholder = "testprovider.example.org" + # XXX check for null on provider?? + + eipconfig = EIPConfig(domain=provider) + eipconfig.load() + conf = eipconfig.config + + primary_gateway = conf.get('primary_gateway', None) + if not primary_gateway: + return placeholder + + eipserviceconfig = EIPServiceConfig(domain=provider) + eipserviceconfig.load() + eipsconf = eipserviceconfig.get_config() + gateways = eipsconf.get('gateways', None) + if not gateways: + logger.error('missing gateways in eip service config') + return placeholder + if len(gateways) > 0: + for gw in gateways: + name = gw.get('name', None) + if not name: + return + + if name == primary_gateway: + hosts = gw.get('hosts', None) + if not hosts: + logger.error('no hosts') + return + if len(hosts) > 0: + return hosts[0] + else: + logger.error('no hosts') + logger.error('could not find primary gateway in provider' + 'gateway list') + + +def build_ovpn_options(daemon=False, socket_path=None, **kwargs): + """ + build a list of options + to be passed in the + openvpn invocation + @rtype: list + @rparam: options + """ + # XXX review which of the + # options we don't need. + + # TODO pass also the config file, + # since we will need to take some + # things from there if present. + + provider = kwargs.pop('provider', None) + + # get user/group name + # also from config. + user = baseconfig.get_username() + group = baseconfig.get_groupname() + + opts = [] + + opts.append('--client') + + opts.append('--dev') + # XXX same in win? + opts.append('tun') + opts.append('--persist-tun') + opts.append('--persist-key') + + verbosity = kwargs.get('ovpn_verbosity', None) + if verbosity and 1 <= verbosity <= 6: + opts.append('--verb') + opts.append("%s" % verbosity) + + # remote + opts.append('--remote') + gw = get_eip_gateway(provider=provider) + logger.debug('setting eip gateway to %s', gw) + opts.append(str(gw)) + opts.append('1194') + #opts.append('80') + opts.append('udp') + + opts.append('--tls-client') + opts.append('--remote-cert-tls') + opts.append('server') + + # set user and group + opts.append('--user') + opts.append('%s' % user) + opts.append('--group') + opts.append('%s' % group) + + opts.append('--management-client-user') + opts.append('%s' % user) + opts.append('--management-signal') + + # set default options for management + # interface. unix sockets or telnet interface for win. + # XXX take them from the config object. + + ourplatform = platform.system() + if ourplatform in ("Linux", "Mac"): + opts.append('--management') + + if socket_path is None: + socket_path = get_socket_path() + opts.append(socket_path) + opts.append('unix') + + if ourplatform == "Windows": + opts.append('--management') + opts.append('localhost') + # XXX which is a good choice? + opts.append('7777') + + # certs + client_cert_path = eipspecs.client_cert_path(provider) + ca_cert_path = eipspecs.provider_ca_path(provider) + + opts.append('--cert') + opts.append(client_cert_path) + opts.append('--key') + opts.append(client_cert_path) + opts.append('--ca') + opts.append(ca_cert_path) + + # we cannot run in daemon mode + # with the current subp setting. + # see: https://leap.se/code/issues/383 + #if daemon is True: + #opts.append('--daemon') + + logger.debug('vpn options: %s', opts) + return opts + + +def build_ovpn_command(debug=False, do_pkexec_check=True, vpnbin=None, + socket_path=None, **kwargs): + """ + build a string with the + complete openvpn invocation + + @rtype [string, [list of strings]] + @rparam: a list containing the command string + and a list of options. + """ + command = [] + use_pkexec = True + ovpn = None + + # XXX get use_pkexec from config instead. + + if platform.system() == "Linux" and use_pkexec and do_pkexec_check: + + # check for both pkexec + # AND a suitable authentication + # agent running. + logger.info('use_pkexec set to True') + + if not is_pkexec_in_system(): + logger.error('no pkexec in system') + raise eip_exceptions.EIPNoPkexecAvailable + + if not is_auth_agent_running(): + logger.warning( + "no polkit auth agent found. " + "pkexec will use its own text " + "based authentication agent. " + "that's probably a bad idea") + raise eip_exceptions.EIPNoPolkitAuthAgentAvailable + + command.append('pkexec') + if vpnbin is None: + ovpn = which('openvpn') + else: + ovpn = vpnbin + if ovpn: + vpn_command = ovpn + else: + vpn_command = "openvpn" + command.append(vpn_command) + daemon_mode = not debug + + for opt in build_ovpn_options(daemon=daemon_mode, socket_path=socket_path, + **kwargs): + command.append(opt) + + # XXX check len and raise proper error + + return [command[0], command[1:]] + + +def check_vpn_keys(provider=None): + """ + performs an existance and permission check + over the openvpn keys file. + Currently we're expecting a single file + per provider, containing the CA cert, + the provider key, and our client certificate + """ + assert provider is not None + provider_ca = eipspecs.provider_ca_path(provider) + client_cert = eipspecs.client_cert_path(provider) + + logger.debug('provider ca = %s', provider_ca) + logger.debug('client cert = %s', client_cert) + + # if no keys, raise error. + # it's catched by the ui and signal user. + + if not os.path.isfile(provider_ca): + # not there. let's try to copy. + folder, filename = os.path.split(provider_ca) + if not os.path.isdir(folder): + mkdir_p(folder) + if provider_ca_file: + cacert = certs.where(provider_ca_file) + with open(provider_ca, 'w') as pca: + with open(cacert, 'r') as cac: + pca.write(cac.read()) + + if not os.path.isfile(provider_ca): + logger.error('key file %s not found. aborting.', + provider_ca) + raise eip_exceptions.EIPInitNoKeyFileError + + if not os.path.isfile(client_cert): + logger.error('key file %s not found. aborting.', + client_cert) + raise eip_exceptions.EIPInitNoKeyFileError + + for keyfile in (provider_ca, client_cert): + # bad perms? try to fix them + try: + check_and_fix_urw_only(keyfile) + except OSError: + raise eip_exceptions.EIPInitBadKeyFilePermError diff --git a/src/leap/eip/constants.py b/src/leap/eip/constants.py new file mode 100644 index 00000000..9af5a947 --- /dev/null +++ b/src/leap/eip/constants.py @@ -0,0 +1,3 @@ +# not used anymore with the new JSONConfig.slug +EIP_CONFIG = "eip.json" +EIP_SERVICE_EXPECTED_PATH = "1/config/eip-service.json" diff --git a/src/leap/eip/eipconnection.py b/src/leap/eip/eipconnection.py new file mode 100644 index 00000000..7828c864 --- /dev/null +++ b/src/leap/eip/eipconnection.py @@ -0,0 +1,350 @@ +""" +EIP Connection Class +""" +from __future__ import (absolute_import,) +import logging +import Queue +import sys + +from leap.eip.checks import ProviderCertChecker +from leap.eip.checks import EIPConfigChecker +from leap.eip import config as eipconfig +from leap.eip import exceptions as eip_exceptions +from leap.eip.openvpnconnection import OpenVPNConnection + +logger = logging.getLogger(name=__name__) + + +class EIPConnection(OpenVPNConnection): + """ + Manages the execution of the OpenVPN process, auto starts, monitors the + network connection, handles configuration, fixes leaky hosts, handles + errors, etc. + Status updates (connected, bandwidth, etc) are signaled to the GUI. + """ + + def __init__(self, + provider_cert_checker=ProviderCertChecker, + config_checker=EIPConfigChecker, + *args, **kwargs): + self.settingsfile = kwargs.get('settingsfile', None) + self.logfile = kwargs.get('logfile', None) + self.provider = kwargs.pop('provider', None) + self._providercertchecker = provider_cert_checker + self._configchecker = config_checker + + self.error_queue = Queue.Queue() + + status_signals = kwargs.pop('status_signals', None) + self.status = EIPConnectionStatus(callbacks=status_signals) + + checker_signals = kwargs.pop('checker_signals', None) + self.checker_signals = checker_signals + + self.init_checkers() + + host = eipconfig.get_socket_path() + kwargs['host'] = host + + super(EIPConnection, self).__init__(*args, **kwargs) + + def has_errors(self): + return True if self.error_queue.qsize() != 0 else False + + def init_checkers(self): + # initialize checkers + self.provider_cert_checker = self._providercertchecker( + domain=self.provider) + self.config_checker = self._configchecker(domain=self.provider) + + def set_provider_domain(self, domain): + """ + sets the provider domain. + used from the first run wizard when we launch the run_checks + and connect process after having initialized the conductor. + """ + # This looks convoluted, right. + # We have to reinstantiate checkers cause we're passing + # the domain param that we did not know at the beginning + # (only for the firstrunwizard case) + self.provider = domain + self.init_checkers() + + def run_checks(self, skip_download=False, skip_verify=False): + """ + run all eip checks previous to attempting a connection + """ + logger.debug('running conductor checks') + + def push_err(exc): + # keep the original traceback! + exc_traceback = sys.exc_info()[2] + self.error_queue.put((exc, exc_traceback)) + + try: + # network (1) + if self.checker_signals: + for signal in self.checker_signals: + signal('checking encryption keys') + self.provider_cert_checker.run_all(skip_verify=skip_verify) + except Exception as exc: + push_err(exc) + try: + if self.checker_signals: + for signal in self.checker_signals: + signal('checking provider config') + self.config_checker.run_all(skip_download=skip_download) + except Exception as exc: + push_err(exc) + try: + self.run_openvpn_checks() + except Exception as exc: + push_err(exc) + + def connect(self): + """ + entry point for connection process + """ + #self.forget_errors() + self._try_connection() + + def disconnect(self): + """ + disconnects client + """ + self.cleanup() + logger.debug("disconnect: clicked.") + self.status.change_to(self.status.DISCONNECTED) + + #def shutdown(self): + #""" + #shutdown and quit + #""" + #self.desired_con_state = self.status.DISCONNECTED + + def connection_state(self): + """ + returns the current connection state + """ + return self.status.current + + def poll_connection_state(self): + """ + """ + try: + state = self.get_connection_state() + except eip_exceptions.ConnectionRefusedError: + # connection refused. might be not ready yet. + logger.warning('connection refused') + return + if not state: + logger.debug('no state') + return + (ts, status_step, + ok, ip, remote) = state + self.status.set_vpn_state(status_step) + status_step = self.status.get_readable_status() + return (ts, status_step, ok, ip, remote) + + def get_icon_name(self): + """ + get icon name from status object + """ + return self.status.get_state_icon() + + def get_leap_status(self): + return self.status.get_leap_status() + + # + # private methods + # + + #def _disconnect(self): + # """ + # private method for disconnecting + # """ + # if self.subp is not None: + # logger.debug('disconnecting...') + # self.subp.terminate() + # self.subp = None + + #def _is_alive(self): + #""" + #don't know yet + #""" + #pass + + def _connect(self): + """ + entry point for connection cascade methods. + """ + try: + conn_result = self._try_connection() + except eip_exceptions.UnrecoverableError as except_msg: + logger.error("FATAL: %s" % unicode(except_msg)) + conn_result = self.status.UNRECOVERABLE + + # XXX enqueue exceptions themselves instead? + except Exception as except_msg: + self.error_queue.append(except_msg) + logger.error("Failed Connection: %s" % + unicode(except_msg)) + return conn_result + + +class EIPConnectionStatus(object): + """ + Keep track of client (gui) and openvpn + states. + + These are the OpenVPN states: + CONNECTING -- OpenVPN's initial state. + WAIT -- (Client only) Waiting for initial response + from server. + AUTH -- (Client only) Authenticating with server. + GET_CONFIG -- (Client only) Downloading configuration options + from server. + ASSIGN_IP -- Assigning IP address to virtual network + interface. + ADD_ROUTES -- Adding routes to system. + CONNECTED -- Initialization Sequence Completed. + RECONNECTING -- A restart has occurred. + EXITING -- A graceful exit is in progress. + + We add some extra states: + + DISCONNECTED -- GUI initial state. + UNRECOVERABLE -- An unrecoverable error has been raised + while invoking openvpn service. + """ + CONNECTING = 1 + WAIT = 2 + AUTH = 3 + GET_CONFIG = 4 + ASSIGN_IP = 5 + ADD_ROUTES = 6 + CONNECTED = 7 + RECONNECTING = 8 + EXITING = 9 + + # gui specific states: + UNRECOVERABLE = 11 + DISCONNECTED = 0 + + def __init__(self, callbacks=None): + """ + EIPConnectionStatus is initialized with a tuple + of signals to be triggered. + :param callbacks: a tuple of (callable) observers + :type callbacks: tuple + """ + self.current = self.DISCONNECTED + self.previous = None + # (callbacks to connect to signals in Qt-land) + self.callbacks = callbacks + + def get_readable_status(self): + # XXX DRY status / labels a little bit. + # think we'll want to i18n this. + human_status = { + 0: 'disconnected', + 1: 'connecting', + 2: 'waiting', + 3: 'authenticating', + 4: 'getting config', + 5: 'assigning ip', + 6: 'adding routes', + 7: 'connected', + 8: 'reconnecting', + 9: 'exiting', + 11: 'unrecoverable error', + } + return human_status[self.current] + + def get_leap_status(self): + # XXX improve nomenclature + leap_status = { + 0: 'disconnected', + 1: 'connecting to gateway', + 2: 'connecting to gateway', + 3: 'authenticating', + 4: 'establishing network encryption', + 5: 'establishing network encryption', + 6: 'establishing network encryption', + 7: 'connected', + 8: 'reconnecting', + 9: 'exiting', + 11: 'unrecoverable error', + } + return leap_status[self.current] + + def get_state_icon(self): + """ + returns the high level icon + for each fine-grain openvpn state + """ + connecting = (self.CONNECTING, + self.WAIT, + self.AUTH, + self.GET_CONFIG, + self.ASSIGN_IP, + self.ADD_ROUTES) + connected = (self.CONNECTED,) + disconnected = (self.DISCONNECTED, + self.UNRECOVERABLE) + + # this can be made smarter, + # but it's like it'll change, + # so +readability. + + if self.current in connecting: + return "connecting" + if self.current in connected: + return "connected" + if self.current in disconnected: + return "disconnected" + + def set_vpn_state(self, status): + """ + accepts a state string from the management + interface, and sets the internal state. + :param status: openvpn STATE (uppercase). + :type status: str + """ + if hasattr(self, status): + self.change_to(getattr(self, status)) + + def set_current(self, to): + """ + setter for the 'current' property + :param to: destination state + :type to: int + """ + self.current = to + + def change_to(self, to): + """ + :param to: destination state + :type to: int + """ + if to == self.current: + return + changed = False + from_ = self.current + self.current = to + + # We can add transition restrictions + # here to ensure no transitions are + # allowed outside the fsm. + + self.set_current(to) + changed = True + + #trigger signals (as callbacks) + #print('current state: %s' % self.current) + if changed: + self.previous = from_ + if self.callbacks: + for cb in self.callbacks: + if callable(cb): + cb(self) diff --git a/src/leap/eip/exceptions.py b/src/leap/eip/exceptions.py new file mode 100644 index 00000000..41eed77a --- /dev/null +++ b/src/leap/eip/exceptions.py @@ -0,0 +1,156 @@ +""" +Generic error hierarchy +Leap/EIP exceptions used for exception handling, +logging, and notifying user of errors +during leap operation. + +Exception hierarchy +------------------- +All EIP Errors must inherit from EIPClientError (note: move that to +a more generic LEAPClientBaseError). + +Exception attributes and their meaning/uses +------------------------------------------- + +* critical: if True, will abort execution prematurely, + after attempting any cleaning + action. + +* failfirst: breaks any error_check loop that is examining + the error queue. + +* message: the message that will be used in the __repr__ of the exception. + +* usermessage: the message that will be passed to user in ErrorDialogs + in Qt-land. + +TODO: + +* EIPClientError: + Should inherit from LeapException + +* gettext / i18n for user messages. + +""" +from leap.base.exceptions import LeapException + + +# This should inherit from LeapException +class EIPClientError(Exception): + """ + base EIPClient exception + """ + critical = False + failfirst = False + warning = False + + +class CriticalError(EIPClientError): + """ + we cannot do anything about it, sorry + """ + critical = True + failfirst = True + + +class Warning(EIPClientError): + """ + just that, warnings + """ + warning = True + + +class EIPNoPolkitAuthAgentAvailable(CriticalError): + message = "No polkit authentication agent could be found" + usermessage = ("We could not find any authentication " + "agent in your system.<br/>" + "Make sure you have " + "<b>polkit-gnome-authentication-agent-1</b> " + "running and try again.") + + +class EIPNoPkexecAvailable(Warning): + message = "No pkexec binary found" + usermessage = ("We could not find <b>pkexec</b> in your " + "system.<br/> Do you want to try " + "<b>setuid workaround</b>? " + "(<i>DOES NOTHING YET</i>)") + failfirst = True + + +class EIPNoCommandError(EIPClientError): + message = "no suitable openvpn command found" + usermessage = ("No suitable openvpn command found. " + "<br/>(Might be a permissions problem)") + + +class EIPBadCertError(Warning): + # XXX this should be critical and fail close + message = "cert verification failed" + usermessage = "there is a problem with provider certificate" + + +class LeapBadConfigFetchedError(Warning): + message = "provider sent a malformed json file" + usermessage = "an error occurred during configuratio of leap services" + + +class OpenVPNAlreadyRunning(EIPClientError): + message = "Another OpenVPN Process is already running." + usermessage = ("Another OpenVPN Process has been detected." + "Please close it before starting leap-client") + + +class HttpsNotSupported(LeapException): + message = "connection refused while accessing via https" + usermessage = "Server does not allow secure connections." + + +class HttpsBadCertError(LeapException): + message = "verification error on cert" + usermessage = "Server certificate could not be verified." + +# +# errors still needing some love +# + + +class EIPInitNoKeyFileError(CriticalError): + message = "No vpn keys found in the expected path" + usermessage = "We could not find your eip certs in the expected path" + + +class EIPInitBadKeyFilePermError(Warning): + # I don't know if we should be telling user or not, + # we try to fix permissions and should only re-raise + # if permission check failed. + pass + + +class EIPInitNoProviderError(EIPClientError): + pass + + +class EIPInitBadProviderError(EIPClientError): + pass + + +class EIPConfigurationError(EIPClientError): + pass + +# +# Errors that probably we don't need anymore +# chase down for them and check. +# + + +class MissingSocketError(Exception): + pass + + +class ConnectionRefusedError(Exception): + pass + + +class EIPMissingDefaultProvider(Exception): + pass diff --git a/src/leap/eip/openvpnconnection.py b/src/leap/eip/openvpnconnection.py new file mode 100644 index 00000000..859378c0 --- /dev/null +++ b/src/leap/eip/openvpnconnection.py @@ -0,0 +1,460 @@ +""" +OpenVPN Connection +""" +from __future__ import (print_function) +import logging +import os +import psutil +import shutil +import socket +import time +from functools import partial + +logger = logging.getLogger(name=__name__) + +from leap.base.connection import Connection +from leap.util.coroutines import spawn_and_watch_process + +from leap.eip.udstelnet import UDSTelnet +from leap.eip import config as eip_config +from leap.eip import exceptions as eip_exceptions + + +class OpenVPNConnection(Connection): + """ + All related to invocation + of the openvpn binary + """ + + def __init__(self, + watcher_cb=None, + debug=False, + host=None, + port="unix", + password=None, + *args, **kwargs): + """ + :param config_file: configuration file to read from + :param watcher_cb: callback to be \ +called for each line in watched stdout + :param signal_map: dictionary of signal names and callables \ +to be triggered for each one of them. + :type config_file: str + :type watcher_cb: function + :type signal_map: dict + """ + #XXX FIXME + #change watcher_cb to line_observer + + logger.debug('init openvpn connection') + self.debug = debug + # XXX if not host: raise ImproperlyConfigured + self.ovpn_verbosity = kwargs.get('ovpn_verbosity', None) + + #self.config_file = config_file + self.watcher_cb = watcher_cb + #self.signal_maps = signal_maps + + self.subp = None + self.watcher = None + + self.server = None + self.port = None + self.proto = None + + #XXX workaround for signaling + #the ui that we don't know how to + #manage a connection error + #self.with_errors = False + + self.command = None + self.args = None + + # XXX get autostart from config + self.autostart = True + + # + # management init methods + # + + self.host = host + if isinstance(port, str) and port.isdigit(): + port = int(port) + elif port == "unix": + port = "unix" + else: + port = None + self.port = port + self.password = password + + def run_openvpn_checks(self): + logger.debug('running openvpn checks') + self._check_if_running_instance() + self._set_ovpn_command() + self._check_vpn_keys() + + def _set_ovpn_command(self): + # XXX check also for command-line --command flag + try: + command, args = eip_config.build_ovpn_command( + provider=self.provider, + debug=self.debug, + socket_path=self.host, + ovpn_verbosity=self.ovpn_verbosity) + except eip_exceptions.EIPNoPolkitAuthAgentAvailable: + command = args = None + raise + except eip_exceptions.EIPNoPkexecAvailable: + command = args = None + raise + + # XXX if not command, signal error. + self.command = command + self.args = args + + def _check_vpn_keys(self): + """ + checks for correct permissions on vpn keys + """ + try: + eip_config.check_vpn_keys(provider=self.provider) + except eip_exceptions.EIPInitBadKeyFilePermError: + logger.error('Bad VPN Keys permission!') + # do nothing now + # and raise the rest ... + + def _launch_openvpn(self): + """ + invocation of openvpn binaries in a subprocess. + """ + #XXX TODO: + #deprecate watcher_cb, + #use _only_ signal_maps instead + + logger.debug('_launch_openvpn called') + if self.watcher_cb is not None: + linewrite_callback = self.watcher_cb + else: + #XXX get logger instead + linewrite_callback = lambda line: print('watcher: %s' % line) + + # the partial is not + # being applied now because we're not observing the process + # stdout like we did in the early stages. but I leave it + # here since it will be handy for observing patterns in the + # thru-the-manager updates (with regex) + observers = (linewrite_callback, + partial(lambda con_status, line: None, self.status)) + subp, watcher = spawn_and_watch_process( + self.command, + self.args, + observers=observers) + self.subp = subp + self.watcher = watcher + + def _try_connection(self): + """ + attempts to connect + """ + if self.command is None: + raise eip_exceptions.EIPNoCommandError + if self.subp is not None: + logger.debug('cowardly refusing to launch subprocess again') + + self._launch_openvpn() + + def _check_if_running_instance(self): + """ + check if openvpn is already running + """ + for process in psutil.get_process_list(): + if process.name == "openvpn": + logger.debug('an openvpn instance is already running.') + logger.debug('attempting to stop openvpn instance.') + if not self._stop(): + raise eip_exceptions.OpenVPNAlreadyRunning + + logger.debug('no openvpn instance found.') + + def cleanup(self): + """ + terminates openvpn child subprocess + """ + if self.subp: + try: + self._stop() + except eip_exceptions.ConnectionRefusedError: + logger.warning( + 'unable to send sigterm signal to openvpn: ' + 'connection refused.') + + # XXX kali -- + # XXX review-me + # I think this will block if child process + # does not return. + # Maybe we can .poll() for a given + # interval and exit in any case. + + RETCODE = self.subp.wait() + if RETCODE: + logger.error( + 'cannot terminate subprocess! Retcode %s' + '(We might have left openvpn running)' % RETCODE) + + self.cleanup_tempfiles() + + def cleanup_tempfiles(self): + """ + remove all temporal files + we might have left behind + """ + # if self.port is 'unix', we have + # created a temporal socket path that, under + # normal circumstances, we should be able to + # delete + + if self.port == "unix": + logger.debug('cleaning socket file temp folder') + + tempfolder = os.path.split(self.host)[0] + if os.path.isdir(tempfolder): + try: + shutil.rmtree(tempfolder) + except OSError: + logger.error('could not delete tmpfolder %s' % tempfolder) + + def _get_openvpn_process(self): + # plist = [p for p in psutil.get_process_list() if p.name == "openvpn"] + # return plist[0] if plist else None + for process in psutil.get_process_list(): + if process.name == "openvpn": + return process + return None + + # management methods + # + # XXX REVIEW-ME + # REFACTOR INFO: (former "manager". + # Can we move to another + # base class to test independently?) + # + + #def forget_errors(self): + #logger.debug('forgetting errors') + #self.with_errors = False + + def connect_to_management(self): + """Connect to openvpn management interface""" + #logger.debug('connecting socket') + if hasattr(self, 'tn'): + self.close() + self.tn = UDSTelnet(self.host, self.port) + + # XXX make password optional + # specially for win. we should generate + # the pass on the fly when invoking manager + # from conductor + + #self.tn.read_until('ENTER PASSWORD:', 2) + #self.tn.write(self.password + '\n') + #self.tn.read_until('SUCCESS:', 2) + if self.tn: + self._seek_to_eof() + return True + + def _seek_to_eof(self): + """ + Read as much as available. Position seek pointer to end of stream + """ + try: + b = self.tn.read_eager() + except EOFError: + logger.debug("Could not read from socket. Assuming it died.") + return + while b: + try: + b = self.tn.read_eager() + except EOFError: + logger.debug("Could not read from socket. Assuming it died.") + + def connected(self): + """ + Returns True if connected + rtype: bool + """ + return hasattr(self, 'tn') + + def close(self, announce=True): + """ + Close connection to openvpn management interface + """ + logger.debug('closing socket') + if announce: + self.tn.write("quit\n") + self.tn.read_all() + self.tn.get_socket().close() + del self.tn + + def _send_command(self, cmd): + """ + Send a command to openvpn and return response as list + """ + if not self.connected(): + try: + self.connect_to_management() + except eip_exceptions.MissingSocketError: + logger.warning('missing management socket') + return [] + try: + if hasattr(self, 'tn'): + self.tn.write(cmd + "\n") + except socket.error: + logger.error('socket error') + self.close(announce=False) + return [] + buf = self.tn.read_until(b"END", 2) + self._seek_to_eof() + blist = buf.split('\r\n') + if blist[-1].startswith('END'): + del blist[-1] + return blist + else: + return [] + + def _send_short_command(self, cmd): + """ + parse output from commands that are + delimited by "success" instead + """ + if not self.connected(): + self.connect() + self.tn.write(cmd + "\n") + # XXX not working? + buf = self.tn.read_until(b"SUCCESS", 2) + self._seek_to_eof() + blist = buf.split('\r\n') + return blist + + # + # useful vpn commands + # + + def pid(self): + #XXX broken + return self._send_short_command("pid") + + def make_error(self): + """ + capture error and wrap it in an + understandable format + """ + #XXX get helpful error codes + self.with_errors = True + now = int(time.time()) + return '%s,LAUNCHER ERROR,ERROR,-,-' % now + + def state(self): + """ + OpenVPN command: state + """ + state = self._send_command("state") + if not state: + return None + if isinstance(state, str): + return state + if isinstance(state, list): + if len(state) == 1: + return state[0] + else: + return state[-1] + + def vpn_status(self): + """ + OpenVPN command: status + """ + #logger.debug('status called') + status = self._send_command("status") + return status + + def vpn_status2(self): + """ + OpenVPN command: last 2 statuses + """ + return self._send_command("status 2") + + def _stop(self): + """ + stop openvpn process + by sending SIGTERM to the management + interface + """ + logger.debug("disconnecting...") + if self.connected(): + try: + self._send_command("signal SIGTERM\n") + except socket.error: + logger.warning('management socket died') + return + + if self.subp: + # ??? + return True + + #shutting openvpn failured + #try patching in old openvpn host and trying again + process = self._get_openvpn_process() + if process: + logger.debug('process :%s' % process) + cmdline = process.cmdline + + if isinstance(cmdline, list): + _index = cmdline.index("--management") + self.host = cmdline[_index + 1] + self._send_command("signal SIGTERM\n") + + #make sure the process was terminated + process = self._get_openvpn_process() + if not process: + logger.debug("Existing OpenVPN Process Terminated") + return True + else: + logger.error("Unable to terminate existing OpenVPN Process.") + return False + + return True + + # + # parse info + # + + def get_status_io(self): + status = self.vpn_status() + if isinstance(status, str): + lines = status.split('\n') + if isinstance(status, list): + lines = status + try: + (header, when, tun_read, tun_write, + tcp_read, tcp_write, auth_read) = tuple(lines) + except ValueError: + return None + + when_ts = time.strptime(when.split(',')[1], "%a %b %d %H:%M:%S %Y") + sep = ',' + # XXX cleanup! + tun_read = tun_read.split(sep)[1] + tun_write = tun_write.split(sep)[1] + tcp_read = tcp_read.split(sep)[1] + tcp_write = tcp_write.split(sep)[1] + auth_read = auth_read.split(sep)[1] + + # XXX this could be a named tuple. prettier. + return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) + + def get_connection_state(self): + state = self.state() + if state is not None: + ts, status_step, ok, ip, remote = state.split(',') + ts = time.gmtime(float(ts)) + # XXX this could be a named tuple. prettier. + return ts, status_step, ok, ip, remote diff --git a/src/leap/eip/specs.py b/src/leap/eip/specs.py new file mode 100644 index 00000000..57e7537b --- /dev/null +++ b/src/leap/eip/specs.py @@ -0,0 +1,124 @@ +from __future__ import (unicode_literals) +import os + +from leap import __branding +from leap.base import config as baseconfig + +# XXX move provider stuff to base config + +PROVIDER_CA_CERT = __branding.get( + 'provider_ca_file', + 'cacert.pem') + +provider_ca_path = lambda domain: str(os.path.join( + #baseconfig.get_default_provider_path(), + baseconfig.get_provider_path(domain), + 'keys', 'ca', + 'cacert.pem' +)) if domain else None + +default_provider_ca_path = lambda: str(os.path.join( + baseconfig.get_default_provider_path(), + 'keys', 'ca', + PROVIDER_CA_CERT +)) + +PROVIDER_DOMAIN = __branding.get('provider_domain', 'testprovider.example.org') + + +client_cert_path = lambda domain: unicode(os.path.join( + baseconfig.get_provider_path(domain), + 'keys', 'client', + 'openvpn.pem' +)) if domain else None + +default_client_cert_path = lambda: unicode(os.path.join( + baseconfig.get_default_provider_path(), + 'keys', 'client', + 'openvpn.pem' +)) + +eipconfig_spec = { + 'description': 'sample eipconfig', + 'type': 'object', + 'properties': { + 'provider': { + 'type': unicode, + 'default': u"%s" % PROVIDER_DOMAIN, + 'required': True, + }, + 'transport': { + 'type': unicode, + 'default': u"openvpn", + }, + 'openvpn_protocol': { + 'type': unicode, + 'default': u"tcp" + }, + 'openvpn_port': { + 'type': int, + 'default': 80 + }, + 'openvpn_ca_certificate': { + 'type': unicode, # path + 'default': default_provider_ca_path + }, + 'openvpn_client_certificate': { + 'type': unicode, # path + 'default': default_client_cert_path + }, + 'connect_on_login': { + 'type': bool, + 'default': True + }, + 'block_cleartext_traffic': { + 'type': bool, + 'default': True + }, + 'primary_gateway': { + 'type': unicode, + 'default': u"turkey", + #'required': True + }, + 'secondary_gateway': { + 'type': unicode, + 'default': u"france" + }, + 'management_password': { + 'type': unicode + } + } +} + +eipservice_config_spec = { + 'description': 'sample eip service config', + 'type': 'object', + 'properties': { + 'serial': { + 'type': int, + 'required': True, + 'default': 1 + }, + 'version': { + 'type': unicode, + 'required': True, + 'default': "0.1.0" + }, + 'capabilities': { + 'type': dict, + 'default': { + "transport": ["openvpn"], + "ports": ["80", "53"], + "protocols": ["udp", "tcp"], + "static_ips": True, + "adblock": True} + }, + 'gateways': { + 'type': list, + 'default': [{"country_code": "us", + "label": {"en":"west"}, + "capabilities": {}, + "hosts": ["1.2.3.4", "1.2.3.5"]}] + } + } +} diff --git a/src/leap/eip/tests/__init__.py b/src/leap/eip/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/eip/tests/__init__.py diff --git a/src/leap/eip/tests/data.py b/src/leap/eip/tests/data.py new file mode 100644 index 00000000..cadf720e --- /dev/null +++ b/src/leap/eip/tests/data.py @@ -0,0 +1,48 @@ +from __future__ import unicode_literals +import os + +#from leap import __branding + +# sample data used in tests + +#PROVIDER = __branding.get('provider_domain') +PROVIDER = "testprovider.example.org" + +EIP_SAMPLE_CONFIG = { + "provider": "%s" % PROVIDER, + "transport": "openvpn", + "openvpn_protocol": "tcp", + "openvpn_port": 80, + "openvpn_ca_certificate": os.path.expanduser( + "~/.config/leap/providers/" + "%s/" + "keys/ca/cacert.pem" % PROVIDER), + "openvpn_client_certificate": os.path.expanduser( + "~/.config/leap/providers/" + "%s/" + "keys/client/openvpn.pem" % PROVIDER), + "connect_on_login": True, + "block_cleartext_traffic": True, + "primary_gateway": "turkey", + "secondary_gateway": "france", + #"management_password": "oph7Que1othahwiech6J" +} + +EIP_SAMPLE_SERVICE = { + "serial": 1, + "version": "0.1.0", + "capabilities": { + "transport": ["openvpn"], + "ports": ["80", "53"], + "protocols": ["udp", "tcp"], + "static_ips": True, + "adblock": True + }, + "gateways": [ + {"country_code": "tr", + "name": "turkey", + "label": {"en":"Ankara, Turkey"}, + "capabilities": {}, + "hosts": ["192.0.43.10"]} + ] +} diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py new file mode 100644 index 00000000..1d7bfc17 --- /dev/null +++ b/src/leap/eip/tests/test_checks.py @@ -0,0 +1,367 @@ +from BaseHTTPServer import BaseHTTPRequestHandler +import copy +import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os +import time +import urlparse + +from mock import (patch, Mock) + +import jsonschema +#import ping +import requests + +from leap.base import config as baseconfig +from leap.base.constants import (DEFAULT_PROVIDER_DEFINITION, + DEFINITION_EXPECTED_PATH) +from leap.eip import checks as eipchecks +from leap.eip import specs as eipspecs +from leap.eip import exceptions as eipexceptions +from leap.eip.tests import data as testdata +from leap.testing.basetest import BaseLeapTest +from leap.testing.https_server import BaseHTTPSServerTestCase +from leap.testing.https_server import where as where_cert + + +class NoLogRequestHandler: + def log_message(self, *args): + # don't write log msg to stderr + pass + + def read(self, n=None): + return '' + + +class EIPCheckTest(BaseLeapTest): + + __name__ = "eip_check_tests" + provider = "testprovider.example.org" + maxDiff = None + + def setUp(self): + pass + + def tearDown(self): + pass + + # test methods are there, and can be called from run_all + + def test_checker_should_implement_check_methods(self): + checker = eipchecks.EIPConfigChecker(domain=self.provider) + + self.assertTrue(hasattr(checker, "check_default_eipconfig"), + "missing meth") + self.assertTrue(hasattr(checker, "check_is_there_default_provider"), + "missing meth") + self.assertTrue(hasattr(checker, "fetch_definition"), "missing meth") + self.assertTrue(hasattr(checker, "fetch_eip_service_config"), + "missing meth") + self.assertTrue(hasattr(checker, "check_complete_eip_config"), + "missing meth") + + def test_checker_should_actually_call_all_tests(self): + checker = eipchecks.EIPConfigChecker(domain=self.provider) + + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_default_eipconfig.called, "not called") + self.assertTrue(mc.check_is_there_default_provider.called, + "not called") + self.assertTrue(mc.fetch_definition.called, + "not called") + self.assertTrue(mc.fetch_eip_service_config.called, + "not called") + self.assertTrue(mc.check_complete_eip_config.called, + "not called") + + # test individual check methods + + def test_check_default_eipconfig(self): + checker = eipchecks.EIPConfigChecker(domain=self.provider) + # no eip config (empty home) + eipconfig_path = checker.eipconfig.filename + self.assertFalse(os.path.isfile(eipconfig_path)) + checker.check_default_eipconfig() + # we've written one, so it should be there. + self.assertTrue(os.path.isfile(eipconfig_path)) + with open(eipconfig_path, 'rb') as fp: + deserialized = json.load(fp) + + # force re-evaluation of the paths + # small workaround for evaluating home dirs correctly + EIP_SAMPLE_CONFIG = copy.copy(testdata.EIP_SAMPLE_CONFIG) + EIP_SAMPLE_CONFIG['openvpn_client_certificate'] = \ + eipspecs.client_cert_path(self.provider) + EIP_SAMPLE_CONFIG['openvpn_ca_certificate'] = \ + eipspecs.provider_ca_path(self.provider) + self.assertEqual(deserialized, EIP_SAMPLE_CONFIG) + + # TODO: shold ALSO run validation methods. + + def test_check_is_there_default_provider(self): + checker = eipchecks.EIPConfigChecker(domain=self.provider) + # we do dump a sample eip config, but lacking a + # default provider entry. + # This error will be possible catched in a different + # place, when JSONConfig does validation of required fields. + + # passing direct config + with self.assertRaises(eipexceptions.EIPMissingDefaultProvider): + checker.check_is_there_default_provider(config={}) + + # ok. now, messing with real files... + # blank out default_provider + sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG) + sampleconfig['provider'] = None + eipcfg_path = checker.eipconfig.filename + with open(eipcfg_path, 'w') as fp: + json.dump(sampleconfig, fp) + #with self.assertRaises(eipexceptions.EIPMissingDefaultProvider): + # XXX we should catch this as one of our errors, but do not + # see how to do it quickly. + with self.assertRaises(jsonschema.ValidationError): + #import ipdb;ipdb.set_trace() + checker.eipconfig.load(fromfile=eipcfg_path) + checker.check_is_there_default_provider() + + sampleconfig = testdata.EIP_SAMPLE_CONFIG + #eipcfg_path = checker._get_default_eipconfig_path() + with open(eipcfg_path, 'w') as fp: + json.dump(sampleconfig, fp) + checker.eipconfig.load() + self.assertTrue(checker.check_is_there_default_provider()) + + def test_fetch_definition(self): + with patch.object(requests, "get") as mocked_get: + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json = DEFAULT_PROVIDER_DEFINITION + checker = eipchecks.EIPConfigChecker(fetcher=requests) + sampleconfig = testdata.EIP_SAMPLE_CONFIG + checker.fetch_definition(config=sampleconfig) + + fn = os.path.join(baseconfig.get_default_provider_path(), + DEFINITION_EXPECTED_PATH) + with open(fn, 'r') as fp: + deserialized = json.load(fp) + self.assertEqual(DEFAULT_PROVIDER_DEFINITION, deserialized) + + # XXX TODO check for ConnectionError, HTTPError, InvalidUrl + # (and proper EIPExceptions are raised). + # Look at base.test_config. + + def test_fetch_eip_service_config(self): + with patch.object(requests, "get") as mocked_get: + mocked_get.return_value.status_code = 200 + mocked_get.return_value.json = testdata.EIP_SAMPLE_SERVICE + checker = eipchecks.EIPConfigChecker(fetcher=requests) + sampleconfig = testdata.EIP_SAMPLE_CONFIG + checker.fetch_eip_service_config(config=sampleconfig) + + def test_check_complete_eip_config(self): + checker = eipchecks.EIPConfigChecker() + with self.assertRaises(eipexceptions.EIPConfigurationError): + sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG) + sampleconfig['provider'] = None + checker.check_complete_eip_config(config=sampleconfig) + with self.assertRaises(eipexceptions.EIPConfigurationError): + sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG) + del sampleconfig['provider'] + checker.check_complete_eip_config(config=sampleconfig) + + # normal case + sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG) + checker.check_complete_eip_config(config=sampleconfig) + + +class ProviderCertCheckerTest(BaseLeapTest): + + __name__ = "provider_cert_checker_tests" + provider = "testprovider.example.org" + + def setUp(self): + pass + + def tearDown(self): + pass + + # test methods are there, and can be called from run_all + + def test_checker_should_implement_check_methods(self): + checker = eipchecks.ProviderCertChecker() + + # For MVS+ + self.assertTrue(hasattr(checker, "download_ca_cert"), + "missing meth") + self.assertTrue(hasattr(checker, "download_ca_signature"), + "missing meth") + self.assertTrue(hasattr(checker, "get_ca_signatures"), "missing meth") + self.assertTrue(hasattr(checker, "is_there_trust_path"), + "missing meth") + + # For MVS + self.assertTrue(hasattr(checker, "is_there_provider_ca"), + "missing meth") + self.assertTrue(hasattr(checker, "is_https_working"), "missing meth") + self.assertTrue(hasattr(checker, "check_new_cert_needed"), + "missing meth") + + def test_checker_should_actually_call_all_tests(self): + checker = eipchecks.ProviderCertChecker() + + mc = Mock() + checker.run_all(checker=mc) + # XXX MVS+ + #self.assertTrue(mc.download_ca_cert.called, "not called") + #self.assertTrue(mc.download_ca_signature.called, "not called") + #self.assertTrue(mc.get_ca_signatures.called, "not called") + #self.assertTrue(mc.is_there_trust_path.called, "not called") + + # For MVS + self.assertTrue(mc.is_there_provider_ca.called, "not called") + self.assertTrue(mc.is_https_working.called, + "not called") + self.assertTrue(mc.check_new_cert_needed.called, + "not called") + + # test individual check methods + + @unittest.skip + def test_is_there_provider_ca(self): + # XXX commenting out this test. + # With the generic client this does not make sense, + # we should dump one there. + # or test conductor logic. + checker = eipchecks.ProviderCertChecker() + self.assertTrue( + checker.is_there_provider_ca()) + + +class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase, BaseLeapTest): + provider = "testprovider.example.org" + + class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler): + responses = { + '/': ['OK', ''], + '/client.cert': [ + # XXX get sample cert + '-----BEGIN CERTIFICATE-----', + '-----END CERTIFICATE-----'], + '/badclient.cert': [ + 'BADCERT']} + + def do_GET(self): + path = urlparse.urlparse(self.path) + message = '\n'.join(self.responses.get( + path.path, None)) + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + def test_is_https_working(self): + fetcher = requests + uri = "https://%s/" % (self.get_server()) + # bare requests call. this should just pass (if there is + # an https service there). + fetcher.get(uri, verify=False) + checker = eipchecks.ProviderCertChecker(fetcher=fetcher) + self.assertTrue(checker.is_https_working(uri=uri, verify=False)) + + # for local debugs, when in doubt + #self.assertTrue(checker.is_https_working(uri="https://github.com", + #verify=True)) + + # for the two checks below, I know they fail because no ca + # cert is passed to them, and I know that's the error that + # requests return with our implementation. + # We're receiving this because our + # server is dying prematurely when the handshake is interrupted on the + # client side. + # Since we have access to the server, we could check that + # the error raised has been: + # SSL23_READ_BYTES: alert bad certificate + with self.assertRaises(requests.exceptions.SSLError) as exc: + fetcher.get(uri, verify=True) + self.assertTrue( + "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message) + + # XXX FIXME! Uncomment after #638 is done + #with self.assertRaises(eipexceptions.EIPBadCertError) as exc: + #checker.is_https_working(uri=uri, verify=True) + #self.assertTrue( + #"cert verification failed" in exc.message) + + # get cacert from testing.https_server + cacert = where_cert('cacert.pem') + fetcher.get(uri, verify=cacert) + self.assertTrue(checker.is_https_working(uri=uri, verify=cacert)) + + # same, but get cacert from leap.custom + # XXX TODO! + + @unittest.skip + def test_download_new_client_cert(self): + # FIXME + # Magick srp decorator broken right now... + # Have to mock the decorator and inject something that + # can bypass the authentication + + uri = "https://%s/client.cert" % (self.get_server()) + cacert = where_cert('cacert.pem') + checker = eipchecks.ProviderCertChecker(domain=self.provider) + credentials = "testuser", "testpassword" + self.assertTrue(checker.download_new_client_cert( + credentials=credentials, uri=uri, verify=cacert)) + + # now download a malformed cert + uri = "https://%s/badclient.cert" % (self.get_server()) + cacert = where_cert('cacert.pem') + checker = eipchecks.ProviderCertChecker() + with self.assertRaises(ValueError): + self.assertTrue(checker.download_new_client_cert( + credentials=credentials, uri=uri, verify=cacert)) + + # did we write cert to its path? + clientcertfile = eipspecs.client_cert_path() + self.assertTrue(os.path.isfile(clientcertfile)) + certfile = eipspecs.client_cert_path() + with open(certfile, 'r') as cf: + certcontent = cf.read() + self.assertEqual(certcontent, + '\n'.join( + self.request_handler.responses['/client.cert'])) + os.remove(clientcertfile) + + def test_is_cert_valid(self): + checker = eipchecks.ProviderCertChecker() + # TODO: better exception catching + # should raise eipexceptions.BadClientCertificate, and give reasons + # on msg. + with self.assertRaises(Exception) as exc: + self.assertFalse(checker.is_cert_valid()) + exc.message = "missing cert" + + def test_bad_validity_certs(self): + checker = eipchecks.ProviderCertChecker() + certfile = where_cert('leaptestscert.pem') + self.assertFalse(checker.is_cert_not_expired( + certfile=certfile, + now=lambda: time.mktime((2038, 1, 1, 1, 1, 1, 1, 1, 1)))) + self.assertFalse(checker.is_cert_not_expired( + certfile=certfile, + now=lambda: time.mktime((1970, 1, 1, 1, 1, 1, 1, 1, 1)))) + + def test_check_new_cert_needed(self): + # check: missing cert + checker = eipchecks.ProviderCertChecker(domain=self.provider) + self.assertTrue(checker.check_new_cert_needed(skip_download=True)) + # TODO check: malformed cert + # TODO check: expired cert + # TODO check: pass test server uri instead of skip + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/eip/tests/test_config.py b/src/leap/eip/tests/test_config.py new file mode 100644 index 00000000..50538240 --- /dev/null +++ b/src/leap/eip/tests/test_config.py @@ -0,0 +1,153 @@ +import json +import os +import platform +import stat + +try: + import unittest2 as unittest +except ImportError: + import unittest + +#from leap.base import constants +#from leap.eip import config as eip_config +from leap import __branding as BRANDING +from leap.eip import config as eipconfig +from leap.eip.tests.data import EIP_SAMPLE_CONFIG, EIP_SAMPLE_SERVICE +from leap.testing.basetest import BaseLeapTest +from leap.util.fileutil import mkdir_p + +_system = platform.system() + +#PROVIDER = BRANDING.get('provider_domain') +#PROVIDER_SHORTNAME = BRANDING.get('short_name') + + +class EIPConfigTest(BaseLeapTest): + + __name__ = "eip_config_tests" + provider = "testprovider.example.org" + + def setUp(self): + pass + + def tearDown(self): + pass + + # + # helpers + # + + def touch_exec(self): + path = os.path.join( + self.tempdir, 'bin') + mkdir_p(path) + tfile = os.path.join( + path, + 'openvpn') + open(tfile, 'wb').close() + os.chmod(tfile, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) + + def write_sample_eipservice(self): + conf = eipconfig.EIPServiceConfig() + folder, f = os.path.split(conf.filename) + if not os.path.isdir(folder): + mkdir_p(folder) + with open(conf.filename, 'w') as fd: + fd.write(json.dumps(EIP_SAMPLE_SERVICE)) + + def write_sample_eipconfig(self): + conf = eipconfig.EIPConfig() + folder, f = os.path.split(conf.filename) + if not os.path.isdir(folder): + mkdir_p(folder) + with open(conf.filename, 'w') as fd: + fd.write(json.dumps(EIP_SAMPLE_CONFIG)) + + def get_expected_openvpn_args(self): + args = [] + username = self.get_username() + groupname = self.get_groupname() + + args.append('--client') + args.append('--dev') + #does this have to be tap for win?? + args.append('tun') + args.append('--persist-tun') + args.append('--persist-key') + args.append('--remote') + args.append('%s' % eipconfig.get_eip_gateway( + provider=self.provider)) + # XXX get port!? + args.append('1194') + # XXX get proto + args.append('udp') + args.append('--tls-client') + args.append('--remote-cert-tls') + args.append('server') + + args.append('--user') + args.append(username) + args.append('--group') + args.append(groupname) + args.append('--management-client-user') + args.append(username) + args.append('--management-signal') + + args.append('--management') + #XXX hey! + #get platform switches here! + args.append('/tmp/test.socket') + args.append('unix') + + # certs + # XXX get values from specs? + args.append('--cert') + args.append(os.path.join( + self.home, + '.config', 'leap', 'providers', + '%s' % self.provider, + 'keys', 'client', + 'openvpn.pem')) + args.append('--key') + args.append(os.path.join( + self.home, + '.config', 'leap', 'providers', + '%s' % self.provider, + 'keys', 'client', + 'openvpn.pem')) + args.append('--ca') + args.append(os.path.join( + self.home, + '.config', 'leap', 'providers', + '%s' % self.provider, + 'keys', 'ca', + 'cacert.pem')) + return args + + # build command string + # these tests are going to have to check + # many combinations. we should inject some + # params in the function call, to disable + # some checks. + + def test_build_ovpn_command_empty_config(self): + self.touch_exec() + self.write_sample_eipservice() + self.write_sample_eipconfig() + + from leap.eip import config as eipconfig + from leap.util.fileutil import which + path = os.environ['PATH'] + vpnbin = which('openvpn', path=path) + print 'path =', path + print 'vpnbin = ', vpnbin + command, args = eipconfig.build_ovpn_command( + do_pkexec_check=False, vpnbin=vpnbin, + socket_path="/tmp/test.socket", + provider=self.provider) + self.assertEqual(command, self.home + '/bin/openvpn') + self.assertEqual(args, self.get_expected_openvpn_args()) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/eip/tests/test_eipconnection.py b/src/leap/eip/tests/test_eipconnection.py new file mode 100644 index 00000000..aefca36f --- /dev/null +++ b/src/leap/eip/tests/test_eipconnection.py @@ -0,0 +1,191 @@ +import logging +import platform +import os + +logging.basicConfig() +logger = logging.getLogger(name=__name__) + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from mock import Mock, patch # MagicMock + +from leap.eip.eipconnection import EIPConnection +from leap.eip.exceptions import ConnectionRefusedError +from leap.eip import specs as eipspecs +from leap.testing.basetest import BaseLeapTest + +_system = platform.system() + +PROVIDER = "testprovider.example.org" + + +class NotImplementedError(Exception): + pass + + +@patch('OpenVPNConnection._get_or_create_config') +@patch('OpenVPNConnection._set_ovpn_command') +class MockedEIPConnection(EIPConnection): + + def _set_ovpn_command(self): + self.command = "mock_command" + self.args = [1, 2, 3] + + +class EIPConductorTest(BaseLeapTest): + + __name__ = "eip_conductor_tests" + provider = PROVIDER + + def setUp(self): + # XXX there's a conceptual/design + # mistake here. + # If we're testing just attrs after init, + # init shold not be doing so much side effects. + + # for instance: + # We have to TOUCH a keys file because + # we're triggerig the key checks FROM + # the constructor. me not like that, + # key checker should better be called explicitelly. + + # XXX change to keys_checker invocation + # (see config_checker) + + keyfiles = (eipspecs.provider_ca_path(domain=self.provider), + eipspecs.client_cert_path(domain=self.provider)) + for filepath in keyfiles: + self.touch(filepath) + self.chmod600(filepath) + + # we init the manager with only + # some methods mocked + self.manager = Mock(name="openvpnmanager_mock") + self.con = MockedEIPConnection() + self.con.provider = self.provider + self.con.run_openvpn_checks() + + def tearDown(self): + del self.con + + # + # tests + # + + def test_vpnconnection_defaults(self): + """ + default attrs as expected + """ + con = self.con + self.assertEqual(con.autostart, True) + + def test_ovpn_command(self): + """ + set_ovpn_command called + """ + self.assertEqual(self.con.command, + "mock_command") + self.assertEqual(self.con.args, + [1, 2, 3]) + + # config checks + + def test_config_checked_called(self): + # XXX this single test is taking half of the time + # needed to run tests. (roughly 3 secs for this only) + # We should modularize and inject Mocks on more places. + + del(self.con) + config_checker = Mock() + self.con = MockedEIPConnection(config_checker=config_checker) + self.assertTrue(config_checker.called) + self.con.run_checks() + self.con.config_checker.run_all.assert_called_with( + skip_download=False) + + # XXX test for cert_checker also + + # connect/disconnect calls + + def test_disconnect(self): + """ + disconnect method calls private and changes status + """ + self.con._disconnect = Mock( + name="_disconnect") + + # first we set status to connected + self.con.status.set_current(self.con.status.CONNECTED) + self.assertEqual(self.con.status.current, + self.con.status.CONNECTED) + + # disconnect + self.con.cleanup = Mock() + self.con.disconnect() + self.con.cleanup.assert_called_once_with() + + # new status should be disconnected + # XXX this should evolve and check no errors + # during disconnection + self.assertEqual(self.con.status.current, + self.con.status.DISCONNECTED) + + def test_connect(self): + """ + connect calls _launch_openvpn private + """ + self.con._launch_openvpn = Mock() + self.con.connect() + self.con._launch_openvpn.assert_called_once_with() + + # XXX tests breaking here ... + + def test_good_poll_connection_state(self): + """ + """ + #@patch -- + # self.manager.get_connection_state + + #XXX review this set of poll_state tests + #they SHOULD NOT NEED TO MOCK ANYTHING IN THE + #lower layers!! -- status, vpn_manager.. + #right now we're testing implementation, not + #behavior!!! + good_state = ["1345466946", "unknown_state", "ok", + "192.168.1.1", "192.168.1.100"] + self.con.get_connection_state = Mock(return_value=good_state) + self.con.status.set_vpn_state = Mock() + + state = self.con.poll_connection_state() + good_state[1] = "disconnected" + final_state = tuple(good_state) + self.con.status.set_vpn_state.assert_called_with("unknown_state") + self.assertEqual(state, final_state) + + # TODO between "good" and "bad" (exception raised) cases, + # we can still test for malformed states and see that only good + # states do have a change (and from only the expected transition + # states). + + def test_bad_poll_connection_state(self): + """ + get connection state raises ConnectionRefusedError + state is None + """ + self.con.get_connection_state = Mock( + side_effect=ConnectionRefusedError('foo!')) + state = self.con.poll_connection_state() + self.assertEqual(state, None) + + + # XXX more things to test: + # - called config routines during initz. + # - raising proper exceptions with no config + # - called proper checks on config / permissions + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/eip/tests/test_openvpnconnection.py b/src/leap/eip/tests/test_openvpnconnection.py new file mode 100644 index 00000000..0f27facf --- /dev/null +++ b/src/leap/eip/tests/test_openvpnconnection.py @@ -0,0 +1,147 @@ +import logging +import os +import platform +import psutil +import shutil +#import socket + +logging.basicConfig() +logger = logging.getLogger(name=__name__) + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from mock import Mock, patch # MagicMock + +from leap.eip import config as eipconfig +from leap.eip import openvpnconnection +from leap.eip import exceptions as eipexceptions +from leap.eip.udstelnet import UDSTelnet +from leap.testing.basetest import BaseLeapTest + +_system = platform.system() + + +class NotImplementedError(Exception): + pass + + +mock_UDSTelnet = Mock(spec=UDSTelnet) +# XXX cautious!!! +# this might be fragile right now (counting a global +# reference of calls I think. +# investigate this other form instead: +# http://www.voidspace.org.uk/python/mock/patch.html#start-and-stop + +# XXX redo after merge-refactor + + +@patch('openvpnconnection.OpenVPNConnection.connect_to_management') +class MockedOpenVPNConnection(openvpnconnection.OpenVPNConnection): + def __init__(self, *args, **kwargs): + self.mock_UDSTelnet = Mock() + super(MockedOpenVPNConnection, self).__init__( + *args, **kwargs) + self.tn = self.mock_UDSTelnet(self.host, self.port) + + def connect_to_management(self): + #print 'patched connect' + self.tn = mock_UDSTelnet(self.host, port=self.port) + + +class OpenVPNConnectionTest(BaseLeapTest): + + __name__ = "vpnconnection_tests" + + def setUp(self): + # XXX this will have to change for win, host=localhost + host = eipconfig.get_socket_path() + self.manager = MockedOpenVPNConnection(host=host) + + def tearDown(self): + # remove the socket folder. + # XXX only if posix. in win, host is localhost, so nothing + # has to be done. + if self.manager.host: + folder, fpath = os.path.split(self.manager.host) + assert folder.startswith('/tmp/leap-tmp') # safety check + shutil.rmtree(folder) + + del self.manager + + # + # tests + # + + def test_detect_vpn(self): + # XXX review, not sure if captured all the logic + # while fixing. kali. + openvpn_connection = openvpnconnection.OpenVPNConnection() + + with patch.object(psutil, "get_process_list") as mocked_psutil: + mocked_process = Mock() + mocked_process.name = "openvpn" + mocked_psutil.return_value = [mocked_process] + with self.assertRaises(eipexceptions.OpenVPNAlreadyRunning): + openvpn_connection._check_if_running_instance() + + openvpn_connection._check_if_running_instance() + + @unittest.skipIf(_system == "Windows", "lin/mac only") + def test_lin_mac_default_init(self): + """ + check default host for management iface + """ + self.assertTrue(self.manager.host.startswith('/tmp/leap-tmp')) + self.assertEqual(self.manager.port, 'unix') + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_default_init(self): + """ + check default host for management iface + """ + # XXX should we make the platform specific switch + # here or in the vpn command string building? + self.assertEqual(self.manager.host, 'localhost') + self.assertEqual(self.manager.port, 7777) + + def test_port_types_init(self): + self.manager = MockedOpenVPNConnection(port="42") + self.assertEqual(self.manager.port, 42) + self.manager = MockedOpenVPNConnection() + self.assertEqual(self.manager.port, "unix") + self.manager = MockedOpenVPNConnection(port="bad") + self.assertEqual(self.manager.port, None) + + def test_uds_telnet_called_on_connect(self): + self.manager.connect_to_management() + mock_UDSTelnet.assert_called_with( + self.manager.host, + port=self.manager.port) + + @unittest.skip + def test_connect(self): + raise NotImplementedError + # XXX calls close + # calls UDSTelnet mock. + + # XXX + # tests to write: + # UDSTelnetTest (for real?) + # HAVE A LOOK AT CORE TESTS FOR TELNETLIB. + # very illustrative instead... + + # - raise MissingSocket + # - raise ConnectionRefusedError + # - test send command + # - tries connect + # - ... tries? + # - ... calls _seek_to_eof + # - ... read_until --> return value + # - ... + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/eip/udstelnet.py b/src/leap/eip/udstelnet.py new file mode 100644 index 00000000..18e927c2 --- /dev/null +++ b/src/leap/eip/udstelnet.py @@ -0,0 +1,38 @@ +import os +import socket +import telnetlib + +from leap.eip import exceptions as eip_exceptions + + +class UDSTelnet(telnetlib.Telnet): + """ + a telnet-alike class, that can listen + on unix domain sockets + """ + + def open(self, host, port=23, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): + """Connect to a host. If port is 'unix', it + will open a connection over unix docmain sockets. + + The optional second argument is the port number, which + defaults to the standard telnet port (23). + + Don't try to reopen an already connected instance. + """ + self.eof = 0 + self.host = host + self.port = port + self.timeout = timeout + + if self.port == "unix": + # unix sockets spoken + if not os.path.exists(self.host): + raise eip_exceptions.MissingSocketError + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + self.sock.connect(self.host) + except socket.error: + raise eip_exceptions.ConnectionRefusedError + else: + self.sock = socket.create_connection((host, port), timeout) diff --git a/src/leap/eip/vpnmanager.py b/src/leap/eip/vpnmanager.py deleted file mode 100644 index 78777cfb..00000000 --- a/src/leap/eip/vpnmanager.py +++ /dev/null @@ -1,262 +0,0 @@ -from __future__ import (print_function) -import logging -import os -import socket -import telnetlib -import time - -logger = logging.getLogger(name=__name__) - -TELNET_PORT = 23 - - -class MissingSocketError(Exception): - pass - - -class ConnectionRefusedError(Exception): - pass - - -class UDSTelnet(telnetlib.Telnet): - - def open(self, host, port=0, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): - """Connect to a host. If port is 'unix', it - will open a connection over unix docmain sockets. - - The optional second argument is the port number, which - defaults to the standard telnet port (23). - - Don't try to reopen an already connected instance. - """ - self.eof = 0 - if not port: - port = TELNET_PORT - self.host = host - self.port = port - self.timeout = timeout - - if self.port == "unix": - # unix sockets spoken - if not os.path.exists(self.host): - raise MissingSocketError - self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - self.sock.connect(self.host) - except socket.error: - raise ConnectionRefusedError - else: - self.sock = socket.create_connection((host, port), timeout) - - -# this class based in code from cube-routed project - -class OpenVPNManager(object): - """ - Run commands over OpenVPN management interface - and parses the output. - """ - # XXX might need a lock to avoid - # race conditions here... - - def __init__(self, host="/tmp/.eip.sock", port="unix", password=None): - #XXX hardcoded host here. change. - self.host = host - if isinstance(port, str) and port.isdigit(): - port = int(port) - self.port = port - self.password = password - self.tn = None - - #XXX workaround for signaling - #the ui that we don't know how to - #manage a connection error - self.with_errors = False - - def forget_errors(self): - print('forgetting errors') - self.with_errors = False - - def connect(self): - """Connect to openvpn management interface""" - try: - self.close() - except: - #XXX don't like this general - #catch here. - pass - if self.connected(): - return True - self.tn = UDSTelnet(self.host, self.port) - - # XXX make password optional - # specially for win plat. we should generate - # the pass on the fly when invoking manager - # from conductor - - #self.tn.read_until('ENTER PASSWORD:', 2) - #self.tn.write(self.password + '\n') - #self.tn.read_until('SUCCESS:', 2) - - self._seek_to_eof() - self.forget_errors() - return True - - def _seek_to_eof(self): - """ - Read as much as available. Position seek pointer to end of stream - """ - b = self.tn.read_eager() - while b: - b = self.tn.read_eager() - - def connected(self): - """ - Returns True if connected - rtype: bool - """ - #return bool(getattr(self, 'tn', None)) - try: - assert self.tn - return True - except: - #XXX get rid of - #this pokemon exception!!! - return False - - def close(self, announce=True): - """ - Close connection to openvpn management interface - """ - if announce: - self.tn.write("quit\n") - self.tn.read_all() - self.tn.get_socket().close() - del self.tn - - def _send_command(self, cmd, tries=0): - """ - Send a command to openvpn and return response as list - """ - if tries > 3: - return [] - if not self.connected(): - try: - self.connect() - except MissingSocketError: - #XXX capture more helpful error - #messages - #pass - return self.make_error() - try: - self.tn.write(cmd + "\n") - except socket.error: - logger.error('socket error') - print('socket error!') - self.close(announce=False) - self._send_command(cmd, tries=tries + 1) - return [] - buf = self.tn.read_until(b"END", 2) - self._seek_to_eof() - blist = buf.split('\r\n') - if blist[-1].startswith('END'): - del blist[-1] - return blist - else: - return [] - - def _send_short_command(self, cmd): - """ - parse output from commands that are - delimited by "success" instead - """ - if not self.connected(): - self.connect() - self.tn.write(cmd + "\n") - # XXX not working? - buf = self.tn.read_until(b"SUCCESS", 2) - self._seek_to_eof() - blist = buf.split('\r\n') - return blist - - # - # useful vpn commands - # - - def pid(self): - #XXX broken - return self._send_short_command("pid") - - def make_error(self): - """ - capture error and wrap it in an - understandable format - """ - #XXX get helpful error codes - self.with_errors = True - now = int(time.time()) - return '%s,LAUNCHER ERROR,ERROR,-,-' % now - - def state(self): - """ - OpenVPN command: state - """ - state = self._send_command("state") - if not state: - return None - if isinstance(state, str): - return state - if isinstance(state, list): - if len(state) == 1: - return state[0] - else: - return state[-1] - - def status(self): - """ - OpenVPN command: status - """ - status = self._send_command("status") - return status - - def status2(self): - """ - OpenVPN command: last 2 statuses - """ - return self._send_command("status 2") - - # - # parse info - # - - def get_status_io(self): - status = self.status() - if isinstance(status, str): - lines = status.split('\n') - if isinstance(status, list): - lines = status - try: - (header, when, tun_read, tun_write, - tcp_read, tcp_write, auth_read) = tuple(lines) - except ValueError: - return None - - when_ts = time.strptime(when.split(',')[1], "%a %b %d %H:%M:%S %Y") - sep = ',' - # XXX cleanup! - tun_read = tun_read.split(sep)[1] - tun_write = tun_write.split(sep)[1] - tcp_read = tcp_read.split(sep)[1] - tcp_write = tcp_write.split(sep)[1] - auth_read = auth_read.split(sep)[1] - - # XXX this could be a named tuple. prettier. - return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) - - def get_connection_state(self): - state = self.state() - if state is not None: - ts, status_step, ok, ip, remote = state.split(',') - ts = time.gmtime(float(ts)) - # XXX this could be a named tuple. prettier. - return ts, status_step, ok, ip, remote diff --git a/src/leap/eip/vpnwatcher.py b/src/leap/eip/vpnwatcher.py deleted file mode 100644 index 09bd5811..00000000 --- a/src/leap/eip/vpnwatcher.py +++ /dev/null @@ -1,169 +0,0 @@ -"""generic watcher object that keeps track of connection status""" -# This should be deprecated in favor of daemon mode + management -# interface. But we can leave it here for debug purposes. - - -class EIPConnectionStatus(object): - """ - Keep track of client (gui) and openvpn - states. - - These are the OpenVPN states: - CONNECTING -- OpenVPN's initial state. - WAIT -- (Client only) Waiting for initial response - from server. - AUTH -- (Client only) Authenticating with server. - GET_CONFIG -- (Client only) Downloading configuration options - from server. - ASSIGN_IP -- Assigning IP address to virtual network - interface. - ADD_ROUTES -- Adding routes to system. - CONNECTED -- Initialization Sequence Completed. - RECONNECTING -- A restart has occurred. - EXITING -- A graceful exit is in progress. - - We add some extra states: - - DISCONNECTED -- GUI initial state. - UNRECOVERABLE -- An unrecoverable error has been raised - while invoking openvpn service. - """ - CONNECTING = 1 - WAIT = 2 - AUTH = 3 - GET_CONFIG = 4 - ASSIGN_IP = 5 - ADD_ROUTES = 6 - CONNECTED = 7 - RECONNECTING = 8 - EXITING = 9 - - # gui specific states: - UNRECOVERABLE = 11 - DISCONNECTED = 0 - - def __init__(self, callbacks=None): - """ - EIPConnectionStatus is initialized with a tuple - of signals to be triggered. - :param callbacks: a tuple of (callable) observers - :type callbacks: tuple - """ - # (callbacks to connect to signals in Qt-land) - self.current = self.DISCONNECTED - self.previous = None - self.callbacks = callbacks - - def get_readable_status(self): - # XXX DRY status / labels a little bit. - # think we'll want to i18n this. - human_status = { - 0: 'disconnected', - 1: 'connecting', - 2: 'waiting', - 3: 'authenticating', - 4: 'getting config', - 5: 'assigning ip', - 6: 'adding routes', - 7: 'connected', - 8: 'reconnecting', - 9: 'exiting', - 11: 'unrecoverable error', - } - return human_status[self.current] - - def get_state_icon(self): - """ - returns the high level icon - for each fine-grain openvpn state - """ - connecting = (self.CONNECTING, - self.WAIT, - self.AUTH, - self.GET_CONFIG, - self.ASSIGN_IP, - self.ADD_ROUTES) - connected = (self.CONNECTED,) - disconnected = (self.DISCONNECTED, - self.UNRECOVERABLE) - - # this can be made smarter, - # but it's like it'll change, - # so +readability. - - if self.current in connecting: - return "connecting" - if self.current in connected: - return "connected" - if self.current in disconnected: - return "disconnected" - - def set_vpn_state(self, status): - """ - accepts a state string from the management - interface, and sets the internal state. - :param status: openvpn STATE (uppercase). - :type status: str - """ - if hasattr(self, status): - self.change_to(getattr(self, status)) - - def set_current(self, to): - """ - setter for the 'current' property - :param to: destination state - :type to: int - """ - self.current = to - - def change_to(self, to): - """ - :param to: destination state - :type to: int - """ - if to == self.current: - return - changed = False - from_ = self.current - self.current = to - - # We can add transition restrictions - # here to ensure no transitions are - # allowed outside the fsm. - - self.set_current(to) - changed = True - - #trigger signals (as callbacks) - #print('current state: %s' % self.current) - if changed: - self.previous = from_ - if self.callbacks: - for cb in self.callbacks: - if callable(cb): - cb(self) - - -def status_watcher(cs, line): - """ - a wrapper that calls to ConnectionStatus object - :param cs: a EIPConnectionStatus instance - :type cs: EIPConnectionStatus object - :param line: a single line of the watched output - :type line: str - """ - #print('status watcher watching') - - # from the mullvad code, should watch for - # things like: - # "Initialization Sequence Completed" - # "With Errors" - # "Tap-Win32" - - if "Completed" in line: - cs.change_to(cs.CONNECTED) - return - - if "Initial packet from" in line: - cs.change_to(cs.CONNECTING) - return diff --git a/src/leap/gui/__init__.py b/src/leap/gui/__init__.py index e69de29b..9b8f8746 100644 --- a/src/leap/gui/__init__.py +++ b/src/leap/gui/__init__.py @@ -0,0 +1,10 @@ +try: + import sip + sip.setapi('QString', 2) + sip.setapi('QVariant', 2) +except ValueError: + pass + +import firstrun + +__all__ = ['firstrun'] diff --git a/src/leap/gui/constants.py b/src/leap/gui/constants.py new file mode 100644 index 00000000..277f3540 --- /dev/null +++ b/src/leap/gui/constants.py @@ -0,0 +1,13 @@ +import time + +APP_LOGO = ':/images/leap-color-small.png' + +# bare is the username portion of a JID +# full includes the "at" and some extra chars +# that can be allowed for fqdn + +BARE_USERNAME_REGEX = r"^[A-Za-z\d_]+$" +FULL_USERNAME_REGEX = r"^[A-Za-z\d_@.-]+$" + +GUI_PAUSE_FOR_USER_SECONDS = 1 +pause_for_user = lambda: time.sleep(GUI_PAUSE_FOR_USER_SECONDS) diff --git a/src/leap/gui/firstrun/__init__.py b/src/leap/gui/firstrun/__init__.py new file mode 100644 index 00000000..8a70d90e --- /dev/null +++ b/src/leap/gui/firstrun/__init__.py @@ -0,0 +1,29 @@ +try: + import sip + sip.setapi('QString', 2) + sip.setapi('QVariant', 2) +except ValueError: + pass + +import connect +import intro +import last +import login +import mixins +import providerinfo +import providerselect +import providersetup +import register +import regvalidation + +__all__ = [ + 'connect', + 'intro', + 'last', + 'login', + 'mixins', + 'providerinfo', + 'providerselect', + 'providersetup', + 'register', + 'regvalidation'] diff --git a/src/leap/gui/firstrun/connect.py b/src/leap/gui/firstrun/connect.py new file mode 100644 index 00000000..a0fe021c --- /dev/null +++ b/src/leap/gui/firstrun/connect.py @@ -0,0 +1,231 @@ +""" +Connecting Page, used in First Run Wizard +""" +# XXX FIXME +# DEPRECATED. All functionality moved to regvalidation +# This file should be removed after checking that one is ok. +# XXX + +import logging + +from PyQt4 import QtGui + +logger = logging.getLogger(__name__) + +from leap.base import auth + +from leap.gui.constants import APP_LOGO +from leap.gui.styles import ErrorLabelStyleSheet + + +class ConnectingPage(QtGui.QWizardPage): + + # XXX change to a ValidationPage + + def __init__(self, parent=None): + super(ConnectingPage, self).__init__(parent) + + self.setTitle("Connecting") + self.setSubTitle('Connecting to provider.') + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + self.status = QtGui.QLabel("") + self.status.setWordWrap(True) + self.progress = QtGui.QProgressBar() + self.progress.setMaximum(100) + self.progress.hide() + + # for pre-checks + self.status_line_1 = QtGui.QLabel() + self.status_line_2 = QtGui.QLabel() + self.status_line_3 = QtGui.QLabel() + self.status_line_4 = QtGui.QLabel() + + # for connecting signals... + self.status_line_5 = QtGui.QLabel() + + layout = QtGui.QGridLayout() + layout.addWidget(self.status, 0, 1) + layout.addWidget(self.progress, 5, 1) + layout.addWidget(self.status_line_1, 8, 1) + layout.addWidget(self.status_line_2, 9, 1) + layout.addWidget(self.status_line_3, 10, 1) + layout.addWidget(self.status_line_4, 11, 1) + + # XXX to be used? + #self.validation_status = QtGui.QLabel("") + #self.validation_status.setStyleSheet( + #ErrorLabelStyleSheet) + #self.validation_msg = QtGui.QLabel("") + + self.setLayout(layout) + + self.goto_login_again = False + + def set_status(self, status): + self.status.setText(status) + self.status.setWordWrap(True) + + def set_status_line(self, line, status): + line = getattr(self, 'status_line_%s' % line) + if line: + line.setText(status) + + def set_validation_status(self, status): + # Do not remember if we're using + # status lines > 3 now... + # if we are, move below + self.status_line_3.setStyleSheet( + ErrorLabelStyleSheet) + self.status_line_3.setText(status) + + def set_validation_message(self, message): + self.status_line_4.setText(message) + self.status_line_4.setWordWrap(True) + + def get_donemsg(self, msg): + return "%s ... done" % msg + + def run_eip_checks_for_provider_and_connect(self, domain): + wizard = self.wizard() + conductor = wizard.conductor + start_eip_signal = getattr( + wizard, + 'start_eipconnection_signal', None) + + if conductor: + conductor.set_provider_domain(domain) + conductor.run_checks() + self.conductor = conductor + errors = self.eip_error_check() + if not errors and start_eip_signal: + start_eip_signal.emit() + + else: + logger.warning( + "No conductor found. This means that " + "probably the wizard has been launched " + "in an stand-alone way") + + def eip_error_check(self): + """ + a version of the main app error checker, + but integrated within the connecting page of the wizard. + consumes the conductor error queue. + pops errors, and add those to the wizard page + """ + logger.debug('eip error check from connecting page') + errq = self.conductor.error_queue + # XXX missing! + + def fetch_and_validate(self): + # XXX MOVE TO validate function in register-validation + import time + domain = self.field('provider_domain') + wizard = self.wizard() + #pconfig = wizard.providerconfig + eipconfigchecker = wizard.eipconfigchecker() + pCertChecker = wizard.providercertchecker( + domain=domain) + + # username and password are in different fields + # if they were stored in log_in or sign_up pages. + from_login = self.wizard().from_login + unamek_base = 'userName' + passwk_base = 'userPassword' + unamek = 'login_%s' % unamek_base if from_login else unamek_base + passwk = 'login_%s' % passwk_base if from_login else passwk_base + + username = self.field(unamek) + password = self.field(passwk) + credentials = username, password + + self.progress.show() + + fetching_eip_conf_msg = 'Fetching eip service configuration' + self.set_status(fetching_eip_conf_msg) + self.progress.setValue(30) + + # Fetching eip service + eipconfigchecker.fetch_eip_service_config( + domain=domain) + + self.status_line_1.setText( + self.get_donemsg(fetching_eip_conf_msg)) + + getting_client_cert_msg = 'Getting client certificate' + self.set_status(getting_client_cert_msg) + self.progress.setValue(66) + + # Download cert + try: + pCertChecker.download_new_client_cert( + credentials=credentials, + # FIXME FIXME FIXME + # XXX FIX THIS!!!!! + # BUG #638. remove verify + # FIXME FIXME FIXME + verify=False) + except auth.SRPAuthenticationError as exc: + self.set_validation_status( + "Authentication error: %s" % exc.message) + return False + + time.sleep(2) + self.status_line_2.setText( + self.get_donemsg(getting_client_cert_msg)) + + validating_clientcert_msg = 'Validating client certificate' + self.set_status(validating_clientcert_msg) + self.progress.setValue(90) + time.sleep(2) + self.status_line_3.setText( + self.get_donemsg(validating_clientcert_msg)) + + self.progress.setValue(100) + time.sleep(3) + + # here we go! :) + self.run_eip_checks_for_provider_and_connect(domain) + + #self.validation_block = self.wait_for_validation_block() + + # XXX signal timeout! + return True + + # + # wizardpage methods + # + + def nextId(self): + wizard = self.wizard() + # XXX this does not work because + # page login has already been met + #if self.goto_login_again: + #next_ = "login" + #else: + #next_ = "lastpage" + next_ = "lastpage" + return wizard.get_page_index(next_) + + def initializePage(self): + # XXX if we're coming from signup page + # we could say something like + # 'registration successful!' + self.status.setText( + "We have " + "all we need to connect with the provider.<br><br> " + "Click <i>next</i> to continue. ") + self.progress.setValue(0) + self.progress.hide() + self.status_line_1.setText('') + self.status_line_2.setText('') + self.status_line_3.setText('') + + def validatePage(self): + # XXX remove + validated = self.fetch_and_validate() + return validated diff --git a/src/leap/gui/firstrun/constants.py b/src/leap/gui/firstrun/constants.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/gui/firstrun/constants.py diff --git a/src/leap/gui/firstrun/intro.py b/src/leap/gui/firstrun/intro.py new file mode 100644 index 00000000..4bb008c7 --- /dev/null +++ b/src/leap/gui/firstrun/intro.py @@ -0,0 +1,68 @@ +""" +Intro page used in first run wizard +""" + +from PyQt4 import QtGui + +from leap.gui.constants import APP_LOGO + + +class IntroPage(QtGui.QWizardPage): + def __init__(self, parent=None): + super(IntroPage, self).__init__(parent) + + self.setTitle("First run wizard.") + + #self.setPixmap( + #QtGui.QWizard.WatermarkPixmap, + #QtGui.QPixmap(':/images/watermark1.png')) + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + label = QtGui.QLabel( + "Now we will guide you through " + "some configuration that is needed before you " + "can connect for the first time.<br><br>" + "If you ever need to modify these options again, " + "you can find the wizard in the '<i>Settings</i>' menu from the " + "main window.<br><br>" + "Do you want to <b>sign up</b> for a new account, or <b>log " + "in</b> with an already existing username?<br>") + label.setWordWrap(True) + + radiobuttonGroup = QtGui.QGroupBox() + + self.sign_up = QtGui.QRadioButton( + "Sign up for a new account.") + self.sign_up.setChecked(True) + self.log_in = QtGui.QRadioButton( + "Log In with my credentials.") + + radiobLayout = QtGui.QVBoxLayout() + radiobLayout.addWidget(self.sign_up) + radiobLayout.addWidget(self.log_in) + radiobuttonGroup.setLayout(radiobLayout) + + layout = QtGui.QVBoxLayout() + layout.addWidget(label) + layout.addWidget(radiobuttonGroup) + self.setLayout(layout) + + self.registerField('is_signup', self.sign_up) + + def validatePage(self): + return True + + def nextId(self): + """ + returns next id + in a non-linear wizard + """ + if self.sign_up.isChecked(): + next_ = 'providerselection' + if self.log_in.isChecked(): + next_ = 'login' + wizard = self.wizard() + return wizard.get_page_index(next_) diff --git a/src/leap/gui/firstrun/last.py b/src/leap/gui/firstrun/last.py new file mode 100644 index 00000000..d33d2e77 --- /dev/null +++ b/src/leap/gui/firstrun/last.py @@ -0,0 +1,92 @@ +""" +Last Page, used in First Run Wizard +""" +import logging + +from PyQt4 import QtGui + +from leap.util.coroutines import coroutine +from leap.gui.constants import APP_LOGO + +logger = logging.getLogger(__name__) + + +class LastPage(QtGui.QWizardPage): + def __init__(self, parent=None): + super(LastPage, self).__init__(parent) + + self.setTitle("Connecting to Encrypted Internet Proxy service...") + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + #self.setPixmap( + #QtGui.QWizard.WatermarkPixmap, + #QtGui.QPixmap(':/images/watermark2.png')) + + self.label = QtGui.QLabel() + self.label.setWordWrap(True) + + # XXX REFACTOR to a Validating Page... + self.status_line_1 = QtGui.QLabel() + self.status_line_2 = QtGui.QLabel() + self.status_line_3 = QtGui.QLabel() + self.status_line_4 = QtGui.QLabel() + + layout = QtGui.QVBoxLayout() + layout.addWidget(self.label) + + # make loop + layout.addWidget(self.status_line_1) + layout.addWidget(self.status_line_2) + layout.addWidget(self.status_line_3) + layout.addWidget(self.status_line_4) + + self.setLayout(layout) + + def set_status_line(self, line, status): + statusline = getattr(self, 'status_line_%s' % line) + if statusline: + statusline.setText(status) + + def set_finished_status(self): + self.setTitle('You are now using an encrypted connection!') + finishText = self.wizard().buttonText( + QtGui.QWizard.FinishButton) + finishText = finishText.replace('&', '') + self.label.setText( + "Click '<i>%s</i>' to end the wizard and " + "save your settings." % finishText) + + @coroutine + def eip_status_handler(self): + # XXX this can be changed to use + # signals. See progress.py + logger.debug('logging status in last page') + self.validation_done = False + status_count = 0 + try: + while True: + status = (yield) + status_count += 1 + # XXX add to line... + logger.debug('status --> %s', status) + self.set_status_line(status_count, status) + if status == "connected": + self.set_finished_status() + break + except GeneratorExit: + pass + except StopIteration: + pass + + def initializePage(self): + wizard = self.wizard() + if not wizard: + return + eip_status_handler = self.eip_status_handler() + eip_statuschange_signal = wizard.eip_statuschange_signal + if eip_statuschange_signal: + eip_statuschange_signal.connect( + lambda status: eip_status_handler.send(status)) diff --git a/src/leap/gui/firstrun/login.py b/src/leap/gui/firstrun/login.py new file mode 100644 index 00000000..02bace86 --- /dev/null +++ b/src/leap/gui/firstrun/login.py @@ -0,0 +1,330 @@ +""" +LogIn Page, used inf First Run Wizard +""" +from PyQt4 import QtCore +from PyQt4 import QtGui + +import requests + +from leap.base import auth +from leap.gui.firstrun.mixins import UserFormMixIn +from leap.gui.progress import InlineValidationPage +from leap.gui import styles + +from leap.gui.constants import APP_LOGO, FULL_USERNAME_REGEX + + +class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage + + def __init__(self, parent=None): + + super(LogInPage, self).__init__(parent) + self.current_page = "login" + + self.setTitle("Log In") + self.setSubTitle("Log in with your credentials.") + self.current_page = "login" + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + self.setupSteps() + self.setupUI() + + self.do_confirm_next = False + + def setupUI(self): + userNameLabel = QtGui.QLabel("User &name:") + userNameLineEdit = QtGui.QLineEdit() + userNameLineEdit.cursorPositionChanged.connect( + self.reset_validation_status) + userNameLabel.setBuddy(userNameLineEdit) + + # let's add regex validator + usernameRe = QtCore.QRegExp(FULL_USERNAME_REGEX) + userNameLineEdit.setValidator( + QtGui.QRegExpValidator(usernameRe, self)) + + #userNameLineEdit.setPlaceholderText( + #'username@provider.example.org') + self.userNameLineEdit = userNameLineEdit + + userPasswordLabel = QtGui.QLabel("&Password:") + self.userPasswordLineEdit = QtGui.QLineEdit() + self.userPasswordLineEdit.setEchoMode( + QtGui.QLineEdit.Password) + userPasswordLabel.setBuddy(self.userPasswordLineEdit) + + self.registerField('login_userName*', self.userNameLineEdit) + self.registerField('login_userPassword*', self.userPasswordLineEdit) + + layout = QtGui.QGridLayout() + layout.setColumnMinimumWidth(0, 20) + + validationMsg = QtGui.QLabel("") + validationMsg.setStyleSheet(styles.ErrorLabelStyleSheet) + self.validationMsg = validationMsg + + layout.addWidget(validationMsg, 0, 3) + layout.addWidget(userNameLabel, 1, 0) + layout.addWidget(self.userNameLineEdit, 1, 3) + layout.addWidget(userPasswordLabel, 2, 0) + layout.addWidget(self.userPasswordLineEdit, 2, 3) + + # add validation frame + self.setupValidationFrame() + layout.addWidget(self.valFrame, 4, 2, 4, 2) + self.valFrame.hide() + + self.nextText("Log in") + self.setLayout(layout) + + #self.registerField('is_login_wizard') + + def nextText(self, text): + self.setButtonText( + QtGui.QWizard.NextButton, text) + + def nextFocus(self): + self.wizard().button( + QtGui.QWizard.NextButton).setFocus() + + def disableNextButton(self): + self.wizard().button( + QtGui.QWizard.NextButton).setDisabled(True) + + def onUserNameEdit(self, *args): + if self.initial_username_sample: + self.userNameLineEdit.setText('') + # XXX set regular color + self.initial_username_sample = None + + def disableFields(self): + for field in (self.userNameLineEdit, + self.userPasswordLineEdit): + field.setDisabled(True) + + def populateErrors(self): + # XXX could move this to ValidationMixin + # used in providerselect and register too + + errors = self.wizard().get_validation_error( + self.current_page) + #prev_er = getattr(self, 'prevalidation_error', None) + showerr = self.validationMsg.setText + + #if not errors and prev_er: + #showerr(prev_er) + #return +# + if errors: + bad_str = getattr(self, 'bad_string', None) + cur_str = self.userNameLineEdit.text() + + if bad_str is None: + # first time we fall here. + # save the current bad_string value + self.bad_string = cur_str + showerr(errors) + else: + #if prev_er: + #showerr(prev_er) + #return + # not the first time + if cur_str == bad_str: + showerr(errors) + else: + self.focused_field = False + showerr('') + + def cleanup_errormsg(self): + """ + we reset bad_string to None + should be called before leaving the page + """ + self.bad_string = None + + def paintEvent(self, event): + """ + we hook our populate errors + on paintEvent because we need it to catch + when user enters the page coming from next, + and initializePage does not cover that case. + Maybe there's a better event to hook upon. + """ + super(LogInPage, self).paintEvent(event) + self.populateErrors() + + def set_prevalidation_error(self, error): + self.prevalidation_error = error + + # pagewizard methods + + def nextId(self): + wizard = self.wizard() + if not wizard: + return + if wizard.is_provider_setup is False: + next_ = 'providersetupvalidation' + if wizard.is_provider_setup is True: + # XXX bad name, ok, gonna change that + next_ = 'signupvalidation' + return wizard.get_page_index(next_) + + def initializePage(self): + super(LogInPage, self).initializePage() + username = self.userNameLineEdit + username.setText('username@provider.example.org') + username.cursorPositionChanged.connect( + self.onUserNameEdit) + self.initial_username_sample = True + self.validationMsg.setText('') + self.valFrame.hide() + + def reset_validation_status(self): + """ + empty the validation msg + and clean the inline validation widget. + """ + self.validationMsg.setText('') + self.steps.removeAllSteps() + self.clearTable() + + def validatePage(self): + """ + if not register done, do checks. + if done, wait for click. + """ + self.disableNextButton() + self.cleanup_errormsg() + self.clean_wizard_errors(self.current_page) + + if self.do_confirm_next: + full_username = self.userNameLineEdit.text() + password = self.userPasswordLineEdit.text() + username, domain = full_username.split('@') + self.setField('provider_domain', domain) + self.setField('login_userName', username) + self.setField('login_userPassword', password) + + return True + + if not self.is_done(): + self.reset_validation_status() + self.do_checks() + + return self.is_done() + + def _do_checks(self): + # XXX convert this to inline + + full_username = self.userNameLineEdit.text() + ########################### + # 0) check user@domain form + ########################### + + def checkusername(): + if full_username.count('@') != 1: + return self.fail( + self.tr( + "Username must be in the username@provider form.")) + else: + return True + + yield(("head_sentinel", 0), checkusername) + + # XXX I think this is not needed + # since we're also checking for the is_signup field. + #self.wizard().from_login = True + + username, domain = full_username.split('@') + password = self.userPasswordLineEdit.text() + + # We try a call to an authenticated + # page here as a mean to catch + # srp authentication errors while + wizard = self.wizard() + eipconfigchecker = wizard.eipconfigchecker() + + ######################## + # 1) try name resolution + ######################## + # show the frame before going on... + QtCore.QMetaObject.invokeMethod( + self, "showStepsFrame") + + # Able to contact domain? + # can get definition? + # two-by-one + def resolvedomain(): + try: + eipconfigchecker.fetch_definition(domain=domain) + + # we're using requests here for all + # the possible error cases that it catches. + except requests.exceptions.ConnectionError as exc: + return self.fail(exc.message[1]) + except requests.exceptions.HTTPError as exc: + return self.fail(exc.message) + except Exception as exc: + # XXX get catchall error msg + return self.fail( + exc.message) + + yield((self.tr("resolving domain name"), 20), resolvedomain) + + wizard.set_providerconfig( + eipconfigchecker.defaultprovider.config) + + ######################## + # 2) do authentication + ######################## + credentials = username, password + pCertChecker = wizard.providercertchecker( + domain=domain) + + def validate_credentials(): + ################# + # FIXME #BUG #638 + verify = False + + try: + pCertChecker.download_new_client_cert( + credentials=credentials, + verify=verify) + + except auth.SRPAuthenticationError as exc: + return self.fail( + self.tr("Authentication error: %s" % exc.message)) + + except Exception as exc: + return self.fail(exc.message) + + else: + return True + + yield(('Validating credentials', 20), validate_credentials) + + self.set_done() + yield(("end_sentinel", 0), lambda: None) + + def green_validation_status(self): + val = self.validationMsg + val.setText(self.tr('Credentials validated.')) + val.setStyleSheet(styles.GreenLineEdit) + + def on_checks_validation_ready(self): + """ + after checks + """ + if self.is_done(): + self.disableFields() + self.cleanup_errormsg() + self.clean_wizard_errors(self.current_page) + # make the user confirm the transition + # to next page. + self.nextText('&Next') + self.nextFocus() + self.green_validation_status() + self.do_confirm_next = True diff --git a/src/leap/gui/firstrun/mixins.py b/src/leap/gui/firstrun/mixins.py new file mode 100644 index 00000000..c4731893 --- /dev/null +++ b/src/leap/gui/firstrun/mixins.py @@ -0,0 +1,18 @@ +""" +mixins used in First Run Wizard +""" + + +class UserFormMixIn(object): + + def reset_validation_status(self): + """ + empty the validation msg + """ + self.validationMsg.setText('') + + def set_validation_status(self, msg): + """ + set generic validation status + """ + self.validationMsg.setText(msg) diff --git a/src/leap/gui/firstrun/providerinfo.py b/src/leap/gui/firstrun/providerinfo.py new file mode 100644 index 00000000..c5b2984c --- /dev/null +++ b/src/leap/gui/firstrun/providerinfo.py @@ -0,0 +1,98 @@ +""" +Provider Info Page, used in First run Wizard +""" +import logging + +from PyQt4 import QtGui + +from leap.gui.constants import APP_LOGO + +logger = logging.getLogger(__name__) + + +class ProviderInfoPage(QtGui.QWizardPage): + + def __init__(self, parent=None): + super(ProviderInfoPage, self).__init__(parent) + + self.setTitle(self.tr("Provider Info")) + self.setSubTitle(self.tr( + "This is what provider says.")) + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + self.create_info_panel() + + def create_info_panel(self): + # Use stacked widget instead + # of reparenting the layout. + + infoWidget = QtGui.QStackedWidget() + + info = QtGui.QWidget() + layout = QtGui.QVBoxLayout() + + displayName = QtGui.QLabel("") + description = QtGui.QLabel("") + enrollment_policy = QtGui.QLabel("") + + # XXX set stylesheet... + # prettify a little bit. + # bigger fonts and so on... + + # We could use a QFrame here + + layout.addWidget(displayName) + layout.addWidget(description) + layout.addWidget(enrollment_policy) + layout.addStretch(1) + + info.setLayout(layout) + infoWidget.addWidget(info) + + pageLayout = QtGui.QVBoxLayout() + pageLayout.addWidget(infoWidget) + self.setLayout(pageLayout) + + # add refs to self to allow for + # updates. + # Watch out! Have to get rid of these references! + # this should be better handled with signals !! + self.displayName = displayName + self.description = description + self.enrollment_policy = enrollment_policy + + def show_provider_info(self): + + # XXX get multilingual objects + # directly from the config object + + lang = "en" + pconfig = self.wizard().providerconfig + + dn = pconfig.get('display_name') + display_name = dn[lang] if dn else '' + domain_name = self.field('provider_domain') + + self.displayName.setText( + "<b>%s</b> https://%s" % (display_name, domain_name)) + + desc = pconfig.get('description') + description_text = desc[lang] if desc else '' + self.description.setText( + "<i>%s</i>" % description_text) + + enroll = pconfig.get('enrollment_policy') + if enroll: + self.enrollment_policy.setText( + 'enrollment policy: %s' % enroll) + + def nextId(self): + wizard = self.wizard() + next_ = "providersetupvalidation" + return wizard.get_page_index(next_) + + def initializePage(self): + self.show_provider_info() diff --git a/src/leap/gui/firstrun/providerselect.py b/src/leap/gui/firstrun/providerselect.py new file mode 100644 index 00000000..a4be51a9 --- /dev/null +++ b/src/leap/gui/firstrun/providerselect.py @@ -0,0 +1,472 @@ +""" +Select Provider Page, used in First Run Wizard +""" +import logging + +import requests + +from PyQt4 import QtCore +from PyQt4 import QtGui + +from leap.base import exceptions as baseexceptions +#from leap.crypto import certs +from leap.eip import exceptions as eipexceptions +from leap.gui.progress import InlineValidationPage +from leap.gui import styles +from leap.gui.utils import delay +from leap.util.web import get_https_domain_and_port + +from leap.gui.constants import APP_LOGO + +logger = logging.getLogger(__name__) + + +class SelectProviderPage(InlineValidationPage): + + launchChecks = QtCore.pyqtSignal() + + def __init__(self, parent=None, providers=None): + super(SelectProviderPage, self).__init__(parent) + self.current_page = 'providerselection' + + self.setTitle(self.tr("Enter Provider")) + self.setSubTitle(self.tr( + "Please enter the domain of the provider you want " + "to use for your connection.") + ) + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + self.did_cert_check = False + + self.is_done = False + + self.setupSteps() + self.setupUI() + + self.launchChecks.connect( + self.launch_checks) + + self.providerNameEdit.editingFinished.connect( + lambda: self.providerCheckButton.setFocus(True)) + + def setupUI(self): + """ + initializes the UI + """ + providerNameLabel = QtGui.QLabel("h&ttps://") + # note that we expect the bare domain name + # we will add the scheme later + providerNameEdit = QtGui.QLineEdit() + providerNameEdit.cursorPositionChanged.connect( + self.reset_validation_status) + providerNameLabel.setBuddy(providerNameEdit) + + # add regex validator + providerDomainRe = QtCore.QRegExp(r"^[a-z\d_-.]+$") + providerNameEdit.setValidator( + QtGui.QRegExpValidator(providerDomainRe, self)) + self.providerNameEdit = providerNameEdit + + # Eventually we will seed a list of + # well known providers here. + + #providercombo = QtGui.QComboBox() + #if providers: + #for provider in providers: + #providercombo.addItem(provider) + #providerNameSelect = providercombo + + self.registerField("provider_domain*", self.providerNameEdit) + #self.registerField('provider_name_index', providerNameSelect) + + validationMsg = QtGui.QLabel("") + validationMsg.setStyleSheet(styles.ErrorLabelStyleSheet) + self.validationMsg = validationMsg + providerCheckButton = QtGui.QPushButton(self.tr("chec&k!")) + self.providerCheckButton = providerCheckButton + + # cert info + + # this is used in the callback + # for the checkbox changes. + # tricky, since the first time came + # from the exception message. + # should get string from exception too! + self.bad_cert_status = self.tr( + "Server certificate could not be verified.") + + self.certInfo = QtGui.QLabel("") + self.certInfo.setWordWrap(True) + self.certWarning = QtGui.QLabel("") + self.trustProviderCertCheckBox = QtGui.QCheckBox( + "&Trust this provider certificate.") + + self.trustProviderCertCheckBox.stateChanged.connect( + self.onTrustCheckChanged) + self.providerNameEdit.textChanged.connect( + self.onProviderChanged) + self.providerCheckButton.clicked.connect( + self.onCheckButtonClicked) + + layout = QtGui.QGridLayout() + layout.addWidget(validationMsg, 0, 2) + layout.addWidget(providerNameLabel, 1, 1) + layout.addWidget(providerNameEdit, 1, 2) + layout.addWidget(providerCheckButton, 1, 3) + + # add certinfo group + # XXX not shown now. should move to validation box. + #layout.addWidget(certinfoGroup, 4, 1, 4, 2) + #self.certinfoGroup = certinfoGroup + #self.certinfoGroup.hide() + + # add validation frame + self.setupValidationFrame() + layout.addWidget(self.valFrame, 4, 2, 4, 2) + self.valFrame.hide() + + self.setLayout(layout) + + # certinfo + + def setupCertInfoGroup(self): + # XXX not used now. + certinfoGroup = QtGui.QGroupBox( + self.tr("Certificate validation")) + certinfoLayout = QtGui.QVBoxLayout() + certinfoLayout.addWidget(self.certInfo) + certinfoLayout.addWidget(self.certWarning) + certinfoLayout.addWidget(self.trustProviderCertCheckBox) + certinfoGroup.setLayout(certinfoLayout) + self.certinfoGroup = self.certinfoGroup + + # progress frame + + def setupValidationFrame(self): + qframe = QtGui.QFrame + valFrame = qframe() + valFrame.setFrameStyle(qframe.NoFrame) + valframeLayout = QtGui.QVBoxLayout() + zeros = (0, 0, 0, 0) + valframeLayout.setContentsMargins(*zeros) + + valframeLayout.addWidget(self.stepsTableWidget) + valFrame.setLayout(valframeLayout) + self.valFrame = valFrame + + @QtCore.pyqtSlot() + def onDisableCheckButton(self): + #print 'CHECK BUTTON DISABLED!!!' + self.providerCheckButton.setDisabled(True) + + @QtCore.pyqtSlot() + def launch_checks(self): + self.do_checks() + + def onCheckButtonClicked(self): + QtCore.QMetaObject.invokeMethod( + self, "onDisableCheckButton") + + QtCore.QMetaObject.invokeMethod( + self, "showStepsFrame") + + delay(self, "launch_checks") + + def _do_checks(self): + """ + generator that yields actual checks + that are executed in a separate thread + """ + + wizard = self.wizard() + full_domain = self.providerNameEdit.text() + + # we check if we have a port in the domain string. + domain, port = get_https_domain_and_port(full_domain) + _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain) + + netchecker = wizard.netchecker() + + providercertchecker = wizard.providercertchecker() + eipconfigchecker = wizard.eipconfigchecker(domain=_domain) + + yield(("head_sentinel", 0), lambda: None) + + ######################## + # 1) try name resolution + ######################## + + def namecheck(): + """ + in which we check if + we are able to name resolve + this domain + """ + try: + netchecker.check_name_resolution( + domain) + + except baseexceptions.LeapException as exc: + logger.error(exc.message) + return self.fail(exc.usermessage) + + except Exception as exc: + return self.fail(exc.message) + + else: + return True + + logger.debug('checking name resolution') + yield((self.tr("checking domain name"), 20), namecheck) + + ######################### + # 2) try https connection + ######################### + + def httpscheck(): + """ + in which we check + if the provider + is offering service over + https + """ + try: + providercertchecker.is_https_working( + "https://%s" % _domain, + verify=True) + + except eipexceptions.HttpsBadCertError as exc: + logger.debug('exception') + return self.fail(exc.usermessage) + # XXX skipping for now... + ############################################## + # We had this validation logic + # in the provider selection page before + ############################################## + #if self.trustProviderCertCheckBox.isChecked(): + #pass + #else: + #fingerprint = certs.get_cert_fingerprint( + #domain=domain, sep=" ") + + # it's ok if we've trusted this fgprt before + #trustedcrts = wizard.trusted_certs + #if trustedcrts and \ + # fingerprint.replace(' ', '') in trustedcrts: + #pass + #else: + # let your user face panick :P + #self.add_cert_info(fingerprint) + #self.did_cert_check = True + #self.completeChanged.emit() + #return False + + except baseexceptions.LeapException as exc: + return self.fail(exc.usermessage) + + except Exception as exc: + return self.fail(exc.message) + + else: + return True + + logger.debug('checking https connection') + yield((self.tr("checking https connection"), 40), httpscheck) + + ################################## + # 3) try download provider info... + ################################## + + def fetchinfo(): + try: + # XXX we already set _domain in the initialization + # so it should not be needed here. + eipconfigchecker.fetch_definition(domain=_domain) + wizard.set_providerconfig( + eipconfigchecker.defaultprovider.config) + except requests.exceptions.SSLError: + # XXX we should have catched this before. + # but cert checking is broken. + return self.fail(self.tr( + "Could not get info from provider.")) + except requests.exceptions.ConnectionError: + return self.fail(self.tr( + "Could not download provider info " + "(refused conn.).")) + + except Exception as exc: + return self.fail( + self.tr(exc.message)) + else: + return True + + yield((self.tr("fetching provider info"), 80), fetchinfo) + + # done! + + self.is_done = True + yield(("end_sentinel", 100), lambda: None) + + def on_checks_validation_ready(self): + """ + called after _do_checks has finished. + """ + self.domain_checked = True + self.completeChanged.emit() + # let's set focus... + if self.is_done: + self.wizard().clean_validation_error(self.current_page) + nextbutton = self.wizard().button(QtGui.QWizard.NextButton) + nextbutton.setFocus() + else: + self.providerNameEdit.setFocus() + + # cert trust verification + # (disabled for now) + + def is_insecure_cert_trusted(self): + return self.trustProviderCertCheckBox.isChecked() + + def onTrustCheckChanged(self, state): + checked = False + if state == 2: + checked = True + + if checked: + self.reset_validation_status() + else: + self.set_validation_status(self.bad_cert_status) + + # trigger signal to redraw next button + self.completeChanged.emit() + + def add_cert_info(self, certinfo): + self.certWarning.setText( + "Do you want to <b>trust this provider certificate?</b>") + self.certInfo.setText( + 'SHA-256 fingerprint: <i>%s</i><br>' % certinfo) + self.certInfo.setWordWrap(True) + self.certinfoGroup.show() + + def onProviderChanged(self, text): + self.is_done = False + provider = self.providerNameEdit.text() + if provider: + self.providerCheckButton.setDisabled(False) + else: + self.providerCheckButton.setDisabled(True) + self.completeChanged.emit() + + def reset_validation_status(self): + """ + empty the validation msg + and clean the inline validation widget. + """ + self.validationMsg.setText('') + self.steps.removeAllSteps() + self.clearTable() + self.domain_checked = False + + # pagewizard methods + + def isComplete(self): + provider = self.providerNameEdit.text() + + if not self.is_done: + return False + + if not provider: + return False + else: + if self.is_insecure_cert_trusted(): + return True + if not self.did_cert_check: + if self.is_done: + # XXX sure? + return True + return False + + def populateErrors(self): + # XXX could move this to ValidationMixin + # with some defaults for the validating fields + # (now it only allows one field, manually specified) + + #logger.debug('getting errors') + errors = self.wizard().get_validation_error( + self.current_page) + if errors: + bad_str = getattr(self, 'bad_string', None) + cur_str = self.providerNameEdit.text() + showerr = self.validationMsg.setText + markred = lambda: self.providerNameEdit.setStyleSheet( + styles.ErrorLineEdit) + umarkrd = lambda: self.providerNameEdit.setStyleSheet( + styles.RegularLineEdit) + if bad_str is None: + # first time we fall here. + # save the current bad_string value + self.bad_string = cur_str + showerr(errors) + markred() + else: + # not the first time + # XXX hey, this is getting convoluted. + # roll out this. + # but be careful about all the possibilities + # with going back and forth once you + # enter a domain. + if cur_str == bad_str: + showerr(errors) + markred() + else: + if not getattr(self, 'domain_checked', None): + showerr('') + umarkrd() + else: + self.bad_string = cur_str + showerr(errors) + + def cleanup_errormsg(self): + """ + we reset bad_string to None + should be called before leaving the page + """ + self.bad_string = None + self.domain_checked = False + + def paintEvent(self, event): + """ + we hook our populate errors + on paintEvent because we need it to catch + when user enters the page coming from next, + and initializePage does not cover that case. + Maybe there's a better event to hook upon. + """ + super(SelectProviderPage, self).paintEvent(event) + self.populateErrors() + + def initializePage(self): + self.validationMsg.setText('') + if hasattr(self, 'certinfoGroup'): + # XXX remove ? + self.certinfoGroup.hide() + self.is_done = False + self.providerCheckButton.setDisabled(True) + self.valFrame.hide() + self.steps.removeAllSteps() + self.clearTable() + + def validatePage(self): + # some cleanup before we leave the page + self.cleanup_errormsg() + + # go + return True + + def nextId(self): + wizard = self.wizard() + if not wizard: + return + return wizard.get_page_index('providerinfo') diff --git a/src/leap/gui/firstrun/providersetup.py b/src/leap/gui/firstrun/providersetup.py new file mode 100644 index 00000000..1a362794 --- /dev/null +++ b/src/leap/gui/firstrun/providersetup.py @@ -0,0 +1,174 @@ +""" +Provider Setup Validation Page, +used if First Run Wizard +""" +import logging + +from PyQt4 import QtGui + +from leap.base import exceptions as baseexceptions +from leap.gui.progress import ValidationPage + +from leap.gui.constants import APP_LOGO + +logger = logging.getLogger(__name__) + + +class ProviderSetupValidationPage(ValidationPage): + def __init__(self, parent=None): + super(ProviderSetupValidationPage, self).__init__(parent) + self.current_page = "providersetupvalidation" + + # XXX needed anymore? + is_signup = self.field("is_signup") + self.is_signup = is_signup + + self.setTitle(self.tr("Provider setup")) + self.setSubTitle( + self.tr("Doing autoconfig.")) + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + def _do_checks(self): + """ + generator that yields actual checks + that are executed in a separate thread + """ + + full_domain = self.field('provider_domain') + wizard = self.wizard() + pconfig = wizard.providerconfig + + #pCertChecker = wizard.providercertchecker + #certchecker = pCertChecker(domain=full_domain) + pCertChecker = wizard.providercertchecker( + domain=full_domain) + + yield(("head_sentinel", 0), lambda: None) + + ######################## + # 1) fetch ca cert + ######################## + + def fetchcacert(): + if pconfig: + ca_cert_uri = pconfig.get('ca_cert_uri').geturl() + else: + ca_cert_uri = None + + # XXX check scheme == "https" + # XXX passing verify == False because + # we have trusted right before. + # We should check it's the same domain!!! + # (Check with the trusted fingerprints dict + # or something smart) + try: + pCertChecker.download_ca_cert( + uri=ca_cert_uri, + verify=False) + + except baseexceptions.LeapException as exc: + logger.error(exc.message) + # XXX this should be _ method + return self.fail(self.tr(exc.usermessage)) + + except Exception as exc: + return self.fail(exc.message) + + else: + return True + + yield((self.tr('Fetching CA certificate'), 30), + fetchcacert) + + ######################### + # 2) check CA fingerprint + ######################### + + def checkcafingerprint(): + # XXX get the real thing!!! + pass + #ca_cert_fingerprint = pconfig.get('ca_cert_fingerprint', None) + + # XXX get fingerprint dict (types) + #sha256_fpr = ca_cert_fingerprint.split('=')[1] + + #validate_fpr = pCertChecker.check_ca_cert_fingerprint( + #fingerprint=sha256_fpr) + #if not validate_fpr: + # XXX update validationMsg + # should catch exception + #return False + + yield((self.tr("Checking CA fingerprint"), 60), + checkcafingerprint) + + ######################### + # 2) check CA fingerprint + ######################### + + def validatecacert(): + pass + #api_uri = pconfig.get('api_uri', None) + #try: + #api_cert_verified = pCertChecker.verify_api_https(api_uri) + #except requests.exceptions.SSLError as exc: + #logger.error('BUG #638. %s' % exc.message) + # XXX RAISE! See #638 + # bypassing until the hostname is fixed. + # We probably should raise yet-another-warning + # here saying user that the hostname "XX.XX.XX.XX' does not + # match 'foo.bar.baz' + #api_cert_verified = True + + #if not api_cert_verified: + # XXX update validationMsg + # should catch exception + #return False + + #??? + #ca_cert_path = checker.ca_cert_path + + yield((self.tr('Validating api certificate'), 90), validatecacert) + + self.set_done() + yield(('end_sentinel', 100), lambda: None) + + def on_checks_validation_ready(self): + """ + called after _do_checks has finished + (connected to checker thread finished signal) + """ + prevpage = "providerselection" if self.is_signup else "login" + wizard = self.wizard() + + if self.errors: + logger.debug('going back with errors') + name, first_error = self.pop_first_error() + wizard.set_validation_error( + prevpage, + first_error) + # XXX don't go back, signal error + #self.go_back() + else: + logger.debug('should be going next, wait on user') + #self.go_next() + + def nextId(self): + wizard = self.wizard() + if not wizard: + return + is_signup = self.field('is_signup') + if is_signup is True: + next_ = 'signup' + if is_signup is False: + # XXX bad name. change to connect again. + next_ = 'signupvalidation' + return wizard.get_page_index(next_) + + def initializePage(self): + super(ProviderSetupValidationPage, self).initializePage() + self.set_undone() + self.completeChanged.emit() diff --git a/src/leap/gui/firstrun/register.py b/src/leap/gui/firstrun/register.py new file mode 100644 index 00000000..e85723cb --- /dev/null +++ b/src/leap/gui/firstrun/register.py @@ -0,0 +1,368 @@ +""" +Register User Page, used in First Run Wizard +""" +import json +import logging +import socket + +import requests + +from PyQt4 import QtCore +from PyQt4 import QtGui + +from leap.gui.firstrun.mixins import UserFormMixIn + +logger = logging.getLogger(__name__) + +from leap.base import auth +from leap.gui import styles +from leap.gui.constants import APP_LOGO, BARE_USERNAME_REGEX +from leap.gui.progress import InlineValidationPage +from leap.gui.styles import ErrorLabelStyleSheet + + +class RegisterUserPage(InlineValidationPage, UserFormMixIn): + + def __init__(self, parent=None): + + super(RegisterUserPage, self).__init__(parent) + self.current_page = "signup" + + self.setTitle(self.tr("Sign Up")) + # subtitle is set in the initializePage + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + # commit page means there's no way back after this... + # XXX should change the text on the "commit" button... + self.setCommitPage(True) + + self.setupSteps() + self.setupUI() + self.do_confirm_next = False + self.focused_field = False + + def setupUI(self): + userNameLabel = QtGui.QLabel("User &name:") + userNameLineEdit = QtGui.QLineEdit() + userNameLineEdit.cursorPositionChanged.connect( + self.reset_validation_status) + userNameLabel.setBuddy(userNameLineEdit) + + # let's add regex validator + usernameRe = QtCore.QRegExp(BARE_USERNAME_REGEX) + userNameLineEdit.setValidator( + QtGui.QRegExpValidator(usernameRe, self)) + self.userNameLineEdit = userNameLineEdit + + userPasswordLabel = QtGui.QLabel("&Password:") + self.userPasswordLineEdit = QtGui.QLineEdit() + self.userPasswordLineEdit.setEchoMode( + QtGui.QLineEdit.Password) + userPasswordLabel.setBuddy(self.userPasswordLineEdit) + + userPassword2Label = QtGui.QLabel("Password (again):") + self.userPassword2LineEdit = QtGui.QLineEdit() + self.userPassword2LineEdit.setEchoMode( + QtGui.QLineEdit.Password) + userPassword2Label.setBuddy(self.userPassword2LineEdit) + + rememberPasswordCheckBox = QtGui.QCheckBox( + "&Remember username and password.") + rememberPasswordCheckBox.setChecked(True) + + self.registerField('userName*', self.userNameLineEdit) + self.registerField('userPassword*', self.userPasswordLineEdit) + self.registerField('userPassword2*', self.userPassword2LineEdit) + + # XXX missing password confirmation + # XXX validator! + + self.registerField('rememberPassword', rememberPasswordCheckBox) + + layout = QtGui.QGridLayout() + layout.setColumnMinimumWidth(0, 20) + + validationMsg = QtGui.QLabel("") + validationMsg.setStyleSheet(ErrorLabelStyleSheet) + + self.validationMsg = validationMsg + + layout.addWidget(validationMsg, 0, 3) + layout.addWidget(userNameLabel, 1, 0) + layout.addWidget(self.userNameLineEdit, 1, 3) + layout.addWidget(userPasswordLabel, 2, 0) + layout.addWidget(userPassword2Label, 3, 0) + layout.addWidget(self.userPasswordLineEdit, 2, 3) + layout.addWidget(self.userPassword2LineEdit, 3, 3) + layout.addWidget(rememberPasswordCheckBox, 4, 3, 4, 4) + + # add validation frame + self.setupValidationFrame() + layout.addWidget(self.valFrame, 5, 2, 5, 2) + self.valFrame.hide() + + self.setLayout(layout) + self.commitText("Sign up!") + + # commit button + + def commitText(self, text): + # change "commit" button text + self.setButtonText( + QtGui.QWizard.CommitButton, text) + + @property + def commitButton(self): + return self.wizard().button(QtGui.QWizard.CommitButton) + + def commitFocus(self): + self.commitButton.setFocus() + + def disableCommitButton(self): + self.commitButton.setDisabled(True) + + def disableFields(self): + for field in (self.userNameLineEdit, + self.userPasswordLineEdit, + self.userPassword2LineEdit): + field.setDisabled(True) + + # error painting + + def markRedAndGetFocus(self, field): + field.setStyleSheet(styles.ErrorLineEdit) + if not self.focused_field: + self.focused_field = True + field.setFocus(QtCore.Qt.OtherFocusReason) + + def markRegular(self, field): + field.setStyleSheet(styles.RegularLineEdit) + + def populateErrors(self): + def showerr(text): + self.validationMsg.setText(text) + err_lower = text.lower() + if "username" in err_lower: + self.markRedAndGetFocus( + self.userNameLineEdit) + if "password" in err_lower: + self.markRedAndGetFocus( + self.userPasswordLineEdit) + + def unmarkred(): + for field in (self.userNameLineEdit, + self.userPasswordLineEdit, + self.userPassword2LineEdit): + self.markRegular(field) + + errors = self.wizard().get_validation_error( + self.current_page) + if errors: + bad_str = getattr(self, 'bad_string', None) + cur_str = self.userNameLineEdit.text() + #prev_er = getattr(self, 'prevalidation_error', None) + + if bad_str is None: + # first time we fall here. + # save the current bad_string value + self.bad_string = cur_str + showerr(errors) + else: + #if prev_er: + #showerr(prev_er) + #return + # not the first time + if cur_str == bad_str: + showerr(errors) + else: + self.focused_field = False + showerr('') + unmarkred() + else: + # no errors + self.focused_field = False + unmarkred() + + def cleanup_errormsg(self): + """ + we reset bad_string to None + should be called before leaving the page + """ + self.bad_string = None + + def paintEvent(self, event): + """ + we hook our populate errors + on paintEvent because we need it to catch + when user enters the page coming from next, + and initializePage does not cover that case. + Maybe there's a better event to hook upon. + """ + super(RegisterUserPage, self).paintEvent(event) + self.populateErrors() + + def _do_checks(self): + """ + generator that yields actual checks + that are executed in a separate thread + """ + provider = self.field('provider_domain') + username = self.userNameLineEdit.text() + password = self.userPasswordLineEdit.text() + password2 = self.userPassword2LineEdit.text() + + def checkpass(): + # we better have here + # some call to a password checker... + # to assess strenght and avoid silly stuff. + + if password != password2: + return self.fail(self.tr('Password does not match..')) + + if len(password) < 6: + #self.set_prevalidation_error('Password too short.') + return self.fail(self.tr('Password too short.')) + + if password == "123456": + # joking, but not too much. + #self.set_prevalidation_error('Password too obvious.') + return self.fail(self.tr('Password too obvious.')) + + # go + return True + + yield(("head_sentinel", 0), checkpass) + + # XXX should emit signal for .show the frame! + # XXX HERE! + + ################################################## + # 1) register user + ################################################## + + # show the frame before going on... + QtCore.QMetaObject.invokeMethod( + self, "showStepsFrame") + + def register(): + # XXX FIXME! + verify = False + + signup = auth.LeapSRPRegister( + schema="https", + provider=provider, + verify=verify) + try: + ok, req = signup.register_user( + username, password) + + except socket.timeout: + return self.fail( + self.tr("Error connecting to provider (timeout)")) + + except requests.exceptions.ConnectionError as exc: + logger.error(exc.message) + return self.fail( + self.tr('Error Connecting to provider (connerr).')) + except Exception as exc: + return self.fail(exc.message) + + # XXX check for != OK instead??? + + if req.status_code in (404, 500): + return self.fail( + self.tr( + "Error during registration (%s)") % req.status_code) + + validation_msgs = json.loads(req.content) + errors = validation_msgs.get('errors', None) + logger.debug('validation errors: %s' % validation_msgs) + + if errors and errors.get('login', None): + # XXX this sometimes catch the blank username + # but we're not allowing that (soon) + return self.fail( + self.tr('Username not available.')) + + logger.debug('registering user') + yield(("registering with provider", 40), register) + + self.set_done() + yield(("end_sentinel", 0), lambda: None) + + def on_checks_validation_ready(self): + """ + after checks + """ + if self.is_done(): + self.disableFields() + self.cleanup_errormsg() + self.clean_wizard_errors(self.current_page) + # make the user confirm the transition + # to next page. + self.commitText('Connect!') + self.commitFocus() + self.green_validation_status() + self.do_confirm_next = True + + def green_validation_status(self): + val = self.validationMsg + val.setText(self.tr('Registration succeeded!')) + val.setStyleSheet(styles.GreenLineEdit) + + def reset_validation_status(self): + """ + empty the validation msg + and clean the inline validation widget. + """ + self.validationMsg.setText('') + self.steps.removeAllSteps() + self.clearTable() + + # pagewizard methods + + def validatePage(self): + """ + if not register done, do checks. + if done, wait for click. + """ + self.disableCommitButton() + self.cleanup_errormsg() + self.clean_wizard_errors(self.current_page) + + # After a successful validation + # (ie, success register with server) + # we change the commit button text + # and set this flag to True. + if self.do_confirm_next: + return True + + if not self.is_done(): + # calls checks, which after successful + # execution will call on_checks_validation_ready + self.reset_validation_status() + self.do_checks() + + return self.is_done() + + def initializePage(self): + """ + inits wizard page + """ + provider = self.field('provider_domain') + self.setSubTitle( + self.tr("Register a new user with provider %s.") % + provider) + self.validationMsg.setText('') + self.userPassword2LineEdit.setText('') + self.valFrame.hide() + + def nextId(self): + wizard = self.wizard() + if not wizard: + return + # XXX this should be called connect + return wizard.get_page_index('signupvalidation') diff --git a/src/leap/gui/firstrun/regvalidation.py b/src/leap/gui/firstrun/regvalidation.py new file mode 100644 index 00000000..0e67834b --- /dev/null +++ b/src/leap/gui/firstrun/regvalidation.py @@ -0,0 +1,204 @@ +""" +Provider Setup Validation Page, +used in First Run Wizard +""" +# XXX This page is called regvalidation +# but it's implementing functionality in the former +# connect page. +# We should remame it to connect again, when we integrate +# the login branch of the wizard. + +import logging +#import json +#import socket + +from PyQt4 import QtGui + +#import requests + +from leap.gui.progress import ValidationPage +from leap.util.web import get_https_domain_and_port + +from leap.base import auth +from leap.gui.constants import APP_LOGO + +logger = logging.getLogger(__name__) + + +class RegisterUserValidationPage(ValidationPage): + + def __init__(self, parent=None): + super(RegisterUserValidationPage, self).__init__(parent) + + title = "Connecting..." + # XXX uh... really? + subtitle = "Checking connection with provider." + + self.setTitle(title) + self.setSubTitle(subtitle) + + self.setPixmap( + QtGui.QWizard.LogoPixmap, + QtGui.QPixmap(APP_LOGO)) + + def _do_checks(self, update_signal=None): + """ + executes actual checks in a separate thread + + we initialize the srp protocol register + and try to register user. + """ + wizard = self.wizard() + full_domain = self.field('provider_domain') + domain, port = get_https_domain_and_port(full_domain) + _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain) + + # FIXME #BUG 638 FIXME FIXME FIXME + verify = False # !!!!!!!!!!!!!!!! + # FIXME #BUG 638 FIXME FIXME FIXME + + ########################################### + # Set Credentials. + # username and password are in different fields + # if they were stored in log_in or sign_up pages. + is_signup = self.field("is_signup") + + unamek_base = 'userName' + passwk_base = 'userPassword' + unamek = 'login_%s' % unamek_base if not is_signup else unamek_base + passwk = 'login_%s' % passwk_base if not is_signup else passwk_base + + username = self.field(unamek) + password = self.field(passwk) + credentials = username, password + + eipconfigchecker = wizard.eipconfigchecker(domain=_domain) + #XXX change for _domain (sanitized) + pCertChecker = wizard.providercertchecker( + domain=full_domain) + + yield(("head_sentinel", 0), lambda: None) + + ################################################## + # 1) fetching eip service config + ################################################## + def fetcheipconf(): + try: + eipconfigchecker.fetch_eip_service_config( + domain=full_domain) + + # XXX get specific exception + except Exception as exc: + return self.fail(exc.message) + + yield((self.tr("Fetching provider config..."), 40), + fetcheipconf) + + ################################################## + # 2) getting client certificate + ################################################## + + def fetcheipcert(): + try: + pCertChecker.download_new_client_cert( + credentials=credentials, + verify=verify) + + except auth.SRPAuthenticationError as exc: + return self.fail(self.tr( + "Authentication error: %s" % exc.message)) + else: + return True + + yield((self.tr("Fetching eip certificate"), 80), + fetcheipcert) + + ################ + # end ! + ################ + self.set_done() + yield(("end_sentinel", 100), lambda: None) + + def on_checks_validation_ready(self): + """ + called after _do_checks has finished + (connected to checker thread finished signal) + """ + # this should be called CONNECT PAGE AGAIN. + # here we go! :) + full_domain = self.field('provider_domain') + domain, port = get_https_domain_and_port(full_domain) + _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain) + self.run_eip_checks_for_provider_and_connect(_domain) + + def run_eip_checks_for_provider_and_connect(self, domain): + wizard = self.wizard() + conductor = wizard.conductor + start_eip_signal = getattr( + wizard, + 'start_eipconnection_signal', None) + + if conductor: + conductor.set_provider_domain(domain) + conductor.run_checks() + self.conductor = conductor + errors = self.eip_error_check() + if not errors and start_eip_signal: + start_eip_signal.emit() + + else: + logger.warning( + "No conductor found. This means that " + "probably the wizard has been launched " + "in an stand-alone way.") + + # XXX look for a better place to signal + # we are done. + # We could probably have a fake validatePage + # that checks if the domain transfer has been + # done to conductor object, triggers the start_signal + # and does the go_next() + self.set_done() + + def eip_error_check(self): + """ + a version of the main app error checker, + but integrated within the connecting page of the wizard. + consumes the conductor error queue. + pops errors, and add those to the wizard page + """ + logger.debug('eip error check from connecting page') + errq = self.conductor.error_queue + # XXX missing! + + def _do_validation(self): + """ + called after _do_checks has finished + (connected to checker thread finished signal) + """ + is_signup = self.field("is_signup") + prevpage = "signup" if is_signup else "login" + + wizard = self.wizard() + if self.errors: + logger.debug('going back with errors') + logger.error(self.errors) + name, first_error = self.pop_first_error() + wizard.set_validation_error( + prevpage, + first_error) + self.go_back() + else: + logger.debug('should go next, wait for user to click next') + #self.go_next() + + def nextId(self): + wizard = self.wizard() + if not wizard: + return + return wizard.get_page_index('lastpage') + + def initializePage(self): + super(RegisterUserValidationPage, self).initializePage() + self.set_undone() + self.completeChanged.emit() diff --git a/src/leap/gui/firstrun/tests/integration/fake_provider.py b/src/leap/gui/firstrun/tests/integration/fake_provider.py new file mode 100755 index 00000000..33ee0ee6 --- /dev/null +++ b/src/leap/gui/firstrun/tests/integration/fake_provider.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python +"""A server faking some of the provider resources and apis, +used for testing Leap Client requests + +It needs that you create a subfolder named 'certs', +and that you place the following files: + +[ ] certs/leaptestscert.pem +[ ] certs/leaptestskey.pem +[ ] certs/cacert.pem +[ ] certs/openvpn.pem + +[ ] provider.json +[ ] eip-service.json +""" +# XXX NOTE: intended for manual debug. +# I intend to include this as a regular test after 0.2.0 release +# (so we can add twisted as a dep there) +import binascii +import json +import os +import sys + +# python SRP LIB (! important MUST be >=1.0.1 !) +import srp + +# GnuTLS Example -- is not working as expected +from gnutls import crypto +from gnutls.constants import COMP_LZO, COMP_DEFLATE, COMP_NULL +from gnutls.interfaces.twisted import X509Credentials + +# Going with OpenSSL as a workaround instead +# But we DO NOT want to introduce this dependency. +from OpenSSL import SSL + +from zope.interface import Interface, Attribute, implements + +from twisted.web.server import Site +from twisted.web.static import File +from twisted.web.resource import Resource +from twisted.internet import reactor + +# See +# http://twistedmatrix.com/documents/current/web/howto/web-in-60/index.htmln +# for more examples + +""" +Testing the FAKE_API: +##################### + + 1) register an user + >> curl -d "user[login]=me" -d "user[password_salt]=foo" \ + -d "user[password_verifier]=beef" http://localhost:8000/1/users.json + << {"errors": null} + + 2) check that if you try to register again, it will fail: + >> curl -d "user[login]=me" -d "user[password_salt]=foo" \ + -d "user[password_verifier]=beef" http://localhost:8000/1/users.json + << {"errors": {"login": "already taken!"}} + +""" + +# Globals to mock user/sessiondb + +USERDB = {} +SESSIONDB = {} + + +safe_unhexlify = lambda x: binascii.unhexlify(x) \ + if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) + + +class IUser(Interface): + login = Attribute("User login.") + salt = Attribute("Password salt.") + verifier = Attribute("Password verifier.") + session = Attribute("Session.") + svr = Attribute("Server verifier.") + + +class User(object): + implements(IUser) + + def __init__(self, login, salt, verifier): + self.login = login + self.salt = salt + self.verifier = verifier + self.session = None + + def set_server_verifier(self, svr): + self.svr = svr + + def set_session(self, session): + SESSIONDB[session] = self + self.session = session + + +class FakeUsers(Resource): + def __init__(self, name): + self.name = name + + def render_POST(self, request): + args = request.args + + login = args['user[login]'][0] + salt = args['user[password_salt]'][0] + verifier = args['user[password_verifier]'][0] + + if login in USERDB: + return "%s\n" % json.dumps( + {'errors': {'login': 'already taken!'}}) + + print login, verifier, salt + user = User(login, salt, verifier) + USERDB[login] = user + return json.dumps({'errors': None}) + + +def get_user(request): + login = request.args.get('login') + if login: + user = USERDB.get(login[0], None) + if user: + return user + + session = request.getSession() + user = SESSIONDB.get(session, None) + return user + + +class FakeSession(Resource): + def __init__(self, name): + self.name = name + + def render_GET(self, request): + return "%s\n" % json.dumps({'errors': None}) + + def render_POST(self, request): + + user = get_user(request) + + if not user: + # XXX get real error from demo provider + return json.dumps({'errors': 'no such user'}) + + A = request.args['A'][0] + + _A = safe_unhexlify(A) + _salt = safe_unhexlify(user.salt) + _verifier = safe_unhexlify(user.verifier) + + svr = srp.Verifier( + user.login, + _salt, + _verifier, + _A, + hash_alg=srp.SHA256, + ng_type=srp.NG_1024) + + s, B = svr.get_challenge() + + _B = binascii.hexlify(B) + + print 'login = %s' % user.login + print 'salt = %s' % user.salt + print 'len(_salt) = %s' % len(_salt) + print 'vkey = %s' % user.verifier + print 'len(vkey) = %s' % len(_verifier) + print 's = %s' % binascii.hexlify(s) + print 'B = %s' % _B + print 'len(B) = %s' % len(_B) + + session = request.getSession() + user.set_session(session) + user.set_server_verifier(svr) + + # yep, this is tricky. + # some things are *already* unhexlified. + data = { + 'salt': user.salt, + 'B': _B, + 'errors': None} + + return json.dumps(data) + + def render_PUT(self, request): + + # XXX check session??? + user = get_user(request) + + if not user: + print 'NO USER' + return json.dumps({'errors': 'no such user'}) + + data = request.content.read() + auth = data.split("client_auth=") + M = auth[1] if len(auth) > 1 else None + # if not H, return + if not M: + return json.dumps({'errors': 'no M proof passed by client'}) + + svr = user.svr + HAMK = svr.verify_session(binascii.unhexlify(M)) + if HAMK is None: + print 'verification failed!!!' + raise Exception("Authentication failed!") + #import ipdb;ipdb.set_trace() + + assert svr.authenticated() + print "***" + print 'server authenticated user SRP!' + print "***" + + return json.dumps( + {'M2': binascii.hexlify(HAMK), 'errors': None}) + + +class API_Sessions(Resource): + def getChild(self, name, request): + return FakeSession(name) + + +def get_certs_path(): + script_path = os.path.realpath(os.path.dirname(sys.argv[0])) + certs_path = os.path.join(script_path, 'certs') + return certs_path + + +def get_TLS_credentials(): + # XXX this is giving errors + # XXX REview! We want to use gnutls! + certs_path = get_certs_path() + + cert = crypto.X509Certificate( + open(certs_path + '/leaptestscert.pem').read()) + key = crypto.X509PrivateKey( + open(certs_path + '/leaptestskey.pem').read()) + ca = crypto.X509Certificate( + open(certs_path + '/cacert.pem').read()) + #crl = crypto.X509CRL(open(certs_path + '/crl.pem').read()) + #cred = crypto.X509Credentials(cert, key, [ca], [crl]) + cred = X509Credentials(cert, key, [ca]) + cred.verify_peer = True + cred.session_params.compressions = (COMP_LZO, COMP_DEFLATE, COMP_NULL) + return cred + + +class OpenSSLServerContextFactory: + # XXX workaround for broken TLS interface + # from gnuTLS. + + def getContext(self): + """Create an SSL context. + This is a sample implementation that loads a certificate from a file + called 'server.pem'.""" + certs_path = get_certs_path() + + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_certificate_file(certs_path + '/leaptestscert.pem') + ctx.use_privatekey_file(certs_path + '/leaptestskey.pem') + return ctx + + +if __name__ == "__main__": + + from twisted.python import log + log.startLogging(sys.stdout) + + root = Resource() + root.putChild("provider.json", File("./provider.json")) + config = Resource() + config.putChild( + "eip-service.json", + File("./eip-service.json")) + apiv1 = Resource() + apiv1.putChild("config", config) + apiv1.putChild("sessions.json", API_Sessions()) + apiv1.putChild("users.json", FakeUsers(None)) + apiv1.putChild("cert", File(get_certs_path() + '/openvpn.pem')) + root.putChild("1", apiv1) + + cred = get_TLS_credentials() + + factory = Site(root) + + # regular http (for debugging with curl) + reactor.listenTCP(8000, factory) + + # TLS with gnutls --- seems broken :( + #reactor.listenTLS(8003, factory, cred) + + # OpenSSL + reactor.listenSSL(8443, factory, OpenSSLServerContextFactory()) + + reactor.run() diff --git a/src/leap/gui/firstrun/wizard.py b/src/leap/gui/firstrun/wizard.py new file mode 100755 index 00000000..9b77b877 --- /dev/null +++ b/src/leap/gui/firstrun/wizard.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python +import logging + +import sip +sip.setapi('QString', 2) +sip.setapi('QVariant', 2) + +from PyQt4 import QtCore +from PyQt4 import QtGui + +from leap.base import checks as basechecks +from leap.crypto import leapkeyring +from leap.eip import checks as eipchecks + +from leap.gui import firstrun + +from leap.gui import mainwindow_rc + +try: + from collections import OrderedDict +except ImportError: + # We must be in 2.6 + from leap.util.dicts import OrderedDict + +logger = logging.getLogger(__name__) + +""" +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Work in progress! +~~~~~~~~~~~~~~~~~~~~~~~~~~ +This wizard still needs to be refactored out. + +TODO-ish: + +[X] Break file in wizard / pages files (and its own folder). +[ ] Separate presentation from logic. +[ ] Have a "manager" class for connections, that can be + dep-injected for testing. +[ ] Document signals used / expected. +[ ] Separate style from widgets. +[ ] Fix TOFU Widget for provider cert. +[X] Refactor widgets out. +[ ] Follow more MVC style. +[ ] Maybe separate "first run wizard" into different wizards + that share some of the pages? +""" + + +class FirstRunWizard(QtGui.QWizard): + + def __init__( + self, + conductor_instance, + parent=None, + eip_username=None, + providers=None, + success_cb=None, is_provider_setup=False, + trusted_certs=None, + netchecker=basechecks.LeapNetworkChecker, + providercertchecker=eipchecks.ProviderCertChecker, + eipconfigchecker=eipchecks.EIPConfigChecker, + start_eipconnection_signal=None, + eip_statuschange_signal=None, + debug_server=None, + quitcallback=None): + super(FirstRunWizard, self).__init__( + parent, + QtCore.Qt.WindowStaysOnTopHint) + + # we keep a reference to the conductor + # to be able to launch eip checks and connection + # in the connection page, before the wizard has ended. + self.conductor = conductor_instance + + self.eip_username = eip_username + self.providers = providers + + # success callback + self.success_cb = success_cb + + # is provider setup? + self.is_provider_setup = is_provider_setup + + # a dict with trusted fingerprints + # in the form {'nospacesfingerprint': ['host1', 'host2']} + self.trusted_certs = trusted_certs + + # Checkers + self.netchecker = netchecker + self.providercertchecker = providercertchecker + self.eipconfigchecker = eipconfigchecker + + # debug server + self.debug_server = debug_server + + # Signals + # will be emitted in connecting page + self.start_eipconnection_signal = start_eipconnection_signal + self.eip_statuschange_signal = eip_statuschange_signal + + if quitcallback is not None: + self.button( + QtGui.QWizard.CancelButton).clicked.connect( + quitcallback) + + self.providerconfig = None + # previously registered + # if True, jumps to LogIn page. + # by setting 1st page?? + #self.is_previously_registered = is_previously_registered + # XXX ??? ^v + self.is_previously_registered = bool(self.eip_username) + self.from_login = False + + pages_dict = OrderedDict(( + ('intro', firstrun.intro.IntroPage), + ('providerselection', + firstrun.providerselect.SelectProviderPage), + ('login', firstrun.login.LogInPage), + ('providerinfo', firstrun.providerinfo.ProviderInfoPage), + ('providersetupvalidation', + firstrun.providersetup.ProviderSetupValidationPage), + ('signup', firstrun.register.RegisterUserPage), + ('signupvalidation', + firstrun.regvalidation.RegisterUserValidationPage), + ('connecting', firstrun.connect.ConnectingPage), + ('lastpage', firstrun.last.LastPage) + )) + self.add_pages_from_dict(pages_dict) + + self.validation_errors = {} + + self.setPixmap( + QtGui.QWizard.BannerPixmap, + QtGui.QPixmap(':/images/banner.png')) + self.setPixmap( + QtGui.QWizard.BackgroundPixmap, + QtGui.QPixmap(':/images/background.png')) + + # set options + self.setOption(QtGui.QWizard.IndependentPages, on=False) + self.setOption(QtGui.QWizard.NoBackButtonOnStartPage, on=True) + + self.setWindowTitle("First Run Wizard") + + # TODO: set style for MAC / windows ... + #self.setWizardStyle() + + def add_pages_from_dict(self, pages_dict): + """ + @param pages_dict: the dictionary with pages, where + values are a tuple of InstanceofWizardPage, kwargs. + @type pages_dict: dict + """ + for name, page in pages_dict.items(): + # XXX check for is_previously registered + # and skip adding the signup branch if so + self.addPage(page()) + self.pages_dict = pages_dict + + def get_page_index(self, page_name): + """ + returns the index of the given page + @param page_name: the name of the desired page + @type page_name: str + @rparam: index of page in wizard + @rtype: int + """ + return self.pages_dict.keys().index(page_name) + + def set_validation_error(self, pagename, error): + self.validation_errors[pagename] = error + + def clean_validation_error(self, pagename): + vald = self.validation_errors + if pagename in vald: + del vald[pagename] + + def get_validation_error(self, pagename): + return self.validation_errors.get(pagename, None) + + def set_providerconfig(self, providerconfig): + self.providerconfig = providerconfig + + def setWindowFlags(self, flags): + logger.debug('setting window flags') + QtGui.QWizard.setWindowFlags(self, flags) + + def focusOutEvent(self, event): + # needed ? + self.setFocus(True) + self.activateWindow() + self.raise_() + self.show() + + def accept(self): + """ + final step in the wizard. + gather the info, update settings + and call the success callback if any has been passed. + """ + super(FirstRunWizard, self).accept() + + # username and password are in different fields + # if they were stored in log_in or sign_up pages. + from_login = self.from_login + unamek_base = 'userName' + passwk_base = 'userPassword' + unamek = 'login_%s' % unamek_base if from_login else unamek_base + passwk = 'login_%s' % passwk_base if from_login else passwk_base + + username = self.field(unamek) + password = self.field(passwk) + provider = self.field('provider_domain') + remember_pass = self.field('rememberPassword') + + logger.debug('chosen provider: %s', provider) + logger.debug('username: %s', username) + logger.debug('remember password: %s', remember_pass) + + # we are assuming here that we only remember one username + # in the form username@provider.domain + # We probably could extend this to support some form of + # profiles. + + settings = QtCore.QSettings() + + settings.setValue("FirstRunWizardDone", True) + settings.setValue("provider_domain", provider) + full_username = "%s@%s" % (username, provider) + + settings.setValue("remember_user_and_pass", remember_pass) + + if remember_pass: + settings.setValue("eip_username", full_username) + seed = self.get_random_str(10) + settings.setValue("%s_seed" % provider, seed) + + # XXX #744: comment out for 0.2.0 release + # if we need to have a version of python-keyring < 0.9 + leapkeyring.leap_set_password( + full_username, password, seed=seed) + + logger.debug('First Run Wizard Done.') + cb = self.success_cb + if cb and callable(cb): + self.success_cb() + + def get_provider_by_index(self): + provider = self.field('provider_index') + return self.providers[provider] + + def get_random_str(self, n): + from string import (ascii_uppercase, ascii_lowercase, digits) + from random import choice + return ''.join(choice( + ascii_uppercase + + ascii_lowercase + + digits) for x in range(n)) + + +if __name__ == '__main__': + # standalone test + # it can be (somehow) run against + # gui/tests/integration/fake_user_signup.py + + import sys + import logging + logging.basicConfig() + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + app = QtGui.QApplication(sys.argv) + server = sys.argv[1] if len(sys.argv) > 1 else None + + trusted_certs = { + "3DF83F316BFA0186" + "0A11A5C9C7FC24B9" + "18C62B941192CC1A" + "49AE62218B2A4B7C": ['springbok']} + + wizard = FirstRunWizard( + None, trusted_certs=trusted_certs, + debug_server=server) + wizard.show() + sys.exit(app.exec_()) diff --git a/src/leap/gui/mainwindow_rc.py b/src/leap/gui/mainwindow_rc.py index e5a671f3..5bee35c7 100644 --- a/src/leap/gui/mainwindow_rc.py +++ b/src/leap/gui/mainwindow_rc.py @@ -2,7 +2,7 @@ # Resource object code # -# Created: Sun Jul 22 17:08:49 2012 +# Created: Wed Nov 21 04:25:36 2012 # by: The Resource Compiler for PyQt (Qt v4.8.2) # # WARNING! All changes made in this file will be lost! @@ -236,6 +236,87 @@ qt_resource_data = "\ \x71\xa4\x40\xda\x14\x7a\xd1\x73\x1f\xf4\x7f\xb7\xf9\x1f\xc2\x26\ \x56\xd5\x70\x45\xfc\x8a\x00\x00\x00\x00\x49\x45\x4e\x44\xae\x42\ \x60\x82\ +\x00\x00\x04\xec\ +\x89\ +\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\ +\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\ +\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\ +\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x06\xec\x00\x00\x06\xec\ +\x01\x1e\x75\x38\x35\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\ +\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\ +\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x00\x13\x74\x45\ +\x58\x74\x41\x75\x74\x68\x6f\x72\x00\x52\x6f\x64\x6e\x65\x79\x20\ +\x44\x61\x77\x65\x73\x0e\xd8\x7e\x1d\x00\x00\x04\x4a\x49\x44\x41\ +\x54\x48\x89\x8d\x96\x5d\x6c\x53\x65\x18\xc7\x7f\xef\x39\x6b\xbb\ +\x7e\x9c\x75\x65\xad\x2b\x9b\xfb\xd0\x31\xdd\x14\xb6\x8c\x19\x44\ +\x90\x44\x63\x82\x42\x88\x5e\x90\x98\xcc\x19\x15\x13\xd4\x18\x76\ +\x61\xd4\x18\xe3\x85\x57\xca\x05\xe1\xc2\x0c\xa3\xa8\x51\xd0\x4c\ +\x12\xe3\x85\x31\x80\x26\x6a\xe2\x85\x23\xb0\x38\xb6\xc1\x1c\xce\ +\xb1\x40\x59\xf6\xe5\xca\xda\xae\xed\xfa\x75\x7a\x5e\x2f\x4e\xd7\ +\x59\xd6\x32\xfe\xc9\x7b\xf3\x9e\xe7\xf9\xff\x9f\xe7\xff\x9e\xf3\ +\x9c\x57\x48\x29\x59\x0f\xbd\x7b\x85\x0d\x17\xed\x1e\xbb\xb2\x07\ +\x20\x94\x30\x7e\x22\xc6\x48\xcf\x59\x99\x5a\x2f\x57\x94\x12\xf8\ +\xec\x55\x61\x71\x65\x6d\x47\xfc\xbe\xda\x47\x9d\x5a\xa5\xbf\xda\ +\x69\xaf\xda\xe0\x28\x2f\x07\x58\x5c\x4e\x26\xe7\xe3\x89\x9b\xf1\ +\x68\x78\x6e\x6e\x61\xfa\x8f\x98\x9a\x7a\xfb\x95\xe3\x32\x73\xc7\ +\x02\x9f\x76\x89\x8e\xba\xda\xda\x2f\xb7\x37\xdf\xdf\xe6\x2a\x13\ +\x8a\x94\x06\x82\xc2\x38\x89\x40\x08\x85\x98\x2e\x8d\xf3\x13\xe3\ +\x97\xa6\xa6\xa7\x5f\x7e\xed\x94\x1c\x5a\x57\xa0\xef\xa0\xfd\x70\ +\x5b\xf3\x96\x03\xcd\xde\x8a\x6a\x61\x64\xd7\x73\xc0\x14\x53\x54\ +\x26\x82\x4b\xf3\x97\x26\x2e\x7f\xd5\xfd\x79\xe2\xdd\x92\x02\x27\ +\x5f\x2a\x7b\xe1\x89\xce\x1d\xc7\xbc\x76\x55\x13\xc5\x98\xac\x4e\ +\x10\x0a\xa4\xa2\x6b\x45\x80\x60\x22\x1b\xfd\x6d\xf0\xdc\xa1\x17\ +\x4f\xe8\x5f\xaf\x11\x38\xfa\x9c\xf0\x6e\xdb\xf4\xc0\xf9\x6d\xf5\ +\xfe\x26\x30\xf2\x89\xca\xc6\x76\xd4\x07\xf7\xa3\xd4\x74\x80\xd5\ +\x65\x6e\xa6\xe3\x64\x03\xfd\x64\x2f\x9e\x40\x46\x67\xff\x27\xa3\ +\x30\x70\x63\x6e\x72\xe0\xea\xd8\xf6\x37\xbf\x95\x41\x73\x27\x87\ +\x06\x8f\xa7\x6f\x6b\x7d\x4d\x01\x39\x80\x52\xff\x08\x4a\xe3\xae\ +\x55\xf2\x5c\x27\x6a\xf3\x6e\x2c\x7b\x8f\x9a\x5d\xe5\x61\xb0\xb5\ +\xbe\xa6\xa9\xc1\xe3\xe9\x5b\x95\x04\x7a\xbb\x44\x47\x5b\x53\xcb\ +\x4e\x15\xbd\x98\x31\xc8\x70\x00\xfd\xfc\xc7\x64\xce\xbc\x81\x7e\ +\xe1\x13\xc8\x75\x2d\xb4\x8d\x28\xb5\x0f\x15\xc4\xaa\xe8\xb4\x35\ +\xb5\xec\xec\xed\x12\x1d\x00\x65\x00\xee\x72\x65\x9f\x5f\x73\x38\ +\x05\x6b\x0f\x35\x3b\xf6\x03\xfa\xc0\xf1\x3c\x29\xb3\xc3\xa8\xf7\ +\x3e\x8e\xf0\xb5\x98\x22\xf6\x0d\x05\xf1\x02\xf0\x6b\x0e\xa7\xbb\ +\x5c\xd9\x07\x0c\x29\x00\x9a\xc3\xd5\x69\x55\xd5\xe2\xd5\x47\xe7\ +\x56\xc9\x01\xe1\xbe\x1b\xe1\xb9\x67\xf5\x79\x70\x7c\x4d\x8e\x55\ +\x55\xd1\x1c\xae\xce\xbc\x45\x15\x6e\x5f\x9d\x90\xc5\xed\x29\xa8\ +\xae\xa2\x06\xcb\x53\x47\xa0\xcc\x66\x76\x37\xfa\x3d\xd9\xa9\x81\ +\xb5\x71\x52\xa7\xc2\xed\xab\x83\x9c\x45\x76\xbb\x56\x25\xa5\xa4\ +\xe8\xab\xb9\x02\x9b\x86\x65\xf7\x87\x08\xcd\x6f\x92\x8f\x9f\x21\ +\xf5\xdd\xf3\xa0\xa7\x10\xe5\x6e\x44\x45\x2d\x38\x7d\x08\x21\x90\ +\xd2\xe4\xcc\x0b\x24\x12\xd1\x9b\x42\xbd\xab\x81\x6c\xba\x28\xb7\ +\x94\x06\x65\xcd\x4f\x22\x2a\x1b\x00\x30\xa6\xff\x24\xd5\xb7\x1f\ +\x74\x73\x14\xc9\x64\x04\x99\x8c\x80\xc5\x8e\xe2\xae\x03\xab\x93\ +\x44\x22\x7a\x33\x6f\xd1\x52\x64\x61\x0a\xb5\xbc\x28\xb1\xb1\x34\ +\x83\x91\xb3\xc1\x98\x1d\xc1\x98\x1d\x41\x3f\xd7\x9b\x27\x2f\x40\ +\x26\x81\x11\xfc\x07\x99\x8a\x99\x9c\x2b\x1d\x44\x97\x63\x83\xc9\ +\xe8\xfc\x33\x36\x23\x05\xaa\x05\xd2\xcb\xc8\x74\xcc\xfc\x88\x72\ +\x5d\xa5\x7f\x3c\x74\x3b\x03\x0b\x90\x52\xed\x44\x97\x63\x83\x79\ +\x81\x48\xd2\x38\x3d\x1b\xcf\xbc\x53\x1f\xb9\xe4\x44\x1a\x45\x93\ +\xac\xcf\x7e\x83\xda\xb8\xcb\x2c\xf4\xd7\xf7\xd1\x2f\x9e\x2c\xce\ +\x2e\x14\xe6\xd2\x65\xf1\x48\xd2\x38\x0d\x39\x8b\x7a\x4e\xc9\xa1\ +\xd1\xc0\xb5\xfe\xac\xb7\xb5\x64\x55\xc2\xe5\x47\x54\x36\x98\xe7\ +\x60\xd3\x4a\xc6\x65\xbd\xad\x8c\x06\xae\xf5\xf7\xe4\x26\x6b\x7e\ +\x54\x04\x42\xa1\xee\xe1\x90\x31\x29\x1c\xde\xd2\xbd\xaf\x03\xe1\ +\xf0\x32\x1c\x32\x26\x03\xa1\x50\x77\x7e\xef\xd6\x69\xfa\x58\x7b\ +\xe7\x31\x5f\x78\x54\x23\xb3\x5c\x90\xac\xf8\xdb\x10\x0e\xf3\xab\ +\x35\x82\x13\xc8\xa5\xe9\x42\x76\x8b\x83\x85\xca\xcd\xd1\xdf\x47\ +\x06\x8b\x4f\xd3\x15\xf4\x1d\xb4\x1f\xde\xd2\xd4\x7a\x60\x93\x1a\ +\xaa\x26\x74\xfd\xce\x4a\xf7\x34\x72\x35\xeb\x99\xbf\x3c\x79\xe5\ +\xf6\xff\x83\x15\x7c\xf0\xb4\xd8\xbe\xb9\xa9\xe6\x8b\x1d\x0d\xd5\ +\xad\xae\xd8\x94\x22\x13\x21\x90\xb7\xcc\x29\xa1\x22\xec\x1e\x62\ +\xae\x3a\xa3\xff\xfa\xfc\xdf\xe7\xc6\x66\x5e\x3f\xf2\x0b\xfd\x52\ +\x16\x8e\x84\x02\x01\x21\x84\x0a\x54\x01\x95\x9a\x1d\xdf\x7b\x7b\ +\xac\x6f\xdd\x57\xb7\xb1\x6d\x83\xbb\xd2\x53\xe3\x10\x2e\x9f\xcd\ +\xb0\x00\xfc\x9b\x54\xf4\x99\x84\x8c\x2d\x86\xc3\xe1\x2b\x81\xd9\ +\xbf\x0e\xff\x9c\xfe\x28\x9e\x22\x08\x84\x80\xb0\x94\x32\x5c\xb2\ +\x03\x21\x84\x13\xf0\x00\xee\xdc\xd2\x5c\x56\x3c\x5b\xeb\x69\x79\ +\xb8\x51\x74\x18\x12\xe5\xc2\x75\x39\x3c\x74\x83\xc9\x78\x86\x10\ +\x10\x03\x96\x80\x48\x6e\x2d\x4a\xb9\x7a\x01\x28\x79\xab\xc8\x89\ +\x59\x00\x2b\x60\xcb\x2d\x0b\xa0\x02\x3a\x90\x02\xd2\x40\x12\xc8\ +\x48\x79\xab\x87\x26\xfe\x03\x26\x93\xd5\x41\x51\x76\x98\xdb\x00\ +\x00\x00\x00\x49\x45\x4e\x44\xae\x42\x60\x82\ \x00\x00\x0b\xd7\ \x89\ \x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\ @@ -744,6 +825,927 @@ qt_resource_data = "\ \x8f\xf3\x2f\x02\x93\x69\x3a\xed\x1c\xe8\xee\xee\x4e\xd2\xa7\x46\ \xff\xff\x67\x8f\x8f\x7b\xf9\x5f\x5a\xf1\x31\x65\xff\xe0\x15\x90\ \x00\x00\x00\x00\x49\x45\x4e\x44\xae\x42\x60\x82\ +\x00\x00\x2e\x85\ +\x89\ +\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\ +\x00\x00\x80\x00\x00\x00\x65\x08\x06\x00\x00\x00\x85\xb7\xeb\xfa\ +\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\ +\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x02\x4b\x00\x00\x02\x4b\ +\x01\x08\x6c\xbf\x82\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\ +\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\ +\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x20\x00\x49\x44\ +\x41\x54\x78\x9c\xed\x9d\x77\x7c\x5c\xc5\xb9\xf7\xbf\x73\xb6\xaf\ +\x56\xbb\xd2\xaa\xf7\x2e\xcb\xb2\x6c\xcb\xdd\x32\x6e\x80\x4d\xc0\ +\x38\x40\xe8\x10\x7a\x0b\x24\xe4\xa6\x70\x79\x93\x10\x20\x84\xf4\ +\x0a\xc9\x0d\x29\x70\x81\x84\x40\x6e\x08\x3d\xc4\x74\x03\xae\x72\ +\xef\x2a\x96\x64\x15\xab\x97\x55\xd9\xae\x2d\x67\xde\x3f\x56\x92\ +\x6d\x59\xb6\x31\x2e\x58\xc0\xef\xf3\x39\x20\xcf\x99\x3e\xcf\xce\ +\x99\x79\xaa\x90\x52\xf2\x39\x0e\x40\x08\xa1\x03\x72\x81\x10\xf0\ +\x20\x30\x05\xc8\x06\x1e\x01\x9e\x90\x52\xb6\x7f\x72\xbd\x3b\xf9\ +\x10\x9f\x13\x00\x08\x21\x04\x30\x1d\xb8\x18\xb8\x09\xc8\x00\x82\ +\x80\x6e\x54\x56\x2f\xf0\x3f\x80\x16\x98\x0f\x94\xc9\x71\x3e\x81\ +\x9f\x69\x02\x10\x42\x24\x01\xff\x00\xa6\x01\xb1\x00\xf1\x68\x99\ +\x43\x14\xd3\x45\x14\x6f\xe6\x4f\xc5\xe9\x75\xe0\xf2\x39\x70\x7a\ +\x1d\x78\x07\x07\x18\x35\x5f\x39\x52\xca\xc6\xd3\xdf\xf3\x93\x07\ +\xed\x27\xdd\x81\xd3\x09\x21\xc4\x32\xc0\x0c\xec\x06\xfc\xc0\x7f\ +\x80\x92\xf3\xb1\x72\x03\x71\xcc\x21\x8a\x5c\x0c\x91\xcc\x12\xdc\ +\xd1\x79\xb8\xf3\xce\x1b\x29\x1f\x56\xc3\xd2\xed\xeb\x15\x1b\xaa\ +\x5f\x65\x7f\x77\x25\xc0\x54\xa0\xf1\xf4\x8e\xe2\xe4\x42\xf9\xa4\ +\x3b\x70\x9a\x31\x07\x78\x01\xa8\x06\x1a\x8b\x31\x96\xbc\x49\x01\ +\x6f\x52\xc0\x35\xd8\x0f\x2c\xfe\x10\xca\x7a\x3a\x77\x1f\xfc\x6f\ +\x8d\xa2\x11\xb6\xa8\x04\x96\x4c\xbb\x99\xb4\xb8\x42\x80\xc7\x85\ +\x10\xff\x75\xba\x3a\x7f\x2a\xf0\x99\x20\x00\x21\x84\x46\x08\xf1\ +\x04\xf0\xff\xec\x68\xb9\x06\x3b\xcf\x91\xc3\x4e\x8a\x39\x1f\xeb\ +\x11\xcb\x2d\xec\xed\x71\x8d\x95\xae\x51\xb4\x9c\x37\xfd\x56\x92\ +\x63\x73\x13\x81\x5f\x09\x21\x96\x9c\xa2\xae\x9f\x72\x7c\x26\x08\ +\x40\x4a\x19\x06\xd6\x00\xc6\x07\x16\x69\x79\x52\x97\xc1\xb5\xd8\ +\xd1\x22\x8e\x5a\x2e\xc1\xe3\xce\x3d\xd2\x3b\xad\x46\xcf\x17\x66\ +\xdc\x4e\x82\x2d\x53\x0f\xbc\x2e\x84\xf8\xab\x10\x62\xf9\xc9\xed\ +\xf9\xa9\xc7\x67\x82\x00\x86\xf0\x06\xe0\x34\x67\xf8\xd9\x73\x67\ +\x85\x44\xa7\xfa\x8f\x55\x40\x91\x6a\x72\x92\xc7\xd9\x7c\xa4\xf7\ +\x7a\xad\x91\x0b\x66\x7e\x05\x7b\x74\xaa\x11\xb8\x11\x78\x4a\x08\ +\x11\x73\x12\xfb\x3c\x26\x84\x10\x93\x84\x10\xdf\x15\x42\x4c\x3d\ +\xd1\xba\x3e\x33\x04\x20\xa5\xec\x01\x7a\xfa\xbc\x30\x75\x62\x48\ +\x04\x1e\xde\x53\x81\x56\x0e\x1e\xab\xdc\xac\x8e\xc6\xa6\xa3\xbd\ +\x37\xe8\xcc\x2c\x9b\x75\x17\x16\x53\x2c\x40\x02\xb0\xe8\xe4\xf4\ +\xf8\x50\x08\x21\x66\x0b\x21\xbe\x2f\x84\x58\x07\xec\xd1\xea\x0d\ +\x3f\x03\x76\x08\x21\xde\x14\x42\xcc\xfd\xb8\xf5\x7e\x66\x08\x60\ +\x08\x9b\x7a\x3d\x43\x7f\xd9\x82\x33\x82\x0f\xef\xde\x85\x56\x06\ +\x8e\x56\xa0\xb4\xa3\xe5\x98\x95\x86\xa5\x0e\xbd\x69\x0a\x8a\x62\ +\x00\xb8\x5a\x08\x91\x75\xe2\x5d\x8d\x40\x08\xa1\x08\x21\x7e\x06\ +\x6c\x44\xa3\xfd\x31\xb3\xaf\x9d\x97\x76\xe5\x0f\xf9\xde\xab\x2d\ +\x7c\xf5\xb7\xaf\x91\x98\x91\x7f\x3e\x50\x2e\x84\x78\x45\x08\x91\ +\x73\xbc\xf5\x7f\x66\x08\x40\x08\x61\x04\x96\xf5\x79\x0f\xa4\xc9\ +\x98\xe0\xac\xe0\x43\x7b\x76\xa0\x91\xc1\x23\x95\x4b\x73\xf7\x67\ +\x1e\xab\xee\xea\xbe\x68\x77\xce\x84\xdf\x90\x90\x74\x19\xc0\xd5\ +\xc0\x57\x4f\xbc\xc7\x20\x84\x88\x05\x56\x24\x19\x92\xbf\xfb\xe3\ +\x92\xff\x41\x77\xe7\xcb\xb0\xf4\xdb\x6a\xeb\xb9\x0f\xf2\xa3\xe6\ +\x78\x9e\x10\xf3\x43\x9a\x3b\x5f\x24\xf1\x86\xdf\x23\xd2\xa7\x5c\ +\x02\x54\x09\x21\xae\x3c\x9e\x36\x3e\x33\x04\x00\x7c\x01\xb0\x8e\ +\xec\x00\x43\x90\xf6\xc0\xec\xe0\x43\x7b\xb6\xa1\xc8\xd0\x58\x85\ +\x34\x52\xcd\xb4\xfb\x3d\x9d\x47\xaa\x74\x50\x9a\x1a\x15\xd3\x2c\ +\x13\x80\x46\x17\x83\xd6\x6c\x03\xf0\x9d\x68\x67\x85\x10\xc5\xc0\ +\x96\x7c\x73\xc1\xf9\x6f\xcf\x59\xcd\xa3\xf3\xb2\x43\xc1\xe4\x02\ +\x30\x59\xc2\xa8\xa1\x9d\x00\x41\xb3\x5d\xdb\x1e\x3f\x95\xae\xac\ +\x85\xc8\x70\x10\xa0\x1d\x78\xf7\x78\xda\xf9\x2c\x11\x80\x01\x60\ +\x34\x01\x00\xc8\xb8\xc0\x9c\xe0\x0f\x2a\xb7\x1c\x89\x08\xa6\x77\ +\xec\xaf\x3f\x52\xa5\xfb\xdc\x25\xad\xa0\x68\x00\xa2\xd2\x0b\x48\ +\x5d\x7c\x1d\xc0\x31\xcf\x16\x47\xc3\x10\xc3\xaa\x7c\xa6\x6d\x56\ +\xee\x87\xb3\xd7\x71\x45\xee\x5e\x6f\x4f\x46\x7e\x84\x69\x27\xa5\ +\x2e\xa3\xbd\x59\xbf\xbc\x7c\x93\x5a\xf8\xd2\x2f\xe0\xc7\xd3\xe1\ +\xa7\xb3\xa0\xbd\xca\x0b\x5c\x2a\xa5\xec\x3b\x9e\xb6\x3e\x31\x02\ +\x10\x42\x24\x0a\x21\xfe\x2d\x84\xf8\xa7\x10\xe2\x3b\x42\x88\x23\ +\x5e\xb9\x4e\x12\x06\x00\x0e\xfe\x04\x1c\x0c\x99\xe0\x9f\x1b\x7c\ +\xb0\x72\x33\x8a\x0c\x8f\x7e\x37\xbd\xa3\x79\xcc\x4f\x84\x3b\x6c\ +\xab\xf0\xab\xc6\x79\x00\x61\x5d\x6f\x63\xf2\xf5\x17\xa2\xb7\x25\ +\xc0\xc7\x20\x00\x21\x84\x41\x08\x71\xb1\x10\xe2\xdf\xc0\xeb\x5f\ +\x88\x5f\x66\x7d\x7b\xe6\x07\x5c\x9f\xb6\xb3\xa7\x36\x2f\xdf\xac\ +\x97\xc2\xb3\xa8\x4f\xbf\xf3\xeb\x4d\x51\x7d\xb7\x0d\xe4\x4c\x9c\ +\x65\x9b\xad\x74\x96\x3f\x06\xcd\xdb\x9f\x26\x1c\xbc\x1b\x98\x28\ +\xa5\xdc\x7e\xbc\xed\x7e\x22\xac\xe0\xa1\xed\x6d\x05\x3a\x63\x76\ +\xcc\xbc\x8b\x18\xd8\xf8\xc6\x55\xd2\xef\x7e\x48\x08\xf1\x6b\xe0\ +\x11\x29\x65\xef\x29\x68\xd6\x09\x63\xef\x00\xc3\x90\x89\xfe\xb2\ +\xe0\x03\x95\xeb\x75\x3f\x2a\x9e\x83\x2a\x34\xc3\xe9\x99\x4e\x47\ +\xea\x58\xf9\xf7\x79\x8a\x02\x80\x40\x2f\x76\x47\x5f\x91\x55\x2c\ +\x34\x5a\x74\xd6\x78\x80\xe2\x8f\xd2\x21\x21\x84\x1e\x38\x0f\xb8\ +\x12\xb8\xd8\x6e\x4a\xb2\x9e\x9b\x7b\x15\x4b\x73\xaf\xc5\xee\xf0\ +\x70\x5f\x42\x65\x73\x7d\x5e\x81\xfb\xd6\xde\xd8\xda\x94\xa0\x61\ +\xa6\x4e\x63\x98\xda\x13\xa8\x62\xe5\xae\xfb\xd9\x53\xf3\x3c\x03\ +\xae\xe6\x3e\xe0\xab\x52\xca\x63\x5e\x69\x8f\x84\xd3\x4e\x00\x43\ +\x57\x96\x37\x48\xc8\x8b\x15\x37\x3f\xc3\xec\x8b\xa7\xd2\xb1\x6b\ +\x3b\x55\xcf\xff\xd5\x18\xdc\xf0\xaf\xfb\xf1\xbb\xee\x11\x42\x3c\ +\x0b\x3c\x26\xa5\xdc\x79\x12\x9b\x76\x02\xf4\x1f\x61\x07\x18\x86\ +\x4c\xf2\xcf\x0b\x7e\xbf\x72\x9d\xee\xc7\x93\xca\x90\x91\x1d\x52\ +\xaf\x86\xf3\x2c\x81\x40\x9f\x5b\xaf\x8f\x1d\xce\xd7\x1b\x4c\xdc\ +\x1c\x96\xfa\x59\x68\x45\x9d\xf9\xf2\x98\x4c\xa1\x8f\x10\x8c\x3e\ +\x42\x00\xd7\x0b\x21\xbe\x23\xa5\x74\x8c\xae\x7f\x48\xdc\xbc\x84\ +\xc8\xa2\x5f\x62\xd0\x9a\x63\xe6\x67\x2e\x67\x69\xee\xb5\xcc\x48\ +\x3d\x07\x77\xb0\x1f\x67\xb0\xc6\xdd\x97\xe3\xdb\x65\x2b\x9a\x20\ +\x6e\xf7\x86\x0b\x31\xf6\x15\x69\xa2\x55\xb1\x76\xd3\x63\xbc\xf9\ +\xce\x23\xa8\xaa\xda\x49\x84\xaf\xf1\x9b\x13\x59\x7c\x38\xcd\x04\ +\x30\xc4\x32\x7d\x95\x69\x5f\x8a\xe2\xc6\xa7\xb1\xf8\xf6\x3b\x0c\ +\x96\xa8\xb8\xac\x79\xf3\x49\x98\x58\x2c\xcb\x5f\x38\x5b\x78\xab\ +\x56\x9a\x68\x7a\xee\x76\x1a\x02\xb7\x0b\x21\x3a\x89\x70\xf0\x86\ +\x9f\x9d\x52\x4a\xf5\x63\x36\xef\x04\x70\xfa\x41\x95\xa8\x8a\x38\ +\xf2\xe7\x4f\xa6\xf8\xcf\x0a\xde\x57\xb9\x56\xf7\xd3\xe2\x79\x43\ +\x44\x20\x4a\xbb\x9a\x6b\xd7\xa6\xe7\xcd\x06\x90\x52\xa8\xfb\xbd\ +\x05\x56\x04\x6d\xa6\x2f\xd9\xa2\x85\x41\xd8\x86\xcb\xea\xa3\xe3\ +\x01\xf4\x40\x14\xe0\x18\x1a\xb7\x06\x58\x4c\xe4\x86\x70\xa9\x40\ +\xd8\xa7\x26\xcf\xa7\x2c\xa7\x8c\x82\x84\x5c\xb4\xa6\x06\xec\x09\ +\x4f\xf8\xf4\x59\x5f\x0b\x25\xe8\x5b\xdb\xda\x5a\xca\x1a\x76\x0d\ +\xdc\xf8\x85\x50\xe3\xde\xc1\x9c\x99\x0b\x8c\x9b\x37\xbf\xcd\x8b\ +\x2f\x3e\x8a\xd3\xd9\xdb\x00\x7c\x4d\x4a\xf9\xe6\xc7\x9c\x83\xc3\ +\x70\x5a\xc4\xc1\x42\x08\x0b\xb0\x06\x45\x5b\xca\x65\xbf\x80\x25\ +\xdf\x06\xa0\xac\x98\xca\x78\xeb\xa1\xdb\xe5\xae\x37\xde\xf3\x36\ +\xbb\x3e\x30\xab\xf6\xb7\xa1\x66\x17\x54\x07\xa1\x0a\xe8\x86\xa1\ +\xff\xbe\x49\x44\x8a\xf7\xb6\x94\xd2\x39\x46\x5b\xd1\x80\x2a\xa5\ +\xf4\x8c\x4a\x7f\x08\xf8\x01\x40\xc7\xaf\x09\xc5\x9a\x8f\x4d\xfc\ +\xa2\x39\x6a\x8d\xee\xe7\x45\xf3\x91\x88\xaa\xb8\x94\x55\x8f\xcc\ +\x3a\x77\x11\x40\xe7\x60\xfa\xda\xd6\xc1\x9c\x62\xf3\xc5\xd6\x01\ +\x11\xab\x3d\xe4\xee\x1d\xf2\xb9\x58\xff\x8d\xa9\x3e\x20\x8e\x88\ +\x8e\xc1\xd5\xc0\x15\x40\x52\x3e\x30\xc7\x26\x58\x72\x4b\x3a\x39\ +\x19\xfd\x58\xa2\x0e\x11\x35\xb4\x35\xb7\x88\x7d\x9b\x5a\xbf\x18\ +\xe5\xd0\x5e\x38\x1d\xa0\xa7\xa7\x99\xcd\x9b\x5f\xa4\xb9\xb9\x3a\ +\x08\xfc\x06\x78\x58\x4a\x79\xc2\x37\x8c\x83\x71\xba\x76\x80\x07\ +\xb0\x26\x95\x72\xe7\xcb\x90\x37\x2f\xd2\xb0\x42\xcd\xe8\xc5\x07\ +\x98\xb2\x6c\x89\x39\xcb\x31\xbb\x6b\xb7\x23\x36\xd8\xb7\x68\x7a\ +\x12\xeb\xb6\x68\x31\xec\x86\xde\x10\x54\x91\x40\x25\x37\x50\xcd\ +\x0d\x78\x09\x0a\x21\xd6\x00\xcf\x10\xf9\xb5\x95\x01\x33\x81\x09\ +\x80\x14\x42\xd4\x00\xdb\x80\x5a\x22\xda\x3d\x0f\x0c\xb7\xd1\xed\ +\x42\x1b\x6b\x3e\x76\xa7\x65\x86\x67\x41\xf0\xde\xea\x35\xba\x5f\ +\x15\xcd\xcf\x19\xe8\x4e\x00\x50\x51\x06\xdb\x06\xb3\x52\x74\x8b\ +\xf0\x1c\xbc\xf8\x32\x14\xa4\xaf\x72\x0d\x5d\x9b\x5f\x07\xe8\x05\ +\x36\x01\x25\xd9\xc0\xd9\x43\x4f\x3a\x50\x87\x1c\x9c\x5c\xd4\x7c\ +\xb0\xd8\xd1\xdb\xdf\xc7\xa6\xba\x06\xa5\x60\x87\xfb\xd6\x89\x6a\ +\xd4\xcc\x78\xaf\xd7\xc9\xe6\xcd\xaf\x52\x53\xb3\x3e\x24\xa5\x7c\ +\x96\xc8\x56\xbf\xe7\xa3\x4d\xf5\xf1\xe1\x94\xef\x00\x42\x88\x89\ +\xc0\x4e\xbe\xf2\xb2\x8e\xe8\xac\x30\xdd\xbd\xb5\xf4\xf8\xba\x0c\ +\x83\x4e\xcd\x03\xbf\x9f\x2b\x73\x26\x58\x52\x14\xa4\x0e\x21\xb4\ +\x02\xa9\x97\xa0\x1f\x0c\x4b\xff\x53\x15\xad\x4d\xbb\xf4\x1f\x68\ +\xfa\x35\xfb\xa7\x33\x18\xec\x67\xfb\x56\x37\xbe\xed\x69\x88\xb0\ +\x40\x02\x4d\x44\x76\x86\x4a\xa0\x9e\xc8\x12\x03\x5a\x83\x82\x35\ +\x31\x16\xd5\x15\x8d\x46\xa3\x3f\xec\xb1\x5a\x2d\xdc\x7a\x61\x06\ +\x57\xc5\x3d\xae\xda\xa3\xe4\x47\xba\x05\x89\x7a\xcb\x6a\xdd\xaf\ +\x27\x94\x7d\x7d\xe9\x55\x81\xda\xc0\xc4\x4d\x7d\xd3\x12\x33\x8c\ +\xc5\x09\xf9\x52\x0d\xd3\x5f\x5d\x4e\xf7\xe6\xd7\xe9\xd9\xfe\x0e\ +\x21\xef\x00\x10\x51\x27\x5a\x4c\x64\xd1\xb3\x47\xd5\xa5\xc9\x66\ +\x4d\xd4\xf7\x58\x00\xe0\xf3\xb3\xbe\xae\x8e\x9c\x5e\x97\x5e\x54\ +\x84\xbe\x9d\x28\x0c\xe9\xca\xae\x5d\xef\xb1\x63\xc7\x9b\x04\x83\ +\x83\x1f\x02\x77\x4b\x29\x2b\x4e\x68\x01\x8e\x35\xb6\xd3\x40\x00\ +\x7f\x42\x28\x77\x92\x7f\x17\x58\xa7\x8c\x7e\xe9\xb9\xe2\x8e\x09\ +\x3b\x2e\xbb\xad\xf0\xac\xe1\xa4\xae\x3e\x57\xdf\x33\xbb\x5b\xc2\ +\xd2\x62\x89\x97\x52\xf5\xd6\x98\xdf\xde\xef\xd6\x76\x17\x01\x20\ +\x65\x1b\x75\x5b\x5b\x68\xdd\x34\x03\x54\x45\xe7\x56\x84\x31\xac\ +\xc5\xa6\x18\xc9\x8b\x8e\x25\x23\xdb\x16\x32\xdb\x75\x5a\x55\x4a\ +\x59\xf3\x58\x4a\xd8\xb3\x27\x6d\xd4\x0e\x17\x72\x5d\x7e\xf9\x32\ +\xc5\x62\x89\x8a\x12\xde\x7d\xc1\x0b\xe4\xbd\xee\x74\xab\x2f\x96\ +\x8f\x00\xa5\xce\xb2\xfa\x89\x95\x77\x2a\xff\x38\xe7\x7a\xe9\xcb\ +\xaa\x2d\xf5\xb2\x23\xba\xe5\xa9\x97\x70\xed\xaa\x02\x20\x85\x03\ +\xbf\xf4\xbc\xa3\xd4\xa3\x2b\xe5\x43\xc3\x1d\x24\xee\xdb\x87\xea\ +\x74\x51\xd2\xd0\x6b\xf3\x74\x47\xdf\x1f\xd5\xb8\xbf\x86\x8d\x1b\ +\x5f\xc6\xed\x76\xb4\x02\xf7\x4a\x29\xff\xef\xa3\xf4\xeb\x44\x71\ +\x3a\x08\x20\x16\x58\x83\xa2\x9f\x44\xfe\xd7\x21\xba\xf0\xb0\x3c\ +\xf1\xb6\x1e\xf7\x0f\xfe\x7a\x89\xa5\xa9\xbd\x8d\x77\x5d\x12\x6d\ +\x4c\x0c\x61\x97\x0b\xb5\xb7\x0b\xe9\x6f\x75\xf7\xe4\x6c\x6e\x8d\ +\xb1\xaa\x5e\x7b\xd8\x1b\x65\x93\xfe\xf8\x98\xc0\x80\x3d\x4a\x51\ +\x11\x47\x17\xe7\x06\xeb\xff\x37\xbb\xb9\x6f\x73\x5c\x2e\x80\x94\ +\x78\x97\x2f\x5f\xdc\x9d\x90\x10\x3b\xc2\xa7\xd7\x32\x18\xbc\x44\ +\xf9\xd6\x36\xbb\x68\x98\x33\x56\x05\x12\x42\xbe\x20\xad\xae\x00\ +\xdd\x2e\x3f\xde\xbf\x27\x5f\xd2\xfe\x5e\xe9\xd2\x45\x52\x90\x0c\ +\xe0\xd9\x55\x4f\xd6\xa6\xb7\xb8\xee\x5d\x17\x36\x6f\x84\x55\xa0\ +\x40\x20\x56\x88\x8e\x44\x8d\xa6\x43\x17\xd1\x2b\x94\xaa\xaa\x82\ +\x56\x84\xf5\x59\x7a\x77\x6b\x8e\x53\xdd\x61\xec\x59\x8e\x44\x6c\ +\x6f\xcf\xa1\x43\x73\x25\xe5\xe5\x2f\xd2\xd9\xb9\xcf\x4f\xe4\x3b\ +\xff\xb3\xd1\xe7\x97\x53\x89\xd3\x75\x08\xcc\x00\x56\xa2\xe8\x0a\ +\xd0\xc5\x82\xa2\x8b\x3c\x42\x07\x8a\x16\x84\x0e\xab\xdd\x80\x26\ +\x49\xc1\xd7\xd9\x49\xa0\xa3\x13\x8d\xea\x63\xe9\xcd\x82\x25\x37\ +\xe9\x90\x3a\x2d\xbe\x41\xbd\xcb\x10\x6f\xac\x0b\xea\x8c\x4e\x5f\ +\x50\x87\x3f\xa0\xd7\xfa\xc3\x1a\x53\x28\xac\xb1\xa8\x52\xc4\x48\ +\x29\xec\x1c\x7e\xa6\x09\x34\x3e\x95\x5d\xe3\xd8\x18\x37\x71\xe1\ +\xc2\x99\xbb\x73\x73\x33\x4a\x47\xbd\x97\x52\xd0\x3b\xd7\xf8\xfb\ +\x1d\x49\xbd\x6f\x46\x0f\xf8\xf1\xb9\x03\x08\x5f\x88\xa8\x50\x98\ +\x04\x55\x92\x3a\x5c\x67\x48\x23\xaa\x57\xcc\x89\xcf\xde\x9a\x79\ +\x7f\xb3\x8a\xb6\x40\x4a\xa9\xde\xbf\xe7\x81\x40\x5a\xc8\xa9\x9d\ +\xfd\xbf\x33\x35\x30\x06\x35\x6a\xf0\x91\x2a\x76\x91\x4d\x08\x8b\ +\x98\x8c\xc0\xfa\x56\x6b\x73\x43\xbd\xda\x92\xb3\xbe\x79\x1a\x9b\ +\xea\x0d\xd4\xd6\x6e\x04\xe4\x2b\xc0\x3d\x52\xca\x86\x93\x3e\xf9\ +\xc7\xc0\x69\x53\x0a\x15\x42\x14\x01\xbb\x88\xf0\xc9\x7f\x03\xd8\ +\x89\x68\xd6\x4e\x05\x7a\x80\xf5\xc0\x3a\x22\xea\x5a\xbd\x40\x0e\ +\x70\xb9\xce\xc0\x97\x2e\xbc\x1d\x71\xf5\x77\x21\x3e\xed\xc8\xf5\ +\x4b\x90\x83\x21\x6d\x9f\x3f\xa8\xeb\x73\x0f\x1a\x5c\xde\x80\xde\ +\xeb\xf3\x1b\x3d\x21\x4f\x74\xa8\xe2\x9d\xdb\xfd\xb6\xe2\xd2\x58\ +\xa7\x55\xab\x19\xb0\x69\x8c\x4e\xab\xce\xe2\xb4\x6a\xec\x1e\xb3\ +\x12\x27\x15\xb4\xdf\xee\xa9\x5e\x37\xf3\xf9\x87\x82\x9d\x8d\x55\ +\x8b\x8f\x50\x7d\x68\xcd\x14\x6b\x5d\xaf\x55\x57\x34\x60\xcc\xdb\ +\x53\x13\x7f\xe3\x24\x4b\xd8\xbd\xeb\xb7\xfb\xbe\x33\xd5\xd6\x18\ +\xb3\x3b\xef\xb5\x09\x93\x47\x72\x6a\xf1\x90\xc1\x2e\x32\x15\x88\ +\x62\x0a\x91\x03\xea\x08\xfe\xbc\xa7\x91\x67\xaa\xf5\x6c\xae\x6a\ +\x27\x14\x0a\x54\x00\xdf\x94\x52\xbe\xf7\x71\xe6\xf4\x64\xe0\xb4\ +\x6a\x05\x0b\x21\x2e\x05\xd6\x4a\x29\xbb\x0e\x4a\x33\x1e\x8d\x99\ +\x21\x84\x98\x06\xfc\x5a\x67\xe0\x9c\x0b\x6f\x87\xb1\x08\xc1\x08\ +\x9d\xd1\xd0\x66\x07\xa7\x15\x14\x13\xd8\x74\x90\xd9\xdc\x7c\xf6\ +\xfe\x77\xde\xfd\x83\xf1\xd5\x6b\x0d\x2d\x75\x46\xcf\x39\x63\xd5\ +\x3f\xd9\xdf\xbf\xe6\x5f\xfb\xd7\x2f\x90\x52\xaa\xab\xfe\xef\xe7\ +\xdb\xfc\x1e\xe7\xcc\xd1\x79\x3a\x6c\xba\x55\x1b\x27\x5b\x47\xe4\ +\xfc\x95\x89\x77\xac\x29\xf5\x35\xab\x37\x76\x3e\xbb\x28\x63\x55\ +\xf6\xea\x84\xca\xe4\xa9\x64\xb2\x87\x0c\x45\x83\x49\xa6\x23\x85\ +\x7b\x30\xac\x0e\xf6\xfb\xa4\xae\xd7\xab\xda\x5a\x9d\xc1\xe4\x3d\ +\x5d\x1e\xcd\xa6\xb6\x01\xde\x6d\xe8\xa3\xc7\x13\xec\x27\x72\x25\ +\xfd\xa3\x94\x63\xcb\x1f\x4e\x17\xc6\x8d\x5a\xb8\x10\xe2\x8b\xc0\ +\xcf\x75\x06\x8a\x97\xdf\x06\xb7\x5c\x06\x69\xf1\x90\x9f\x4d\x38\ +\x3a\x1a\xcd\xc1\x79\x03\x01\xdb\xc0\x8a\x15\x7f\xdb\xd9\xdd\x3d\ +\x79\x81\x6a\xd3\x6e\xdf\x70\x95\x5e\xdd\x48\xd7\x61\x0b\x6b\x92\ +\xe1\xea\xf2\xba\xf7\xb2\x0d\x32\x6c\x04\x08\xfa\x7d\x03\xef\x3f\ +\xf7\x93\x7e\xa9\x86\x47\xce\x09\xaa\x42\xd3\x8a\xb9\xf6\x44\x55\ +\x11\xa6\xe1\x34\x5d\x94\x69\xdf\x17\x06\x82\xc1\x94\x40\x67\x51\ +\xe1\xee\x59\x6f\x1b\xd4\x98\x9c\xb0\x14\x49\x7a\x45\x17\x3d\xcc\ +\x60\xaa\xe8\x71\xf1\x46\x5d\x37\x6f\xec\xeb\x62\x5d\x4b\x2f\x41\ +\x55\x42\x84\x31\xf4\x17\x22\xec\xee\x9e\x53\x32\x51\xc7\x89\x71\ +\x43\x00\x30\xc2\x51\xbb\x06\xb8\xfd\xbc\xf3\x58\xf8\xf2\xcb\x10\ +\x15\x75\x68\x9e\x3d\x7b\x6e\x28\x2f\x2f\x7f\x20\x4f\x4a\x25\x11\ +\x60\x60\xbe\xa5\xb2\xbb\x58\x9f\xf8\x3a\xfb\xe3\x0f\xa9\x0b\xe9\ +\xfc\x4f\xc3\xea\xfe\xdc\xa0\xe7\x10\x79\xbf\xd3\xd1\x5a\xb7\xfe\ +\xe5\x3f\xa4\x30\xb4\x75\x6f\x2c\xb6\xec\xe8\xb0\x1b\x4a\x01\x92\ +\xe2\x44\xe5\xac\xa9\xba\xbe\x04\x93\x46\xa3\x7f\xa9\xd5\xae\xaa\ +\xb2\xd0\xba\x6e\x61\x20\xdf\x1a\xad\xf7\x06\xc3\xac\x6c\xec\xe1\ +\x8d\x7d\xdd\xbc\x59\xdf\x45\xd3\xc0\x21\xfc\x9a\xbd\xc0\xa3\xc0\ +\xdf\x4e\x36\x23\xe7\x44\x31\xae\xec\x02\x86\x94\x3b\x9f\x05\x9e\ +\x15\x42\xdc\x71\xee\xb9\xfc\x65\xc5\x0a\x88\x8b\x03\x97\x2b\xbd\ +\xed\xf5\xd7\xff\xaf\xc5\xed\x4e\x2d\x3b\xa8\x88\xcf\x57\x68\x9a\ +\x68\x02\x21\x10\x1d\x12\x99\x3c\xfc\xe2\xc1\xae\x8a\xca\xdc\xa0\ +\xe7\x30\x55\x2a\x6b\x5c\x5a\xfe\xa4\x79\x17\x6d\xa8\x58\xff\xef\ +\xb9\x03\x16\xed\x9a\xce\x38\x43\xd9\xc4\x5c\xcd\xfa\x92\x09\x1a\ +\xab\x5e\x27\x4a\x00\x0a\x1d\xc6\x8a\x26\xad\xf0\x85\x7c\x52\xbe\ +\x5a\xd5\xa5\x7f\xb7\xa1\x8a\x55\xfb\x1d\x0c\x86\x55\x00\x17\xb0\ +\x83\x08\x13\x6a\xfb\xd0\xdf\xbb\xce\x54\x0b\xa2\x71\x45\x00\xa3\ +\xf0\xfa\xc6\x8d\xfc\x65\xe1\x42\x85\x7b\xee\xb9\xbd\x33\x14\xba\ +\xd7\x0a\x62\xf6\xc1\x19\x02\xf1\xda\xfa\xb0\x96\x49\x00\x46\x34\ +\xfb\x7d\x84\x92\x01\xe6\xfa\x1c\xab\xae\xee\xdf\x7f\x44\xdd\xbd\ +\x8c\x49\x65\x73\xbb\x3a\xea\xde\xee\x5d\xd2\xa3\xbd\x3a\x5f\xd3\ +\xa3\x08\x31\x6f\xf8\x9d\x4e\x55\xb6\xc6\x0c\x6a\x67\xd4\x86\x34\ +\x3b\xeb\xf6\xab\xe2\xde\xf7\xab\x21\xb2\xd0\xbf\x03\xca\x81\xda\ +\x33\x75\xb1\xc7\xc2\xb8\x25\x00\x29\x65\xbb\x10\xa2\xcd\x5c\xfa\ +\xa5\xd4\xee\xde\x1b\x63\x63\xad\x42\x3f\x3a\x8f\x67\xaa\x79\x64\ +\xdb\x8f\xc3\xe0\x6d\x21\x44\xb4\x1a\xdc\xf3\x44\xcb\xa6\x79\xa3\ +\xf3\x22\x19\xa0\xbd\xbf\x9a\x95\x95\x3e\x5e\xd9\x96\x36\xad\xba\ +\xf5\xdc\x60\x57\x72\x65\x48\x88\xe4\x83\xb3\x15\xf6\x19\xf5\x00\ +\x4e\xb7\x62\xad\x6e\x1c\x49\xfe\xa1\x94\xf2\xb5\x93\x38\xbc\xd3\ +\x86\x71\x4b\x00\x42\x88\x04\x20\xc5\x36\x3f\x9b\xba\xb3\x5b\x95\ +\x92\x6d\xb6\x75\x26\xb7\x71\x84\xa3\x88\x22\x3c\xbe\x6c\x7d\xd2\ +\xf0\x3f\x93\x30\x9b\x5b\xf1\xf4\xbe\xd2\xb4\xce\xae\x95\x52\x07\ +\xf8\x70\xb8\x2b\x58\x53\x33\x20\x5f\xde\x9a\xc0\x96\xc6\x62\xc2\ +\xea\x08\x43\x48\x00\x53\xa7\x77\x27\x6c\x6d\x48\xea\x43\x44\xec\ +\x06\xf5\xaa\xd8\x1e\x3b\xa8\x9d\x06\xe0\xf2\x2a\xb1\x43\x04\x30\ +\x08\x7c\x62\xd7\xb8\x13\xc5\xb8\x25\x00\xe0\x1c\x40\xd8\x72\x52\ +\x11\x51\x1a\x6d\xc5\x82\xbd\x67\xa5\xd6\x25\xad\x4d\xad\x4d\x9e\ +\x0e\x98\x03\x09\xba\x1a\xa9\x11\xd3\x86\x33\x27\x63\x4a\xf9\xc1\ +\x9e\xf2\x75\x69\x6b\xb7\x9a\xe4\xcb\x5b\x5a\xf8\x70\x6f\x09\xc1\ +\xf0\x61\x37\x83\x83\x61\x6c\x0a\xa7\xe4\x7f\xa5\x7f\x63\xdd\xe3\ +\x31\x73\x00\xf2\xfb\x4d\x23\xb2\x03\x7f\x48\xd8\xf6\x36\x02\xf0\ +\xe1\xe9\xe4\xdc\x9d\x6c\x8c\x67\x02\x98\x05\x60\xcb\x49\x19\x49\ +\x68\xcb\xef\x9c\xdf\x9f\xe8\xac\x2f\xda\x90\x1f\xf6\x96\x18\x74\ +\x9a\x40\xb0\x21\x75\xeb\xb6\xa6\xd2\xa7\x9f\xd3\x4d\x7a\xfe\xa5\ +\x22\xa3\xd3\xf5\xc5\xe3\xfd\x38\x27\x3d\xe1\x9d\xd3\x75\xa3\x79\ +\xb5\x6f\x9e\xc1\x66\xf7\x6b\x47\x0c\x31\x34\x1a\x8d\xa8\xd9\x0f\ +\xc0\x3b\x27\x61\x2c\x9f\x18\xc6\x33\x01\x4c\x35\x27\xd9\xd1\x9a\ +\x8d\x87\x24\x7a\xad\xbe\xdc\x95\xf3\x6b\x3f\x9c\xfd\x97\x8e\x81\ +\xab\xd7\xbe\x17\x6f\x72\x75\x0b\xbd\xa7\x37\x38\x68\x4c\xa8\x94\ +\x98\xf4\xda\x41\x8f\x59\x13\x0a\x58\x14\x35\x64\x43\xca\x58\xc6\ +\x62\xe1\x8e\x42\xd6\x12\xc7\xfc\x3f\x7c\xd7\x8e\x76\x8e\x8e\xa2\ +\x7c\x3d\x51\x66\x05\xa1\x35\xe0\x1f\x74\x01\x6c\x38\x35\xc3\x3b\ +\x3d\x18\xcf\x04\x30\xc5\x96\x7b\x98\xaa\xde\x40\x79\x87\xb5\x72\ +\x4f\x77\xd4\xac\x17\xad\x39\x51\x55\x0b\x72\xd6\x5c\xe1\xeb\x2e\ +\x23\xa2\xa1\x73\x18\x84\x54\xc3\x3a\x9f\xab\xdf\xe8\xe9\xeb\x37\ +\xba\x7b\xdd\x06\x77\x8f\xcf\xe4\xea\x0e\x9a\xdd\xbd\x61\x9d\xbb\ +\x57\x74\x78\x7a\x73\x3a\x7d\xce\xcc\x07\xfd\x1e\x65\xfd\x43\x0e\ +\xee\xc7\x81\x10\x90\x99\xa6\xc3\x16\xad\x40\x44\xd8\xb3\xed\xd4\ +\x0e\xf3\xd4\x62\x5c\x12\x80\x10\xc2\x0e\x24\xda\x72\x0e\x10\x80\ +\x3f\x24\x76\xbe\xd2\x10\x9f\xe0\x0e\x6a\xca\x08\xd1\x06\x44\x3d\ +\x1d\x9d\xbc\x60\x9b\x31\xba\xe2\x27\xbd\xf5\x71\xca\x90\x04\xef\ +\x60\x48\xa1\x68\x02\x66\x5b\x5c\xc0\x6c\x8b\x73\x26\x64\x8f\xa4\ +\xab\xd0\xbe\x4b\x9a\x6a\xfe\xee\x0c\x47\x55\xf5\xb6\xc3\x7f\xee\ +\x02\xf8\x37\xd0\x2a\x25\x8e\xa6\x96\xa0\x81\x08\xa3\xa8\xe7\x44\ +\x75\xf2\x3e\x69\x8c\x4b\x02\x60\xe8\x17\x6d\xcb\x49\x41\xaa\xaa\ +\x5a\xe7\xb4\xac\x5e\xd5\x1e\xbd\x50\x4a\x11\x39\xa4\x85\x35\xee\ +\xe1\x8c\x3b\x75\x51\x93\xae\x4f\x98\xd8\xf3\xa7\xde\xda\x1d\x56\ +\x35\x34\x5a\x1a\x38\x02\x89\xe8\xae\x52\x0d\x55\xff\x54\x63\x62\ +\xb6\x86\xcd\x93\x25\xa4\x60\x04\x9c\x5b\x22\xaf\x23\x3a\xf7\x87\ +\xa9\x8c\x8f\x77\x8c\x57\x02\x88\xc0\x62\xe5\xf5\x6d\xa2\xbf\xd3\ +\x6c\x5d\x7c\x48\x7a\xe0\x50\xcb\x9c\x3e\x45\x1b\xff\xe5\xf8\x89\ +\x31\x3f\xea\x6b\x58\x55\x1a\x74\x8f\x30\x80\xa4\xa4\xaf\x0e\xc3\ +\x9e\x7f\x85\x62\x2c\xe5\x6a\xd4\x14\x15\x16\x1e\xd6\x86\xb3\x05\ +\xa0\xff\xd3\xb8\xf8\x30\x0e\x09\x60\xc8\xa1\xd3\x03\x18\x75\x6c\ +\xdc\x0d\xd8\xfc\x46\x46\x9b\x94\x84\x74\x87\x6d\xcb\x61\xd0\xde\ +\x17\x9b\xb3\xe8\x52\x4f\xf7\xaa\x25\x1e\x87\xe6\xa5\x50\xac\x61\ +\x55\x38\xaa\x34\x84\x58\x70\xd4\x06\x23\x04\xd0\x7d\xb2\xfa\x7f\ +\xa6\x61\xdc\x11\x00\xf0\x4d\xe0\xab\x7c\x69\x3a\x2c\x30\xaa\xfc\ +\xaf\xdb\x4c\x69\xcb\x16\x2e\x4e\xb5\xa0\x57\x22\xaa\x63\x01\x46\ +\x5b\xf2\x78\x09\xb1\x03\xf7\xa0\xe6\x65\x77\xd4\x9c\x97\xe5\xa8\ +\xab\xc3\xd1\xe0\x6c\x85\x88\xe6\xe1\xa7\x12\xe3\x91\x00\xbe\x82\ +\xa2\xc0\x0d\xf3\x24\x69\xd1\x0a\xff\x54\x77\xb2\xa3\x6f\x26\x3b\ +\xfb\x24\x73\xe3\x36\x70\x7e\x4a\x22\x41\xc2\xc0\x20\x61\xb9\x03\ +\x57\x28\x8c\x37\x38\x95\xb0\x3c\x9c\xfd\x7b\x2c\x04\xbd\xd0\xdf\ +\x00\xf0\xe1\x49\x1e\xc3\x19\x83\x71\x65\x1c\x2a\x22\x7c\xf9\x09\ +\x2c\x99\x08\x69\xb1\x91\xfb\xfb\xa5\x9e\xc8\xf7\x5e\x22\x28\x77\ +\xcc\xe5\xa1\x3d\x59\xe6\x4e\x7f\x8f\xe8\x53\x37\xd2\xe6\x2d\xc5\ +\x15\x98\x47\x58\x46\x1d\xa5\xda\x23\xa3\xf6\x0d\x08\xf9\x7d\xc0\ +\x73\x27\x69\x08\x67\x1c\xc6\xdb\x0e\x10\x39\xa4\xdd\x34\xff\x40\ +\xca\x62\xef\x2c\x9e\x8d\x6e\x43\x25\x72\x27\x94\x68\xf2\x97\xe7\ +\x4d\x56\x52\x12\xf4\x7b\x56\xa8\x95\xa1\xf6\x1e\x3f\x5d\x5d\xe0\ +\x1f\x9c\x82\xe0\xf8\x08\xa1\xf2\x05\x80\x67\xce\x14\xe5\x8d\x53\ +\x81\x71\xb5\x03\x00\xd7\x92\x6f\x82\x44\xcb\x81\x14\x81\x86\x79\ +\x83\x35\x23\xff\xd4\x88\xce\xa8\xa2\x84\x42\x93\x8d\xec\x99\x57\ +\x2b\x25\x31\x53\x13\x03\x94\x94\xcc\x65\xfa\x34\x0d\xd9\x59\x1b\ +\x31\x99\xd6\x01\xfd\xc7\x6c\xa9\x63\x3b\xf4\xee\x03\x78\xf2\xe4\ +\x0f\xe3\xcc\xc1\xb8\x21\x80\x21\xf3\xf1\x8b\x98\xee\x83\xbd\x8f\ +\x85\xf0\x39\xca\x47\x5e\x5e\x35\x30\x85\x88\xe3\x47\x4c\x79\x71\ +\xb5\x23\x65\x14\x74\x13\x97\xb0\x68\xc2\x02\xb6\xa3\x51\x9c\xc4\ +\xc5\xcf\xa1\xb8\xf8\x2c\xa6\x4f\x8f\x22\x37\x77\x2b\x66\xf3\x1a\ +\x10\x63\x9f\xf0\x2b\x5e\x00\xd8\x22\xa5\xdc\x7c\x2a\xc7\xf5\x49\ +\x63\xdc\x10\x00\x70\x1b\x20\x28\x05\xb4\x7e\x2d\x3b\x7e\x57\x46\ +\xf3\xbb\xeb\x90\xd2\x8d\x45\xda\xc9\x0c\x6e\x01\x48\x3c\xbf\x50\ +\x33\xba\xa0\x3d\x9b\xe9\x33\x2f\x47\xa3\x8f\x22\xb2\x98\x42\xe8\ +\x88\x8d\x9d\xc1\xc4\x89\x0b\x98\x31\x3d\x8e\x82\xfc\x9d\x58\x2d\ +\xab\x40\xb4\x01\xe0\x75\x40\xc3\xfb\x00\x7f\x38\x5d\x83\xfb\xa4\ +\x30\x6e\x74\x02\x85\x10\x35\x64\x51\xc0\x7d\xa3\x5e\x18\xe3\x9a\ +\x99\x7c\x47\x3f\x8d\xd1\x06\x7e\x12\x97\x37\xf3\x8d\x1b\x5d\xba\ +\x68\xc3\x11\x5d\xb5\x35\x6c\x66\x55\x47\x35\x73\x61\x94\x5b\xd0\ +\x61\xec\xdb\xaa\xb2\xfa\x71\x85\x0d\x8f\xb7\x01\xb9\x52\x1e\xdb\ +\x93\xd8\x78\xc6\xb8\xd8\x01\x86\x4e\xff\xb9\x8c\xe5\x7b\xcb\xef\ +\xc8\x60\xf3\x2f\x8a\x89\xdd\xde\x66\x28\x31\xac\x3c\xda\xe2\x03\ +\xe4\xcc\x62\xd1\xe4\x0b\x68\x12\x0a\x63\xbb\x7d\xc9\x9b\xa1\xd0\ +\xb2\x11\xe0\xd7\x9f\xf6\xc5\x87\xf1\xb5\x03\xfc\x91\x68\xee\xe2\ +\xe7\x8c\x79\x77\x31\x6b\x94\xca\xff\x2e\x9c\x5e\x1b\x35\xe5\xf6\ +\x98\xad\xba\xb3\xf5\x6d\xe4\x64\x86\x85\xf6\x88\xa6\x24\x6a\x18\ +\x6f\xc5\x5b\x6c\x75\xf7\x72\x28\x27\xb0\xab\x0e\x1e\x28\xe8\x06\ +\xb2\xa5\x94\xc7\x70\x27\x31\xfe\x31\x9e\xae\x81\xf5\xb8\x88\xa8\ +\x5f\xce\x3a\x90\x68\x54\x94\xea\x3b\x33\x32\x9d\x4b\xe3\xe3\x66\ +\x83\x4c\x5c\x16\xf8\x4a\x1c\x81\x88\x8c\xbf\x47\xa4\x74\x6f\xd0\ +\x5e\xd0\xb0\x56\x7b\x91\x77\x8f\x52\x66\xed\x53\x12\xf2\xa5\x14\ +\x56\x00\x45\x83\x79\xf2\x85\x2c\xe8\xa8\x61\x43\xc3\x26\x8a\x90\ +\x44\x76\x8e\xed\x2f\x01\xfc\xcf\x89\x2c\xfe\x90\x8f\x82\xb9\x40\ +\x3c\xb0\x0f\xe8\x22\x72\xf3\x18\x38\xd3\x14\x46\xc7\xd3\x0e\xf0\ +\x20\xf0\x43\xce\x05\xae\x04\x83\x50\x6a\x6f\xcd\xc8\x70\x2c\x4b\ +\x88\x9f\x23\x0e\x52\xea\x98\x9b\xbc\xb7\xca\xae\x77\x4f\x1c\xab\ +\x0e\x89\x90\xfb\x34\x93\x1b\xd6\x68\x2f\x6a\xdb\xa0\x5c\xa0\xee\ +\x53\xa6\x24\x78\x85\x25\x7f\xd0\x4d\xf7\xce\x67\x9a\x0d\xe1\xe8\ +\x8c\x38\x7e\x36\x1b\x1a\x37\xdf\x02\x74\x00\xd1\xa3\x1e\x0b\x91\ +\xb3\xc3\xe8\x47\x47\x64\xb1\xe3\x89\x78\x0b\x8d\x61\x6c\x45\x13\ +\x95\x08\x21\x0c\x3f\x03\x43\xff\xff\x83\x94\xf2\xfd\x13\x9d\xa3\ +\x8f\x83\x71\x41\x00\x42\x88\x74\xa0\x86\xe9\x98\x74\x5f\x12\x5c\ +\x91\x95\x5c\x7f\x6d\x71\x6a\x8e\x18\x63\x92\x53\xa3\x7a\x3f\x2c\ +\x8d\x6b\x58\xfc\x51\xeb\x0e\x08\xa3\x7f\xa7\xb2\xa0\xee\xef\xdb\ +\x4a\x33\xdf\x7e\x3d\xc9\xea\x7d\xe3\x05\xa4\x94\x48\x35\x8c\x2a\ +\x55\xa4\x1a\x46\xaa\x2a\x72\xe8\xef\x91\x34\xa9\x0e\xa5\x87\x51\ +\xd5\x03\x69\x8a\xcd\x82\x2d\xa7\x10\x5b\x6e\x11\xae\xac\x74\x2e\ +\x2f\xbb\x8c\x58\x5f\x08\xa7\xa3\x13\x57\x6f\x27\x4e\x47\x27\xce\ +\xde\xc8\xe3\x6a\xac\xa2\x66\xdb\x1a\x86\xb6\x9a\xc7\xa4\x94\x77\ +\x9f\xac\x39\xfb\xa8\x38\xe3\x3f\x01\x42\x88\x42\xe0\xdf\x09\x73\ +\xf4\xa6\x39\x97\xc5\xf0\x95\x45\x19\x28\x9a\xc3\xe4\x7f\x23\xe8\ +\xf2\xd9\xec\x1f\xa9\x62\x95\x36\x3a\x45\x93\x7e\x57\x28\x30\x6b\ +\xfd\xaa\xb8\xf0\xa6\x2d\xe6\x58\xd3\x15\x74\xdc\xf4\xd8\xc7\xea\ +\xa7\x37\xd0\x1f\x6e\xd6\x35\xfa\x4d\x39\x39\x51\x00\xad\x22\xe8\ +\x68\x29\x48\xb7\x14\x28\x39\x06\x0d\x82\xa4\xcc\x43\xcd\xe2\x43\ +\x41\xbf\x6a\x7b\xf4\xd2\x40\x7f\x15\xc6\x9d\x3e\xf8\x07\x7c\x4d\ +\x08\xf1\x3b\x29\x65\xed\x98\x0d\x9c\x22\x9c\xf1\x04\x00\xdc\x17\ +\x6b\xd6\x4e\x78\xfb\xca\x22\x5a\xd4\x70\xd8\xb7\xc3\xdf\x6b\xcd\ +\xd6\xb5\x8a\x58\xad\x2e\x24\x64\x26\x82\xe8\x83\x33\x87\x54\xcd\ +\xa4\xa0\xaa\x19\xd0\x29\x61\xdb\x41\xc9\x3e\xdc\xd4\x52\xa7\xe9\ +\x63\xa3\x62\x60\xb3\x26\x87\x1e\x91\x0a\xa4\x02\xb2\x9b\x82\x35\ +\x01\xf3\xa5\x1a\xf7\xe4\xe2\x6e\x22\x5b\xf8\x71\xc1\x65\xf5\x97\ +\xf7\xd9\x99\x60\x22\xc7\x0e\xd0\x6d\x31\x36\xed\xce\x48\x4d\x32\ +\x2a\x4a\xab\x26\x3c\x86\xff\x43\x41\x38\x23\xd5\xb8\xf5\x82\xb6\ +\xb5\xd1\xd5\xe9\x18\x74\xb5\xe4\xbe\x02\x78\xe0\x5a\xe0\x87\xc7\ +\xdb\xfe\x89\xe0\x8c\x26\x80\x21\xd7\xeb\x57\x2d\x9d\x64\x63\x9a\ +\xcd\xc4\x34\xd0\x00\x09\x34\x93\xc0\x90\x13\x77\x9f\x5e\x76\x0e\ +\x98\x65\x5b\x9f\x19\x97\x33\x4a\xe2\xd6\x4b\x4b\x77\x77\x8c\x23\ +\x75\xb0\xd7\xcc\x0e\x4d\x98\x0d\x9a\x04\x6a\x94\x7c\x42\x4c\x39\ +\xac\x01\x29\x06\x6a\xb4\xcb\xf6\x76\x30\x79\xa1\xaa\x28\x4d\xee\ +\xd4\xdc\xe3\x72\xf2\x1c\xd6\xaa\x3d\x5d\xc9\xee\xba\xa0\x56\x96\ +\x0d\x7f\x8d\x1c\x51\x86\xb6\xad\x99\xf6\x78\xc0\x98\x28\x75\xbd\ +\x70\xf8\x6e\x15\x1f\xc3\x3a\x5b\xf3\x9e\x6c\xad\xd7\x95\x59\x62\ +\x65\x70\x93\x16\x7c\x11\x1b\xe1\x31\xcf\x2e\xa7\x12\x67\x34\x01\ +\x00\xe7\x03\xc6\x0b\x4a\x6c\x47\xcc\x60\x0a\x88\x24\x93\x57\x7a\ +\x92\xfb\x06\x5d\x34\xf4\x6b\x69\xe8\x4f\xc3\x13\x9c\x0e\x47\x17\ +\xf9\xab\x68\x6b\xb6\x69\x6f\x36\x7a\x89\x9b\x0d\xd0\x1a\x9f\xd0\ +\x00\x63\x72\x1a\xc6\x84\xcb\x3a\x58\xde\x67\xf7\x4d\x80\x03\xae\ +\xda\xfb\xcc\xfa\xae\xcd\x59\x76\x2b\x43\x86\xa5\x79\xc2\x7c\x98\ +\x4b\xbb\x18\x0b\xab\x0c\x3a\x16\x65\x57\xfe\xfd\x5d\x4c\x5a\x1b\ +\xbe\x90\xcd\x9e\x80\xf4\xb7\x1f\x5b\x3b\xf9\x54\xe0\x4c\x27\x00\ +\x23\xc0\xbf\x77\xf6\x71\xf5\xec\x38\x8c\x3a\x05\x40\x25\xa4\xd6\ +\xe1\xf0\xb5\x53\xdf\x6f\xa0\x71\xa0\x00\x5f\x28\x97\x31\x7e\x69\ +\x47\x82\x5b\x24\xae\xdb\xa1\x5c\x3f\x5d\x45\x37\x62\xf2\x5d\x91\ +\x57\x98\xfe\x51\xca\x86\x35\xd2\xd1\x99\xec\xae\x0d\xe9\xd4\xb2\ +\x83\xcf\xa0\x03\x26\x9d\x63\x53\x76\x82\x11\xe4\x88\xa4\xaa\x50\ +\x35\x1f\x12\x8f\x26\xca\xc4\x86\x28\x73\x44\xa2\x99\xa1\xac\x2a\ +\xe2\xfa\xa9\x4a\xb0\xc6\xb1\x61\xe6\x07\xcd\x73\x35\x1d\x61\xc2\ +\x92\x34\x21\x84\xe9\x74\x5a\x10\x9f\xe9\x04\xb0\x1b\xd8\xff\xca\ +\xb6\xbe\xcc\x2f\xfc\xae\x92\x6f\x96\x18\xb9\xb8\xc7\xef\x51\x02\ +\x6a\x21\x70\xb8\xb3\xa1\x63\x23\xd0\x22\xe6\x6e\xa8\x57\x16\x1f\ +\xa2\xfb\x17\x56\x34\x75\x0e\xab\x35\xff\x58\x85\x5d\xb6\x40\x79\ +\x5f\xac\xf7\x90\x5f\x3d\x80\xd3\xa0\x1b\xd8\x90\x13\xaf\x91\xc8\ +\x43\x16\x3c\x17\xd3\x08\x51\x19\xf4\xec\x89\x89\xa6\x14\x10\xf1\ +\x3d\x9b\x6a\x45\x38\x54\x00\xa0\x2b\x8c\x9b\x6b\x49\xb7\xc9\x07\ +\x92\x3b\xc5\x4f\xdf\xed\x9a\xef\x0b\xaa\x3b\x85\x10\x37\x4b\x29\ +\xd7\x7d\x8c\xf1\x1d\x37\xce\x68\x56\xb0\x94\x72\x2b\x50\x00\xbc\ +\xb4\xba\xda\x83\x66\xb9\x83\xb7\x26\xf9\xea\x7b\x0d\x72\x2d\xc7\ +\xed\x90\x59\x69\xdb\xa5\x5c\x5b\x3b\x7a\xf1\x01\xf6\x27\x26\x1f\ +\x35\x2a\x44\x58\x23\x1d\x6d\xe9\xee\xf2\xbe\x58\x5f\x19\x88\x43\ +\x6e\x19\x6e\x83\xce\x5d\x9e\x17\x1f\x96\x1c\x1a\x2a\x46\x81\xde\ +\x18\xa9\xb5\x00\x68\x35\x34\xc5\xdb\x48\x65\x68\x47\x33\xb4\x3d\ +\xdd\x14\x24\x38\xa2\x63\x10\x63\xd6\x8a\xfb\x2f\x4e\x63\xf7\x0f\ +\x8a\x39\x7b\x42\x74\x01\xb0\x5a\x08\xf1\xeb\xa1\x18\x07\xa7\x14\ +\x67\x34\x01\x00\x48\x29\x03\xc0\x8b\x00\xe5\xdb\xa0\xec\xbf\xd4\ +\xfc\x0f\x16\xf8\xd3\x9e\x29\xf2\xbb\xeb\x6c\xa1\x0f\xa5\xa0\xe3\ +\x58\x75\x04\x89\xda\x56\xae\xb9\xdb\xd8\x2f\x32\x27\x8d\xf5\xbe\ +\x22\x3f\xef\x88\x91\x36\x5c\xb6\x40\x79\x6b\xfa\x00\x21\x6d\xb8\ +\x6c\xf4\x3b\x8f\x41\xe7\x5b\x9f\x1b\x3f\x28\x47\x11\x05\x40\x34\ +\xda\x66\x00\x45\xa1\x37\xc9\x0e\x08\xec\x00\x01\xd5\x2f\xaf\x70\ +\x3d\x1f\x57\x9d\x55\xbb\x71\x97\x76\x5b\xfd\x2e\x76\xed\xa8\x13\ +\xb5\xeb\xda\xd4\xce\x9d\x66\x4b\xc0\xff\xde\xb7\x0a\x78\xfc\xba\ +\x2c\xc5\x66\xd2\xdc\x03\x6c\x17\x42\x8c\xe9\xc1\xec\xe3\x42\x08\ +\x91\x29\x84\xf8\x82\x10\xe2\x2c\x18\x07\x04\x30\x84\x4d\x00\xe1\ +\x30\xc4\xc6\x12\x75\xd9\xdd\xc4\xd8\x0b\x65\xd7\xfb\x19\xc1\xc5\ +\x4f\x16\xfb\xe2\xcb\x93\x83\xe5\x01\x0d\xbb\xc7\x28\x27\x1d\xa2\ +\xf0\xc3\x0d\x9a\xbb\x4b\x83\x98\xc7\xe4\x0f\x04\x75\xda\xaa\x81\ +\xa8\xe8\xc3\x0e\x7f\xaa\xa2\xf6\x8e\xfc\xea\x85\x88\x1b\xfd\x7e\ +\x50\x75\x07\x36\x78\xfe\xd6\xac\x0e\x54\xec\x66\x0c\x8d\xa1\x34\ +\x69\x70\x0a\x81\x3f\x39\x8e\x56\xc4\x81\xc3\xe5\xbb\xcd\x4f\x56\ +\xd7\x39\xfb\xa6\xfd\xd4\x92\x69\xb5\x94\x98\x1b\x13\x52\x82\xee\ +\x01\xe9\x9c\xd7\xae\xb4\x4c\x6d\x33\xd5\x18\xb7\xc9\x6d\xc1\xd4\ +\x89\x2d\xa1\xdf\x7f\x4d\xcb\xb5\xe7\x88\x22\x45\x61\x9d\x10\xe2\ +\xe7\x43\x9e\xc5\x4f\x08\x42\x88\xf3\x81\x46\x43\x94\xf6\x2d\xe0\ +\x35\x21\x44\xdc\x78\xe1\x04\xde\x0b\xfc\xb2\x72\x2d\x4c\x3c\xf0\ +\xe5\xf7\xbe\xf7\x1c\xd5\xf5\x55\x4c\x1f\x4e\x48\xf1\x2a\x95\xf3\ +\xdb\xf5\xbd\xb1\x3e\x31\x1b\x29\x7c\x43\x57\xbc\xd9\x63\xd5\x39\ +\x8c\xba\x8c\x8c\x55\x1b\x8b\xa7\x1c\xe2\x2c\xc2\x6d\x1d\xdc\xd0\ +\x1b\xeb\x2b\x18\x6b\xe1\x01\x82\x72\x30\xb4\xda\xfb\x97\xb6\x20\ +\x83\x11\xf7\x32\x42\x13\x20\x79\xf1\x66\x52\x96\xd8\xd0\x18\x4b\ +\x00\x2e\x26\x7e\xdd\xcd\xb1\xf1\x1a\xad\x86\x43\xce\x0b\xb7\x7e\ +\x90\x5d\xdd\xe9\x6f\x2a\xb2\xe8\x0d\xed\x5f\x9e\x5c\x62\xfd\x86\ +\xac\xd9\x52\x10\x76\xd9\x6b\xeb\x30\x7a\x3d\x14\x8c\x6e\xeb\xed\ +\xcd\xf0\xc0\x53\xa0\xaa\x6c\x02\x2e\x93\x52\x1e\x3b\x88\xd1\x18\ +\x10\x42\x64\x03\xdb\xa7\x5e\x94\x15\x73\xfe\xf7\xa6\xf1\x8b\xb2\ +\x57\x01\xbe\x7f\xa6\x1f\x02\x87\x71\xc7\x9c\xe9\x91\xc5\x97\x21\ +\x3a\xc5\x00\xd5\x78\x49\x3a\x77\x3e\xd9\xce\x6e\xca\x7b\x7a\x28\ +\x03\x68\x37\xab\xc5\x2f\xe4\xf9\x31\x85\x44\x77\x4a\xd7\xbc\xed\ +\x3a\x67\xc9\x54\x71\x74\xfa\x96\x15\xb9\x05\x23\x93\xae\x2a\x6a\ +\x6f\x47\xaa\x67\x6f\x48\x7b\xe8\x09\xff\x60\x84\x65\x48\x5d\xeb\ +\x7f\x72\x5f\x90\xc1\x09\x07\x6a\x09\xeb\x69\x5f\x79\x16\xed\x2b\ +\x21\x3a\xbb\x9a\xcc\xcb\x7a\x66\x24\x26\x0d\xe8\x34\x72\x3a\x82\ +\x2e\x40\x91\x52\x6a\xba\xfc\x4d\xb2\xd3\xdf\x34\x01\xc0\x1d\x18\ +\x4c\xa9\xef\xeb\x5f\xf5\x68\x6c\xe1\xa2\xeb\x75\x4d\xab\xe6\x17\ +\xf5\xcc\x1b\x70\xb2\xaa\xbe\x9e\x99\x6a\xf8\x80\xee\xe2\x79\x33\ +\xa1\xdf\x17\xcf\xa3\xcf\x0f\xcc\x0e\x86\x82\x5b\x85\x10\xd7\x4a\ +\x29\x57\x7e\x8c\x39\x7c\xdc\x12\x6f\x8c\xb9\xe5\xef\x67\x63\x88\ +\xd6\x63\x8e\xd1\xe3\xed\x0f\x4c\x38\xe3\x3f\x01\x43\x86\x20\x19\ +\x67\x95\x1a\xa8\xdb\x80\x53\xb4\x91\x80\x87\x45\x48\x8a\x84\xc0\ +\xbe\x68\x31\x01\x7b\x02\x6b\x0f\x2e\xe3\xd3\xca\x84\xfa\xd4\x75\ +\xe7\xd5\x4c\xf8\x91\xbd\x2b\xe9\xdd\xf5\x61\x8d\x7f\x4c\x7f\xbb\ +\x01\x9d\x6e\x8f\xdb\x68\x4a\x05\x70\x47\x07\x37\xb4\x66\xba\xd5\ +\xc8\xe2\x8f\x0d\x89\x94\xeb\x7c\x4f\xef\x1e\x54\x3d\x13\x8e\x94\ +\x27\x59\xd6\x1a\x6e\xd0\x7d\x2f\xc6\x37\x78\x4b\x91\x3d\xa6\xdc\ +\x91\x14\xbb\x2d\x31\x29\x76\x5b\x7c\xb2\x7d\x7b\x6c\x49\x4a\x8f\ +\xfd\xb2\xfc\xcb\xaa\x88\x08\x85\xf8\xb0\xa9\x61\xa6\x94\x74\xfe\ +\x5d\x66\x2d\x7a\x85\xd4\x4d\x36\x2b\x67\x95\x96\xe2\x4c\x8c\x3f\ +\x60\x71\x2c\x04\x5c\xb9\xa0\x87\xeb\x2e\x2c\x26\xc6\x96\x96\x08\ +\xbc\x23\x84\xf8\x95\x10\xe2\x23\xb8\xbb\x1e\x99\xc3\x2f\x02\x4b\ +\x97\x3f\x38\x1d\xa3\x55\x8f\x10\x90\x37\x3b\x07\xc0\x7c\xc6\x13\ +\xc0\x90\xf8\xb4\x71\x5b\xfd\x15\xac\xda\x59\xa2\x01\x0e\x31\xd1\ +\x8a\xb3\x53\x96\x92\x42\x66\x42\x0a\xab\x0f\x2b\x2b\xc2\xba\x5e\ +\xfb\xda\x79\xb5\x85\x3f\x9b\xd4\x9c\xf9\x74\x45\xc0\xd0\xbd\x0e\ +\x18\x09\x13\x57\x9f\x9a\xd1\xa7\x2a\x6a\x6f\x5b\xaa\xbb\xbc\x37\ +\xce\x3b\x57\x22\xe3\x47\xd7\x71\x30\x36\xf8\x9f\xdb\xe8\x95\xfd\ +\x63\x06\x6b\x34\xeb\x55\xcf\x25\x13\xfb\xb6\x5e\x34\xa1\x3f\xdb\ +\xa8\x91\x25\x8e\xc1\xe6\xbc\xbf\xd6\x7e\x63\x42\xbd\x73\xcb\xbb\ +\xc3\x6d\x2a\x42\xe1\x7b\x73\xbe\x57\xfc\xbb\xc5\xbf\x73\x68\xd1\ +\xba\x82\x61\x35\x6a\x5b\x6b\x7b\x03\xc0\x5b\xa4\x9c\xf5\x94\xcc\ +\xd9\x21\xc0\x96\x91\xc5\xdc\x29\x93\xd9\x62\x34\xb2\x1f\x22\x44\ +\xf0\xb5\x65\x3b\xb9\x70\x41\x3a\x39\xd9\x0b\x15\x8d\x46\xff\xdf\ +\x40\xa5\x10\xe2\xc2\x63\xcd\x9f\x10\xc2\x00\x3c\x9c\x58\x60\x63\ +\xce\x35\x93\x68\xfc\x67\xc9\xae\x35\x37\x2f\x0b\x87\x5b\xa6\x00\ +\x24\x8d\x97\x33\xc0\x8a\x79\xf3\xae\x5e\x56\x52\x72\x36\xf9\xc9\ +\xb5\x55\xdf\xbc\xf0\x11\x93\x4e\x13\xce\x1e\x7e\xef\xf3\xb1\x61\ +\x4d\x39\x73\x03\x83\xac\x6d\x6b\x61\x38\xc8\xc3\x98\xd0\x05\xa3\ +\xbb\x12\xbb\x96\x55\x5a\x5c\x13\x27\x3e\xb7\x7c\x61\x45\x6b\x0a\ +\x93\x15\x30\xeb\xc0\xad\x93\x8a\x57\x8f\xf0\xe9\x55\x06\x0d\x42\ +\x09\x68\x55\x42\x26\x94\xb0\x5e\x6a\xd4\xd7\xc2\xaf\x85\x76\xab\ +\xdb\x0f\x8b\x11\xac\x55\x24\x33\xa3\xfb\xdb\x4a\xb2\x82\x89\x8a\ +\x4e\x68\xa5\x2a\x19\x68\x1e\xc4\x9a\x6a\x40\xd1\x45\x3e\x23\x16\ +\x4f\x5a\xeb\x05\xd9\x5f\x53\x62\xe3\xe2\x46\xbc\x59\x74\xb8\x3a\ +\xb8\xfc\x8f\x97\x33\x98\xee\xe7\x9a\x9c\x29\x3d\x16\xbb\x2e\x1e\ +\xa0\x48\x3a\x2b\xbe\x25\x6a\x53\x88\x78\x52\xf5\xf7\xf5\xb1\xa1\ +\xb1\x91\x32\x55\x8d\xa8\xb0\x3d\xf6\xe6\xac\xe0\x96\x7d\x85\xba\ +\xfd\xfb\xd7\x30\x30\xb0\x1f\x22\x37\xa4\xff\x92\x52\xb6\x8f\x31\ +\x6f\xa5\xc0\xdf\xf4\x5a\x31\xe5\x6b\x5f\xca\xa3\xdf\xf6\x70\x97\ +\x4e\x29\x8c\x07\x94\x95\x2b\xef\x63\xdf\xbe\x77\x06\xc6\x0b\x01\ +\xfc\x6a\xf6\xec\x4b\xfe\xbb\xb4\xf4\x02\x00\xf4\xda\x80\xf7\xdb\ +\x17\xfe\x66\x6b\x4e\x52\xe3\x88\x36\xcf\xf6\x9d\xec\xec\x76\x30\ +\x35\x14\x62\x63\x6b\x13\xd3\xa4\x1c\xdb\x27\x80\x22\xe8\x88\xb7\ +\xb2\x77\x62\xaa\x2e\x66\x95\x79\xad\x1d\x34\x69\x1c\xe3\x36\x54\ +\xd7\xbb\xb7\xfb\xd7\xf6\x67\x13\x24\x87\xce\x55\xbe\xc6\x29\x67\ +\xa6\x78\x84\xaf\xd3\x4f\xfb\x36\x37\xed\xdb\x5c\x74\xec\xf0\x10\ +\xf4\x86\xd1\x1a\x15\x92\xa6\x58\x48\x9d\x69\x21\x75\x66\x34\xf6\ +\x1c\x0b\xa9\xad\x0b\xb9\x68\xf9\xd5\x00\x3c\x70\xdb\x03\xbc\xf3\ +\xf2\x3b\x70\x31\xa4\xcd\x8d\xe6\xc2\x73\x0f\x9c\x6e\xd3\xf1\xd5\ +\xdf\x47\x95\x51\x83\x4c\x05\x90\x2a\x8d\x0d\x8d\xf4\xf6\xf5\x45\ +\x0e\xbc\x7f\xfd\x60\xa6\xa3\xbc\x76\x7a\x9c\xc3\x51\x43\x73\x73\ +\x39\xe1\xf0\xa0\x13\xf8\x1e\xf0\x67\x29\xa5\x2a\x84\x98\x4e\x24\ +\x3e\xc2\xc5\xcb\x4a\x6c\xe2\x96\x04\x41\x49\xc0\x14\x5e\x75\xc3\ +\xeb\x1a\x97\x8b\xb6\x9d\x3b\xdb\x93\x3f\xf8\xe0\x27\x4a\x4b\xcb\ +\x86\x9a\xf1\x42\x00\x97\x4e\x9b\x76\xc1\x4b\xb3\x66\x5d\x72\x48\ +\xfa\xfc\xa2\xb5\x1b\xaf\x5b\xf0\x6c\x81\x10\xd2\x1e\x0a\x51\xfd\ +\xfe\x6a\x26\x00\x22\xac\xb2\xa3\xa5\x91\x7c\xa9\x62\x01\x10\xd0\ +\x15\x17\x4d\x75\x71\x3a\x31\xc9\x36\x26\x0f\xbb\x19\x7f\xa1\xed\ +\xfe\x56\x19\x7f\xc9\x98\x6a\x63\x0e\x47\x1b\xfb\xf7\x57\xd3\xe4\ +\x69\x62\xd5\xdc\x66\xbc\xf8\x41\x4a\x50\x41\x09\x86\x49\x6b\x6c\ +\x21\xbc\xa9\x83\xae\x9d\x6e\x42\xfe\x11\x96\x7f\x0b\x91\x00\x15\ +\x1d\x44\xa4\x8a\x33\x20\xe2\x60\x4a\x6b\x54\x48\x9e\x6a\x21\x3b\ +\xbb\x00\x06\x4c\xac\x7d\x6b\x6d\x08\xf8\x1b\x70\x39\x93\xb1\x9d\ +\x7f\x77\x3e\x99\x33\x0e\xc8\x3c\x62\xd4\x40\xf7\x8f\x94\x8a\x01\ +\x3d\xea\x08\x87\xd2\xef\xa7\xbc\xa6\x86\xdc\x60\x90\xa4\x97\x36\ +\x4c\xdf\xf7\xce\xae\x99\x79\xc1\xa0\x97\x8e\x8e\x1d\x38\x1c\xb5\ +\x84\xc3\x83\x0d\x44\x42\xe3\x4c\x2d\x4c\x32\xf2\xeb\xcb\xd3\xb1\ +\xed\x73\x1e\x67\x2a\x29\x00\x00\x0d\xf6\x49\x44\x41\x54\xb0\xa0\ +\xbb\x8f\xb6\x84\xc9\x75\x2b\xbe\xf4\x74\x3e\x40\x47\x47\x23\xbf\ +\xfd\xed\xdd\x0c\x0c\x74\x3c\x30\x5e\x08\x60\x49\x49\xc9\x39\xef\ +\xce\x9b\x77\xd5\x61\xef\x6c\xe6\x81\xae\xef\x5f\xfa\x93\x66\x9b\ +\x79\x60\xc6\xbe\x06\xd6\xee\x6b\x60\x3e\x80\x54\xa9\x34\x21\x7a\ +\x52\xa2\xb1\xc7\x47\xc9\x62\x31\xc6\x67\xa1\xd9\x9d\xce\x3a\xed\ +\xcb\xf4\xf4\xb4\xb0\x7f\x7f\x75\x64\xc1\x9b\xaa\x69\x6e\xae\xc6\ +\xe3\x39\x2c\x1a\x0d\x44\x3c\x82\x1c\xac\xd1\xd3\x4b\x84\x5d\xbd\ +\x11\x58\x39\x3a\x66\xdf\x90\x67\xd3\x62\xc0\x46\x44\xa3\xa8\x18\ +\xb8\x85\x08\x0b\xfe\x3e\x29\xe5\x4b\x42\x08\x2b\xf0\x55\x73\xa2\ +\xf6\x9e\xa2\x8b\x12\xe2\x0b\x2f\x8c\xc3\x9a\x11\x51\x58\x36\x0c\ +\x06\x03\x3f\x52\x2a\xf7\xd9\x74\xa1\x83\xa5\x84\xee\xee\x6e\xb6\ +\x36\xef\x67\xfe\x9b\x3b\x4a\xab\x5e\xd9\x34\x6b\x12\x08\xa1\xaa\ +\x61\xfa\xfa\xf6\x71\x7b\x69\x2d\x17\x4d\xb5\x51\x92\x6a\x62\xf3\ +\x4b\xb5\xce\x59\x0e\xa7\x15\xa0\x2e\xf7\x9c\x9d\xef\x2f\xf9\xe5\ +\xd4\x8e\x8e\x46\x7e\xf3\x9b\xaf\xe0\x74\x3a\xb6\x02\xe7\x8c\x17\ +\x02\xf8\xfa\x84\x09\x67\xfd\x7e\xd1\xa2\x1b\x8e\xf0\x5e\xca\xab\ +\xcf\x7a\x7e\xf5\xa2\x89\x1f\xe6\xbc\xbf\x5a\xc6\x85\xc3\x98\xcb\ +\xa6\xe9\xd6\x45\x5b\xc4\xb0\x0d\x59\x48\x86\xa4\x5f\xaa\xd2\x19\ +\x0e\x8a\x8e\x97\x56\x93\xfe\xda\x66\x53\xe2\xce\x5d\x4e\x5a\xfa\ +\xcd\xdd\x2e\x8f\xf7\x7d\x22\x26\xe0\x7d\x63\x3c\xbd\xc3\x7f\x9f\ +\x4a\x25\x51\x21\x84\x89\x08\x71\xdc\x9b\x34\xc5\x92\x55\xb8\x3c\ +\x8e\xbc\xf3\x62\x31\xea\x24\xb7\x75\x57\x76\xcd\x48\x0f\x26\x1e\ +\x9c\x5f\x55\xd9\xbb\x6f\x1f\xc1\x57\xd7\x97\xa8\xcf\xaf\x9f\x57\ +\x0c\x68\x2d\x3a\x5f\xf8\xc3\x9b\x2b\x35\x00\xdd\xaf\xee\xdd\x95\ +\xd0\xe5\x1e\x11\x81\xbf\x9c\x7e\xbe\x7b\xcf\xf4\xdb\x2c\x43\x8b\ +\x5f\x05\x2c\x94\x52\xf6\x7c\x2c\x02\x18\xe2\x51\x5f\xc9\x01\x57\ +\xef\xdb\xa4\x94\x6f\x7f\xec\xd1\x1f\xbb\xbd\x27\xf3\xf2\x66\xde\ +\x72\xee\xb9\xb7\x1f\x35\x5f\x56\x7c\x53\xdd\x79\x13\x9f\xae\xcb\ +\x4f\xe8\x36\x5a\x63\x74\x39\x1e\xad\xb5\xb3\x47\x9f\xea\x6a\x33\ +\xe7\x07\x9b\x2c\x45\xfa\x46\xcb\x24\x6b\x53\x28\x39\x4f\x5a\x6c\ +\xb1\x2b\x96\xde\x41\xdb\x07\x9b\xdf\x05\x96\x7d\xd2\x1e\xbb\x0f\ +\x86\x10\x42\x0b\xdc\x00\xdc\xaf\xd1\x2b\x39\xd9\x8b\x6d\x14\x9e\ +\x1f\xc7\xad\xd6\x66\x79\x59\xe9\xe0\x68\xe6\x84\xf4\x78\x58\xfb\ +\xd4\x6b\x13\x75\x7f\x5b\xb5\x60\xfa\xcc\x54\x57\xed\x9f\x2f\xac\ +\x99\xc8\xab\xd5\xeb\xe8\xf6\x1c\xa2\xed\xfc\x78\xd4\x2c\xee\xdd\ +\x51\x8f\xd3\xe9\xd8\x05\x2c\x1d\xf6\xd8\x7e\x5c\x04\x20\x84\xc8\ +\x07\xee\x04\x6e\x52\x14\x6d\x5c\x5a\xda\x2c\x1c\x8e\x3a\xbc\xde\ +\x6e\x09\xdc\x22\xa5\xfc\xeb\x09\x8c\xfd\x68\xed\x96\x27\x26\x16\ +\xcd\x3d\xf7\xdc\x6f\x06\x16\x2f\xbe\x63\xe0\x48\xf9\x54\x55\x06\ +\xde\xde\xb7\x69\x5b\xf2\x2f\x0d\xe7\xab\x42\xa3\x3b\x52\xbe\xda\ +\x67\x57\xf0\xc1\xf5\xf7\x39\x81\x49\x1f\x97\xb3\x76\xaa\x31\x44\ +\x08\xd7\x03\xdf\x07\xf2\xcc\x71\x3a\xae\x5c\x60\x27\x31\xf9\x3c\ +\x99\x9b\xea\xeb\x9c\x92\xdf\xdb\x93\x6e\x6f\x35\xd9\xa3\xdb\xe3\ +\xcd\x7a\x67\xf8\xb9\xb7\x27\x6e\xec\xaf\xcf\x33\xdf\xdd\xf8\x9e\ +\x8e\x01\xff\x61\xa6\xf0\xdf\x6f\x85\x9f\x76\xd0\x04\x4c\x3f\x38\ +\x30\xe7\x61\x04\x30\xb4\x15\x0d\x6b\xc1\xc6\x12\xb9\x8e\x4c\x05\ +\xae\x02\x66\x68\xb5\x26\x12\x13\x27\x91\x90\x50\x4c\x7a\x7a\x19\ +\x69\x69\xb3\x58\xbb\xf6\x67\x54\x56\xbe\x34\x00\x24\x9f\x0a\xa7\ +\x49\x42\x88\xf6\xb4\xb4\xd9\xc9\xc9\xc9\x53\x29\x28\x28\xfb\x70\ +\xf1\xe2\x9b\x16\x0d\x31\x88\x46\xa0\x4a\xd9\xf9\xbb\x15\x2b\x7a\ +\x1a\x3c\x5d\xa1\x1b\x57\xce\x1c\xf3\xae\x0e\x10\xe8\x77\xf1\x7c\ +\xd1\xc5\xf8\x3a\x1d\x77\x4b\x29\x3f\x9e\x02\xe0\x69\xc4\xd0\x39\ +\xe2\xcb\xc0\xfd\x40\x41\x5c\x5c\x21\x67\x9f\xfd\x30\x76\xfb\x21\ +\xd2\x6b\xaf\x56\x13\xec\xce\xab\x7a\x74\xc7\xf2\xc6\xe7\x63\x53\ +\xa3\x89\x8a\xd2\x93\x2f\x22\x67\x0f\x00\x26\x56\x40\xb5\x9f\xef\ +\x4a\x29\x7f\x71\x70\x41\xed\x50\x23\x93\x80\xff\x22\x12\xec\xea\ +\x50\x39\xbb\xd1\x06\x51\x49\x10\x95\x48\xba\x46\x4b\x62\x6c\x1e\ +\x62\x28\xaa\xaa\xd9\x6c\xaf\x04\x8a\x4b\x4b\x6f\xa6\xa6\xfa\x15\ +\x5b\x48\x55\xbf\x08\xbc\x70\x92\x27\x40\x0f\x24\x19\x8d\x11\x4d\ +\xeb\xda\xda\x0d\x8b\x3b\x3a\x6a\x37\x5e\x7e\xf9\x83\xc5\x3a\x9d\ +\x31\x1a\xc0\x17\x08\x54\x3c\xf4\xc2\x0b\xf1\xfd\x1e\xcf\xa4\xe8\ +\x54\xe3\xc6\xa3\xd5\xb7\xe9\xbe\xdf\xe3\xeb\x74\x54\x00\x7f\x3e\ +\x99\xfd\x3c\x55\x18\xf2\x4d\xf4\x8c\x10\xe2\x39\xe0\x1a\x87\xa3\ +\xe6\xfe\x97\x5f\xbe\x7e\xc2\x8c\x19\xb7\x53\x5a\x7a\x13\x22\xe2\ +\x17\xcb\x1c\x0a\xeb\xb2\xfe\x23\x8b\x1a\x4c\x4d\x4a\xbe\x46\xaa\ +\xa9\x00\x71\x26\x5a\x14\x0b\x51\x2e\x3d\xb1\xd5\x7e\x24\x70\x58\ +\x20\x2a\xed\x50\xa4\xea\x15\xe4\x96\xc1\x45\x0f\xc3\x40\x27\x78\ +\x3c\xa0\x4d\x04\x63\x22\x68\x0f\x98\xd0\xb5\x7a\x7b\x3c\xa6\xc6\ +\x0f\xfa\xac\xbe\xbe\xf4\x48\xe7\x84\x0b\xc0\xe6\xdd\xc6\x65\x13\ +\x92\x79\xbe\xaa\xed\x3a\x4e\x32\x01\x10\x51\xd3\x12\x11\x02\x88\ +\xc0\xe5\x72\xcc\x79\xe6\x99\x7b\x1a\x2e\xbe\xf8\xbb\x5d\x3e\x4c\ +\xed\x3f\x7f\xed\xb5\x59\xa1\x70\xd8\x00\x60\x88\x56\x8e\xf8\x3d\ +\xef\xde\x5c\x41\xd5\x5f\x5e\x04\xf8\xef\xf1\xe6\xf4\x69\xd8\x55\ +\xbe\x10\xe2\x1f\xaa\x1a\xbc\x6a\xf3\xe6\x3f\x3e\xd0\xd4\xb4\x6a\ +\xe2\xe2\xc5\x3f\x24\x26\x26\x1b\x80\x60\x78\x50\x57\x91\x58\xd2\ +\x32\xa5\x73\x57\x2a\x80\xc3\x47\x3a\xbe\x91\x55\xdf\x2e\xa5\xdc\ +\x3f\xba\x5e\x2d\xf0\x17\x4a\x2e\x80\xbb\x57\x44\x78\x8e\x07\xc3\ +\x3b\x30\xc0\x40\xaf\x03\xb7\x57\xc5\xaf\xc6\x49\xe2\x63\x6b\x8b\ +\xaf\xd0\x25\x74\x55\xac\xca\x6c\x5e\xb7\x28\x14\xf2\xab\x7a\x11\ +\xaa\xf9\xe5\x84\x2d\xf9\x1b\xa3\x32\x95\xe7\xab\xda\x2e\x10\x42\ +\x94\x9c\xe4\x20\x87\x4b\x0d\xd8\x98\x53\xf9\x1d\x4c\xd8\x1d\x5a\ +\x8c\x1e\x19\x09\x11\x24\xff\xdd\xb5\xa3\x7e\x63\x5c\xf3\xd2\x83\ +\x33\xeb\xad\x9a\x31\x17\x56\xaa\x2a\x6b\xee\xfc\x11\x52\x55\xdf\ +\x92\x52\xbe\x75\x12\xfb\x77\x5a\x31\x14\x3a\xf7\xff\x84\x10\x2f\ +\x76\x75\x55\xfc\xe2\xa5\x97\xae\xfd\xd6\xac\x59\x5f\x63\xf2\xe4\ +\x6b\x70\xb9\xda\xa6\x6e\x4d\x9d\x61\x9a\xdc\xb9\xbb\x5f\x20\x63\ +\x20\x62\x92\xb4\x26\x52\xf4\xc5\xb1\xea\xd3\x02\x41\xc2\x41\x08\ +\x78\xc1\x30\xca\x81\x86\xd9\x66\xc3\x6c\x3b\xc0\x9d\x90\xb2\x15\ +\xaf\xbf\xa9\x3b\x35\x59\xef\x4a\xce\x7f\x3f\xad\xa3\x92\xff\x67\ +\x7d\x52\x63\x52\xc2\xca\xe2\xcc\x38\xa6\x25\x59\x75\xdb\x3b\x9d\ +\x3b\x84\x10\x4f\x03\x0f\x49\x29\x5b\x4f\xc2\x98\x2f\xcb\x62\x31\ +\x56\x99\x01\x91\x70\xac\x71\x40\xdf\x9b\xb9\x9b\xea\x6b\x62\x5b\ +\x97\x8e\xce\x6c\xb6\x1b\xc6\x3c\xd5\x56\x3c\xf6\x3c\x3d\xdb\xaa\ +\xc2\xc0\x3d\x27\xa1\x4f\x9f\x38\xa4\x94\x41\xe0\xdb\x42\x88\x0f\ +\x36\x6c\x78\xe4\xaf\x8d\x8d\x1f\xda\x6d\xb6\x4c\x8b\xd6\x92\xc4\ +\x8e\xd8\x6c\xef\xb4\xbe\x86\x18\x80\xd7\x18\x11\x9e\x1c\x91\x00\ +\xae\xa1\xea\xbd\x35\xfc\x64\xba\x0e\x7b\x16\x89\xe9\x79\x14\xcd\ +\x5c\x4c\xfe\x94\x32\xf4\x06\x63\x48\x6b\x34\x79\x8c\x46\x8b\x47\ +\x28\x22\xac\x84\x85\x10\x3e\x53\xba\xf4\xa6\x75\x89\xac\x34\x61\ +\xaa\x75\x74\xcb\xfa\x8d\x15\x6a\x4a\xa2\x4b\xea\xf5\x53\x37\xdd\ +\x34\x5f\xbc\x5a\xd3\xa1\xf9\xc3\xd6\xa6\xdb\x56\xed\x77\x7c\x59\ +\x08\xf1\x14\xf0\xb8\x94\x72\xd7\xc7\x19\xa4\x88\xc8\xe3\x17\xe5\ +\x70\x60\x9d\x43\x42\xad\x7b\xae\x78\xa5\xbe\xcf\xe8\x9e\x31\x56\ +\x19\x53\x9c\xee\x30\x39\xae\xb7\xbd\x87\x2d\xf7\xff\x0f\xc0\x53\ +\x52\xca\x4f\x95\xc7\x2f\x29\xe5\xeb\x42\x88\xd2\x8e\x8e\xed\x4f\ +\x76\x74\x6c\x5f\x1a\x1d\x9d\x8a\x3f\x36\x2f\xb5\xa4\xaf\x81\x56\ +\x60\x6b\x24\xdb\xb7\x8e\x64\x70\x22\xa4\x94\xc3\x21\xdd\x7f\x4b\ +\xc4\xf6\x2d\x0d\xb0\x5b\x62\xe2\x29\x5b\x7e\x23\x8b\x97\xdd\x85\ +\xd1\x65\x53\x0d\xaa\x25\xa0\x53\x8c\x23\x3a\x6a\x52\x4a\xb6\x6f\ +\x7f\xb2\x73\x62\xe5\x0f\xf7\xda\x1b\x5b\xbc\x73\x4d\x86\x28\x6b\ +\x61\x4a\x42\xbb\x2d\x21\xab\x30\x25\xce\xb4\xaf\xdf\xcb\x1f\xb6\ +\x34\xf2\x6c\x45\x2b\xde\x60\x78\x33\xf0\x04\xf0\x4f\x29\xa5\x6b\ +\xac\x8e\x8c\xd9\x39\x21\x6e\x01\xf1\xe4\x1d\x54\x60\x26\x1e\xa7\ +\xde\xbf\xf1\xd9\xe2\x77\x27\x05\x34\x21\xcb\x11\x8a\xf4\xcd\xbb\ +\x2f\xb3\xa2\xe4\xd2\xd4\xf9\x07\x27\xae\xbc\xe6\x3b\xec\xfb\xe7\ +\x5b\x3e\xa0\xe0\x24\xed\x4a\x67\x24\x84\x10\x77\x02\xbf\x02\x2c\ +\xb1\x44\x58\x95\x12\x7e\x2a\xa5\xfc\xfe\x11\xcb\x8c\xc5\x07\x18\ +\x22\x88\x5b\x89\x44\xbe\xb6\x58\xad\xc9\xcc\x9d\x7b\x3d\xf3\xe7\ +\xdf\x2a\x93\x92\x26\x08\x80\xde\xde\xda\xba\x86\x86\x0f\xf2\x4d\ +\xfe\xb6\xc6\xe2\xaa\x1f\xa6\xd3\xcf\x3a\x63\x03\x96\xf3\xf2\x30\ +\x27\x44\x6b\xf2\x76\x45\x25\x48\x4f\x5c\x92\x21\x2d\x29\x96\x57\ +\x6a\x3a\x79\x6c\x6b\x13\xf5\xfd\x5e\x0f\xf0\x12\xf0\x14\xb0\xfa\ +\x58\x96\xb2\x42\x88\x7f\x27\x33\xfd\x8b\x57\xf3\x26\xb5\x31\xfb\ +\xd7\xae\xc8\xdb\x76\x16\x07\x34\x35\x82\xe8\xf4\x0d\xc4\x98\xfb\ +\x49\xb2\x6b\x49\x8e\x4d\x21\xfd\x9e\xbd\x67\x9f\x77\x9e\xa5\xa0\ +\x70\xf9\x48\x1c\x80\xd6\xf7\x36\xb0\x62\xe9\x57\x00\x7e\x29\xa5\ +\xfc\xce\x47\x9e\xcd\x71\x0a\x21\x44\x26\x30\x8f\x88\xc1\x6a\x40\ +\x4a\xf9\xfc\x51\xf3\x1f\x6d\x0d\x86\x2c\x73\xee\x06\xbe\x41\xc4\ +\xf2\x95\xdc\xdc\x32\xce\x3a\xeb\x16\xa4\x1c\x08\x99\xcd\xd1\x5a\ +\x80\xe2\xbd\x3f\x5e\x6b\xf2\x36\xcf\x17\x41\x36\xcb\xdd\x14\xc6\ +\xea\xa8\x38\x2f\x8f\x0c\x9b\x91\x0c\x55\x0a\xb9\xdb\x68\xa7\x2f\ +\x2e\x49\xec\x0f\x29\xfc\x63\x6f\x27\xef\xd4\x77\x23\x23\x42\x93\ +\x3f\x13\x89\xa4\xe5\x18\xba\xee\x15\x03\xa5\x43\x4f\x09\x30\x7f\ +\x86\xf8\x86\xc1\x9d\x36\xd5\xb3\x37\x4d\xba\xb0\x9a\x1d\xc4\xc7\ +\x2a\x24\xc7\x24\x63\x8b\x8e\x45\x33\xc4\xde\x17\x7e\x1f\xf6\x8b\ +\x76\xa2\x69\x9e\x7b\xe1\xb2\x3f\x57\xa4\xa5\x97\x4d\x02\x08\x0f\ +\x06\x78\x71\xca\xe5\x0c\xd4\x34\xb5\x02\x13\x8f\x67\xf7\xf9\xac\ +\xe0\x23\x71\x02\x87\xb4\x4f\x6e\x24\xc2\xab\x9e\x29\x84\xc2\x92\ +\x25\x77\x90\x93\x33\x14\x90\xc3\xb5\xdf\x3d\xa3\xee\x27\x06\x40\ +\x27\x24\x0d\xb2\x02\xf0\x91\x96\x1a\x4d\xf9\xd2\x5c\xa6\x18\xb4\ +\x8c\x04\x68\xae\xd6\xdb\xa8\xd0\x58\x78\xa3\xcb\xcf\x8b\x4d\x03\ +\x38\x07\x43\x3e\x22\xec\xe4\x74\x84\x10\xc4\x65\x43\x6a\x31\xa4\ +\x14\x43\x6a\x31\x7a\xf3\x04\x02\x76\x0b\xe8\x8f\xc0\xd8\xd3\x74\ +\x77\x61\xbf\xa0\x07\xe1\x29\x06\xb8\xfc\xf2\x17\x1a\xec\xf6\xc2\ +\x1c\x80\x6d\x0f\xff\x85\x2d\x3f\xf8\x23\x44\x74\xe9\x5e\xfe\x98\ +\x73\xf4\xa9\xc6\x71\xcb\x02\x84\x10\x25\x44\x58\x94\x97\xc7\xc7\ +\x67\xe6\xce\x9c\x79\x31\x99\x99\x25\x98\xb6\x3c\xec\x2d\xd6\xb4\ +\x0e\xab\x29\x39\xa9\x67\x2f\x0e\x66\x01\x03\x45\xf1\xec\x38\x2b\ +\x83\x39\x1a\xe5\x50\x7b\xad\x46\x0c\x3c\xd7\x0b\x0f\x4c\xfc\x16\ +\x32\xb5\x04\x92\x8b\x0e\xbf\x89\x1c\x0d\xba\x3d\xb5\xc4\x5e\x1d\ +\x05\xe1\x11\xbf\xf1\xd7\x5d\xb7\xb2\xc7\x6c\x8e\x8f\x77\xee\x6b\ +\xe6\x85\x92\xcb\x08\xfb\x07\xff\x23\xa5\xfc\xe2\x71\x0d\xf2\x33\ +\x84\x13\x92\x06\x0a\x21\xbe\x00\x3c\x9d\x94\x94\x9b\x32\x63\xe2\ +\x2c\x96\xb9\x9e\x47\x77\x40\xe8\xaa\xd2\xc3\x1a\x1a\x58\x04\xa0\ +\x08\xda\x66\xa4\xd0\x50\x9a\x42\x99\x38\x48\x01\x43\x05\x8a\xe6\ +\xfc\xc4\x55\x9b\x72\x69\xf4\x61\x0d\x1c\x0d\x86\xff\x6c\xc1\xfa\ +\xff\x8a\x10\xf2\x90\x03\xe1\x6d\xb7\x6e\x0e\x2a\x1a\xbd\xee\x8d\ +\xf3\xef\xa2\xe5\xed\xf5\x1e\x22\xfc\xfe\xa6\x8f\x3d\xc8\x4f\x39\ +\x4e\x48\x27\x70\x48\x02\x38\xa5\xb3\xb3\xfe\x5f\x6f\x7c\xf8\x3c\ +\x4f\xd4\x0b\x5c\x07\xf8\x70\x0a\xf1\x2c\x62\x32\xeb\x51\xf0\xa9\ +\x92\xd4\xcd\x6d\x9c\xf5\xf4\x76\xea\xf6\xf5\x0d\xdf\x4e\x22\x1d\ +\x78\x23\xed\xb7\x1f\x28\x42\xfd\xe8\xde\x38\x2d\xbf\x58\x8d\xed\ +\xde\x69\xa3\x17\x1f\xf0\x2a\x1a\xbd\xae\xfe\x85\x77\x68\x79\x7b\ +\x3d\x44\xc2\xb9\x7d\xbe\xf8\x47\xc1\x09\x2b\x85\x4a\x29\x7b\xa4\ +\x94\x57\x01\x97\xef\xec\x93\x5d\x0f\xef\x81\x8a\x83\xe5\x75\x46\ +\xe6\x31\x8d\x26\x61\xa0\x0d\x20\xa4\x52\xb8\xb2\x9e\x19\x7f\xdf\ +\xc5\xb6\x0e\x37\xd5\x00\xb1\x55\x0e\xcb\xf7\x73\xfe\xbc\xf3\xd8\ +\xad\xa9\x2a\x31\x37\xae\xc2\xfc\xd7\x85\xc0\x61\xfe\x00\x45\x50\ +\x09\x05\x5d\x1e\xca\xbf\xf5\x2b\x88\x28\x6a\x3c\x72\xa2\xe3\xfb\ +\xb4\xe3\xa4\x2a\x84\x0c\x31\x6e\x7e\x2f\xe0\xda\x73\x92\xe0\xd2\ +\x0c\xd0\x0e\x5f\xda\x24\x3d\xd4\xd2\xc6\xc0\x21\x76\xfa\x32\xd6\ +\xc8\xfa\xb3\x66\xe3\x35\x3f\x88\x2e\x6f\xeb\x06\x7b\x6f\xc0\x76\ +\xb8\x1d\x3f\x80\xf0\x78\xb0\x2f\xaf\x44\xd3\x31\x6b\xcc\xf7\x80\ +\xa6\x4f\x2f\x8b\x3b\xae\x14\xbb\x1f\x7d\xd6\x0f\xcc\x3a\xc9\x2c\ +\xe9\x4f\x25\x4e\x89\x46\x50\x84\x81\xc3\x9f\xd2\x4c\xe8\x6f\xcb\ +\x83\xd4\x11\x23\x6c\x82\xb2\x9d\x0d\xa2\xe5\x50\xd7\x6c\x42\xa1\ +\x65\xee\x53\x28\xfe\xf4\x49\xee\xd9\xe5\x2f\xe6\x32\xda\x6a\x59\ +\xd3\xda\x8e\x7d\xb9\x0b\xe1\x3f\xb2\x45\xb0\x9f\xb0\x76\xb7\x41\ +\x13\x7e\x3a\x88\x0c\xab\x5f\x95\x52\xfe\xe9\xa4\x0e\xea\x53\x8a\ +\x53\x62\x17\x20\xa5\x7c\x0a\x58\xd8\xea\xa3\xf5\xa7\x95\xf0\x7e\ +\x27\xc3\xfa\xb4\x3a\x91\xc2\x02\x26\xb2\x1a\x71\x20\xa8\x83\x54\ +\x49\xf7\x74\xa3\x9d\x6a\xae\xe8\x5b\x1a\xbf\xfe\x50\xb3\x68\xfd\ +\xd6\x2a\xec\xe7\x29\x63\x2c\x7e\x37\x83\x94\xd3\xc1\x6a\xb6\xd2\ +\xcf\x76\x34\xa1\x7f\x0d\x22\xc3\xea\x2b\x9f\x2f\xfe\x47\xc7\x29\ +\xd5\x09\x14\x42\x24\x11\x91\x46\x9e\x3d\xc9\x06\x37\xe5\x80\x75\ +\xf8\x3a\x1f\x62\xa7\xdc\x4d\xba\x08\x11\x07\x90\x75\x23\x9b\x0a\ +\x96\xa1\x4d\xc8\x32\x4f\x88\x7d\x77\x93\x2b\x8c\x26\x19\xd3\x0b\ +\x1b\x89\x7e\x70\x32\x60\x46\xd2\xce\x20\x0d\xf4\x13\xc6\x41\x3a\ +\x7e\x0e\xb5\xe8\x7d\x0b\xd8\x4f\x33\x30\x75\xb4\x72\xe6\xe7\x38\ +\x32\x4e\xb9\x52\xa8\x88\x68\x2c\x3c\x08\x3c\x10\xad\x45\xb9\x21\ +\x07\xa6\x0c\x89\xf6\xa5\xa4\x5d\x54\x33\x80\x9b\xa2\x98\x52\x56\ +\xe5\xdd\xc5\xc2\x69\xd3\x68\xfa\x63\xf3\x75\x9d\x5f\xaf\x8f\x75\ +\xa1\xfd\x83\x81\x7e\x29\x70\x90\x45\x80\x8c\x23\x36\xb2\x15\xd8\ +\x4a\x18\x38\x5b\x4a\xb9\xe6\x94\x0e\xe8\x53\x86\xd3\xa6\x15\x2c\ +\x84\x38\x87\x48\xe4\x8d\xe4\x45\x89\x70\x45\x06\xe8\x14\x90\xe0\ +\x13\xfb\xd9\xa1\xf3\x61\x98\xf2\x2b\xa6\x67\x66\xb2\x2a\x21\x81\ +\x45\x5f\x7e\x9f\x6d\xff\xa8\x3b\x60\xf9\x7b\x44\xec\x27\xf2\xeb\ +\x8f\x88\x9f\x4f\xab\x87\xad\x4f\x03\x4e\x9b\x6d\xe0\x90\x27\xcc\ +\x52\xe0\x9f\xab\xba\x90\x3f\xae\x80\x66\x2f\x08\x30\x91\x49\x59\ +\x28\x9d\x7e\x24\xb2\xb5\x95\x52\xc0\xfb\xf4\x22\xe2\x0c\x9a\x63\ +\x78\x01\x71\x02\x1f\x00\xb0\x02\xf8\xd1\x29\x1e\xc2\xa7\x12\xa7\ +\xd5\x38\x54\x4a\xd9\x29\xa5\xbc\x06\xc8\xe9\xf0\xf3\xc1\xcf\x2b\ +\xe1\x9d\x8e\xc8\x01\x51\x46\x71\x8e\xd7\x41\x73\x38\x8c\xcd\xe3\ +\x61\xab\x5e\x43\xd6\x33\x67\x1f\x25\x2e\x6f\x08\x78\x17\x18\xa4\ +\x06\xf8\xf2\x90\xa6\xcc\xe7\x38\x4e\x7c\x22\xd6\xc1\x43\xdc\xb9\ +\xa5\x21\xc9\xb7\x5f\x6a\xc6\xf3\xe8\x5e\xe8\x0f\x40\xe5\xf6\x88\ +\x29\x57\x4b\x33\x89\x00\x57\xe6\x32\x67\x82\x6d\xd8\x23\xe0\x28\ +\xac\x05\x1c\x74\x01\x17\x48\x29\x8f\xa8\x2a\xfe\x39\x8e\x8e\x4f\ +\xcc\x3c\x5c\x4a\x19\x96\x52\x3e\x02\x94\x56\x3b\x69\x7f\xb8\x02\ +\xaa\xbb\x22\x7e\x74\xdc\x1e\x26\x84\x54\x76\x03\xc6\xf7\x2f\xa2\ +\xf3\xb0\xc2\x95\x40\x0d\x7e\x60\xb9\x94\x72\x6c\xbf\xff\x9f\xe3\ +\x23\xe1\x13\xf7\x0f\x20\xa5\xac\x03\xce\xf7\x84\x18\x78\xa7\x29\ +\x62\x7f\x09\xd0\xd1\x86\x0b\x20\xd5\xc4\xcc\xbb\x8a\x39\xa0\xea\ +\xdd\x05\xac\x07\x22\x86\x28\x9f\xea\x78\x3e\xa7\x03\x9f\x38\x01\ +\x00\x0c\xe9\x0c\x5e\x5a\xd9\x44\xf0\xae\x47\xa0\xcd\x01\xed\xed\ +\xcc\x0d\x85\x22\xd1\xbd\x7e\x3f\x8f\xcc\x28\x85\x30\x3e\x22\xdf\ +\x7d\x95\xa7\xa5\x94\x87\xe9\xb8\x7f\x8e\xe3\xc7\x19\x41\x00\x30\ +\x72\x4b\xb8\x74\xcb\x5e\x06\xae\xfe\x21\xbc\xb2\x16\xe5\xad\x75\ +\x91\x20\x0e\x5a\x85\x94\x1f\xa4\xa1\x61\x25\xe0\xa1\x9e\x88\x11\ +\xcb\xe7\x38\x09\x38\xe3\xac\x83\x85\x10\x05\xc0\x9f\x80\x73\x67\ +\x4f\x82\xd9\x93\x60\x57\x1d\xec\xaa\x85\x7e\x17\x0d\x44\x0e\x7d\ +\x7b\x3f\xe1\x6e\x7e\x6a\x70\xc6\x11\xc0\x30\x84\x10\x4b\x89\xd8\ +\xc3\x25\x0d\x25\x35\x01\xd7\x0f\x5b\xb5\x7e\x8e\x93\x83\xff\x0f\ +\x92\x04\x28\x92\xfd\x58\xc9\xac\x00\x00\x00\x00\x49\x45\x4e\x44\ +\xae\x42\x60\x82\ +\x00\x00\x05\x24\ +\x89\ +\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\ +\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\ +\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\ +\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x06\xec\x00\x00\x06\xec\ +\x01\x1e\x75\x38\x35\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\ +\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\ +\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x00\x13\x74\x45\ +\x58\x74\x41\x75\x74\x68\x6f\x72\x00\x52\x6f\x64\x6e\x65\x79\x20\ +\x44\x61\x77\x65\x73\x0e\xd8\x7e\x1d\x00\x00\x04\x82\x49\x44\x41\ +\x54\x48\x89\x8d\x96\x6d\x88\x54\x55\x18\xc7\x7f\xe7\xbe\xcd\xdc\ +\xb9\xb3\xb3\x33\xfb\x1a\xeb\xbe\xe9\x6e\xad\x2f\x69\xea\x4a\x68\ +\x8a\x22\x24\x94\x58\x7e\xaa\x48\x23\x30\x41\xa1\xec\x5b\x51\x41\ +\xf8\x31\xa3\x3e\x44\x94\x44\x92\x09\xa5\x46\x54\x84\x21\x59\x44\ +\x18\x96\x19\xe8\xb2\x5a\x9b\xba\xea\xa4\xfb\x66\xee\xce\xce\xec\ +\xec\xcc\xdc\xbd\xf3\x76\xef\xe9\x43\xb3\x63\xe6\xbe\x3d\xf0\x7c\ +\xfb\x9f\xff\xef\x3c\xcf\x3d\xf7\x39\x47\x48\x29\x99\x2d\x36\xbf\ +\x27\x7c\x6a\x40\x7f\xc0\x6f\xaa\x8f\x02\x64\x1d\xf7\x84\x3b\x51\ +\xb8\xf0\xed\x8b\x32\x37\xdb\x5a\x31\x1d\x60\xf7\x01\xa1\x67\x2a\ +\x42\x6f\x37\xd7\x76\xac\xaf\x0a\xd6\xd5\x07\x2c\x7f\xb5\x69\x05\ +\x7c\x52\xba\x38\xb6\x93\xb3\xed\x6c\x62\xdc\x8e\xdd\xea\x8b\x5d\ +\x39\x15\x4c\xa7\x5e\xfe\x70\x97\x2c\xcc\x19\xb0\xf5\xa0\xb1\xa2\ +\xb9\x7e\xfe\xa1\x65\x1d\xab\x96\x7a\x5a\x56\xc9\xbb\x0e\x92\x3b\ +\x75\x02\x81\xa1\x9a\x50\x50\xbd\x9e\x2b\xe7\xff\xe8\x1f\xbe\xb1\ +\xe3\xd8\xce\x7c\xf7\xac\x80\x6d\x9f\x87\xf7\x2d\x69\x59\xf9\x5c\ +\x5d\x7d\x4d\x9d\xe3\xa6\x67\xeb\x00\x00\x01\x35\x44\x7c\x78\x74\ +\xe4\x42\x5f\xd7\xc7\x47\x9f\x4a\xbe\x36\x2d\xe0\x89\xc3\xe6\xb3\ +\x6b\x96\xae\xdf\x6f\x04\xf5\xa0\x27\xdd\x39\x99\x87\xb4\x7a\xaa\ +\xb4\x05\xa4\x8a\x43\x8c\x24\xaf\x65\xce\xf4\xfc\xfa\xc2\x17\xcf\ +\x38\x9f\xdc\x05\x78\xfc\x80\xa8\x59\xdc\xb6\xfc\xb7\xc6\xd6\x79\ +\x6d\x73\x33\x17\x6c\xa8\xdc\x43\x83\xb1\x0c\x9f\x08\x72\xc9\xf9\ +\x9e\x73\x99\x4f\x19\xb8\xd1\x17\xbd\x18\xed\x59\xfd\xcd\x2e\x39\ +\x0a\xa0\x4c\xca\xab\xab\xea\x8e\x34\x35\x37\xcc\xd1\x1c\x56\x5a\ +\x4f\xd2\xea\x5b\x8d\x5f\x54\x10\x2b\x46\x39\x97\x39\x8c\x2b\x5d\ +\x5a\x9a\x5b\xdb\xaa\xab\x6a\x8f\x4c\xea\x14\xf8\xf7\xa3\xde\xd7\ +\xd2\xb1\xd6\x15\x73\x33\xaf\xd1\xdb\xb9\xd7\xdc\x88\x82\x4a\xda\ +\x1d\xe1\x97\xd4\x07\xb8\xb2\x08\x40\x51\xb8\xb4\x37\xcf\x5f\xbb\ +\xf5\xa0\xb1\xa2\x0c\xf0\xfb\xd4\x2d\x56\xd0\xb4\x40\x20\xc4\x54\ +\xa9\x94\x53\x53\xfc\xac\x0b\xed\x22\xa0\x86\xc9\x4b\x87\x4b\xce\ +\x09\xd2\xee\x2d\x14\xa1\xa2\x08\x0d\x21\x54\x02\x96\xdf\xf2\xfb\ +\xd4\x2d\x00\x1a\x80\x65\x55\x76\xaa\x86\x86\xbc\xdd\xb1\x52\x97\ +\xc5\x5d\xbb\x5f\x57\xb1\x9b\x88\xd6\x02\xc0\xad\x42\x0f\x97\xb3\ +\x3f\xa2\x08\xed\x0e\x8d\x66\xe8\x58\x66\xa8\xb3\x0c\x08\x9b\x91\ +\x26\x4f\x4a\x54\x45\x9d\xd1\x7c\x81\xff\x21\x1a\x7d\xcb\xf1\x3c\ +\x8f\x91\x42\x2f\xc7\xe3\x7b\x31\xd5\x10\x86\x12\xb8\x43\x57\x14\ +\x0a\xe1\x40\xa8\xa9\x0c\x30\x0d\xb3\x5a\x11\x3a\x0a\x0a\x4c\x61\ +\x0c\x10\x50\x23\x2c\x31\x1f\x43\xb8\x3a\x49\xf7\x26\xc7\xe3\x7b\ +\x89\xe5\xa3\x00\x58\x6a\x15\x11\xa3\x11\x53\x0d\x23\x10\xe4\xbd\ +\x1c\xa6\xe1\xab\x2e\x03\x9c\xfc\x44\x5c\x57\x8c\x96\xff\x9a\x77\ +\x98\x9b\x08\x28\x61\xba\xed\x2f\xf1\x64\x91\x4e\x73\x1b\x16\xb5\ +\x14\x65\x8e\x0b\xf6\x57\x0c\xe5\xcf\x97\xb5\xb6\x9b\xc0\x76\x12\ +\xf8\x14\x8b\xb0\xde\x48\xde\x4b\xe1\xe4\xb3\xf1\x32\x20\x69\x8f\ +\x0d\x08\xe9\xad\x14\xc2\x00\x60\x51\xe0\x11\x96\x04\x36\x23\x50\ +\x18\x2b\x0c\x52\x70\x73\xdc\xa3\x2d\x06\xa0\x3f\x77\x96\x53\xa9\ +\xfd\x53\x56\x99\xf3\x6c\x86\x73\xbd\x54\xa9\xb5\x24\x27\xc6\x07\ +\xca\x00\xdb\x49\x75\x25\x9d\xa1\xad\xe8\x3a\xa6\x5a\x89\x81\x85\ +\xf0\x34\x3c\xcf\xa3\xcd\xd8\x80\x2e\x4c\x14\xa1\x13\x2b\x5c\xe5\ +\x58\xe2\x55\x60\xe6\x09\xac\xba\x3a\xb6\x93\xe9\x2a\x03\xb2\x39\ +\xf7\x78\x2e\x53\x7c\x65\x3c\xd0\x67\x49\x3c\xbc\xa2\xa4\xb5\x76\ +\x0d\x9a\xf0\x53\x55\x3a\x31\xb6\x1b\xe7\x64\xea\x1d\x26\xbc\xc4\ +\x8c\xe6\x02\x85\x82\x2d\xed\x6c\xce\x3d\x0e\xa5\xff\xe0\xd8\xce\ +\x7c\x77\x74\xa0\xff\x74\x8d\xd2\x00\xc0\x60\xbe\x9b\x44\xb1\xbf\ +\xbc\xc8\x93\x2e\x97\x9c\xef\xb8\xea\xfc\x34\xa3\x39\x40\x8d\x52\ +\x47\x74\x70\xf0\xf4\xe4\x64\x2d\x1f\xfc\x78\x22\xb6\x3d\x31\xe4\ +\x44\x4d\xc5\xc2\xc3\xe5\x77\xfb\x6b\x3c\x59\xa4\x20\x1d\x06\xf2\ +\xe7\xf8\x21\xf9\xe6\xac\xe6\xa6\x12\x20\x31\xe4\x44\xe3\x89\xf8\ +\xf6\x72\x45\xff\x9f\xa6\x9d\x0b\xef\x7f\x3f\xe9\xff\xbb\xc2\x95\ +\x1e\x0b\xcd\x87\xc9\xca\x34\xd7\xb3\x67\x98\xad\xef\x9a\xd0\x08\ +\x3b\xb5\xe9\xae\xde\xde\x3d\x53\x4e\xd3\xc9\xd8\xf6\x59\x78\x5f\ +\xc7\xbc\x96\x1d\x6a\x4d\xa1\x3e\xe9\xc6\x67\xdd\x35\x40\x58\x8d\ +\xe0\x8e\x6a\xc3\xbd\x43\xfd\x87\x8e\x3e\x9d\x9a\xfe\x3e\x98\x8c\ +\x4d\x6f\x19\xab\xdb\xda\xeb\x3f\x5a\xd0\x5e\xb7\x28\xad\x24\x94\ +\xac\x9c\xfa\x46\xf3\x0b\x3f\x15\x6e\xd8\xfb\xeb\xda\xc8\xe5\x8b\ +\x5d\xb1\xe7\x7f\xde\x57\x3c\x2d\x65\x69\xea\x4d\x05\x10\x42\xa8\ +\x40\x35\x10\x36\x2b\xa8\xdd\xf8\x7a\xf0\xa5\x96\xe6\xea\x65\xa1\ +\x50\x20\x12\x08\x2b\x41\xa3\x42\xd1\x01\x72\x29\xaf\x38\x91\x2c\ +\x66\xd2\xa9\x6c\xf2\xfa\xf5\xd8\x9f\x27\xdf\x98\x78\x37\x67\x33\ +\x0a\x8c\x01\x49\x29\x65\x72\xda\x0a\x84\x10\x16\x10\x01\x2a\x4b\ +\x59\x61\x04\x89\x34\x3c\xa8\x2c\x6c\x5a\xa5\xad\x90\x12\x65\xe0\ +\x6c\xf1\xfc\xcd\xb3\x5e\xb4\x60\x33\x06\x64\x80\x14\x30\x5e\xca\ +\x84\x94\xb7\x1f\x00\xd3\xbe\x2a\x4a\x30\x1d\x30\x00\x5f\x29\x75\ +\x40\x05\x8a\x40\x0e\xc8\x03\x59\xa0\x20\xe5\xd4\x37\xd5\x3f\x13\ +\x05\x02\x8c\xec\xcf\x7e\xae\x00\x00\x00\x00\x49\x45\x4e\x44\xae\ +\x42\x60\x82\ +\x00\x00\x05\x64\ +\x89\ +\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\ +\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\ +\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\ +\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x0d\xd7\x00\x00\x0d\xd7\ +\x01\x42\x28\x9b\x78\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\ +\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\ +\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x04\xe1\x49\x44\ +\x41\x54\x48\x89\xb5\x95\x4b\x6c\x54\x55\x18\xc7\x7f\xe7\x9c\x3b\ +\xd3\x99\x32\xb5\xf4\x41\x1f\x38\x02\xa5\x8d\x4c\x54\xb0\x25\x48\ +\xa2\x12\x62\x14\x17\x46\x12\x64\x81\x26\xb0\x31\x18\x12\x36\x04\ +\xd2\x45\xe9\xca\x84\x5d\x29\x1a\x43\x80\x6e\x0c\x2b\xdc\x88\x09\ +\xa0\x21\x04\x35\x18\xa2\x26\x3e\xa3\x62\xa8\x4f\x7c\x61\x95\x81\ +\xb6\x4c\xcb\xbc\x67\xee\x39\x9f\x8b\xdb\xde\x4e\x2b\x10\x37\xde\ +\xe4\x4b\xee\xf3\xf7\xff\xfe\xdf\xf9\xee\x77\x94\x88\xf0\x7f\x1e\ +\xde\xdd\x1e\x9e\x53\xaa\x23\x6a\xcc\x33\xf1\x58\xec\xc9\xc6\xae\ +\x15\x2b\xeb\x97\x76\x2e\xb5\xc5\x62\x29\xfb\xd7\x5f\x63\xb9\x6b\ +\xd7\x7f\x2c\x55\x2a\xe7\x2b\xd6\x7e\xb0\x59\xa4\x70\x27\x86\xba\ +\xad\x03\xa5\xd4\x87\xb1\xd8\xcb\xc9\xc7\x1e\xdd\xdd\xd9\xde\xd6\ +\x11\xab\x8b\xa3\x4a\x45\xc8\xe5\xc1\x18\x48\x2c\xc2\x79\x86\x5c\ +\x21\xef\x7e\x1f\xbd\x7c\x65\xe2\xa7\x5f\xf7\x3c\x59\xad\xbe\xf7\ +\x9f\x04\xce\x2b\xb5\xa2\x75\x69\xe7\x1b\xa9\x0d\x8f\xaf\x4f\x54\ +\x6d\x84\x52\x29\x7c\x26\xc0\xec\xfb\x32\x13\x34\x24\xb8\x91\x9f\ +\x9e\xfa\xe5\xe3\x4f\xdf\xb6\xb7\xb2\xbb\x9f\x10\x29\xd5\xf2\xe6\ +\x09\xbc\xa7\x54\x6a\xc5\xda\xde\xf7\x7b\xba\xbb\x93\x3a\x57\x98\ +\x83\x2c\x80\x2e\x14\x11\xcf\xc3\xc6\x23\xf2\xdd\xa7\x9f\x7d\x91\ +\x1f\xbb\xf6\xf8\x13\x22\xfe\x2c\x53\xcf\x9e\x1c\x50\x4a\xb7\x2e\ +\x5f\x76\xa2\x67\xf9\xb2\x24\xd9\x3c\x4e\x24\x0c\x0b\xd8\x9a\x6b\ +\x07\xc1\x3d\xc0\x01\xce\xf7\x21\x5b\x54\x3d\x0f\xaf\x5e\xa7\xeb\ +\xeb\x5f\xab\x75\x10\x0a\x6c\x8a\xc7\x87\xee\x7f\xe8\x81\x3e\x29\ +\x56\x02\x80\x08\x95\xed\xdb\xa9\x6e\xd9\x32\x27\x52\x03\xb5\x9e\ +\x87\x1d\x18\xc0\xef\xed\x0d\xc5\x8d\x55\x3a\xd9\xf7\xe0\x0b\xef\ +\x46\xa3\xeb\x67\xb9\x1e\xc0\x05\xa5\xee\x5f\xf9\xd8\xfa\x17\xeb\ +\x7c\x67\x1c\x0a\x01\xec\x8e\x1d\xf8\xcf\x3d\x17\x64\xa1\x14\xfa\ +\xf4\xe9\xb0\x3c\x2e\x12\x81\xfd\xfb\xa1\xb7\x17\xfa\xfa\x70\x43\ +\x43\xf0\xf5\xd7\x88\x15\x9a\x1b\x9b\x96\x34\x25\xdb\x5f\x47\xa9\ +\x5e\x44\x44\x03\x78\x9e\xf7\x52\xc7\xe2\xa6\x25\x0e\x15\x64\xe9\ +\x79\xd8\x9e\x9e\xd0\xa6\xdb\xbe\x1d\x7f\xeb\xd6\x20\xd3\x68\x14\ +\x06\x07\x51\x7d\x7d\x28\xa5\x50\x91\x08\x74\x77\x63\x95\xc2\x29\ +\x85\xad\x38\x5a\xdb\x5a\x97\x9d\x85\x9e\xd0\xc1\xa2\xd6\x96\x07\ +\x11\x70\x5a\x07\x59\x3a\x07\x43\x43\xa8\xc1\x41\x58\xbd\x3a\x50\ +\xd9\xb1\x03\x17\x89\xa0\x52\x29\xd4\x9a\x35\xa1\xb8\x3d\x79\x12\ +\xff\xcc\x19\xc4\x18\x10\x41\xb4\xa6\x2e\x16\x5b\x1c\xd1\xfa\x29\ +\xe0\x67\x0d\x10\x6b\x6d\xba\xcf\x39\xc1\x69\x3d\x17\xd6\xe2\x86\ +\x87\x91\xcb\x97\xe7\x16\xec\xf9\xe7\x43\xb8\x88\xe0\xbf\xf9\x26\ +\xd5\x53\xa7\x82\xc4\xb4\xc6\x79\x1e\x62\x0c\x26\x12\xa5\x2e\x11\ +\xdf\x00\xa0\xdf\x52\x2a\x51\xb7\x28\xde\xee\xbc\xc8\x7c\x01\xad\ +\x71\xce\xe1\xbf\xf2\x0a\x32\x3a\x3a\x57\x2e\xe7\xf0\x7d\x9f\xe2\ +\xa1\x43\x64\x07\x06\x28\x5f\xb9\x82\x2d\x95\x70\xc6\x04\xdf\x18\ +\x83\xad\x5a\xe2\x4b\x9a\x97\x87\x25\xc2\x98\x60\xe1\x6a\x7b\x7b\ +\x06\x28\xc6\x20\x4a\x81\x13\x9c\xb8\xf0\x1f\x70\x80\xb5\x16\x3f\ +\x9d\x46\xd2\x69\x74\x4b\x0b\x5e\x32\x89\x6e\x6e\xc6\x96\xcb\x88\ +\xb8\xa0\x8b\xb6\x89\xe4\xbe\x79\x68\xd5\x75\xb9\xd7\xb4\x8b\x95\ +\x00\xae\x54\x00\x8f\x44\x88\xec\xdb\x07\xa9\x14\xd6\xd9\xda\xf6\ +\x26\xd6\xdf\x8f\x03\x0a\x87\x0f\x07\x5d\x37\x31\x41\x65\x7c\x1c\ +\xe2\x71\xbc\xce\x26\xf2\x13\x53\x7f\x84\xff\x41\x71\x32\xf3\xa7\ +\x3f\x39\x81\xd3\x26\xb4\xea\x97\xca\xe8\x5d\xbb\x90\x55\xab\x70\ +\x2e\xc8\xa6\xf0\xea\xab\x54\x3f\xfa\x28\x14\xa9\xef\xef\x27\xb6\ +\x77\x2f\x56\x24\x08\xc0\x2f\x14\xa8\x7a\x9a\x72\xae\xf8\x71\x58\ +\xa2\xc2\x44\x66\xb4\x9c\xcd\x3c\xeb\x46\x7f\x40\x25\x12\x38\xdf\ +\x27\x3e\x32\x82\x5e\xbb\x36\x84\xe5\x87\x87\x29\x1c\x3b\x86\x44\ +\xa3\x34\x1e\x3f\x4e\xdd\xc6\x8d\x00\x24\xfa\xfb\x71\xbe\x4f\xf6\ +\xc8\x91\xb0\xb4\x3e\x4c\x39\xe7\x2e\x84\x0e\xf0\xfd\xe3\x19\xe7\ +\xc6\x2d\x8e\xca\xf8\x38\xd5\x4c\x06\xff\xd2\xa5\x10\x9e\x3b\x78\ +\x90\xdc\xd1\xa3\x41\x96\xe5\x32\x93\x3b\x77\x52\xba\x78\x31\x28\ +\x63\xb9\x4c\xe9\xdb\x6f\xf1\x67\x5c\xe8\xf6\x26\xa6\x27\x32\x57\ +\x37\xc3\x15\xa8\x19\x76\x17\xea\xeb\x87\xbb\xfa\x52\xfd\xf6\xf2\ +\x2f\x46\x24\x58\x8b\xc4\xe0\x20\xf6\xd6\x2d\x72\x23\x23\x41\x8f\ +\xd7\x36\x41\x34\x4a\xeb\xc8\x08\xd9\x13\x27\x28\x5c\xbc\x18\x0c\ +\xc4\xa8\x87\xac\xec\x18\xff\xfb\xab\x9f\x36\x6f\x16\xf9\x7c\x9e\ +\xc0\x01\xa5\xf4\xa6\x64\xfb\x67\x1d\x0d\x0d\xeb\x2a\x63\xd7\x91\ +\x05\xc0\xd9\xde\x9f\x27\x52\x73\x0d\x50\x97\xba\xcf\xa5\xbf\xff\ +\x7d\xe4\xe9\x42\x69\xcf\xac\xfb\x79\xe3\xfa\xac\x52\xa9\xb6\x54\ +\xd7\xfb\x8d\xc6\x24\x2b\x7f\xa4\xff\x05\x58\x08\x0d\x47\xb7\x67\ +\x88\x76\x77\xca\x74\xfa\xe6\x17\x92\x9e\x9c\x37\xae\x6f\xbb\xe1\ +\xc4\xda\x9a\xdf\x58\xb2\x62\xe9\x23\xfe\x6f\xd7\xa2\x36\x5f\x9c\ +\x2f\xb0\x40\xc4\x6b\x6b\x44\x9a\x13\x99\xc9\x1f\xc7\xde\x89\x15\ +\x4a\x77\xdf\x70\x00\x94\x52\x0d\x31\x68\x39\x12\x31\x03\xa9\x55\ +\xcb\xb7\x35\x18\xd3\xaa\x7c\x8b\x9d\xce\x53\x9d\xce\x83\xd1\x98\ +\xc5\x09\x74\x43\x0c\xab\x95\x4c\x66\x72\x63\x17\xae\x5e\x3f\x78\ +\x08\xde\x05\xa6\x80\x29\xb9\x93\x03\xa5\x54\x04\x68\x01\x16\x03\ +\x4d\xbd\xd0\xb5\x49\xeb\x8d\x5d\x9e\xe9\x6b\x69\x69\x68\x6b\x6c\ +\x5c\xd4\x58\xf5\xad\xbd\x99\xc9\x4d\xdf\x98\xce\x5f\x1b\xf5\xed\ +\x97\xe7\xe0\x93\x71\x48\xcf\xc2\x81\x8c\x88\x64\xef\xe8\x60\x46\ +\xa8\x1e\xb8\x67\x26\xea\x81\x38\x50\x07\x18\x40\x11\xec\x3b\x15\ +\xa0\x0c\xe4\x81\x1c\x70\x0b\xc8\xca\xec\x8c\x98\x39\xfe\x01\x76\ +\x95\xba\xf1\x06\x3a\xff\x81\x00\x00\x00\x00\x49\x45\x4e\x44\xae\ +\x42\x60\x82\ " qt_resource_name = "\ @@ -755,6 +1757,11 @@ qt_resource_name = "\ \x05\xcd\xf4\xe7\ \x00\x63\ \x00\x6f\x00\x6e\x00\x6e\x00\x5f\x00\x65\x00\x72\x00\x72\x00\x6f\x00\x72\x00\x2e\x00\x70\x00\x6e\x00\x67\ +\x00\x13\ +\x09\xd2\x6c\x67\ +\x00\x45\ +\x00\x6d\x00\x62\x00\x6c\x00\x65\x00\x6d\x00\x2d\x00\x71\x00\x75\x00\x65\x00\x73\x00\x74\x00\x69\x00\x6f\x00\x6e\x00\x2e\x00\x70\ +\x00\x6e\x00\x67\ \x00\x12\ \x04\xe4\x91\x47\ \x00\x63\ @@ -769,15 +1776,33 @@ qt_resource_name = "\ \x00\x63\ \x00\x6f\x00\x6e\x00\x6e\x00\x5f\x00\x63\x00\x6f\x00\x6e\x00\x6e\x00\x65\x00\x63\x00\x74\x00\x69\x00\x6e\x00\x67\x00\x2e\x00\x70\ \x00\x6e\x00\x67\ +\x00\x14\ +\x00\xe9\x23\x87\ +\x00\x6c\ +\x00\x65\x00\x61\x00\x70\x00\x2d\x00\x63\x00\x6f\x00\x6c\x00\x6f\x00\x72\x00\x2d\x00\x73\x00\x6d\x00\x61\x00\x6c\x00\x6c\x00\x2e\ +\x00\x70\x00\x6e\x00\x67\ +\x00\x11\ +\x06\x1a\x44\xa7\ +\x00\x44\ +\x00\x69\x00\x61\x00\x6c\x00\x6f\x00\x67\x00\x2d\x00\x61\x00\x63\x00\x63\x00\x65\x00\x70\x00\x74\x00\x2e\x00\x70\x00\x6e\x00\x67\ +\ +\x00\x10\ +\x0f\xc3\x90\x67\ +\x00\x44\ +\x00\x69\x00\x61\x00\x6c\x00\x6f\x00\x67\x00\x2d\x00\x65\x00\x72\x00\x72\x00\x6f\x00\x72\x00\x2e\x00\x70\x00\x6e\x00\x67\ " qt_resource_struct = "\ \x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\ -\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x02\ -\x00\x00\x00\x34\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xf7\ +\x00\x00\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x02\ +\x00\x00\x00\xd4\x00\x00\x00\x00\x00\x01\x00\x00\x32\x3e\ +\x00\x00\x00\x60\x00\x00\x00\x00\x00\x01\x00\x00\x12\xe7\ \x00\x00\x00\x12\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\ -\x00\x00\x00\x5e\x00\x00\x00\x00\x00\x01\x00\x00\x19\xd2\ -\x00\x00\x00\x7c\x00\x00\x00\x00\x00\x01\x00\x00\x20\xbd\ +\x00\x00\x01\x02\x00\x00\x00\x00\x00\x01\x00\x00\x60\xc7\ +\x00\x00\x00\x8a\x00\x00\x00\x00\x00\x01\x00\x00\x1e\xc2\ +\x00\x00\x00\x34\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xf7\ +\x00\x00\x00\xa8\x00\x00\x00\x00\x00\x01\x00\x00\x25\xad\ +\x00\x00\x01\x2a\x00\x00\x00\x00\x00\x01\x00\x00\x65\xef\ " def qInitResources(): diff --git a/src/leap/gui/progress.py b/src/leap/gui/progress.py new file mode 100644 index 00000000..64b87b2c --- /dev/null +++ b/src/leap/gui/progress.py @@ -0,0 +1,448 @@ +""" +classes used in progress pages +from first run wizard +""" +try: + from collections import OrderedDict +except ImportError: + # We must be in 2.6 + from leap.util.dicts import OrderedDict + +import logging + +from PyQt4 import QtCore +from PyQt4 import QtGui + +from leap.gui.threads import FunThread + +from leap.gui import mainwindow_rc + +ICON_CHECKMARK = ":/images/Dialog-accept.png" +ICON_FAILED = ":/images/Dialog-error.png" +ICON_WAITING = ":/images/Emblem-question.png" + +logger = logging.getLogger(__name__) + + +class ImgWidget(QtGui.QWidget): + + # XXX move to widgets + + def __init__(self, parent=None, img=None): + super(ImgWidget, self).__init__(parent) + self.pic = QtGui.QPixmap(img) + + def paintEvent(self, event): + painter = QtGui.QPainter(self) + painter.drawPixmap(0, 0, self.pic) + + +class ProgressStep(object): + """ + Data model for sequential steps + to be used in a progress page in + connection wizard + """ + NAME = 0 + DONE = 1 + + def __init__(self, stepname, done, index=None): + """ + @param step: the name of the step + @type step: str + @param done: whether is completed or not + @type done: bool + """ + self.index = int(index) if index else 0 + self.name = unicode(stepname) + self.done = bool(done) + + @classmethod + def columns(self): + return ('name', 'done') + + +class ProgressStepContainer(object): + """ + a container for ProgressSteps objects + access data in the internal dict + """ + + def __init__(self): + self.dirty = False + self.steps = {} + + def step(self, identity): + return self.step.get(identity) + + def addStep(self, step): + self.steps[step.index] = step + + def removeStep(self, step): + del self.steps[step.index] + del step + self.dirty = True + + def removeAllSteps(self): + for item in iter(self): + self.removeStep(item) + + @property + def columns(self): + return ProgressStep.columns() + + def __len__(self): + return len(self.steps) + + def __iter__(self): + for step in self.steps.values(): + yield step + + +class StepsTableWidget(QtGui.QTableWidget): + """ + initializes a TableWidget + suitable for our display purposes, like removing + header info and grid display + """ + + def __init__(self, parent=None): + super(StepsTableWidget, self).__init__(parent) + + # remove headers and all edit/select behavior + self.horizontalHeader().hide() + self.verticalHeader().hide() + self.setEditTriggers( + QtGui.QAbstractItemView.NoEditTriggers) + self.setSelectionMode( + QtGui.QAbstractItemView.NoSelection) + width = self.width() + # WTF? Here init width is 100... + # but on populating is 456... :( + + # XXX do we need this initial? + logger.debug('init table. width=%s' % width) + self.horizontalHeader().resizeSection(0, width * 0.7) + + # this disables the table grid. + # we should add alignment to the ImgWidget (it's top-left now) + self.setShowGrid(False) + self.setFocusPolicy(QtCore.Qt.NoFocus) + #self.setStyleSheet("QTableView{outline: 0;}") + + # XXX change image for done to rc + + # Note about the "done" status painting: + # + # XXX currently we are setting the CellWidget + # for the whole table on a per-row basis + # (on add_status_line method on ValidationPage). + # However, a more generic solution might be + # to implement a custom Delegate that overwrites + # the paint method (so it paints a checked tickmark if + # done is True and some other thing if checking or false). + # What we have now is quick and works because + # I'm supposing that on first fail we will + # go back to previous wizard page to signal the failure. + # A more generic solution could be used for + # some failing tests if they are not critical. + + +class WithStepsMixIn(object): + + # worker threads for checks + + def setupStepsProcessingQueue(self): + self.steps_queue = Queue.Queue() + self.stepscheck_timer = QtCore.QTimer() + self.stepscheck_timer.timeout.connect(self.processStepsQueue) + self.stepscheck_timer.start(100) + # we need to keep a reference to child threads + self.threads = [] + + def do_checks(self): + + # yo dawg, I heard you like checks + # so I put a __do_checks in your do_checks + # for calling others' _do_checks + + def __do_checks(fun=None, queue=None): + + for checkcase in fun(): + checkmsg, checkfun = checkcase + + queue.put(checkmsg) + if checkfun() is False: + queue.put("failed") + break + + t = FunThread(fun=partial( + __do_checks, + fun=self._do_checks, + queue=self.steps_queue)) + t.finished.connect(self.on_checks_validation_ready) + t.begin() + self.threads.append(t) + + def fail(self, err=None): + """ + return failed state + and send error notification as + a nice side effect + """ + wizard = self.wizard() + senderr = lambda err: wizard.set_validation_error( + self.current_page, err) + self.set_undone() + if err: + senderr(err) + return False + + @QtCore.pyqtSlot() + def launch_checks(self): + self.do_checks() + + # slot + #@QtCore.pyqtSlot(str, int) + def onStepStatusChanged(self, status, progress=None): + if status not in ("head_sentinel", "end_sentinel"): + self.add_status_line(status) + if status in ("end_sentinel"): + self.checks_finished = True + self.set_checked_icon() + if progress and hasattr(self, 'progress'): + self.progress.setValue(progress) + self.progress.update() + + def processStepsQueue(self): + """ + consume steps queue + and pass messages + to the ui updater functions + """ + while self.steps_queue.qsize(): + try: + status = self.steps_queue.get(0) + if status == "failed": + self.set_failed_icon() + else: + self.onStepStatusChanged(*status) + except Queue.Empty: + pass + + def setupSteps(self): + self.steps = ProgressStepContainer() + # steps table widget + self.stepsTableWidget = StepsTableWidget(self) + zeros = (0, 0, 0, 0) + self.stepsTableWidget.setContentsMargins(*zeros) + self.errors = OrderedDict() + + def set_error(self, name, error): + self.errors[name] = error + + def pop_first_error(self): + return list(reversed(self.errors.items())).pop() + + def clean_errors(self): + self.errors = OrderedDict() + + def clean_wizard_errors(self, pagename=None): + if pagename is None: + pagename = getattr(self, 'prev_page', None) + if pagename is None: + return + logger.debug('cleaning wizard errors for %s' % pagename) + self.wizard().set_validation_error(pagename, None) + + def populateStepsTable(self): + # from examples, + # but I guess it's not needed to re-populate + # the whole table. + table = self.stepsTableWidget + table.setRowCount(len(self.steps)) + columns = self.steps.columns + table.setColumnCount(len(columns)) + + for row, step in enumerate(self.steps): + item = QtGui.QTableWidgetItem(step.name) + item.setData(QtCore.Qt.UserRole, + long(id(step))) + table.setItem(row, columns.index('name'), item) + table.setItem(row, columns.index('done'), + QtGui.QTableWidgetItem(step.done)) + self.resizeTable() + self.update() + + def clearTable(self): + # ??? -- not sure what's the difference + #self.stepsTableWidget.clear() + self.stepsTableWidget.clearContents() + + def resizeTable(self): + # resize first column to ~80% + table = self.stepsTableWidget + FIRST_COLUMN_PERCENT = 0.70 + width = table.width() + logger.debug('populate table. width=%s' % width) + table.horizontalHeader().resizeSection(0, width * FIRST_COLUMN_PERCENT) + + def set_item_icon(self, img=ICON_CHECKMARK, current=True): + """ + mark the last item + as done + """ + # setting cell widget. + # see note on StepsTableWidget about plans to + # change this for a better solution. + index = len(self.steps) + table = self.stepsTableWidget + _index = index - 1 if current else index - 2 + table.setCellWidget( + _index, + ProgressStep.DONE, + ImgWidget(img=img)) + table.update() + + def set_failed_icon(self): + self.set_item_icon(img=ICON_FAILED, current=True) + + def set_checking_icon(self): + self.set_item_icon(img=ICON_WAITING, current=True) + + def set_checked_icon(self, current=True): + self.set_item_icon(current=current) + + def add_status_line(self, message): + """ + adds a new status line + and mark the next-to-last item + as done + """ + index = len(self.steps) + step = ProgressStep(message, False, index=index) + self.steps.addStep(step) + self.populateStepsTable() + self.set_checking_icon() + self.set_checked_icon(current=False) + + # Sets/unsets done flag + # for isComplete checks + + def set_done(self): + self.done = True + self.completeChanged.emit() + + def set_undone(self): + self.done = False + self.completeChanged.emit() + + def is_done(self): + return self.done + + def go_back(self): + self.wizard().back() + + def go_next(self): + self.wizard().next() + + +""" +We will use one base class for the intermediate pages +and another one for the in-page validations, both sharing the creation +of the tablewidgets. +The logic of this split comes from where I was trying to solve +the ui update using signals, but now that it's working well with +queues I could join them again. +""" + +import Queue +from functools import partial + + +class InlineValidationPage(QtGui.QWizardPage, WithStepsMixIn): + + def __init__(self, parent=None): + super(InlineValidationPage, self).__init__(parent) + self.setupStepsProcessingQueue() + self.done = False + + # slot + + @QtCore.pyqtSlot() + def showStepsFrame(self): + self.valFrame.show() + self.update() + + # progress frame + + def setupValidationFrame(self): + qframe = QtGui.QFrame + valFrame = qframe() + valFrame.setFrameStyle(qframe.NoFrame) + valframeLayout = QtGui.QVBoxLayout() + zeros = (0, 0, 0, 0) + valframeLayout.setContentsMargins(*zeros) + + valframeLayout.addWidget(self.stepsTableWidget) + valFrame.setLayout(valframeLayout) + self.valFrame = valFrame + + +class ValidationPage(QtGui.QWizardPage, WithStepsMixIn): + """ + class to be used as an intermediate + between two pages in a wizard. + shows feedback to the user and goes back if errors, + goes forward if ok. + initializePage triggers a one shot timer + that calls do_checks. + Derived classes should implement + _do_checks and + _do_validation + """ + + # signals + stepChanged = QtCore.pyqtSignal([str, int]) + + def __init__(self, parent=None): + super(ValidationPage, self).__init__(parent) + self.setupSteps() + #self.connect_step_status() + + layout = QtGui.QVBoxLayout() + self.progress = QtGui.QProgressBar(self) + layout.addWidget(self.progress) + layout.addWidget(self.stepsTableWidget) + + self.setLayout(layout) + self.layout = layout + + self.timer = QtCore.QTimer() + self.done = False + + self.setupStepsProcessingQueue() + + def isComplete(self): + return self.is_done() + + ######################## + + def show_progress(self): + self.progress.show() + self.stepsTableWidget.show() + + def hide_progress(self): + self.progress.hide() + self.stepsTableWidget.hide() + + # pagewizard methods. + # if overriden, child classes should call super. + + def initializePage(self): + self.clean_errors() + self.clean_wizard_errors() + self.steps.removeAllSteps() + self.clearTable() + self.resizeTable() + self.timer.singleShot(0, self.do_checks) diff --git a/src/leap/gui/styles.py b/src/leap/gui/styles.py new file mode 100644 index 00000000..b482922e --- /dev/null +++ b/src/leap/gui/styles.py @@ -0,0 +1,16 @@ +GreenLineEdit = "QLabel {color: green; font-weight: bold}" +ErrorLabelStyleSheet = """QLabel { color: red; font-weight: bold }""" +ErrorLineEdit = """QLineEdit { border: 1px solid red; }""" + + +# XXX this is bad. +# and you should feel bad for it. +# The original style has a sort of box color +# white/beige left-top/right-bottom or something like +# that. + +RegularLineEdit = """ +QLineEdit { + border: 1px solid black; +} +""" diff --git a/src/leap/gui/test_mainwindow_rc.py b/src/leap/gui/test_mainwindow_rc.py new file mode 100644 index 00000000..c5abb4aa --- /dev/null +++ b/src/leap/gui/test_mainwindow_rc.py @@ -0,0 +1,29 @@ +import unittest +import hashlib + +try: + import sip + sip.setapi('QVariant', 2) +except ValueError: + pass + +from leap.gui import mainwindow_rc + +# I have to admit that there's something +# perverse in testing this. +# Even though, I still think that it _is_ a good idea +# to put a check to avoid non-updated resources files. + +# so, if you came here because an updated resource +# did break a test, what you have to do is getting +# the md5 hash of your qt_resource_data and change it here. + +# annoying? yep. try making a script for that :P + + +class MainWindowResourcesTest(unittest.TestCase): + + def test_mainwindow_resources_hash(self): + self.assertEqual( + hashlib.md5(mainwindow_rc.qt_resource_data).hexdigest(), + '53e196f29061d8f08f112e5a2e64eb53') diff --git a/src/leap/gui/tests/integration/fake_user_signup.py b/src/leap/gui/tests/integration/fake_user_signup.py new file mode 100644 index 00000000..78873749 --- /dev/null +++ b/src/leap/gui/tests/integration/fake_user_signup.py @@ -0,0 +1,84 @@ +""" +simple server to test registration and +authentication + +To test: + +curl -d login=python_test_user -d password_salt=54321\ + -d password_verifier=12341234 \ + http://localhost:8000/users.json + +""" +from BaseHTTPServer import HTTPServer +from BaseHTTPServer import BaseHTTPRequestHandler +import cgi +import json +import urlparse + +HOST = "localhost" +PORT = 8000 + +LOGIN_ERROR = """{"errors":{"login":["has already been taken"]}}""" + +from leap.base.tests.test_providers import EXPECTED_DEFAULT_CONFIG + + +class request_handler(BaseHTTPRequestHandler): + responses = { + '/': ['ok\n'], + '/users.json': ['ok\n'], + '/timeout': ['ok\n'], + '/provider.json': ['%s\n' % json.dumps(EXPECTED_DEFAULT_CONFIG)] + } + + def do_GET(self): + path = urlparse.urlparse(self.path) + message = '\n'.join( + self.responses.get( + path.path, None)) + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + def do_POST(self): + form = cgi.FieldStorage( + fp=self.rfile, + headers=self.headers, + environ={'REQUEST_METHOD': 'POST', + 'CONTENT_TYPE': self.headers['Content-Type'], + }) + data = dict( + (key, form[key].value) for key in form.keys()) + path = urlparse.urlparse(self.path) + message = '\n'.join( + self.responses.get( + path.path, '')) + + login = data.get('login', None) + #password_salt = data.get('password_salt', None) + #password_verifier = data.get('password_verifier', None) + + if path.geturl() == "/timeout": + print 'timeout' + self.send_response(200) + self.end_headers() + self.wfile.write(message) + import time + time.sleep(10) + return + + ok = True if (login == "python_test_user") else False + if ok: + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + else: + self.send_response(500) + self.end_headers() + self.wfile.write(LOGIN_ERROR) + + +if __name__ == "__main__": + server = HTTPServer((HOST, PORT), request_handler) + server.serve_forever() diff --git a/src/leap/gui/threads.py b/src/leap/gui/threads.py new file mode 100644 index 00000000..8aad8866 --- /dev/null +++ b/src/leap/gui/threads.py @@ -0,0 +1,21 @@ +from PyQt4 import QtCore + + +class FunThread(QtCore.QThread): + + def __init__(self, fun=None, parent=None): + + QtCore.QThread.__init__(self, parent) + self.exiting = False + self.fun = fun + + def __del__(self): + self.exiting = True + self.wait() + + def run(self): + if self.fun: + self.fun() + + def begin(self): + self.start() diff --git a/src/leap/gui/utils.py b/src/leap/gui/utils.py new file mode 100644 index 00000000..f91ac3ef --- /dev/null +++ b/src/leap/gui/utils.py @@ -0,0 +1,34 @@ +""" +utility functions to work with gui objects +""" +from PyQt4 import QtCore + + +def layout_widgets(layout): + """ + return a generator with all widgets in a layout + """ + return (layout.itemAt(i) for i in range(layout.count())) + + +DELAY_MSECS = 50 + + +def delay(obj, method_str=None, call_args=None): + """ + Triggers a function or slot with a small delay. + this is a mainly a hack to get responsiveness in the ui + in cases in which the event loop freezes and the task + is not heavy enough to setup a processing queue. + """ + if callable(obj) and not method_str: + fun = lambda: obj() + + if method_str: + invoke = QtCore.QMetaObject.invokeMethod + if call_args: + fun = lambda: invoke(obj, method_str, call_args) + else: + fun = lambda: invoke(obj, method_str) + + QtCore.QTimer().singleShot(DELAY_MSECS, fun) diff --git a/src/leap/soledad/README b/src/leap/soledad/README new file mode 100644 index 00000000..dc448374 --- /dev/null +++ b/src/leap/soledad/README @@ -0,0 +1,13 @@ +Soledad -- Synchronization Of Locally Encrypted Data Among Devices +================================================================== + +Dependencies +------------ + +Soledad uses the following python libraries: + + * u1db 0.1.4 [1] + * python-swiftclient 1.1.1 [2] + +[1] http://pypi.python.org/pypi/u1db/0.1.4 +[2] https://launchpad.net/python-swiftclient diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py new file mode 100644 index 00000000..3d685635 --- /dev/null +++ b/src/leap/soledad/__init__.py @@ -0,0 +1,164 @@ +# License? + +"""A U1DB implementation that uses OpenStack Swift as its persistence layer.""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db.backends import CommonBackend, CommonSyncTarget +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) + +from swiftclient import client + + +class OpenStackDatabase(CommonBackend): + """A U1DB implementation that uses OpenStack as its persistence layer.""" + + def __init__(self, auth_url, user, auth_key): + """Create a new OpenStack data container.""" + self._auth_url = auth_url + self._user = user + self._auth_key = auth_key + self.set_document_factory(LeapDocument) + self._connection = swiftclient.Connection(self._auth_url, self._user, + self._auth_key) + + #------------------------------------------------------------------------- + # implemented methods from Database + #------------------------------------------------------------------------- + + def set_document_factory(self, factory): + self._factory = factory + + def set_document_size_limit(self, limit): + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + raise NotImplementedError(self.whats_changed) + + def get_doc(self, doc_id, include_deleted=False): + raise NotImplementedError(self.get_doc) + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + raise NotImplementedError(self.get_all_docs) + + def put_doc(self, doc): + raise NotImplementedError(self.put_doc) + + def delete_doc(self, doc): + raise NotImplementedError(self.delete_doc) + + # start of index-related methods: these are not supported by this backend. + + def create_index(self, index_name, *index_expressions): + return False + + def delete_index(self, index_name): + return False + + def list_indexes(self): + return [] + + def get_from_index(self, index_name, *key_values): + return [] + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + return [] + + def get_index_keys(self, index_name): + return [] + + # end of index-related methods: these are not supported by this backend. + + def get_doc_conflicts(self, doc_id): + return [] + + def resolve_doc(self, doc, conflicted_doc_revs): + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + return OpenStackSyncTarget(self) + + def close(self): + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + raise NotImplementedError(self.close) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + #------------------------------------------------------------------------- + # implemented methods from CommonBackend + #------------------------------------------------------------------------- + + def _get_generation(self): + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + raise NotImplementedError(self._get_generation_info) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + raise NotImplementedError(self._has_conflicts) + + def _get_transaction_log(self): + raise NotImplementedError(self._get_transaction_log) + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + + def _get_trans_id_for_gen(self, generation): + raise NotImplementedError(self._get_trans_id_for_gen) + + #------------------------------------------------------------------------- + # OpenStack specific methods + #------------------------------------------------------------------------- + + def _is_initialized(self, c): + raise NotImplementedError(self._is_initialized) + + def _initialize(self, c): + raise NotImplementedError(self._initialize) + + def _get_auth(self): + self._url, self._auth_token = self._connection.get_auth(self._auth_url, + self._user, + self._auth_key) + return self._url, self.auth_token + + +class LeapDocument(Document): + + def get_content_encrypted(self): + raise NotImplementedError(self.get_content_encrypted) + + def set_content_encrypted(self): + raise NotImplementedError(self.set_content_encrypted) + + +class OpenStackSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + raise NotImplementedError(self.record_sync_info) diff --git a/src/leap/soledad/swiftclient/__init__.py b/src/leap/soledad/swiftclient/__init__.py new file mode 100644 index 00000000..ba0b41a3 --- /dev/null +++ b/src/leap/soledad/swiftclient/__init__.py @@ -0,0 +1,5 @@ +# -*- encoding: utf-8 -*- +"""" +OpenStack Swift Python client binding. +""" +from client import * diff --git a/src/leap/soledad/swiftclient/client.py b/src/leap/soledad/swiftclient/client.py new file mode 100644 index 00000000..79e6594f --- /dev/null +++ b/src/leap/soledad/swiftclient/client.py @@ -0,0 +1,1056 @@ +# Copyright (c) 2010-2012 OpenStack, LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cloud Files client library used internally +""" + +import socket +import os +import logging +import httplib + +from urllib import quote as _quote +from urlparse import urlparse, urlunparse, urljoin + +try: + from eventlet.green.httplib import HTTPException, HTTPSConnection +except ImportError: + from httplib import HTTPException, HTTPSConnection + +try: + from eventlet import sleep +except ImportError: + from time import sleep + +try: + from swift.common.bufferedhttp \ + import BufferedHTTPConnection as HTTPConnection +except ImportError: + try: + from eventlet.green.httplib import HTTPConnection + except ImportError: + from httplib import HTTPConnection + +logger = logging.getLogger("swiftclient") + + +def http_log(args, kwargs, resp, body): + if os.environ.get('SWIFTCLIENT_DEBUG', False): + ch = logging.StreamHandler() + logger.setLevel(logging.DEBUG) + logger.addHandler(ch) + elif not logger.isEnabledFor(logging.DEBUG): + return + + string_parts = ['curl -i'] + for element in args: + if element in ('GET', 'POST', 'PUT', 'HEAD'): + string_parts.append(' -X %s' % element) + else: + string_parts.append(' %s' % element) + + if 'headers' in kwargs: + for element in kwargs['headers']: + header = ' -H "%s: %s"' % (element, kwargs['headers'][element]) + string_parts.append(header) + + logger.debug("REQ: %s\n" % "".join(string_parts)) + if 'raw_body' in kwargs: + logger.debug("REQ BODY (RAW): %s\n" % (kwargs['raw_body'])) + if 'body' in kwargs: + logger.debug("REQ BODY: %s\n" % (kwargs['body'])) + + logger.debug("RESP STATUS: %s\n", resp.status) + if body: + logger.debug("RESP BODY: %s\n", body) + + +def quote(value, safe='/'): + """ + Patched version of urllib.quote that encodes utf8 strings before quoting + """ + if isinstance(value, unicode): + value = value.encode('utf8') + return _quote(value, safe) + + +# look for a real json parser first +try: + # simplejson is popular and pretty good + from simplejson import loads as json_loads + from simplejson import dumps as json_dumps +except ImportError: + # 2.6 will have a json module in the stdlib + from json import loads as json_loads + from json import dumps as json_dumps + + +class ClientException(Exception): + + def __init__(self, msg, http_scheme='', http_host='', http_port='', + http_path='', http_query='', http_status=0, http_reason='', + http_device='', http_response_content=''): + Exception.__init__(self, msg) + self.msg = msg + self.http_scheme = http_scheme + self.http_host = http_host + self.http_port = http_port + self.http_path = http_path + self.http_query = http_query + self.http_status = http_status + self.http_reason = http_reason + self.http_device = http_device + self.http_response_content = http_response_content + + def __str__(self): + a = self.msg + b = '' + if self.http_scheme: + b += '%s://' % self.http_scheme + if self.http_host: + b += self.http_host + if self.http_port: + b += ':%s' % self.http_port + if self.http_path: + b += self.http_path + if self.http_query: + b += '?%s' % self.http_query + if self.http_status: + if b: + b = '%s %s' % (b, self.http_status) + else: + b = str(self.http_status) + if self.http_reason: + if b: + b = '%s %s' % (b, self.http_reason) + else: + b = '- %s' % self.http_reason + if self.http_device: + if b: + b = '%s: device %s' % (b, self.http_device) + else: + b = 'device %s' % self.http_device + if self.http_response_content: + if len(self.http_response_content) <= 60: + b += ' %s' % self.http_response_content + else: + b += ' [first 60 chars of response] %s' \ + % self.http_response_content[:60] + return b and '%s: %s' % (a, b) or a + + +def http_connection(url, proxy=None): + """ + Make an HTTPConnection or HTTPSConnection + + :param url: url to connect to + :param proxy: proxy to connect through, if any; None by default; str of the + format 'http://127.0.0.1:8888' to set one + :returns: tuple of (parsed url, connection object) + :raises ClientException: Unable to handle protocol scheme + """ + parsed = urlparse(url) + proxy_parsed = urlparse(proxy) if proxy else None + if parsed.scheme == 'http': + conn = HTTPConnection((proxy_parsed if proxy else parsed).netloc) + elif parsed.scheme == 'https': + conn = HTTPSConnection((proxy_parsed if proxy else parsed).netloc) + else: + raise ClientException('Cannot handle protocol scheme %s for url %s' % + (parsed.scheme, repr(url))) + if proxy: + conn._set_tunnel(parsed.hostname, parsed.port) + return parsed, conn + + +def json_request(method, url, **kwargs): + """Takes a request in json parse it and return in json""" + kwargs.setdefault('headers', {}) + if 'body' in kwargs: + kwargs['headers']['Content-Type'] = 'application/json' + kwargs['body'] = json_dumps(kwargs['body']) + parsed, conn = http_connection(url) + conn.request(method, parsed.path, **kwargs) + resp = conn.getresponse() + body = resp.read() + http_log((url, method,), kwargs, resp, body) + if body: + try: + body = json_loads(body) + except ValueError: + body = None + if not body or resp.status < 200 or resp.status >= 300: + raise ClientException('Auth GET failed', http_scheme=parsed.scheme, + http_host=conn.host, + http_port=conn.port, + http_path=parsed.path, + http_status=resp.status, + http_reason=resp.reason) + return resp, body + + +def _get_auth_v1_0(url, user, key, snet): + parsed, conn = http_connection(url) + method = 'GET' + conn.request(method, parsed.path, '', + {'X-Auth-User': user, 'X-Auth-Key': key}) + resp = conn.getresponse() + body = resp.read() + url = resp.getheader('x-storage-url') + http_log((url, method,), {}, resp, body) + + # There is a side-effect on current Rackspace 1.0 server where a + # bad URL would get you that document page and a 200. We error out + # if we don't have a x-storage-url header and if we get a body. + if resp.status < 200 or resp.status >= 300 or (body and not url): + raise ClientException('Auth GET failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=parsed.path, http_status=resp.status, + http_reason=resp.reason) + if snet: + parsed = list(urlparse(url)) + # Second item in the list is the netloc + netloc = parsed[1] + parsed[1] = 'snet-' + netloc + url = urlunparse(parsed) + return url, resp.getheader('x-storage-token', + resp.getheader('x-auth-token')) + + +def _get_auth_v2_0(url, user, tenant_name, key, snet): + body = {'auth': + {'passwordCredentials': {'password': key, 'username': user}, + 'tenantName': tenant_name}} + token_url = urljoin(url, "tokens") + resp, body = json_request("POST", token_url, body=body) + token_id = None + try: + url = None + catalogs = body['access']['serviceCatalog'] + for service in catalogs: + if service['type'] == 'object-store': + url = service['endpoints'][0]['publicURL'] + token_id = body['access']['token']['id'] + if not url: + raise ClientException("There is no object-store endpoint " + "on this auth server.") + except(KeyError, IndexError): + raise ClientException("Error while getting answers from auth server") + + if snet: + parsed = list(urlparse(url)) + # Second item in the list is the netloc + parsed[1] = 'snet-' + parsed[1] + url = urlunparse(parsed) + + return url, token_id + + +def get_auth(url, user, key, snet=False, tenant_name=None, auth_version="1.0"): + """ + Get authentication/authorization credentials. + + The snet parameter is used for Rackspace's ServiceNet internal network + implementation. In this function, it simply adds *snet-* to the beginning + of the host name for the returned storage URL. With Rackspace Cloud Files, + use of this network path causes no bandwidth charges but requires the + client to be running on Rackspace's ServiceNet network. + + :param url: authentication/authorization URL + :param user: user to authenticate as + :param key: key or password for authorization + :param snet: use SERVICENET internal network (see above), default is False + :param auth_version: OpenStack auth version, default is 1.0 + :param tenant_name: The tenant/account name, required when connecting + to a auth 2.0 system. + :returns: tuple of (storage URL, auth token) + :raises: ClientException: HTTP GET request to auth URL failed + """ + if auth_version in ["1.0", "1"]: + return _get_auth_v1_0(url, user, key, snet) + elif auth_version in ["2.0", "2"]: + if not tenant_name and ':' in user: + (tenant_name, user) = user.split(':') + if not tenant_name: + raise ClientException('No tenant specified') + return _get_auth_v2_0(url, user, tenant_name, key, snet) + else: + raise ClientException('Unknown auth_version %s specified.' + % auth_version) + + +def get_account(url, token, marker=None, limit=None, prefix=None, + http_conn=None, full_listing=False): + """ + Get a listing of containers for the account. + + :param url: storage URL + :param token: auth token + :param marker: marker query + :param limit: limit query + :param prefix: prefix query + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param full_listing: if True, return a full listing, else returns a max + of 10000 listings + :returns: a tuple of (response headers, a list of containers) The response + headers will be a dict and all header names will be lowercase. + :raises ClientException: HTTP GET request failed + """ + if not http_conn: + http_conn = http_connection(url) + if full_listing: + rv = get_account(url, token, marker, limit, prefix, http_conn) + listing = rv[1] + while listing: + marker = listing[-1]['name'] + listing = \ + get_account(url, token, marker, limit, prefix, http_conn)[1] + if listing: + rv[1].extend(listing) + return rv + parsed, conn = http_conn + qs = 'format=json' + if marker: + qs += '&marker=%s' % quote(marker) + if limit: + qs += '&limit=%d' % limit + if prefix: + qs += '&prefix=%s' % quote(prefix) + full_path = '%s?%s' % (parsed.path, qs) + headers = {'X-Auth-Token': token} + conn.request('GET', full_path, '', + headers) + resp = conn.getresponse() + body = resp.read() + http_log(("%s?%s" % (url, qs), 'GET',), {'headers': headers}, resp, body) + + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + if resp.status < 200 or resp.status >= 300: + raise ClientException('Account GET failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=parsed.path, http_query=qs, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + if resp.status == 204: + body + return resp_headers, [] + return resp_headers, json_loads(body) + + +def head_account(url, token, http_conn=None): + """ + Get account stats. + + :param url: storage URL + :param token: auth token + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :returns: a dict containing the response's headers (all header names will + be lowercase) + :raises ClientException: HTTP HEAD request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + method = "HEAD" + headers = {'X-Auth-Token': token} + conn.request(method, parsed.path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log((url, method,), {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Account HEAD failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=parsed.path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + return resp_headers + + +def post_account(url, token, headers, http_conn=None): + """ + Update an account's metadata. + + :param url: storage URL + :param token: auth token + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP POST request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + method = 'POST' + headers['X-Auth-Token'] = token + conn.request(method, parsed.path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log((url, method,), {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Account POST failed', + http_scheme=parsed.scheme, + http_host=conn.host, + http_port=conn.port, + http_path=parsed.path, + http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + + +def get_container(url, token, container, marker=None, limit=None, + prefix=None, delimiter=None, http_conn=None, + full_listing=False): + """ + Get a listing of objects for the container. + + :param url: storage URL + :param token: auth token + :param container: container name to get a listing for + :param marker: marker query + :param limit: limit query + :param prefix: prefix query + :param delimeter: string to delimit the queries on + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param full_listing: if True, return a full listing, else returns a max + of 10000 listings + :returns: a tuple of (response headers, a list of objects) The response + headers will be a dict and all header names will be lowercase. + :raises ClientException: HTTP GET request failed + """ + if not http_conn: + http_conn = http_connection(url) + if full_listing: + rv = get_container(url, token, container, marker, limit, prefix, + delimiter, http_conn) + listing = rv[1] + while listing: + if not delimiter: + marker = listing[-1]['name'] + else: + marker = listing[-1].get('name', listing[-1].get('subdir')) + listing = get_container(url, token, container, marker, limit, + prefix, delimiter, http_conn)[1] + if listing: + rv[1].extend(listing) + return rv + parsed, conn = http_conn + path = '%s/%s' % (parsed.path, quote(container)) + qs = 'format=json' + if marker: + qs += '&marker=%s' % quote(marker) + if limit: + qs += '&limit=%d' % limit + if prefix: + qs += '&prefix=%s' % quote(prefix) + if delimiter: + qs += '&delimiter=%s' % quote(delimiter) + headers = {'X-Auth-Token': token} + method = 'GET' + conn.request(method, '%s?%s' % (path, qs), '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, qs), method,), {'headers': headers}, resp, body) + + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container GET failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_query=qs, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + if resp.status == 204: + return resp_headers, [] + return resp_headers, json_loads(body) + + +def head_container(url, token, container, http_conn=None, headers=None): + """ + Get container stats. + + :param url: storage URL + :param token: auth token + :param container: container name to get stats for + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :returns: a dict containing the response's headers (all header names will + be lowercase) + :raises ClientException: HTTP HEAD request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + method = 'HEAD' + req_headers = {'X-Auth-Token': token} + if headers: + req_headers.update(headers) + conn.request(method, path, '', req_headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': req_headers}, resp, body) + + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container HEAD failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + return resp_headers + + +def put_container(url, token, container, headers=None, http_conn=None): + """ + Create a container + + :param url: storage URL + :param token: auth token + :param container: container name to create + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP PUT request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + method = 'PUT' + if not headers: + headers = {} + headers['X-Auth-Token'] = token + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container PUT failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +def post_container(url, token, container, headers, http_conn=None): + """ + Update a container's metadata. + + :param url: storage URL + :param token: auth token + :param container: container name to update + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP POST request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + method = 'POST' + headers['X-Auth-Token'] = token + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container POST failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +def delete_container(url, token, container, http_conn=None): + """ + Delete a container + + :param url: storage URL + :param token: auth token + :param container: container name to delete + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP DELETE request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s' % (parsed.path, quote(container)) + headers = {'X-Auth-Token': token} + method = 'DELETE' + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), method,), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Container DELETE failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +def get_object(url, token, container, name, http_conn=None, + resp_chunk_size=None): + """ + Get an object + + :param url: storage URL + :param token: auth token + :param container: container name that the object is in + :param name: object name to get + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param resp_chunk_size: if defined, chunk size of data to read. NOTE: If + you specify a resp_chunk_size you must fully read + the object's contents before making another + request. + :returns: a tuple of (response headers, the object's contents) The response + headers will be a dict and all header names will be lowercase. + :raises ClientException: HTTP GET request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) + method = 'GET' + headers = {'X-Auth-Token': token} + conn.request(method, path, '', headers) + resp = conn.getresponse() + if resp.status < 200 or resp.status >= 300: + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + raise ClientException('Object GET failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + if resp_chunk_size: + + def _object_body(): + buf = resp.read(resp_chunk_size) + while buf: + yield buf + buf = resp.read(resp_chunk_size) + object_body = _object_body() + else: + object_body = resp.read() + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, object_body) + return resp_headers, object_body + + +def head_object(url, token, container, name, http_conn=None): + """ + Get object info + + :param url: storage URL + :param token: auth token + :param container: container name that the object is in + :param name: object name to get info for + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :returns: a dict containing the response's headers (all header names will + be lowercase) + :raises ClientException: HTTP HEAD request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) + method = 'HEAD' + headers = {'X-Auth-Token': token} + conn.request(method, path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object HEAD failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + resp_headers = {} + for header, value in resp.getheaders(): + resp_headers[header.lower()] = value + return resp_headers + + +def put_object(url, token=None, container=None, name=None, contents=None, + content_length=None, etag=None, chunk_size=65536, + content_type=None, headers=None, http_conn=None, proxy=None): + """ + Put an object + + :param url: storage URL + :param token: auth token; if None, no token will be sent + :param container: container name that the object is in; if None, the + container name is expected to be part of the url + :param name: object name to put; if None, the object name is expected to be + part of the url + :param contents: a string or a file like object to read object data from; + if None, a zero-byte put will be done + :param content_length: value to send as content-length header; also limits + the amount read from contents; if None, it will be + computed via the contents or chunked transfer + encoding will be used + :param etag: etag of contents; if None, no etag will be sent + :param chunk_size: chunk size of data to write; default 65536 + :param content_type: value to send as content-type header; if None, no + content-type will be set (remote end will likely try + to auto-detect it) + :param headers: additional headers to include in the request, if any + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param proxy: proxy to connect through, if any; None by default; str of the + format 'http://127.0.0.1:8888' to set one + :returns: etag from server response + :raises ClientException: HTTP PUT request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url, proxy=proxy) + path = parsed.path + if container: + path = '%s/%s' % (path.rstrip('/'), quote(container)) + if name: + path = '%s/%s' % (path.rstrip('/'), quote(name)) + if headers: + headers = dict(headers) + else: + headers = {} + if token: + headers['X-Auth-Token'] = token + if etag: + headers['ETag'] = etag.strip('"') + if content_length is not None: + headers['Content-Length'] = str(content_length) + else: + for n, v in headers.iteritems(): + if n.lower() == 'content-length': + content_length = int(v) + if content_type is not None: + headers['Content-Type'] = content_type + if not contents: + headers['Content-Length'] = '0' + if hasattr(contents, 'read'): + conn.putrequest('PUT', path) + for header, value in headers.iteritems(): + conn.putheader(header, value) + if content_length is None: + conn.putheader('Transfer-Encoding', 'chunked') + conn.endheaders() + chunk = contents.read(chunk_size) + while chunk: + conn.send('%x\r\n%s\r\n' % (len(chunk), chunk)) + chunk = contents.read(chunk_size) + conn.send('0\r\n\r\n') + else: + conn.endheaders() + left = content_length + while left > 0: + size = chunk_size + if size > left: + size = left + chunk = contents.read(size) + conn.send(chunk) + left -= len(chunk) + else: + conn.request('PUT', path, contents, headers) + resp = conn.getresponse() + body = resp.read() + headers = {'X-Auth-Token': token} + http_log(('%s?%s' % (url, path), 'PUT',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object PUT failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + return resp.getheader('etag', '').strip('"') + + +def post_object(url, token, container, name, headers, http_conn=None): + """ + Update object metadata + + :param url: storage URL + :param token: auth token + :param container: container name that the object is in + :param name: name of the object to update + :param headers: additional headers to include in the request + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :raises ClientException: HTTP POST request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url) + path = '%s/%s/%s' % (parsed.path, quote(container), quote(name)) + headers['X-Auth-Token'] = token + conn.request('POST', path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object POST failed', http_scheme=parsed.scheme, + http_host=conn.host, http_port=conn.port, + http_path=path, http_status=resp.status, + http_reason=resp.reason, + http_response_content=body) + + +def delete_object(url, token=None, container=None, name=None, http_conn=None, + headers=None, proxy=None): + """ + Delete object + + :param url: storage URL + :param token: auth token; if None, no token will be sent + :param container: container name that the object is in; if None, the + container name is expected to be part of the url + :param name: object name to delete; if None, the object name is expected to + be part of the url + :param http_conn: HTTP connection object (If None, it will create the + conn object) + :param headers: additional headers to include in the request + :param proxy: proxy to connect through, if any; None by default; str of the + format 'http://127.0.0.1:8888' to set one + :raises ClientException: HTTP DELETE request failed + """ + if http_conn: + parsed, conn = http_conn + else: + parsed, conn = http_connection(url, proxy=proxy) + path = parsed.path + if container: + path = '%s/%s' % (path.rstrip('/'), quote(container)) + if name: + path = '%s/%s' % (path.rstrip('/'), quote(name)) + if headers: + headers = dict(headers) + else: + headers = {} + if token: + headers['X-Auth-Token'] = token + conn.request('DELETE', path, '', headers) + resp = conn.getresponse() + body = resp.read() + http_log(('%s?%s' % (url, path), 'POST',), + {'headers': headers}, resp, body) + if resp.status < 200 or resp.status >= 300: + raise ClientException('Object DELETE failed', + http_scheme=parsed.scheme, http_host=conn.host, + http_port=conn.port, http_path=path, + http_status=resp.status, http_reason=resp.reason, + http_response_content=body) + + +class Connection(object): + """Convenience class to make requests that will also retry the request""" + + def __init__(self, authurl, user, key, retries=5, preauthurl=None, + preauthtoken=None, snet=False, starting_backoff=1, + tenant_name=None, + auth_version="1"): + """ + :param authurl: authentication URL + :param user: user name to authenticate as + :param key: key/password to authenticate with + :param retries: Number of times to retry the request before failing + :param preauthurl: storage URL (if you have already authenticated) + :param preauthtoken: authentication token (if you have already + authenticated) + :param snet: use SERVICENET internal network default is False + :param auth_version: OpenStack auth version, default is 1.0 + :param tenant_name: The tenant/account name, required when connecting + to a auth 2.0 system. + """ + self.authurl = authurl + self.user = user + self.key = key + self.retries = retries + self.http_conn = None + self.url = preauthurl + self.token = preauthtoken + self.attempts = 0 + self.snet = snet + self.starting_backoff = starting_backoff + self.auth_version = auth_version + self.tenant_name = tenant_name + + def get_auth(self): + return get_auth(self.authurl, self.user, + self.key, snet=self.snet, + tenant_name=self.tenant_name, + auth_version=self.auth_version) + + def http_connection(self): + return http_connection(self.url) + + def _retry(self, reset_func, func, *args, **kwargs): + self.attempts = 0 + backoff = self.starting_backoff + while self.attempts <= self.retries: + self.attempts += 1 + try: + if not self.url or not self.token: + self.url, self.token = self.get_auth() + self.http_conn = None + if not self.http_conn: + self.http_conn = self.http_connection() + kwargs['http_conn'] = self.http_conn + rv = func(self.url, self.token, *args, **kwargs) + return rv + except (socket.error, HTTPException): + if self.attempts > self.retries: + raise + self.http_conn = None + except ClientException, err: + if self.attempts > self.retries: + raise + if err.http_status == 401: + self.url = self.token = None + if self.attempts > 1: + raise + elif err.http_status == 408: + self.http_conn = None + elif 500 <= err.http_status <= 599: + pass + else: + raise + sleep(backoff) + backoff *= 2 + if reset_func: + reset_func(func, *args, **kwargs) + + def head_account(self): + """Wrapper for :func:`head_account`""" + return self._retry(None, head_account) + + def get_account(self, marker=None, limit=None, prefix=None, + full_listing=False): + """Wrapper for :func:`get_account`""" + # TODO(unknown): With full_listing=True this will restart the entire + # listing with each retry. Need to make a better version that just + # retries where it left off. + return self._retry(None, get_account, marker=marker, limit=limit, + prefix=prefix, full_listing=full_listing) + + def post_account(self, headers): + """Wrapper for :func:`post_account`""" + return self._retry(None, post_account, headers) + + def head_container(self, container): + """Wrapper for :func:`head_container`""" + return self._retry(None, head_container, container) + + def get_container(self, container, marker=None, limit=None, prefix=None, + delimiter=None, full_listing=False): + """Wrapper for :func:`get_container`""" + # TODO(unknown): With full_listing=True this will restart the entire + # listing with each retry. Need to make a better version that just + # retries where it left off. + return self._retry(None, get_container, container, marker=marker, + limit=limit, prefix=prefix, delimiter=delimiter, + full_listing=full_listing) + + def put_container(self, container, headers=None): + """Wrapper for :func:`put_container`""" + return self._retry(None, put_container, container, headers=headers) + + def post_container(self, container, headers): + """Wrapper for :func:`post_container`""" + return self._retry(None, post_container, container, headers) + + def delete_container(self, container): + """Wrapper for :func:`delete_container`""" + return self._retry(None, delete_container, container) + + def head_object(self, container, obj): + """Wrapper for :func:`head_object`""" + return self._retry(None, head_object, container, obj) + + def get_object(self, container, obj, resp_chunk_size=None): + """Wrapper for :func:`get_object`""" + return self._retry(None, get_object, container, obj, + resp_chunk_size=resp_chunk_size) + + def put_object(self, container, obj, contents, content_length=None, + etag=None, chunk_size=65536, content_type=None, + headers=None): + """Wrapper for :func:`put_object`""" + + def _default_reset(*args, **kwargs): + raise ClientException('put_object(%r, %r, ...) failure and no ' + 'ability to reset contents for reupload.' + % (container, obj)) + + reset_func = _default_reset + tell = getattr(contents, 'tell', None) + seek = getattr(contents, 'seek', None) + if tell and seek: + orig_pos = tell() + reset_func = lambda *a, **k: seek(orig_pos) + elif not contents: + reset_func = lambda *a, **k: None + + return self._retry(reset_func, put_object, container, obj, contents, + content_length=content_length, etag=etag, + chunk_size=chunk_size, content_type=content_type, + headers=headers) + + def post_object(self, container, obj, headers): + """Wrapper for :func:`post_object`""" + return self._retry(None, post_object, container, obj, headers) + + def delete_object(self, container, obj): + """Wrapper for :func:`delete_object`""" + return self._retry(None, delete_object, container, obj) diff --git a/src/leap/soledad/swiftclient/openstack/__init__.py b/src/leap/soledad/swiftclient/openstack/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/soledad/swiftclient/openstack/__init__.py diff --git a/src/leap/soledad/swiftclient/openstack/common/__init__.py b/src/leap/soledad/swiftclient/openstack/common/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/soledad/swiftclient/openstack/common/__init__.py diff --git a/src/leap/soledad/swiftclient/openstack/common/setup.py b/src/leap/soledad/swiftclient/openstack/common/setup.py new file mode 100644 index 00000000..caf06fa5 --- /dev/null +++ b/src/leap/soledad/swiftclient/openstack/common/setup.py @@ -0,0 +1,342 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Utilities with minimum-depends for use in setup.py +""" + +import datetime +import os +import re +import subprocess +import sys + +from setuptools.command import sdist + + +def parse_mailmap(mailmap='.mailmap'): + mapping = {} + if os.path.exists(mailmap): + fp = open(mailmap, 'r') + for l in fp: + l = l.strip() + if not l.startswith('#') and ' ' in l: + canonical_email, alias = l.split(' ') + mapping[alias] = canonical_email + return mapping + + +def canonicalize_emails(changelog, mapping): + """Takes in a string and an email alias mapping and replaces all + instances of the aliases in the string with their real email. + """ + for alias, email in mapping.iteritems(): + changelog = changelog.replace(alias, email) + return changelog + + +# Get requirements from the first file that exists +def get_reqs_from_files(requirements_files): + reqs_in = [] + for requirements_file in requirements_files: + if os.path.exists(requirements_file): + return open(requirements_file, 'r').read().split('\n') + return [] + + +def parse_requirements(requirements_files=['requirements.txt', + 'tools/pip-requires']): + requirements = [] + for line in get_reqs_from_files(requirements_files): + # For the requirements list, we need to inject only the portion + # after egg= so that distutils knows the package it's looking for + # such as: + # -e git://github.com/openstack/nova/master#egg=nova + if re.match(r'\s*-e\s+', line): + requirements.append(re.sub(r'\s*-e\s+.*#egg=(.*)$', r'\1', + line)) + # such as: + # http://github.com/openstack/nova/zipball/master#egg=nova + elif re.match(r'\s*https?:', line): + requirements.append(re.sub(r'\s*https?:.*#egg=(.*)$', r'\1', + line)) + # -f lines are for index locations, and don't get used here + elif re.match(r'\s*-f\s+', line): + pass + # argparse is part of the standard library starting with 2.7 + # adding it to the requirements list screws distro installs + elif line == 'argparse' and sys.version_info >= (2, 7): + pass + else: + requirements.append(line) + + return requirements + + +def parse_dependency_links(requirements_files=['requirements.txt', + 'tools/pip-requires']): + dependency_links = [] + # dependency_links inject alternate locations to find packages listed + # in requirements + for line in get_reqs_from_files(requirements_files): + # skip comments and blank lines + if re.match(r'(\s*#)|(\s*$)', line): + continue + # lines with -e or -f need the whole line, minus the flag + if re.match(r'\s*-[ef]\s+', line): + dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line)) + # lines that are only urls can go in unmolested + elif re.match(r'\s*https?:', line): + dependency_links.append(line) + return dependency_links + + +def write_requirements(): + venv = os.environ.get('VIRTUAL_ENV', None) + if venv is not None: + with open("requirements.txt", "w") as req_file: + output = subprocess.Popen(["pip", "-E", venv, "freeze", "-l"], + stdout=subprocess.PIPE) + requirements = output.communicate()[0].strip() + req_file.write(requirements) + + +def _run_shell_command(cmd): + output = subprocess.Popen(["/bin/sh", "-c", cmd], + stdout=subprocess.PIPE) + out = output.communicate() + if len(out) == 0: + return None + if len(out[0].strip()) == 0: + return None + return out[0].strip() + + +def _get_git_next_version_suffix(branch_name): + datestamp = datetime.datetime.now().strftime('%Y%m%d') + if branch_name == 'milestone-proposed': + revno_prefix = "r" + else: + revno_prefix = "" + _run_shell_command("git fetch origin +refs/meta/*:refs/remotes/meta/*") + milestone_cmd = "git show meta/openstack/release:%s" % branch_name + milestonever = _run_shell_command(milestone_cmd) + if not milestonever: + milestonever = "" + post_version = _get_git_post_version() + revno = post_version.split(".")[-1] + return "%s~%s.%s%s" % (milestonever, datestamp, revno_prefix, revno) + + +def _get_git_current_tag(): + return _run_shell_command("git tag --contains HEAD") + + +def _get_git_tag_info(): + return _run_shell_command("git describe --tags") + + +def _get_git_post_version(): + current_tag = _get_git_current_tag() + if current_tag is not None: + return current_tag + else: + tag_info = _get_git_tag_info() + if tag_info is None: + base_version = "0.0" + cmd = "git --no-pager log --oneline" + out = _run_shell_command(cmd) + revno = len(out.split("\n")) + else: + tag_infos = tag_info.split("-") + base_version = "-".join(tag_infos[:-2]) + revno = tag_infos[-2] + return "%s.%s" % (base_version, revno) + + +def write_git_changelog(): + """Write a changelog based on the git changelog.""" + if os.path.isdir('.git'): + git_log_cmd = 'git log --stat' + changelog = _run_shell_command(git_log_cmd) + mailmap = parse_mailmap() + with open("ChangeLog", "w") as changelog_file: + changelog_file.write(canonicalize_emails(changelog, mailmap)) + + +def generate_authors(): + """Create AUTHORS file using git commits.""" + jenkins_email = 'jenkins@review.openstack.org' + old_authors = 'AUTHORS.in' + new_authors = 'AUTHORS' + if os.path.isdir('.git'): + # don't include jenkins email address in AUTHORS file + git_log_cmd = ("git log --format='%aN <%aE>' | sort -u | " + "grep -v " + jenkins_email) + changelog = _run_shell_command(git_log_cmd) + mailmap = parse_mailmap() + with open(new_authors, 'w') as new_authors_fh: + new_authors_fh.write(canonicalize_emails(changelog, mailmap)) + if os.path.exists(old_authors): + with open(old_authors, "r") as old_authors_fh: + new_authors_fh.write('\n' + old_authors_fh.read()) + +_rst_template = """%(heading)s +%(underline)s + +.. automodule:: %(module)s + :members: + :undoc-members: + :show-inheritance: +""" + + +def read_versioninfo(project): + """Read the versioninfo file. If it doesn't exist, we're in a github + zipball, and there's really know way to know what version we really + are, but that should be ok, because the utility of that should be + just about nil if this code path is in use in the first place.""" + versioninfo_path = os.path.join(project, 'versioninfo') + if os.path.exists(versioninfo_path): + with open(versioninfo_path, 'r') as vinfo: + version = vinfo.read().strip() + else: + version = "0.0.0" + return version + + +def write_versioninfo(project, version): + """Write a simple file containing the version of the package.""" + open(os.path.join(project, 'versioninfo'), 'w').write("%s\n" % version) + + +def get_cmdclass(): + """Return dict of commands to run from setup.py.""" + + cmdclass = dict() + + def _find_modules(arg, dirname, files): + for filename in files: + if filename.endswith('.py') and filename != '__init__.py': + arg["%s.%s" % (dirname.replace('/', '.'), + filename[:-3])] = True + + class LocalSDist(sdist.sdist): + """Builds the ChangeLog and Authors files from VC first.""" + + def run(self): + write_git_changelog() + generate_authors() + # sdist.sdist is an old style class, can't use super() + sdist.sdist.run(self) + + cmdclass['sdist'] = LocalSDist + + # If Sphinx is installed on the box running setup.py, + # enable setup.py to build the documentation, otherwise, + # just ignore it + try: + from sphinx.setup_command import BuildDoc + + class LocalBuildDoc(BuildDoc): + def generate_autoindex(self): + print "**Autodocumenting from %s" % os.path.abspath(os.curdir) + modules = {} + option_dict = self.distribution.get_option_dict('build_sphinx') + source_dir = os.path.join(option_dict['source_dir'][1], 'api') + if not os.path.exists(source_dir): + os.makedirs(source_dir) + for pkg in self.distribution.packages: + if '.' not in pkg: + os.path.walk(pkg, _find_modules, modules) + module_list = modules.keys() + module_list.sort() + autoindex_filename = os.path.join(source_dir, 'autoindex.rst') + with open(autoindex_filename, 'w') as autoindex: + autoindex.write(""".. toctree:: + :maxdepth: 1 + +""") + for module in module_list: + output_filename = os.path.join(source_dir, + "%s.rst" % module) + heading = "The :mod:`%s` Module" % module + underline = "=" * len(heading) + values = dict(module=module, heading=heading, + underline=underline) + + print "Generating %s" % output_filename + with open(output_filename, 'w') as output_file: + output_file.write(_rst_template % values) + autoindex.write(" %s.rst\n" % module) + + def run(self): + if not os.getenv('SPHINX_DEBUG'): + self.generate_autoindex() + + for builder in ['html', 'man']: + self.builder = builder + self.finalize_options() + self.project = self.distribution.get_name() + self.version = self.distribution.get_version() + self.release = self.distribution.get_version() + BuildDoc.run(self) + cmdclass['build_sphinx'] = LocalBuildDoc + except ImportError: + pass + + return cmdclass + + +def get_git_branchname(): + for branch in _run_shell_command("git branch --color=never").split("\n"): + if branch.startswith('*'): + _branch_name = branch.split()[1].strip() + if _branch_name == "(no": + _branch_name = "no-branch" + return _branch_name + + +def get_pre_version(projectname, base_version): + """Return a version which is based""" + if os.path.isdir('.git'): + current_tag = _get_git_current_tag() + if current_tag is not None: + version = current_tag + else: + branch_name = os.getenv('BRANCHNAME', + os.getenv('GERRIT_REFNAME', + get_git_branchname())) + version_suffix = _get_git_next_version_suffix(branch_name) + version = "%s~%s" % (base_version, version_suffix) + write_versioninfo(projectname, version) + return version.split('~')[0] + else: + version = read_versioninfo(projectname) + return version.split('~')[0] + + +def get_post_version(projectname): + """Return a version which is equal to the tag that's on the current + revision if there is one, or tag plus number of additional revisions + if the current revision has no tag.""" + + if os.path.isdir('.git'): + version = _get_git_post_version() + write_versioninfo(projectname, version) + return version + return read_versioninfo(projectname) diff --git a/src/leap/soledad/swiftclient/versioninfo b/src/leap/soledad/swiftclient/versioninfo new file mode 100644 index 00000000..524cb552 --- /dev/null +++ b/src/leap/soledad/swiftclient/versioninfo @@ -0,0 +1 @@ +1.1.1 diff --git a/src/leap/soledad/u1db/__init__.py b/src/leap/soledad/u1db/__init__.py new file mode 100644 index 00000000..ed41bb03 --- /dev/null +++ b/src/leap/soledad/u1db/__init__.py @@ -0,0 +1,697 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""U1DB""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db.errors import InvalidJSON, InvalidContent + +__version_info__ = (0, 1, 4) +__version__ = '.'.join(map(str, __version_info__)) + + +def open(path, create, document_factory=None): + """Open a database at the given location. + + Will raise u1db.errors.DatabaseDoesNotExist if create=False and the + database does not already exist. + + :param path: The filesystem path for the database to open. + :param create: True/False, should the database be created if it doesn't + already exist? + :param document_factory: A function that will be called with the same + parameters as Document.__init__. + :return: An instance of Database. + """ + from u1db.backends import sqlite_backend + return sqlite_backend.SQLiteDatabase.open_database( + path, create=create, document_factory=document_factory) + + +# constraints on database names (relevant for remote access, as regex) +DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*" + +# constraints on doc ids (as regex) +# (no slashes, and no characters outside the ascii range) +DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+" + + +class Database(object): + """A JSON Document data store. + + This data store can be synchronized with other u1db.Database instances. + """ + + def set_document_factory(self, factory): + """Set the document factory that will be used to create objects to be + returned as documents by the database. + + :param factory: A function that returns an object which at minimum must + satisfy the same interface as does the class DocumentBase. + Subclassing that class is the easiest way to create such + a function. + """ + raise NotImplementedError(self.set_document_factory) + + def set_document_size_limit(self, limit): + """Set the maximum allowed document size for this database. + + :param limit: Maximum allowed document size in bytes. + """ + raise NotImplementedError(self.set_document_size_limit) + + def whats_changed(self, old_generation=0): + """Return a list of documents that have changed since old_generation. + This allows APPS to only store a db generation before going + 'offline', and then when coming back online they can use this + data to update whatever extra data they are storing. + + :param old_generation: The generation of the database in the old + state. + :return: (generation, trans_id, [(doc_id, generation, trans_id),...]) + The current generation of the database, its associated transaction + id, and a list of of changed documents since old_generation, + represented by tuples with for each document its doc_id and the + generation and transaction id corresponding to the last intervening + change and sorted by generation (old changes first) + """ + raise NotImplementedError(self.whats_changed) + + def get_doc(self, doc_id, include_deleted=False): + """Get the JSON string for the given document. + + :param doc_id: The unique document identifier + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise asking for a deleted + document will return None. + :return: a Document object. + """ + raise NotImplementedError(self.get_doc) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + """Get the JSON content for many documents. + + :param doc_ids: A list of document identifiers. + :param check_for_conflicts: If set to False, then the conflict check + will be skipped, and 'None' will be returned instead of True/False. + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: iterable giving the Document object for each document id + in matching doc_ids order. + """ + raise NotImplementedError(self.get_docs) + + def get_all_docs(self, include_deleted=False): + """Get the JSON content for all documents in the database. + + :param include_deleted: If set to True, deleted documents will be + returned with empty content. Otherwise deleted documents will not + be included in the results. + :return: (generation, [Document]) + The current generation of the database, followed by a list of all + the documents in the database. + """ + raise NotImplementedError(self.get_all_docs) + + def create_doc(self, content, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param content: A Python dictionary. + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc) + + def create_doc_from_json(self, json, doc_id=None): + """Create a new document. + + You can optionally specify the document identifier, but the document + must not already exist. See 'put_doc' if you want to override an + existing document. + If the database specifies a maximum document size and the document + exceeds it, create will fail and raise a DocumentTooBig exception. + + :param json: The JSON document string + :param doc_id: An optional identifier specifying the document id. + :return: Document + """ + raise NotImplementedError(self.create_doc_from_json) + + def put_doc(self, doc): + """Update a document. + If the document currently has conflicts, put will fail. + If the database specifies a maximum document size and the document + exceeds it, put will fail and raise a DocumentTooBig exception. + + :param doc: A Document with new content. + :return: new_doc_rev - The new revision identifier for the document. + The Document object will also be updated. + """ + raise NotImplementedError(self.put_doc) + + def delete_doc(self, doc): + """Mark a document as deleted. + Will abort if the current revision doesn't match doc.rev. + This will also set doc.content to None. + """ + raise NotImplementedError(self.delete_doc) + + def create_index(self, index_name, *index_expressions): + """Create an named index, which can then be queried for future lookups. + Creating an index which already exists is not an error, and is cheap. + Creating an index which does not match the index_expressions of the + existing index is an error. + Creating an index will block until the expressions have been evaluated + and the index generated. + + :param index_name: A unique name which can be used as a key prefix + :param index_expressions: index expressions defining the index + information. + + Examples: + + "fieldname", or "fieldname.subfieldname" to index alphabetically + sorted on the contents of a field. + + "number(fieldname, width)", "lower(fieldname)" + """ + raise NotImplementedError(self.create_index) + + def delete_index(self, index_name): + """Remove a named index. + + :param index_name: The name of the index we are removing + """ + raise NotImplementedError(self.delete_index) + + def list_indexes(self): + """List the definitions of all known indexes. + + :return: A list of [('index-name', ['field', 'field2'])] definitions. + """ + raise NotImplementedError(self.list_indexes) + + def get_from_index(self, index_name, *key_values): + """Return documents that match the keys supplied. + + You must supply exactly the same number of values as have been defined + in the index. It is possible to do a prefix match by using '*' to + indicate a wildcard match. You can only supply '*' to trailing entries, + (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) + It is also possible to append a '*' to the last supplied value (eg + 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param key_values: values to match. eg, if you have + an index with 3 fields then you would have: + get_from_index(index_name, val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_from_index) + + def get_range_from_index(self, index_name, start_value, end_value): + """Return documents that fall within the specified range. + + Both ends of the range are inclusive. For both start_value and + end_value, one must supply exactly the same number of values as have + been defined in the index, or pass None. In case of a single column + index, a string is accepted as an alternative for a tuple with a single + value. It is possible to do a prefix match by using '*' to indicate + a wildcard match. You can only supply '*' to trailing entries, (eg + 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also + possible to append a '*' to the last supplied value (eg 'val*', '*', + '*' or 'val', 'val*', '*', but not 'val*', 'val', '*') + + :param index_name: The index to query + :param start_values: tuples of values that define the lower bound of + the range. eg, if you have an index with 3 fields then you would + have: (val1, val2, val3) + :param end_values: tuples of values that define the upper bound of the + range. eg, if you have an index with 3 fields then you would have: + (val1, val2, val3) + :return: List of [Document] + """ + raise NotImplementedError(self.get_range_from_index) + + def get_index_keys(self, index_name): + """Return all keys under which documents are indexed in this index. + + :param index_name: The index to query + :return: [] A list of tuples of indexed keys. + """ + raise NotImplementedError(self.get_index_keys) + + def get_doc_conflicts(self, doc_id): + """Get the list of conflicts for the given document. + + The order of the conflicts is such that the first entry is the value + that would be returned by "get_doc". + + :return: [doc] A list of the Document entries that are conflicted. + """ + raise NotImplementedError(self.get_doc_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + """Mark a document as no longer conflicted. + + We take the list of revisions that the client knows about that it is + superseding. This may be a different list from the actual current + conflicts, in which case only those are removed as conflicted. This + may fail if the conflict list is significantly different from the + supplied information. (sync could have happened in the background from + the time you GET_DOC_CONFLICTS until the point where you RESOLVE) + + :param doc: A Document with the new content to be inserted. + :param conflicted_doc_revs: A list of revisions that the new content + supersedes. + """ + raise NotImplementedError(self.resolve_doc) + + def get_sync_target(self): + """Return a SyncTarget object, for another u1db to synchronize with. + + :return: An instance of SyncTarget. + """ + raise NotImplementedError(self.get_sync_target) + + def close(self): + """Release any resources associated with this database.""" + raise NotImplementedError(self.close) + + def sync(self, url, creds=None, autocreate=True): + """Synchronize documents with remote replica exposed at url. + + :param url: the url of the target replica to sync with. + :param creds: optional dictionary giving credentials + to authorize the operation with the server. For using OAuth + the form of creds is: + {'oauth': { + 'consumer_key': ..., + 'consumer_secret': ..., + 'token_key': ..., + 'token_secret': ... + }} + :param autocreate: ask the target to create the db if non-existent. + :return: local_gen_before_sync The local generation before the + synchronisation was performed. This is useful to pass into + whatschanged, if an application wants to know which documents were + affected by a synchronisation. + """ + from u1db.sync import Synchronizer + from u1db.remote.http_target import HTTPSyncTarget + return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync( + autocreate=autocreate) + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + """Return the last known generation and transaction id for the other db + replica. + + When you do a synchronization with another replica, the Database keeps + track of what generation the other database replica was at, and what + the associated transaction id was. This is used to determine what data + needs to be sent, and if two databases are claiming to be the same + replica. + + :param other_replica_uid: The identifier for the other replica. + :return: (gen, trans_id) The generation and transaction id we + encountered during synchronization. If we've never synchronized + with the replica, this is (0, ''). + """ + raise NotImplementedError(self._get_replica_gen_and_trans_id) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + """Set the last-known generation and transaction id for the other + database replica. + + We have just performed some synchronization, and we want to track what + generation the other replica was at. See also + _get_replica_gen_and_trans_id. + :param other_replica_uid: The U1DB identifier for the other replica. + :param other_generation: The generation number for the other replica. + :param other_transaction_id: The transaction id associated with the + generation. + """ + raise NotImplementedError(self._set_replica_gen_and_trans_id) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + """Insert/update document into the database with a given revision. + + This api is used during synchronization operations. + + If a document would conflict and save_conflict is set to True, the + content will be selected as the 'current' content for doc.doc_id, + even though doc.rev doesn't supersede the currently stored revision. + The currently stored document will be added to the list of conflict + alternatives for the given doc_id. + + This forces the new content to be 'current' so that we get convergence + after synchronizing, even if people don't resolve conflicts. Users can + then notice that their content is out of date, update it, and + synchronize again. (The alternative is that users could synchronize and + think the data has propagated, but their local copy looks fine, and the + remote copy is never updated again.) + + :param doc: A Document object + :param save_conflict: If this document is a conflict, do you want to + save it as a conflict, or just ignore it. + :param replica_uid: A unique replica identifier. + :param replica_gen: The generation of the replica corresponding to the + this document. The replica arguments are optional, but are used + during synchronization. + :param replica_trans_id: The transaction_id associated with the + generation. + :return: (state, at_gen) - If we don't have doc_id already, + or if doc_rev supersedes the existing document revision, + then the content will be inserted, and state is 'inserted'. + If doc_rev is less than or equal to the existing revision, + then the put is ignored and state is respecitvely 'superseded' + or 'converged'. + If doc_rev is not strictly superseded or supersedes, then + state is 'conflicted'. The document will not be inserted if + save_conflict is False. + For 'inserted' or 'converged', at_gen is the insertion/current + generation. + """ + raise NotImplementedError(self._put_doc_if_newer) + + +class DocumentBase(object): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json_string: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + def __init__(self, doc_id, rev, json_string, has_conflicts=False): + self.doc_id = doc_id + self.rev = rev + if json_string is not None: + try: + value = json.loads(json_string) + except json.JSONDecodeError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + self.has_conflicts = has_conflicts + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = None + if other._json: + c2 = json.loads(other._json) + else: + c2 = None + return c1 == c2 + + def __repr__(self): + if self.has_conflicts: + extra = ', conflicted' + else: + extra = '' + return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, + self.rev, extra, self.get_json()) + + def __hash__(self): + raise NotImplementedError(self.__hash__) + + def __eq__(self, other): + if not isinstance(other, Document): + return NotImplemented + return ( + self.doc_id == other.doc_id and self.rev == other.rev and + self.same_content_as(other) and self.has_conflicts == + other.has_conflicts) + + def __lt__(self, other): + """This is meant for testing, not part of the official api. + + It is implemented so that sorted([Document, Document]) can be used. + It doesn't imply that users would want their documents to be sorted in + this order. + """ + # Since this is just for testing, we don't worry about comparing + # against things that aren't a Document. + return ((self.doc_id, self.rev, self.get_json()) + < (other.doc_id, other.rev, other.get_json())) + + def get_json(self): + """Get the json serialization of this document.""" + if self._json is not None: + return self._json + return None + + def get_size(self): + """Calculate the total size of the document.""" + size = 0 + json = self.get_json() + if json: + size += len(json) + if self.rev: + size += len(self.rev) + if self.doc_id: + size += len(self.doc_id) + return size + + def set_json(self, json_string): + """Set the json serialization of this document.""" + if json_string is not None: + try: + value = json.loads(json_string) + except json.JSONDecodeError: + raise InvalidJSON + if not isinstance(value, dict): + raise InvalidJSON + self._json = json_string + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._json = None + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._json is not None: + return False + return True + + +class Document(DocumentBase): + """Container for handling a single document. + + :ivar doc_id: Unique identifier for this document. + :ivar rev: The revision identifier of the document. + :ivar json: The JSON string for this document. + :ivar has_conflicts: Boolean indicating if this document has conflicts + """ + + # The following part of the API is optional: no implementation is forced to + # have it but if the language supports dictionaries/hashtables, it makes + # Documents a lot more user friendly. + + def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False): + # TODO: We convert the json in the superclass to check its validity so + # we might as well set _content here directly since the price is + # already being paid. + super(Document, self).__init__(doc_id, rev, json, has_conflicts) + self._content = None + + def same_content_as(self, other): + """Compare the content of two documents.""" + if self._json: + c1 = json.loads(self._json) + else: + c1 = self._content + if other._json: + c2 = json.loads(other._json) + else: + c2 = other._content + return c1 == c2 + + def get_json(self): + """Get the json serialization of this document.""" + json_string = super(Document, self).get_json() + if json_string is not None: + return json_string + if self._content is not None: + return json.dumps(self._content) + return None + + def set_json(self, json): + """Set the json serialization of this document.""" + self._content = None + super(Document, self).set_json(json) + + def make_tombstone(self): + """Make this document into a tombstone.""" + self._content = None + super(Document, self).make_tombstone() + + def is_tombstone(self): + """Return True if the document is a tombstone, False otherwise.""" + if self._content is not None: + return False + return super(Document, self).is_tombstone() + + def _get_content(self): + """Get the dictionary representing this document.""" + if self._json is not None: + self._content = json.loads(self._json) + self._json = None + if self._content is not None: + return self._content + return None + + def _set_content(self, content): + """Set the dictionary representing this document.""" + try: + tmp = json.dumps(content) + except TypeError: + raise InvalidContent( + "Can not be converted to JSON: %r" % (content,)) + if not tmp.startswith('{'): + raise InvalidContent( + "Can not be converted to a JSON object: %r." % (content,)) + # We might as well store the JSON at this point since we did the work + # of encoding it, and it doesn't lose any information. + self._json = tmp + self._content = None + + content = property( + _get_content, _set_content, doc="Content of the Document.") + + # End of optional part. + + +class SyncTarget(object): + """Functionality for using a Database as a synchronization target.""" + + def get_sync_info(self, source_replica_uid): + """Return information about known state. + + Return the replica_uid and the current database generation of this + database, and the last-seen database generation for source_replica_uid + + :param source_replica_uid: Another replica which we might have + synchronized with in the past. + :return: (target_replica_uid, target_replica_generation, + target_trans_id, source_replica_last_known_generation, + source_replica_last_known_transaction_id) + """ + raise NotImplementedError(self.get_sync_info) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + """Record tip information for another replica. + + After sync_exchange has been processed, the caller will have + received new content from this replica. This call allows the + source replica instigating the sync to inform us what their + generation became after applying the documents we returned. + + This is used to allow future sync operations to not need to repeat data + that we just talked about. It also means that if this is called at the + wrong time, there can be database records that will never be + synchronized. + + :param source_replica_uid: The identifier for the source replica. + :param source_replica_generation: + The database generation for the source replica. + :param source_replica_transaction_id: The transaction id associated + with the source replica generation. + """ + raise NotImplementedError(self.record_sync_info) + + def sync_exchange(self, docs_by_generation, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + """Incorporate the documents sent from the source replica. + + This is not meant to be called by client code directly, but is used as + part of sync(). + + This adds docs to the local store, and determines documents that need + to be returned to the source replica. + + Documents must be supplied in docs_by_generation paired with + the generation of their latest change in order from the oldest + change to the newest, that means from the oldest generation to + the newest. + + Documents are also returned paired with the generation of + their latest change in order from the oldest change to the + newest. + + :param docs_by_generation: A list of [(Document, generation, + transaction_id)] tuples indicating documents which should be + updated on this replica paired with the generation and transaction + id of their latest change. + :param source_replica_uid: The source replica's identifier + :param last_known_generation: The last generation that the source + replica knows about this target replica + :param last_known_trans_id: The last transaction id that the source + replica knows about this target replica + :param: return_doc_cb(doc, gen): is a callback + used to return documents to the source replica, it will + be invoked in turn with Documents that have changed since + last_known_generation together with the generation of + their last change. + :param: ensure_callback(replica_uid): if set the target may create + the target db if not yet existent, the callback can then + be used to inform of the created db replica uid. + :return: new_generation - After applying docs_by_generation, this is + the current generation for this replica + """ + raise NotImplementedError(self.sync_exchange) + + def _set_trace_hook(self, cb): + """Set a callback that will be invoked to trace database actions. + + The callback will be passed a string indicating the current state, and + the sync target object. Implementations do not have to implement this + api, it is used by the test suite. + + :param cb: A callable that takes cb(state) + """ + raise NotImplementedError(self._set_trace_hook) + + def _set_trace_hook_shallow(self, cb): + """Set a callback that will be invoked to trace database actions. + + Similar to _set_trace_hook, for implementations that don't offer + state changes from the inner working of sync_exchange(). + + :param cb: A callable that takes cb(state) + """ + self._set_trace_hook(cb) diff --git a/src/leap/soledad/u1db/backends/__init__.py b/src/leap/soledad/u1db/backends/__init__.py new file mode 100644 index 00000000..c8e5adc6 --- /dev/null +++ b/src/leap/soledad/u1db/backends/__init__.py @@ -0,0 +1,211 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Abstract classes and common implementations for the backends.""" + +import re +try: + import simplejson as json +except ImportError: + import json # noqa +import uuid + +import u1db +from u1db import ( + errors, +) +import u1db.sync +from u1db.vectorclock import VectorClockRev + + +check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE) + + +class CommonSyncTarget(u1db.sync.LocalSyncTarget): + pass + + +class CommonBackend(u1db.Database): + + document_size_limit = 0 + + def _allocate_doc_id(self): + """Generate a unique identifier for this document.""" + return 'D-' + uuid.uuid4().hex # 'D-' stands for document + + def _allocate_transaction_id(self): + return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction + + def _allocate_doc_rev(self, old_doc_rev): + vcr = VectorClockRev(old_doc_rev) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def _check_doc_id(self, doc_id): + if not check_doc_id_re.match(doc_id): + raise errors.InvalidDocId() + + def _check_doc_size(self, doc): + if not self.document_size_limit: + return + if doc.get_size() > self.document_size_limit: + raise errors.DocumentTooBig + + def _get_generation(self): + """Return the current generation. + + """ + raise NotImplementedError(self._get_generation) + + def _get_generation_info(self): + """Return the current generation and transaction id. + + """ + raise NotImplementedError(self._get_generation_info) + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Extract the document from storage. + + This can return None if the document doesn't exist. + """ + raise NotImplementedError(self._get_doc) + + def _has_conflicts(self, doc_id): + """Return True if the doc has conflicts, False otherwise.""" + raise NotImplementedError(self._has_conflicts) + + def create_doc(self, content, doc_id=None): + json_string = json.dumps(content) + if doc_id is None: + doc_id = self._allocate_doc_id() + doc = self._factory(doc_id, None, json_string) + self.put_doc(doc) + return doc + + def create_doc_from_json(self, json, doc_id=None): + if doc_id is None: + doc_id = self._allocate_doc_id() + doc = self._factory(doc_id, None, json) + self.put_doc(doc) + return doc + + def _get_transaction_log(self): + """This is only for the test suite, it is not part of the api.""" + raise NotImplementedError(self._get_transaction_log) + + def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content): + raise NotImplementedError(self._put_and_update_indexes) + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + for doc_id in doc_ids: + doc = self._get_doc( + doc_id, check_for_conflicts=check_for_conflicts) + if doc.is_tombstone() and not include_deleted: + continue + yield doc + + def _get_trans_id_for_gen(self, generation): + """Get the transaction id corresponding to a particular generation. + + Raises an InvalidGeneration when the generation does not exist. + + """ + raise NotImplementedError(self._get_trans_id_for_gen) + + def validate_gen_and_trans_id(self, generation, trans_id): + """Validate the generation and transaction id. + + Raises an InvalidGeneration when the generation does not exist, and an + InvalidTransactionId when it does but with a different transaction id. + + """ + if generation == 0: + return + known_trans_id = self._get_trans_id_for_gen(generation) + if known_trans_id != trans_id: + raise errors.InvalidTransactionId + + def _validate_source(self, other_replica_uid, other_generation, + other_transaction_id): + """Validate the new generation and transaction id. + + other_generation must be greater than what we have stored for this + replica, *or* it must be the same and the transaction_id must be the + same as well. + """ + (old_generation, + old_transaction_id) = self._get_replica_gen_and_trans_id( + other_replica_uid) + if other_generation < old_generation: + raise errors.InvalidGeneration + if other_generation > old_generation: + return + if other_transaction_id == old_transaction_id: + return + raise errors.InvalidTransactionId + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen, + replica_trans_id=''): + cur_doc = self._get_doc(doc.doc_id) + doc_vcr = VectorClockRev(doc.rev) + if cur_doc is None: + cur_vcr = VectorClockRev(None) + else: + cur_vcr = VectorClockRev(cur_doc.rev) + self._validate_source(replica_uid, replica_gen, replica_trans_id) + if doc_vcr.is_newer(cur_vcr): + rev = doc.rev + self._prune_conflicts(doc, doc_vcr) + if doc.rev != rev: + # conflicts have been autoresolved + state = 'superseded' + else: + state = 'inserted' + self._put_and_update_indexes(cur_doc, doc) + elif doc.rev == cur_doc.rev: + # magical convergence + state = 'converged' + elif cur_vcr.is_newer(doc_vcr): + # Don't add this to seen_ids, because we have something newer, + # so we should send it back, and we should not generate a + # conflict + state = 'superseded' + elif cur_doc.same_content_as(doc): + # the documents have been edited to the same thing at both ends + doc_vcr.maximize(cur_vcr) + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._put_and_update_indexes(cur_doc, doc) + state = 'superseded' + else: + state = 'conflicted' + if save_conflict: + self._force_doc_sync_conflict(doc) + if replica_uid is not None and replica_gen is not None: + self._do_set_replica_gen_and_trans_id( + replica_uid, replica_gen, replica_trans_id) + return state, self._get_generation() + + def _ensure_maximal_rev(self, cur_rev, extra_revs): + vcr = VectorClockRev(cur_rev) + for rev in extra_revs: + vcr.maximize(VectorClockRev(rev)) + vcr.increment(self._replica_uid) + return vcr.as_str() + + def set_document_size_limit(self, limit): + self.document_size_limit = limit diff --git a/src/leap/soledad/u1db/backends/dbschema.sql b/src/leap/soledad/u1db/backends/dbschema.sql new file mode 100644 index 00000000..ae027fc5 --- /dev/null +++ b/src/leap/soledad/u1db/backends/dbschema.sql @@ -0,0 +1,42 @@ +-- Database schema +CREATE TABLE transaction_log ( + generation INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT NOT NULL, + transaction_id TEXT NOT NULL +); +CREATE TABLE document ( + doc_id TEXT PRIMARY KEY, + doc_rev TEXT NOT NULL, + content TEXT +); +CREATE TABLE document_fields ( + doc_id TEXT NOT NULL, + field_name TEXT NOT NULL, + value TEXT +); +CREATE INDEX document_fields_field_value_doc_idx + ON document_fields(field_name, value, doc_id); + +CREATE TABLE sync_log ( + replica_uid TEXT PRIMARY KEY, + known_generation INTEGER, + known_transaction_id TEXT +); +CREATE TABLE conflicts ( + doc_id TEXT, + doc_rev TEXT, + content TEXT, + CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev) +); +CREATE TABLE index_definitions ( + name TEXT, + offset INT, + field TEXT, + CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset) +); +create index index_definitions_field on index_definitions(field); +CREATE TABLE u1db_config ( + name TEXT PRIMARY KEY, + value TEXT +); +INSERT INTO u1db_config VALUES ('sql_schema', '0'); diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py new file mode 100644 index 00000000..a271bb37 --- /dev/null +++ b/src/leap/soledad/u1db/backends/inmemory.py @@ -0,0 +1,469 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The in-memory Database class for U1DB.""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) +from u1db.backends import CommonBackend, CommonSyncTarget + + +def get_prefix(value): + key_prefix = '\x01'.join(value) + return key_prefix.rstrip('*') + + +class InMemoryDatabase(CommonBackend): + """A database that only stores the data internally.""" + + def __init__(self, replica_uid, document_factory=None): + self._transaction_log = [] + self._docs = {} + # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner' + self._conflicts = {} + self._other_generations = {} + self._indexes = {} + self._replica_uid = replica_uid + self._factory = document_factory or Document + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + self._replica_uid = replica_uid + + def set_document_factory(self, factory): + self._factory = factory + + def close(self): + # This is a no-op, We don't want to free the data because one client + # may be closing it, while another wants to inspect the results. + pass + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + return self._other_generations.get(other_replica_uid, (0, '')) + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + # TODO: to handle race conditions, we may want to check if the current + # value is greater than this new value. + self._other_generations[other_replica_uid] = (other_generation, + other_transaction_id) + + def get_sync_target(self): + return InMemorySyncTarget(self) + + def _get_transaction_log(self): + # snapshot! + return self._transaction_log[:] + + def _get_generation(self): + return len(self._transaction_log) + + def _get_generation_info(self): + if not self._transaction_log: + return 0, '' + return len(self._transaction_log), self._transaction_log[-1][1] + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + if generation > len(self._transaction_log): + raise errors.InvalidGeneration + return self._transaction_log[generation - 1][1] + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _put_and_update_indexes(self, old_doc, doc): + for index in self._indexes.itervalues(): + if old_doc is not None and not old_doc.is_tombstone(): + index.remove_json(old_doc.doc_id, old_doc.get_json()) + if not doc.is_tombstone(): + index.add_json(doc.doc_id, doc.get_json()) + trans_id = self._allocate_transaction_id() + self._docs[doc.doc_id] = (doc.rev, doc.get_json()) + self._transaction_log.append((doc.doc_id, trans_id)) + + def _get_doc(self, doc_id, check_for_conflicts=False): + try: + doc_rev, content = self._docs[doc_id] + except KeyError: + return None + doc = self._factory(doc_id, doc_rev, content) + if check_for_conflicts: + doc.has_conflicts = (doc.doc_id in self._conflicts) + return doc + + def _has_conflicts(self, doc_id): + return doc_id in self._conflicts + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Return all documents in the database.""" + generation = self._get_generation() + results = [] + for doc_id, (doc_rev, content) in self._docs.items(): + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = self._has_conflicts(doc_id) + results.append(doc) + return (generation, results) + + def get_doc_conflicts(self, doc_id): + if doc_id not in self._conflicts: + return [] + result = [self._get_doc(doc_id)] + result[0].has_conflicts = True + result.extend([self._factory(doc_id, rev, content) + for rev, content in self._conflicts[doc_id]]) + return result + + def _replace_conflicts(self, doc, conflicts): + if not conflicts: + del self._conflicts[doc.doc_id] + else: + self._conflicts[doc.doc_id] = conflicts + doc.has_conflicts = bool(conflicts) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + c_vcr = vectorclock.VectorClockRev(c_rev) + if doc_vcr.is_newer(c_vcr): + continue + if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)): + doc_vcr.maximize(c_vcr) + autoresolved = True + continue + remaining_conflicts.append((c_rev, c_doc)) + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + self._replace_conflicts(doc, remaining_conflicts) + + def resolve_doc(self, doc, conflicted_doc_revs): + cur_doc = self._get_doc(doc.doc_id) + if cur_doc is None: + cur_rev = None + else: + cur_rev = cur_doc.rev + new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + remaining_conflicts = [] + cur_conflicts = self._conflicts[doc.doc_id] + for c_rev, c_doc in cur_conflicts: + if c_rev in superseded_revs: + continue + remaining_conflicts.append((c_rev, c_doc)) + doc.rev = new_rev + if cur_rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + remaining_conflicts.append((new_rev, doc.get_json())) + self._replace_conflicts(doc, remaining_conflicts) + + def delete_doc(self, doc): + if doc.doc_id not in self._docs: + raise errors.DocumentDoesNotExist + if self._docs[doc.doc_id][1] in ('null', None): + raise errors.DocumentAlreadyDeleted + doc.make_tombstone() + self.put_doc(doc) + + def create_index(self, index_name, *index_expressions): + if index_name in self._indexes: + if self._indexes[index_name]._definition == list( + index_expressions): + return + raise errors.IndexNameTakenError + index = InMemoryIndex(index_name, list(index_expressions)) + for doc_id, (doc_rev, doc) in self._docs.iteritems(): + if doc is not None: + index.add_json(doc_id, doc) + self._indexes[index_name] = index + + def delete_index(self, index_name): + del self._indexes[index_name] + + def list_indexes(self): + definitions = [] + for idx in self._indexes.itervalues(): + definitions.append((idx._name, idx._definition)) + return definitions + + def get_from_index(self, index_name, *key_values): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + doc_ids = index.lookup(key_values) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + if isinstance(start_value, basestring): + start_value = (start_value,) + if isinstance(end_value, basestring): + end_value = (end_value,) + doc_ids = index.lookup_range(start_value, end_value) + result = [] + for doc_id in doc_ids: + result.append(self._get_doc(doc_id, check_for_conflicts=True)) + return result + + def get_index_keys(self, index_name): + try: + index = self._indexes[index_name] + except KeyError: + raise errors.IndexDoesNotExist + keys = index.keys() + # XXX inefficiency warning + return list(set([tuple(key.split('\x01')) for key in keys])) + + def whats_changed(self, old_generation=0): + changes = [] + relevant_tail = self._transaction_log[old_generation:] + # We don't use len(self._transaction_log) because _transaction_log may + # get mutated by a concurrent operation. + cur_generation = old_generation + len(relevant_tail) + last_trans_id = '' + if relevant_tail: + last_trans_id = relevant_tail[-1][1] + elif self._transaction_log: + last_trans_id = self._transaction_log[-1][1] + seen = set() + generation = cur_generation + for doc_id, trans_id in reversed(relevant_tail): + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + generation -= 1 + changes.reverse() + return (cur_generation, last_trans_id, changes) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._conflicts.setdefault(doc.doc_id, []).append( + (my_doc.rev, my_doc.get_json())) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + +class InMemoryIndex(object): + """Interface for managing an Index.""" + + def __init__(self, index_name, index_definition): + self._name = index_name + self._definition = index_definition + self._values = {} + parser = query_parser.Parser() + self._getters = parser.parse_all(self._definition) + + def evaluate_json(self, doc): + """Determine the 'key' after applying this index to the doc.""" + raw = json.loads(doc) + return self.evaluate(raw) + + def evaluate(self, obj): + """Evaluate a dict object, applying this definition.""" + all_rows = [[]] + for getter in self._getters: + new_rows = [] + keys = getter.get(obj) + if not keys: + return [] + for key in keys: + new_rows.extend([row + [key] for row in all_rows]) + all_rows = new_rows + all_rows = ['\x01'.join(row) for row in all_rows] + return all_rows + + def add_json(self, doc_id, doc): + """Add this json doc to the index.""" + keys = self.evaluate_json(doc) + if not keys: + return + for key in keys: + self._values.setdefault(key, []).append(doc_id) + + def remove_json(self, doc_id, doc): + """Remove this json doc from the index.""" + keys = self.evaluate_json(doc) + if keys: + for key in keys: + doc_ids = self._values[key] + doc_ids.remove(doc_id) + if not doc_ids: + del self._values[key] + + def _find_non_wildcards(self, values): + """Check if this should be a wildcard match. + + Further, this will raise an exception if the syntax is improperly + defined. + + :return: The offset of the last value we need to match against. + """ + if len(values) != len(self._definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + last = 0 + for idx, val in enumerate(values): + if val.endswith('*'): + if val != '*': + # We have an 'x*' style wildcard + if is_wildcard: + # We were already in wildcard mode, so this is invalid + raise errors.InvalidGlobbing + last = idx + 1 + is_wildcard = True + else: + if is_wildcard: + # We were in wildcard mode, we can't follow that with + # non-wildcard + raise errors.InvalidGlobbing + last = idx + 1 + if not is_wildcard: + return -1 + return last + + def lookup(self, values): + """Find docs that match the values.""" + last = self._find_non_wildcards(values) + if last == -1: + return self._lookup_exact(values) + else: + return self._lookup_prefix(values[:last]) + + def lookup_range(self, start_values, end_values): + """Find docs within the range.""" + # TODO: Wildly inefficient, which is unlikely to be a problem for the + # inmemory implementation. + if start_values: + self._find_non_wildcards(start_values) + start_values = get_prefix(start_values) + if end_values: + if self._find_non_wildcards(end_values) == -1: + exact = True + else: + exact = False + end_values = get_prefix(end_values) + found = [] + for key, doc_ids in sorted(self._values.iteritems()): + if start_values and start_values > key: + continue + if end_values and end_values < key: + if exact: + break + else: + if not key.startswith(end_values): + break + found.extend(doc_ids) + return found + + def keys(self): + """Find the indexed keys.""" + return self._values.keys() + + def _lookup_prefix(self, value): + """Find docs that match the prefix string in values.""" + # TODO: We need a different data structure to make prefix style fast, + # some sort of sorted list would work, but a plain dict doesn't. + key_prefix = get_prefix(value) + all_doc_ids = [] + for key, doc_ids in sorted(self._values.iteritems()): + if key.startswith(key_prefix): + all_doc_ids.extend(doc_ids) + return all_doc_ids + + def _lookup_exact(self, value): + """Find docs that match exactly.""" + key = '\x01'.join(value) + if key in self._values: + return self._values[key] + return () + + +class InMemorySyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_transaction_id) diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py new file mode 100644 index 00000000..773213b5 --- /dev/null +++ b/src/leap/soledad/u1db/backends/sqlite_backend.py @@ -0,0 +1,926 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""A U1DB implementation that uses SQLite as its persistence layer.""" + +import errno +import os +try: + import simplejson as json +except ImportError: + import json # noqa +from sqlite3 import dbapi2 +import sys +import time +import uuid + +import pkg_resources + +from u1db.backends import CommonBackend, CommonSyncTarget +from u1db import ( + Document, + errors, + query_parser, + vectorclock, + ) + + +class SQLiteDatabase(CommonBackend): + """A U1DB implementation that uses SQLite as its persistence layer.""" + + _sqlite_registry = {} + + def __init__(self, sqlite_file, document_factory=None): + """Create a new sqlite file.""" + self._db_handle = dbapi2.connect(sqlite_file) + self._real_replica_uid = None + self._ensure_schema() + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + def get_sync_target(self): + return SQLiteSyncTarget(self) + + @classmethod + def _which_index_storage(cls, c): + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'index_storage'") + except dbapi2.OperationalError, e: + # The table does not exist yet + return None, e + else: + return c.fetchone()[0], None + + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5 + + @classmethod + def _open_database(cls, sqlite_file, document_factory=None): + if not os.path.isfile(sqlite_file): + raise errors.DatabaseDoesNotExist() + tries = 2 + while True: + # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6) + # where without re-opening the database on Windows, it + # doesn't see the transaction that was just committed + db_handle = dbapi2.connect(sqlite_file) + c = db_handle.cursor() + v, err = cls._which_index_storage(c) + db_handle.close() + if v is not None: + break + # possibly another process is initializing it, wait for it to be + # done + if tries == 0: + raise err # go for the richest error? + tries -= 1 + time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL) + return SQLiteDatabase._sqlite_registry[v]( + sqlite_file, document_factory=document_factory) + + @classmethod + def open_database(cls, sqlite_file, create, backend_cls=None, + document_factory=None): + try: + return cls._open_database( + sqlite_file, document_factory=document_factory) + except errors.DatabaseDoesNotExist: + if not create: + raise + if backend_cls is None: + # default is SQLitePartialExpandDatabase + backend_cls = SQLitePartialExpandDatabase + return backend_cls(sqlite_file, document_factory=document_factory) + + @staticmethod + def delete_database(sqlite_file): + try: + os.unlink(sqlite_file) + except OSError as ex: + if ex.errno == errno.ENOENT: + raise errors.DatabaseDoesNotExist() + raise + + @staticmethod + def register_implementation(klass): + """Register that we implement an SQLiteDatabase. + + The attribute _index_storage_value will be used as the lookup key. + """ + SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass + + def _get_sqlite_handle(self): + """Get access to the underlying sqlite database. + + This should only be used by the test suite, etc, for examining the + state of the underlying database. + """ + return self._db_handle + + def _close_sqlite_handle(self): + """Release access to the underlying sqlite database.""" + self._db_handle.close() + + def close(self): + self._close_sqlite_handle() + + def _is_initialized(self, c): + """Check if this database has been initialized.""" + c.execute("PRAGMA case_sensitive_like=ON") + try: + c.execute("SELECT value FROM u1db_config" + " WHERE name = 'sql_schema'") + except dbapi2.OperationalError: + # The table does not exist yet + val = None + else: + val = c.fetchone() + if val is not None: + return True + return False + + def _initialize(self, c): + """Create the schema in the database.""" + #read the script with sql commands + # TODO: Change how we set up the dependency. Most likely use something + # like lp:dirspec to grab the file from a common resource + # directory. Doesn't specifically need to be handled until we get + # to the point of packaging this. + schema_content = pkg_resources.resource_string( + __name__, 'dbschema.sql') + # Note: We'd like to use c.executescript() here, but it seems that + # executescript always commits, even if you set + # isolation_level = None, so if we want to properly handle + # exclusive locking and rollbacks between processes, we need + # to execute it line-by-line + for line in schema_content.split(';'): + if not line: + continue + c.execute(line) + #add extra fields + self._extra_schema_init(c) + # A unique identifier should be set for this replica. Implementations + # don't have to strictly use uuid here, but we do want the uid to be + # unique amongst all databases that will sync with each other. + # We might extend this to using something with hostname for easier + # debugging. + self._set_replica_uid_in_transaction(uuid.uuid4().hex) + c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)", + (self._index_storage_value,)) + + def _ensure_schema(self): + """Ensure that the database schema has been created.""" + old_isolation_level = self._db_handle.isolation_level + c = self._db_handle.cursor() + if self._is_initialized(c): + return + try: + # autocommit/own mgmt of transactions + self._db_handle.isolation_level = None + with self._db_handle: + # only one execution path should initialize the db + c.execute("begin exclusive") + if self._is_initialized(c): + return + self._initialize(c) + finally: + self._db_handle.isolation_level = old_isolation_level + + def _extra_schema_init(self, c): + """Add any extra fields, etc to the basic table definitions.""" + + def _parse_index_definition(self, index_field): + """Parse a field definition for an index, returning a Getter.""" + # Note: We may want to keep a Parser object around, and cache the + # Getter objects for a greater length of time. Specifically, if + # you create a bunch of indexes, and then insert 50k docs, you'll + # re-parse the indexes between puts. The time to insert the docs + # is still likely to dominate put_doc time, though. + parser = query_parser.Parser() + getter = parser.parse(index_field) + return getter + + def _update_indexes(self, doc_id, raw_doc, getters, db_cursor): + """Update document_fields for a single document. + + :param doc_id: Identifier for this document + :param raw_doc: The python dict representation of the document. + :param getters: A list of [(field_name, Getter)]. Getter.get will be + called to evaluate the index definition for this document, and the + results will be inserted into the db. + :param db_cursor: An sqlite Cursor. + :return: None + """ + values = [] + for field_name, getter in getters: + for idx_value in getter.get(raw_doc): + values.append((doc_id, field_name, idx_value)) + if values: + db_cursor.executemany( + "INSERT INTO document_fields VALUES (?, ?, ?)", values) + + def _set_replica_uid(self, replica_uid): + """Force the replica_uid to be set.""" + with self._db_handle: + self._set_replica_uid_in_transaction(replica_uid) + + def _set_replica_uid_in_transaction(self, replica_uid): + """Set the replica_uid. A transaction should already be held.""" + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO u1db_config" + " VALUES ('replica_uid', ?)", + (replica_uid,)) + self._real_replica_uid = replica_uid + + def _get_replica_uid(self): + if self._real_replica_uid is not None: + return self._real_replica_uid + c = self._db_handle.cursor() + c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'") + val = c.fetchone() + if val is None: + return None + self._real_replica_uid = val[0] + return self._real_replica_uid + + _replica_uid = property(_get_replica_uid) + + def _get_generation(self): + c = self._db_handle.cursor() + c.execute('SELECT max(generation) FROM transaction_log') + val = c.fetchone()[0] + if val is None: + return 0 + return val + + def _get_generation_info(self): + c = self._db_handle.cursor() + c.execute( + 'SELECT max(generation), transaction_id FROM transaction_log ') + val = c.fetchone() + if val[0] is None: + return(0, '') + return val + + def _get_trans_id_for_gen(self, generation): + if generation == 0: + return '' + c = self._db_handle.cursor() + c.execute( + 'SELECT transaction_id FROM transaction_log WHERE generation = ?', + (generation,)) + val = c.fetchone() + if val is None: + raise errors.InvalidGeneration + return val[0] + + def _get_transaction_log(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, transaction_id FROM transaction_log" + " ORDER BY generation") + return c.fetchall() + + def _get_doc(self, doc_id, check_for_conflicts=False): + """Get just the document content, without fancy handling.""" + c = self._db_handle.cursor() + if check_for_conflicts: + c.execute( + "SELECT document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN " + "conflicts ON conflicts.doc_id = document.doc_id WHERE " + "document.doc_id = ? GROUP BY document.doc_id, " + "document.doc_rev, document.content;", (doc_id,)) + else: + c.execute( + "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?", + (doc_id,)) + val = c.fetchone() + if val is None: + return None + doc_rev, content, conflicts = val + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + return doc + + def _has_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1", + (doc_id,)) + val = c.fetchone() + if val is None: + return False + else: + return True + + def get_doc(self, doc_id, include_deleted=False): + doc = self._get_doc(doc_id, check_for_conflicts=True) + if doc is None: + return None + if doc.is_tombstone() and not include_deleted: + return None + return doc + + def get_all_docs(self, include_deleted=False): + """Get all documents from the database.""" + generation = self._get_generation() + results = [] + c = self._db_handle.cursor() + c.execute( + "SELECT document.doc_id, document.doc_rev, document.content, " + "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts " + "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, " + "document.doc_rev, document.content;") + rows = c.fetchall() + for doc_id, doc_rev, content, conflicts in rows: + if content is None and not include_deleted: + continue + doc = self._factory(doc_id, doc_rev, content) + doc.has_conflicts = conflicts > 0 + results.append(doc) + return (generation, results) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + self._check_doc_id(doc.doc_id) + self._check_doc_size(doc) + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc and old_doc.has_conflicts: + raise errors.ConflictedDoc() + if old_doc and doc.rev is None and old_doc.is_tombstone(): + new_rev = self._allocate_doc_rev(old_doc.rev) + else: + if old_doc is not None: + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + else: + if doc.rev is not None: + raise errors.RevisionConflict() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none): + """Convert a dict representation into named fields. + + So something like: {'key1': 'val1', 'key2': 'val2'} + gets converted into: [(doc_id, 'key1', 'val1', 0) + (doc_id, 'key2', 'val2', 0)] + :param doc_id: Just added to every record. + :param base_field: if set, these are nested keys, so each field should + be appropriately prefixed. + :param raw_doc: The python dictionary. + """ + # TODO: Handle lists + values = [] + for field_name, value in raw_doc.iteritems(): + if value is None and not save_none: + continue + if base_field: + full_name = base_field + '.' + field_name + else: + full_name = field_name + if value is None or isinstance(value, (int, float, basestring)): + values.append((doc_id, full_name, value, len(values))) + else: + subvalues = self._expand_to_fields(doc_id, full_name, value, + save_none) + for _, subfield_name, val, _ in subvalues: + values.append((doc_id, subfield_name, val, len(values))) + return values + + def _put_and_update_indexes(self, old_doc, doc): + """Actually insert a document into the database. + + This both updates the existing documents content, and any indexes that + refer to this document. + """ + raise NotImplementedError(self._put_and_update_indexes) + + def whats_changed(self, old_generation=0): + c = self._db_handle.cursor() + c.execute("SELECT generation, doc_id, transaction_id" + " FROM transaction_log" + " WHERE generation > ? ORDER BY generation DESC", + (old_generation,)) + results = c.fetchall() + cur_gen = old_generation + seen = set() + changes = [] + newest_trans_id = '' + for generation, doc_id, trans_id in results: + if doc_id not in seen: + changes.append((doc_id, generation, trans_id)) + seen.add(doc_id) + if changes: + cur_gen = changes[0][1] # max generation + newest_trans_id = changes[0][2] + changes.reverse() + else: + c.execute("SELECT generation, transaction_id" + " FROM transaction_log ORDER BY generation DESC LIMIT 1") + results = c.fetchone() + if not results: + cur_gen = 0 + newest_trans_id = '' + else: + cur_gen, newest_trans_id = results + + return cur_gen, newest_trans_id, changes + + def delete_doc(self, doc): + with self._db_handle: + old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True) + if old_doc is None: + raise errors.DocumentDoesNotExist + if old_doc.rev != doc.rev: + raise errors.RevisionConflict() + if old_doc.is_tombstone(): + raise errors.DocumentAlreadyDeleted + if old_doc.has_conflicts: + raise errors.ConflictedDoc() + new_rev = self._allocate_doc_rev(doc.rev) + doc.rev = new_rev + doc.make_tombstone() + self._put_and_update_indexes(old_doc, doc) + return new_rev + + def _get_conflicts(self, doc_id): + c = self._db_handle.cursor() + c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?", + (doc_id,)) + return [self._factory(doc_id, doc_rev, content) + for doc_rev, content in c.fetchall()] + + def get_doc_conflicts(self, doc_id): + with self._db_handle: + conflict_docs = self._get_conflicts(doc_id) + if not conflict_docs: + return [] + this_doc = self._get_doc(doc_id) + this_doc.has_conflicts = True + return [this_doc] + conflict_docs + + def _get_replica_gen_and_trans_id(self, other_replica_uid): + c = self._db_handle.cursor() + c.execute("SELECT known_generation, known_transaction_id FROM sync_log" + " WHERE replica_uid = ?", + (other_replica_uid,)) + val = c.fetchone() + if val is None: + other_gen = 0 + trans_id = '' + else: + other_gen = val[0] + trans_id = val[1] + return other_gen, trans_id + + def _set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, other_transaction_id): + with self._db_handle: + self._do_set_replica_gen_and_trans_id( + other_replica_uid, other_generation, other_transaction_id) + + def _do_set_replica_gen_and_trans_id(self, other_replica_uid, + other_generation, + other_transaction_id): + c = self._db_handle.cursor() + c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)", + (other_replica_uid, other_generation, + other_transaction_id)) + + def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + with self._db_handle: + return super(SQLiteDatabase, self)._put_doc_if_newer(doc, + save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + + def _add_conflict(self, c, doc_id, my_doc_rev, my_content): + c.execute("INSERT INTO conflicts VALUES (?, ?, ?)", + (doc_id, my_doc_rev, my_content)) + + def _delete_conflicts(self, c, doc, conflict_revs): + deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs] + c.executemany("DELETE FROM conflicts" + " WHERE doc_id=? AND doc_rev=?", deleting) + doc.has_conflicts = self._has_conflicts(doc.doc_id) + + def _prune_conflicts(self, doc, doc_vcr): + if self._has_conflicts(doc.doc_id): + autoresolved = False + c_revs_to_prune = [] + for c_doc in self._get_conflicts(doc.doc_id): + c_vcr = vectorclock.VectorClockRev(c_doc.rev) + if doc_vcr.is_newer(c_vcr): + c_revs_to_prune.append(c_doc.rev) + elif doc.same_content_as(c_doc): + c_revs_to_prune.append(c_doc.rev) + doc_vcr.maximize(c_vcr) + autoresolved = True + if autoresolved: + doc_vcr.increment(self._replica_uid) + doc.rev = doc_vcr.as_str() + c = self._db_handle.cursor() + self._delete_conflicts(c, doc, c_revs_to_prune) + + def _force_doc_sync_conflict(self, doc): + my_doc = self._get_doc(doc.doc_id) + c = self._db_handle.cursor() + self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev)) + self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json()) + doc.has_conflicts = True + self._put_and_update_indexes(my_doc, doc) + + def resolve_doc(self, doc, conflicted_doc_revs): + with self._db_handle: + cur_doc = self._get_doc(doc.doc_id) + # TODO: https://bugs.launchpad.net/u1db/+bug/928274 + # I think we have a logic bug in resolve_doc + # Specifically, cur_doc.rev is always in the final vector + # clock of revisions that we supersede, even if it wasn't in + # conflicted_doc_revs. We still add it as a conflict, but the + # fact that _put_doc_if_newer propagates resolutions means I + # think that conflict could accidentally be resolved. We need + # to add a test for this case first. (create a rev, create a + # conflict, create another conflict, resolve the first rev + # and first conflict, then make sure that the resolved + # rev doesn't supersede the second conflict rev.) It *might* + # not matter, because the superseding rev is in as a + # conflict, but it does seem incorrect + new_rev = self._ensure_maximal_rev(cur_doc.rev, + conflicted_doc_revs) + superseded_revs = set(conflicted_doc_revs) + c = self._db_handle.cursor() + doc.rev = new_rev + if cur_doc.rev in superseded_revs: + self._put_and_update_indexes(cur_doc, doc) + else: + self._add_conflict(c, doc.doc_id, new_rev, doc.get_json()) + # TODO: Is there some way that we could construct a rev that would + # end up in superseded_revs, such that we add a conflict, and + # then immediately delete it? + self._delete_conflicts(c, doc, superseded_revs) + + def list_indexes(self): + """Return the list of indexes and their definitions.""" + c = self._db_handle.cursor() + # TODO: How do we test the ordering? + c.execute("SELECT name, field FROM index_definitions" + " ORDER BY name, offset") + definitions = [] + cur_name = None + for name, field in c.fetchall(): + if cur_name != name: + definitions.append((name, [])) + cur_name = name + definitions[-1][-1].append(field) + return definitions + + def _get_index_definition(self, index_name): + """Return the stored definition for a given index_name.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions" + " WHERE name = ? ORDER BY offset", (index_name,)) + fields = [x[0] for x in c.fetchall()] + if not fields: + raise errors.IndexDoesNotExist + return fields + + @staticmethod + def _strip_glob(value): + """Remove the trailing * from a value.""" + assert value[-1] == '*' + return value[:-1] + + def _format_query(self, definition, key_values): + # First, build the definition. We join the document_fields table + # against itself, as many times as the 'width' of our definition. + # We then do a query for each key_value, one-at-a-time. + # Note: All of these strings are static, we could cache them, etc. + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = ["d.doc_id = d%d.doc_id" + " AND d%d.field_name = ?" + % (i, i) for i in range(len(definition))] + wildcard_where = [novalue_where[i] + + (" AND d%d.value NOT NULL" % (i,)) + for i in range(len(definition))] + exact_where = [novalue_where[i] + + (" AND d%d.value = ?" % (i,)) + for i in range(len(definition))] + like_where = [novalue_where[i] + + (" AND d%d.value GLOB ?" % (i,)) + for i in range(len(definition))] + is_wildcard = False + # Merge the lists together, so that: + # [field1, field2, field3], [val1, val2, val3] + # Becomes: + # (field1, val1, field2, val2, field3, val3) + args = [] + where = [] + for idx, (field, value) in enumerate(zip(definition, key_values)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(exact_where[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_from_index(self, index_name, *key_values): + definition = self._get_index_definition(index_name) + if len(key_values) != len(definition): + raise errors.InvalidValueForIndex() + statement, args = self._format_query(definition, key_values) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def _format_range_query(self, definition, start_value, end_value): + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + wildcard_where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + like_where = [ + novalue_where[i] + ( + " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in + range(len(definition))] + range_where_lower = [ + novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in + range(len(definition))] + range_where_upper = [ + novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in + range(len(definition))] + args = [] + where = [] + if start_value: + if isinstance(start_value, basestring): + start_value = (start_value,) + if len(start_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, start_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(self._strip_glob(value)) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_lower[idx]) + args.append(value) + if end_value: + if isinstance(end_value, basestring): + end_value = (end_value,) + if len(end_value) != len(definition): + raise errors.InvalidValueForIndex() + is_wildcard = False + for idx, (field, value) in enumerate(zip(definition, end_value)): + args.append(field) + if value.endswith('*'): + if value == '*': + where.append(wildcard_where[idx]) + else: + # This is a glob match + if is_wildcard: + # We can't have a partial wildcard following + # another wildcard + raise errors.InvalidGlobbing + where.append(like_where[idx]) + args.append(self._strip_glob(value)) + args.append(value) + is_wildcard = True + else: + if is_wildcard: + raise errors.InvalidGlobbing + where.append(range_where_upper[idx]) + args.append(value) + statement = ( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = " + "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER " + "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join( + ['d%d.value' % i for i in range(len(definition))]))) + return statement, args + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + """Return all documents with key values in the specified range.""" + definition = self._get_index_definition(index_name) + statement, args = self._format_range_query( + definition, start_value, end_value) + c = self._db_handle.cursor() + try: + c.execute(statement, tuple(args)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, args)) + res = c.fetchall() + results = [] + for row in res: + doc = self._factory(row[0], row[1], row[2]) + doc.has_conflicts = row[3] > 0 + results.append(doc) + return results + + def get_index_keys(self, index_name): + c = self._db_handle.cursor() + definition = self._get_index_definition(index_name) + value_fields = ', '.join([ + 'd%d.value' % i for i in range(len(definition))]) + tables = ["document_fields d%d" % i for i in range(len(definition))] + novalue_where = [ + "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in + range(len(definition))] + where = [ + novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in + range(len(definition))] + statement = ( + "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % ( + value_fields, ', '.join(tables), ' AND '.join(where), + value_fields)) + try: + c.execute(statement, tuple(definition)) + except dbapi2.OperationalError, e: + raise dbapi2.OperationalError(str(e) + + '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition))) + return c.fetchall() + + def delete_index(self, index_name): + with self._db_handle: + c = self._db_handle.cursor() + c.execute("DELETE FROM index_definitions WHERE name = ?", + (index_name,)) + c.execute( + "DELETE FROM document_fields WHERE document_fields.field_name " + " NOT IN (SELECT field from index_definitions)") + + +class SQLiteSyncTarget(CommonSyncTarget): + + def get_sync_info(self, source_replica_uid): + source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id( + source_replica_uid) + my_gen, my_trans_id = self._db._get_generation_info() + return ( + self._db._replica_uid, my_gen, my_trans_id, source_gen, + source_trans_id) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_replica_transaction_id): + if self._trace_hook: + self._trace_hook('record_sync_info') + self._db._set_replica_gen_and_trans_id( + source_replica_uid, source_replica_generation, + source_replica_transaction_id) + + +class SQLitePartialExpandDatabase(SQLiteDatabase): + """An SQLite Backend that expands documents into a document_field table. + + It stores the original document text in document.doc. For fields that are + indexed, the data goes into document_fields. + """ + + _index_storage_value = 'expand referenced' + + def _get_indexed_fields(self): + """Determine what fields are indexed.""" + c = self._db_handle.cursor() + c.execute("SELECT field FROM index_definitions") + return set([x[0] for x in c.fetchall()]) + + def _evaluate_index(self, raw_doc, field): + parser = query_parser.Parser() + getter = parser.parse(field) + return getter.get(raw_doc) + + def _put_and_update_indexes(self, old_doc, doc): + c = self._db_handle.cursor() + if doc and not doc.is_tombstone(): + raw_doc = json.loads(doc.get_json()) + else: + raw_doc = {} + if old_doc is not None: + c.execute("UPDATE document SET doc_rev=?, content=?" + " WHERE doc_id = ?", + (doc.rev, doc.get_json(), doc.doc_id)) + c.execute("DELETE FROM document_fields WHERE doc_id = ?", + (doc.doc_id,)) + else: + c.execute("INSERT INTO document (doc_id, doc_rev, content)" + " VALUES (?, ?, ?)", + (doc.doc_id, doc.rev, doc.get_json())) + indexed_fields = self._get_indexed_fields() + if indexed_fields: + # It is expected that len(indexed_fields) is shorter than + # len(raw_doc) + getters = [(field, self._parse_index_definition(field)) + for field in indexed_fields] + self._update_indexes(doc.doc_id, raw_doc, getters, c) + trans_id = self._allocate_transaction_id() + c.execute("INSERT INTO transaction_log(doc_id, transaction_id)" + " VALUES (?, ?)", (doc.doc_id, trans_id)) + + def create_index(self, index_name, *index_expressions): + with self._db_handle: + c = self._db_handle.cursor() + cur_fields = self._get_indexed_fields() + definition = [(index_name, idx, field) + for idx, field in enumerate(index_expressions)] + try: + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + definition) + except dbapi2.IntegrityError as e: + stored_def = self._get_index_definition(index_name) + if stored_def == [x[-1] for x in definition]: + return + raise errors.IndexNameTakenError, e, sys.exc_info()[2] + new_fields = set( + [f for f in index_expressions if f not in cur_fields]) + if new_fields: + self._update_all_indexes(new_fields) + + def _iter_all_docs(self): + c = self._db_handle.cursor() + c.execute("SELECT doc_id, content FROM document") + while True: + next_rows = c.fetchmany() + if not next_rows: + break + for row in next_rows: + yield row + + def _update_all_indexes(self, new_fields): + """Iterate all the documents, and add content to document_fields. + + :param new_fields: The index definitions that need to be added. + """ + getters = [(field, self._parse_index_definition(field)) + for field in new_fields] + c = self._db_handle.cursor() + for doc_id, doc in self._iter_all_docs(): + if doc is None: + continue + raw_doc = json.loads(doc) + self._update_indexes(doc_id, raw_doc, getters, c) + +SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/u1db/commandline/__init__.py b/src/leap/soledad/u1db/commandline/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. diff --git a/src/leap/soledad/u1db/commandline/client.py b/src/leap/soledad/u1db/commandline/client.py new file mode 100644 index 00000000..15bf8561 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/client.py @@ -0,0 +1,497 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Commandline bindings for the u1db-client program.""" + +import argparse +import os +try: + import simplejson as json +except ImportError: + import json # noqa +import sys + +from u1db import ( + Document, + open as u1db_open, + sync, + errors, + ) +from u1db.commandline import command +from u1db.remote import ( + http_database, + http_target, + ) + + +client_commands = command.CommandGroup() + + +def set_oauth_credentials(client): + keys = os.environ.get('OAUTH_CREDENTIALS', None) + if keys is not None: + consumer_key, consumer_secret, \ + token_key, token_secret = keys.split(":") + client.set_oauth_credentials(consumer_key, consumer_secret, + token_key, token_secret) + + +class OneDbCmd(command.Command): + """Base class for commands operating on one local or remote database.""" + + def _open(self, database, create): + if database.startswith(('http://', 'https://')): + db = http_database.HTTPDatabase(database) + set_oauth_credentials(db) + db.open(create) + return db + else: + return u1db_open(database, create) + + +class CmdCreate(OneDbCmd): + """Create a new document from scratch""" + + name = 'create' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url') + parser.add_argument('infile', nargs='?', default=None, + help='The file to read content from.') + parser.add_argument('--id', dest='doc_id', default=None, + help='Set the document identifier') + + def run(self, database, infile, doc_id): + if infile is None: + infile = self.stdin + db = self._open(database, create=False) + doc = db.create_doc_from_json(infile.read(), doc_id=doc_id) + self.stderr.write('id: %s\nrev: %s\n' % (doc.doc_id, doc.rev)) + +client_commands.register(CmdCreate) + + +class CmdDelete(OneDbCmd): + """Delete a document from the database""" + + name = 'delete' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url') + parser.add_argument('doc_id', help='The document id to retrieve') + parser.add_argument('doc_rev', + help='The revision of the document (which is being superseded.)') + + def run(self, database, doc_id, doc_rev): + db = self._open(database, create=False) + doc = Document(doc_id, doc_rev, None) + db.delete_doc(doc) + self.stderr.write('rev: %s\n' % (doc.rev,)) + +client_commands.register(CmdDelete) + + +class CmdGet(OneDbCmd): + """Extract a document from the database""" + + name = 'get' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to query', + metavar='database-path-or-url') + parser.add_argument('doc_id', help='The document id to retrieve.') + parser.add_argument('outfile', nargs='?', default=None, + help='The file to write the document to', + type=argparse.FileType('wb')) + + def run(self, database, doc_id, outfile): + if outfile is None: + outfile = self.stdout + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + doc = db.get_doc(doc_id) + if doc is None: + self.stderr.write('Document not found (id: %s)\n' % (doc_id,)) + return 1 # failed + if doc.is_tombstone(): + outfile.write('[document deleted]\n') + else: + outfile.write(doc.get_json() + '\n') + self.stderr.write('rev: %s\n' % (doc.rev,)) + if doc.has_conflicts: + self.stderr.write("Document has conflicts.\n") + +client_commands.register(CmdGet) + + +class CmdGetDocConflicts(OneDbCmd): + """Get the conflicts from a document""" + + name = 'get-doc-conflicts' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local database to query', + metavar='database-path') + parser.add_argument('doc_id', help='The document id to retrieve.') + + def run(self, database, doc_id): + try: + db = self._open(database, False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + conflicts = db.get_doc_conflicts(doc_id) + if not conflicts: + if db.get_doc(doc_id) is None: + self.stderr.write("Document does not exist.\n") + return 1 + self.stdout.write("[") + for i, doc in enumerate(conflicts): + if i: + self.stdout.write(",") + self.stdout.write( + json.dumps(dict(rev=doc.rev, content=doc.content), indent=4)) + self.stdout.write("]\n") + +client_commands.register(CmdGetDocConflicts) + + +class CmdInitDB(OneDbCmd): + """Create a new database""" + + name = 'init-db' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to create', + metavar='database-path-or-url') + parser.add_argument('--replica-uid', default=None, + help='The unique identifier for this database (not for remote)') + + def run(self, database, replica_uid): + db = self._open(database, create=True) + if replica_uid is not None: + db._set_replica_uid(replica_uid) + +client_commands.register(CmdInitDB) + + +class CmdPut(OneDbCmd): + """Add a document to the database""" + + name = 'put' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url'), + parser.add_argument('doc_id', help='The document id to retrieve') + parser.add_argument('doc_rev', + help='The revision of the document (which is being superseded.)') + parser.add_argument('infile', nargs='?', default=None, + help='The filename of the document that will be used for content', + type=argparse.FileType('rb')) + + def run(self, database, doc_id, doc_rev, infile): + if infile is None: + infile = self.stdin + try: + db = self._open(database, create=False) + doc = Document(doc_id, doc_rev, infile.read()) + doc_rev = db.put_doc(doc) + self.stderr.write('rev: %s\n' % (doc_rev,)) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + except errors.RevisionConflict: + if db.get_doc(doc_id) is None: + self.stderr.write("Document does not exist.\n") + else: + self.stderr.write("Given revision is not current.\n") + except errors.ConflictedDoc: + self.stderr.write( + "Document has conflicts.\n" + "Inspect with get-doc-conflicts, then resolve.\n") + else: + return + return 1 + +client_commands.register(CmdPut) + + +class CmdResolve(OneDbCmd): + """Resolve a conflicted document""" + + name = 'resolve-doc' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', + help='The local or remote database to update', + metavar='database-path-or-url'), + parser.add_argument('doc_id', help='The conflicted document id') + parser.add_argument('doc_revs', metavar="doc-rev", nargs="+", + help='The revisions that the new content supersedes') + parser.add_argument('--infile', nargs='?', default=None, + help='The filename of the document that will be used for content', + type=argparse.FileType('rb')) + + def run(self, database, doc_id, doc_revs, infile): + if infile is None: + infile = self.stdin + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + doc = db.get_doc(doc_id) + if doc is None: + self.stderr.write("Document does not exist.\n") + return 1 + doc.set_json(infile.read()) + db.resolve_doc(doc, doc_revs) + self.stderr.write("rev: %s\n" % db.get_doc(doc_id).rev) + if doc.has_conflicts: + self.stderr.write("Document still has conflicts.\n") + +client_commands.register(CmdResolve) + + +class CmdSync(command.Command): + """Synchronize two databases""" + + name = 'sync' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('source', help='database to sync from') + parser.add_argument('target', help='database to sync to') + + def _open_target(self, target): + if target.startswith(('http://', 'https://')): + st = http_target.HTTPSyncTarget.connect(target) + set_oauth_credentials(st) + else: + db = u1db_open(target, create=True) + st = db.get_sync_target() + return st + + def run(self, source, target): + """Start a Sync request.""" + source_db = u1db_open(source, create=False) + st = self._open_target(target) + syncer = sync.Synchronizer(source_db, st) + syncer.sync() + source_db.close() + +client_commands.register(CmdSync) + + +class CmdCreateIndex(OneDbCmd): + """Create an index""" + + name = "create-index" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to update', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + parser.add_argument('expression', help='an index expression', + nargs='+') + + def run(self, database, index, expression): + try: + db = self._open(database, create=False) + db.create_index(index, *expression) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + except errors.IndexNameTakenError: + self.stderr.write("There is already a different index named %r.\n" + % (index,)) + return 1 + except errors.IndexDefinitionParseError: + self.stderr.write("Bad index expression.\n") + return 1 + +client_commands.register(CmdCreateIndex) + + +class CmdListIndexes(OneDbCmd): + """List existing indexes""" + + name = "list-indexes" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to query', + metavar='database-path') + + def run(self, database): + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + for (index, expression) in db.list_indexes(): + self.stdout.write("%s: %s\n" % (index, ", ".join(expression))) + +client_commands.register(CmdListIndexes) + + +class CmdDeleteIndex(OneDbCmd): + """Delete an index""" + + name = "delete-index" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to update', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + + def run(self, database, index): + try: + db = self._open(database, create=False) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + return 1 + db.delete_index(index) + +client_commands.register(CmdDeleteIndex) + + +class CmdGetIndexKeys(OneDbCmd): + """Get the index's keys""" + + name = "get-index-keys" + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to query', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + + def run(self, database, index): + try: + db = self._open(database, create=False) + for key in db.get_index_keys(index): + self.stdout.write("%s\n" % (", ".join( + [i.encode('utf-8') for i in key],))) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + except errors.IndexDoesNotExist: + self.stderr.write("Index does not exist.\n") + else: + return + return 1 + +client_commands.register(CmdGetIndexKeys) + + +class CmdGetFromIndex(OneDbCmd): + """Find documents by searching an index""" + + name = "get-from-index" + argv = None + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('database', help='The local database to query', + metavar='database-path') + parser.add_argument('index', help='the name of the index') + parser.add_argument('values', metavar="value", + help='the value to look up (one per index column)', + nargs="+") + + def run(self, database, index, values): + try: + db = self._open(database, create=False) + docs = db.get_from_index(index, *values) + except errors.DatabaseDoesNotExist: + self.stderr.write("Database does not exist.\n") + except errors.IndexDoesNotExist: + self.stderr.write("Index does not exist.\n") + except errors.InvalidValueForIndex: + index_def = db._get_index_definition(index) + len_diff = len(index_def) - len(values) + if len_diff == 0: + # can't happen (HAH) + raise + argv = self.argv if self.argv is not None else sys.argv + self.stderr.write( + "Invalid query: " + "index %r requires %d query expression%s%s.\n" + "For example, the following would be valid:\n" + " %s %s %r %r %s\n" + % (index, + len(index_def), + "s" if len(index_def) > 1 else "", + ", not %d" % len(values) if len(values) else "", + argv[0], argv[1], database, index, + " ".join(map(repr, + values[:len(index_def)] + + ["*" for i in range(len_diff)])), + )) + except errors.InvalidGlobbing: + argv = self.argv if self.argv is not None else sys.argv + fixed = [] + for (i, v) in enumerate(values): + fixed.append(v) + if v.endswith('*'): + break + # values has at least one element, so i is defined + fixed.extend('*' * (len(values) - i - 1)) + self.stderr.write( + "Invalid query: a star can only be followed by stars.\n" + "For example, the following would be valid:\n" + " %s %s %r %r %s\n" + % (argv[0], argv[1], database, index, + " ".join(map(repr, fixed)))) + + else: + self.stdout.write("[") + for i, doc in enumerate(docs): + if i: + self.stdout.write(",") + self.stdout.write( + json.dumps( + dict(id=doc.doc_id, rev=doc.rev, content=doc.content), + indent=4)) + self.stdout.write("]\n") + return + return 1 + +client_commands.register(CmdGetFromIndex) + + +def main(args): + return client_commands.run_argv(args, sys.stdin, sys.stdout, sys.stderr) diff --git a/src/leap/soledad/u1db/commandline/command.py b/src/leap/soledad/u1db/commandline/command.py new file mode 100644 index 00000000..eace0560 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/command.py @@ -0,0 +1,80 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Command infrastructure for u1db""" + +import argparse +import inspect + + +class CommandGroup(object): + """A collection of commands.""" + + def __init__(self, description=None): + self.commands = {} + self.description = description + + def register(self, cmd): + """Register a new command to be incorporated with this group.""" + self.commands[cmd.name] = cmd + + def make_argparser(self): + """Create an argparse.ArgumentParser""" + parser = argparse.ArgumentParser(description=self.description) + subs = parser.add_subparsers(title='commands') + for name, cmd in sorted(self.commands.iteritems()): + sub = subs.add_parser(name, help=cmd.__doc__) + sub.set_defaults(subcommand=cmd) + cmd._populate_subparser(sub) + return parser + + def run_argv(self, argv, stdin, stdout, stderr): + """Run a command, from a sys.argv[1:] style input.""" + parser = self.make_argparser() + args = parser.parse_args(argv) + cmd = args.subcommand(stdin, stdout, stderr) + params, _, _, _ = inspect.getargspec(cmd.run) + vals = [] + for param in params[1:]: + vals.append(getattr(args, param)) + return cmd.run(*vals) + + +class Command(object): + """Definition of a Command that can be run. + + :cvar name: The name of the command, so that you can run + 'u1db-client <name>'. + """ + + name = None + + def __init__(self, stdin, stdout, stderr): + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + + @classmethod + def _populate_subparser(cls, parser): + """Child classes should override this to provide their arguments.""" + raise NotImplementedError(cls._populate_subparser) + + def run(self, *args): + """This is where the magic happens. + + Subclasses should implement this, requesting their specific arguments. + """ + raise NotImplementedError(self.run) diff --git a/src/leap/soledad/u1db/commandline/serve.py b/src/leap/soledad/u1db/commandline/serve.py new file mode 100644 index 00000000..0bb0e641 --- /dev/null +++ b/src/leap/soledad/u1db/commandline/serve.py @@ -0,0 +1,34 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Build server for u1db-serve.""" + +from paste import httpserver + +from u1db.remote import ( + http_app, + server_state, + ) + + +def make_server(host, port, working_dir): + """Make a server on host and port exposing dbs living in working_dir.""" + state = server_state.ServerState() + state.set_workingdir(working_dir) + application = http_app.HTTPApp(state) + server = httpserver.WSGIServer(application, (host, port), + httpserver.WSGIHandler) + return server diff --git a/src/leap/soledad/u1db/errors.py b/src/leap/soledad/u1db/errors.py new file mode 100644 index 00000000..967c7c38 --- /dev/null +++ b/src/leap/soledad/u1db/errors.py @@ -0,0 +1,189 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""A list of errors that u1db can raise.""" + + +class U1DBError(Exception): + """Generic base class for U1DB errors.""" + + # description/tag for identifying the error during transmission (http,...) + wire_description = "error" + + def __init__(self, message=None): + self.message = message + + +class RevisionConflict(U1DBError): + """The document revisions supplied does not match the current version.""" + + wire_description = "revision conflict" + + +class InvalidJSON(U1DBError): + """Content was not valid json.""" + + +class InvalidContent(U1DBError): + """Content was not a python dictionary.""" + + +class InvalidDocId(U1DBError): + """A document was requested with an invalid document identifier.""" + + wire_description = "invalid document id" + + +class MissingDocIds(U1DBError): + """Needs document ids.""" + + wire_description = "missing document ids" + + +class DocumentTooBig(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "document too big" + + +class UserQuotaExceeded(U1DBError): + """Document exceeds the maximum document size for this database.""" + + wire_description = "user quota exceeded" + + +class SubscriptionNeeded(U1DBError): + """User needs a subscription to be able to use this replica..""" + + wire_description = "user needs subscription" + + +class InvalidTransactionId(U1DBError): + """Invalid transaction for generation.""" + + wire_description = "invalid transaction id" + + +class InvalidGeneration(U1DBError): + """Generation was previously synced with a different transaction id.""" + + wire_description = "invalid generation" + + +class ConflictedDoc(U1DBError): + """The document is conflicted, you must call resolve before put()""" + + +class InvalidValueForIndex(U1DBError): + """The values supplied does not match the index definition.""" + + +class InvalidGlobbing(U1DBError): + """Raised if wildcard matches are not strictly at the tail of the request. + """ + + +class DocumentDoesNotExist(U1DBError): + """The document does not exist.""" + + wire_description = "document does not exist" + + +class DocumentAlreadyDeleted(U1DBError): + """The document was already deleted.""" + + wire_description = "document already deleted" + + +class DatabaseDoesNotExist(U1DBError): + """The database does not exist.""" + + wire_description = "database does not exist" + + +class IndexNameTakenError(U1DBError): + """The given index name is already taken.""" + + +class IndexDefinitionParseError(U1DBError): + """The index definition cannot be parsed.""" + + +class IndexDoesNotExist(U1DBError): + """No index of that name exists.""" + + +class Unauthorized(U1DBError): + """Request wasn't authorized properly.""" + + wire_description = "unauthorized" + + +class HTTPError(U1DBError): + """Unspecific HTTP errror.""" + + wire_description = None + + def __init__(self, status, message=None, headers={}): + self.status = status + self.message = message + self.headers = headers + + def __str__(self): + if not self.message: + return "HTTPError(%d)" % self.status + else: + return "HTTPError(%d, %r)" % (self.status, self.message) + + +class Unavailable(HTTPError): + """Server not available not serve request.""" + + wire_description = "unavailable" + + def __init__(self, message=None, headers={}): + super(Unavailable, self).__init__(503, message, headers) + + def __str__(self): + if not self.message: + return "Unavailable()" + else: + return "Unavailable(%r)" % self.message + + +class BrokenSyncStream(U1DBError): + """Unterminated or otherwise broken sync exchange stream.""" + + wire_description = None + + +class UnknownAuthMethod(U1DBError): + """Unknown auhorization method.""" + + wire_description = None + + +# mapping wire (transimission) descriptions/tags for errors to the exceptions +wire_description_to_exc = dict( + (x.wire_description, x) for x in globals().values() + if getattr(x, 'wire_description', None) not in (None, "error") +) +wire_description_to_exc["error"] = U1DBError + + +# +# wire error descriptions not corresponding to an exception +DOCUMENT_DELETED = "document deleted" diff --git a/src/leap/soledad/u1db/query_parser.py b/src/leap/soledad/u1db/query_parser.py new file mode 100644 index 00000000..f564821f --- /dev/null +++ b/src/leap/soledad/u1db/query_parser.py @@ -0,0 +1,370 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Code for parsing Index definitions.""" + +import re +from u1db import ( + errors, + ) + + +class Getter(object): + """Get values from a document based on a specification.""" + + def get(self, raw_doc): + """Get a value from the document. + + :param raw_doc: a python dictionary to get the value from. + :return: A list of values that match the description. + """ + raise NotImplementedError(self.get) + + +class StaticGetter(Getter): + """A getter that returns a defined value (independent of the doc).""" + + def __init__(self, value): + """Create a StaticGetter. + + :param value: the value to return when get is called. + """ + if value is None: + self.value = [] + elif isinstance(value, list): + self.value = value + else: + self.value = [value] + + def get(self, raw_doc): + return self.value + + +def extract_field(raw_doc, subfields, index=0): + if not isinstance(raw_doc, dict): + return [] + val = raw_doc.get(subfields[index]) + if val is None: + return [] + if index < len(subfields) - 1: + if isinstance(val, list): + results = [] + for item in val: + results.extend(extract_field(item, subfields, index + 1)) + return results + if isinstance(val, dict): + return extract_field(val, subfields, index + 1) + return [] + if isinstance(val, dict): + return [] + if isinstance(val, list): + # Strip anything in the list that isn't a simple type + return [v for v in val if not isinstance(v, (dict, list))] + return [val] + + +class ExtractField(Getter): + """Extract a field from the document.""" + + def __init__(self, field): + """Create an ExtractField object. + + When a document is passed to get() this will return a value + from the document based on the field specifier passed to + the constructor. + + None will be returned if the field is nonexistant, or refers to an + object, rather than a simple type or list of simple types. + + :param field: a specifier for the field to return. + This is either a field name, or a dotted field name. + """ + self.field = field.split('.') + + def get(self, raw_doc): + return extract_field(raw_doc, self.field) + + +class Transformation(Getter): + """A transformation on a value from another Getter.""" + + name = None + arity = 1 + args = ['expression'] + + def __init__(self, inner): + """Create a transformation. + + :param inner: the argument(s) to the transformation. + """ + self.inner = inner + + def get(self, raw_doc): + inner_values = self.inner.get(raw_doc) + assert isinstance(inner_values, list),\ + 'get() should always return a list' + return self.transform(inner_values) + + def transform(self, values): + """Transform the values. + + This should be implemented by subclasses to transform the + value when get() is called. + + :param values: the values from the other Getter + :return: the transformed values. + """ + raise NotImplementedError(self.transform) + + +class Lower(Transformation): + """Lowercase a string. + + This transformation will return None for non-string inputs. However, + it will lowercase any strings in a list, dropping any elements + that are not strings. + """ + + name = "lower" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + return [val.lower() for val in values if self._can_transform(val)] + + +class Number(Transformation): + """Convert an integer to a zero padded string. + + This transformation will return None for non-integer inputs. However, it + will transform any integers in a list, dropping any elements that are not + integers. + """ + + name = 'number' + arity = 2 + args = ['expression', int] + + def __init__(self, inner, number): + super(Number, self).__init__(inner) + self.padding = "%%0%sd" % number + + def _can_transform(self, val): + return isinstance(val, int) and not isinstance(val, bool) + + def transform(self, values): + """Transform any integers in values into zero padded strings.""" + if not values: + return [] + return [self.padding % (v,) for v in values if self._can_transform(v)] + + +class Bool(Transformation): + """Convert bool to string.""" + + name = "bool" + args = ['expression'] + + def _can_transform(self, val): + return isinstance(val, bool) + + def transform(self, values): + """Transform any booleans in values into strings.""" + if not values: + return [] + return [('1' if v else '0') for v in values if self._can_transform(v)] + + +class SplitWords(Transformation): + """Split a string on whitespace. + + This Getter will return [] for non-string inputs. It will however + split any strings in an input list, discarding any elements that + are not strings. + """ + + name = "split_words" + + def _can_transform(self, val): + return isinstance(val, basestring) + + def transform(self, values): + if not values: + return [] + result = set() + for value in values: + if self._can_transform(value): + for word in value.split(): + result.add(word) + return list(result) + + +class Combine(Transformation): + """Combine multiple expressions into a single index.""" + + name = "combine" + # variable number of args + arity = -1 + + def __init__(self, *inner): + super(Combine, self).__init__(inner) + + def get(self, raw_doc): + inner_values = [] + for inner in self.inner: + inner_values.extend(inner.get(raw_doc)) + return self.transform(inner_values) + + def transform(self, values): + return values + + +class IsNull(Transformation): + """Indicate whether the input is None. + + This Getter returns a bool indicating whether the input is nil. + """ + + name = "is_null" + + def transform(self, values): + return [len(values) == 0] + + +def check_fieldname(fieldname): + if fieldname.endswith('.'): + raise errors.IndexDefinitionParseError( + "Fieldname cannot end in '.':%s^" % (fieldname,)) + + +class Parser(object): + """Parse an index expression into a sequence of transformations.""" + + _transformations = {} + _delimiters = re.compile("\(|\)|,") + + def __init__(self): + self._tokens = [] + + def _set_expression(self, expression): + self._open_parens = 0 + self._tokens = [] + expression = expression.strip() + while expression: + delimiter = self._delimiters.search(expression) + if delimiter: + idx = delimiter.start() + if idx == 0: + result, expression = (expression[:1], expression[1:]) + self._tokens.append(result) + else: + result, expression = (expression[:idx], expression[idx:]) + result = result.strip() + if result: + self._tokens.append(result) + else: + expression = expression.strip() + if expression: + self._tokens.append(expression) + expression = None + + def _get_token(self): + if self._tokens: + return self._tokens.pop(0) + + def _peek_token(self): + if self._tokens: + return self._tokens[0] + + @staticmethod + def _to_getter(term): + if isinstance(term, Getter): + return term + check_fieldname(term) + return ExtractField(term) + + def _parse_op(self, op_name): + self._get_token() # '(' + op = self._transformations.get(op_name, None) + if op is None: + raise errors.IndexDefinitionParseError( + "Unknown operation: %s" % op_name) + args = [] + while True: + args.append(self._parse_term()) + sep = self._get_token() + if sep == ')': + break + if sep != ',': + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' in parentheses." % (sep,)) + parsed = [] + for i, arg in enumerate(args): + arg_type = op.args[i % len(op.args)] + if arg_type == 'expression': + inner = self._to_getter(arg) + else: + try: + inner = arg_type(arg) + except ValueError, e: + raise errors.IndexDefinitionParseError( + "Invalid value %r for argument type %r " + "(%r)." % (arg, arg_type, e)) + parsed.append(inner) + return op(*parsed) + + def _parse_term(self): + term = self._get_token() + if term is None: + raise errors.IndexDefinitionParseError( + "Unexpected end of index definition.") + if term in (',', ')', '('): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' at start of expression." % (term,)) + next_token = self._peek_token() + if next_token == '(': + return self._parse_op(term) + return term + + def parse(self, expression): + self._set_expression(expression) + term = self._to_getter(self._parse_term()) + if self._peek_token(): + raise errors.IndexDefinitionParseError( + "Unexpected token '%s' after end of expression." + % (self._peek_token(),)) + return term + + def parse_all(self, fields): + return [self.parse(field) for field in fields] + + @classmethod + def register_transormation(cls, transform): + assert transform.name not in cls._transformations, ( + "Transform %s already registered for %s" + % (transform.name, cls._transformations[transform.name])) + cls._transformations[transform.name] = transform + + +Parser.register_transormation(SplitWords) +Parser.register_transormation(Lower) +Parser.register_transormation(Number) +Parser.register_transormation(Bool) +Parser.register_transormation(IsNull) +Parser.register_transormation(Combine) diff --git a/src/leap/soledad/u1db/remote/__init__.py b/src/leap/soledad/u1db/remote/__init__.py new file mode 100644 index 00000000..3f32e381 --- /dev/null +++ b/src/leap/soledad/u1db/remote/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. diff --git a/src/leap/soledad/u1db/remote/basic_auth_middleware.py b/src/leap/soledad/u1db/remote/basic_auth_middleware.py new file mode 100644 index 00000000..a2cbff62 --- /dev/null +++ b/src/leap/soledad/u1db/remote/basic_auth_middleware.py @@ -0,0 +1,68 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. +"""U1DB Basic Auth authorisation WSGI middleware.""" +import httplib +try: + import simplejson as json +except ImportError: + import json # noqa +from wsgiref.util import shift_path_info + + +class Unauthorized(Exception): + """User authorization failed.""" + + +class BasicAuthMiddleware(object): + """U1DB Basic Auth Authorisation WSGI middleware.""" + + def __init__(self, app, prefix): + self.app = app + self.prefix = prefix + + def _error(self, start_response, status, description, message=None): + start_response("%d %s" % (status, httplib.responses[status]), + [('content-type', 'application/json')]) + err = {"error": description} + if message: + err['message'] = message + return [json.dumps(err)] + + def __call__(self, environ, start_response): + if self.prefix and not environ['PATH_INFO'].startswith(self.prefix): + return self._error(start_response, 400, "bad request") + auth = environ.get('HTTP_AUTHORIZATION') + if not auth: + return self._error(start_response, 401, "unauthorized", + "Missing Basic Authentication.") + scheme, encoded = auth.split(None, 1) + if scheme.lower() != 'basic': + return self._error( + start_response, 401, "unauthorized", + "Missing Basic Authentication") + user, password = encoded.decode('base64').split(':', 1) + try: + self.verify_user(environ, user, password) + except Unauthorized: + return self._error( + start_response, 401, "unauthorized", + "Incorrect password or login.") + del environ['HTTP_AUTHORIZATION'] + shift_path_info(environ) + return self.app(environ, start_response) + + def verify_user(self, environ, username, password): + raise NotImplementedError(self.verify_user) diff --git a/src/leap/soledad/u1db/remote/http_app.py b/src/leap/soledad/u1db/remote/http_app.py new file mode 100644 index 00000000..3d7d4248 --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_app.py @@ -0,0 +1,629 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""HTTP Application exposing U1DB.""" + +import functools +import httplib +import inspect +try: + import simplejson as json +except ImportError: + import json # noqa +import sys +import urlparse + +import routes.mapper + +from u1db import ( + __version__ as _u1db_version, + DBNAME_CONSTRAINTS, + Document, + errors, + sync, + ) +from u1db.remote import ( + http_errors, + utils, + ) + + +def parse_bool(expression): + """Parse boolean querystring parameter.""" + if expression == 'true': + return True + return False + + +def parse_list(expression): + if expression is None: + return [] + return [t.strip() for t in expression.split(',')] + + +def none_or_str(expression): + if expression is None: + return None + return str(expression) + + +class BadRequest(Exception): + """Bad request.""" + + +class _FencedReader(object): + """Read and get lines from a file but not past a given length.""" + + MAXCHUNK = 8192 + + def __init__(self, rfile, total, max_entry_size): + self.rfile = rfile + self.remaining = total + self.max_entry_size = max_entry_size + self._kept = None + + def read_chunk(self, atmost): + if self._kept is not None: + # ignore atmost, kept data should be a subchunk anyway + kept, self._kept = self._kept, None + return kept + if self.remaining == 0: + return '' + data = self.rfile.read(min(self.remaining, atmost)) + self.remaining -= len(data) + return data + + def getline(self): + line_parts = [] + size = 0 + while True: + chunk = self.read_chunk(self.MAXCHUNK) + if chunk == '': + break + nl = chunk.find("\n") + if nl != -1: + size += nl + 1 + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk[:nl + 1]) + rest = chunk[nl + 1:] + self._kept = rest or None + break + else: + size += len(chunk) + if size > self.max_entry_size: + raise BadRequest + line_parts.append(chunk) + return ''.join(line_parts) + + +def http_method(**control): + """Decoration for handling of query arguments and content for a HTTP + method. + + args and content here are the query arguments and body of the incoming + HTTP requests. + + Match query arguments to python method arguments: + w = http_method()(f) + w(self, args, content) => args["content"]=content; + f(self, **args) + + JSON deserialize content to arguments: + w = http_method(content_as_args=True,...)(f) + w(self, args, content) => args.update(json.loads(content)); + f(self, **args) + + Support conversions (e.g int): + w = http_method(Arg=Conv,...)(f) + w(self, args, content) => args["Arg"]=Conv(args["Arg"]); + f(self, **args) + + Enforce no use of query arguments: + w = http_method(no_query=True,...)(f) + w(self, args, content) raises BadRequest if args is not empty + + Argument mismatches, deserialisation failures produce BadRequest. + """ + content_as_args = control.pop('content_as_args', False) + no_query = control.pop('no_query', False) + conversions = control.items() + + def wrap(f): + argspec = inspect.getargspec(f) + assert argspec.args[0] == "self" + nargs = len(argspec.args) + ndefaults = len(argspec.defaults or ()) + required_args = set(argspec.args[1:nargs - ndefaults]) + all_args = set(argspec.args) + + @functools.wraps(f) + def wrapper(self, args, content): + if no_query and args: + raise BadRequest() + if content is not None: + if content_as_args: + try: + args.update(json.loads(content)) + except ValueError: + raise BadRequest() + else: + args["content"] = content + if not (required_args <= set(args) <= all_args): + raise BadRequest("Missing required arguments.") + for name, conv in conversions: + if name not in args: + continue + try: + args[name] = conv(args[name]) + except ValueError: + raise BadRequest() + return f(self, **args) + + return wrapper + + return wrap + + +class URLToResource(object): + """Mappings from URLs to resources.""" + + def __init__(self): + self._map = routes.mapper.Mapper(controller_scan=None) + + def register(self, resource_cls): + # register + self._map.connect(None, resource_cls.url_pattern, + resource_cls=resource_cls, + requirements={"dbname": DBNAME_CONSTRAINTS}) + self._map.create_regs() + return resource_cls + + def match(self, path): + params = self._map.match(path) + if params is None: + return None, None + resource_cls = params.pop('resource_cls') + return resource_cls, params + +url_to_resource = URLToResource() + + +@url_to_resource.register +class GlobalResource(object): + """Global (root) resource.""" + + url_pattern = "/" + + def __init__(self, state, responder): + self.responder = responder + + @http_method() + def get(self): + self.responder.send_response_json(version=_u1db_version) + + +@url_to_resource.register +class DatabaseResource(object): + """Database resource.""" + + url_pattern = "/{dbname}" + + def __init__(self, dbname, state, responder): + self.dbname = dbname + self.state = state + self.responder = responder + + @http_method() + def get(self): + self.state.check_database(self.dbname) + self.responder.send_response_json(200) + + @http_method(content_as_args=True) + def put(self): + self.state.ensure_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + @http_method() + def delete(self): + self.state.delete_database(self.dbname) + self.responder.send_response_json(200, ok=True) + + +@url_to_resource.register +class DocsResource(object): + """Documents resource.""" + + url_pattern = "/{dbname}/docs" + + def __init__(self, dbname, state, responder): + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool, + include_deleted=parse_bool) + def get(self, doc_ids=None, check_for_conflicts=True, + include_deleted=False): + if doc_ids is None: + raise errors.MissingDocIds + docs = self.db.get_docs(doc_ids, include_deleted=include_deleted) + self.responder.content_type = 'application/json' + self.responder.start_response(200) + self.responder.start_stream(), + for doc in docs: + entry = dict( + doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(), + has_conflicts=doc.has_conflicts) + self.responder.stream_entry(entry) + self.responder.end_stream() + self.responder.finish_response() + + +@url_to_resource.register +class DocResource(object): + """Document resource.""" + + url_pattern = "/{dbname}/doc/{id:.*}" + + def __init__(self, dbname, id, state, responder): + self.id = id + self.responder = responder + self.db = state.open_database(dbname) + + @http_method(old_rev=str) + def put(self, content, old_rev=None): + doc = Document(self.id, old_rev, content) + doc_rev = self.db.put_doc(doc) + if old_rev is None: + status = 201 # created + else: + status = 200 + self.responder.send_response_json(status, rev=doc_rev) + + @http_method(old_rev=str) + def delete(self, old_rev=None): + doc = Document(self.id, old_rev, None) + self.db.delete_doc(doc) + self.responder.send_response_json(200, rev=doc.rev) + + @http_method(include_deleted=parse_bool) + def get(self, include_deleted=False): + doc = self.db.get_doc(self.id, include_deleted=include_deleted) + if doc is None: + wire_descr = errors.DocumentDoesNotExist.wire_description + self.responder.send_response_json( + http_errors.wire_description_to_status[wire_descr], + error=wire_descr, + headers={ + 'x-u1db-rev': '', + 'x-u1db-has-conflicts': 'false' + }) + return + headers = { + 'x-u1db-rev': doc.rev, + 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts) + } + if doc.is_tombstone(): + self.responder.send_response_json( + http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED], + error=errors.DOCUMENT_DELETED, + headers=headers) + else: + self.responder.send_response_content( + doc.get_json(), headers=headers) + + +@url_to_resource.register +class SyncResource(object): + """Sync endpoint resource.""" + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + url_pattern = "/{dbname}/sync-from/{source_replica_uid}" + + # pluggable + sync_exchange_class = sync.SyncExchange + + def __init__(self, dbname, source_replica_uid, state, responder): + self.source_replica_uid = source_replica_uid + self.responder = responder + self.state = state + self.dbname = dbname + self.replica_uid = None + + def get_target(self): + return self.state.open_database(self.dbname).get_sync_target() + + @http_method() + def get(self): + result = self.get_target().get_sync_info(self.source_replica_uid) + self.responder.send_response_json( + target_replica_uid=result[0], target_replica_generation=result[1], + target_replica_transaction_id=result[2], + source_replica_uid=self.source_replica_uid, + source_replica_generation=result[3], + source_transaction_id=result[4]) + + @http_method(generation=int, + content_as_args=True, no_query=True) + def put(self, generation, transaction_id): + self.get_target().record_sync_info(self.source_replica_uid, + generation, + transaction_id) + self.responder.send_response_json(ok=True) + + # Implements the same logic as LocalSyncTarget.sync_exchange + + @http_method(last_known_generation=int, last_known_trans_id=none_or_str, + content_as_args=True) + def post_args(self, last_known_generation, last_known_trans_id=None, + ensure=False): + if ensure: + db, self.replica_uid = self.state.ensure_database(self.dbname) + else: + db = self.state.open_database(self.dbname) + db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + self.sync_exch = self.sync_exchange_class( + db, self.source_replica_uid, last_known_generation) + + @http_method(content_as_args=True) + def post_stream_entry(self, id, rev, content, gen, trans_id): + doc = Document(id, rev, content) + self.sync_exch.insert_doc_from_source(doc, gen, trans_id) + + def post_end(self): + + def send_doc(doc, gen, trans_id): + entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + self.responder.stream_entry(entry) + + new_gen = self.sync_exch.find_changes_to_return() + self.responder.content_type = 'application/x-u1db-sync-stream' + self.responder.start_response(200) + self.responder.start_stream(), + header = {"new_generation": new_gen, + "new_transaction_id": self.sync_exch.new_trans_id} + if self.replica_uid is not None: + header['replica_uid'] = self.replica_uid + self.responder.stream_entry(header) + self.sync_exch.return_docs(send_doc) + self.responder.end_stream() + self.responder.finish_response() + + +class HTTPResponder(object): + """Encode responses from the server back to the client.""" + + # a multi document response will put args and documents + # each on one line of the response body + + def __init__(self, start_response): + self._started = False + self._stream_state = -1 + self._no_initial_obj = True + self.sent_response = False + self._start_response = start_response + self._write = None + self.content_type = 'application/json' + self.content = [] + + def start_response(self, status, obj_dic=None, headers={}): + """start sending response with optional first json object.""" + if self._started: + return + self._started = True + status_text = httplib.responses[status] + self._write = self._start_response('%d %s' % (status, status_text), + [('content-type', self.content_type), + ('cache-control', 'no-cache')] + + headers.items()) + # xxx version in headers + if obj_dic is not None: + self._no_initial_obj = False + self._write(json.dumps(obj_dic) + "\r\n") + + def finish_response(self): + """finish sending response.""" + self.sent_response = True + + def send_response_json(self, status=200, headers={}, **kwargs): + """send and finish response with json object body from keyword args.""" + content = json.dumps(kwargs) + "\r\n" + self.send_response_content(content, headers=headers, status=status) + + def send_response_content(self, content, status=200, headers={}): + """send and finish response with content""" + headers['content-length'] = str(len(content)) + self.start_response(status, headers=headers) + if self._stream_state == 1: + self.content = [',\r\n', content] + else: + self.content = [content] + self.finish_response() + + def start_stream(self): + "start stream (array) as part of the response." + assert self._started and self._no_initial_obj + self._stream_state = 0 + self._write("[") + + def stream_entry(self, entry): + "send stream entry as part of the response." + assert self._stream_state != -1 + if self._stream_state == 0: + self._stream_state = 1 + self._write('\r\n') + else: + self._write(',\r\n') + self._write(json.dumps(entry)) + + def end_stream(self): + "end stream (array)." + assert self._stream_state != -1 + self._write("\r\n]\r\n") + + +class HTTPInvocationByMethodWithBody(object): + """Invoke methods on a resource.""" + + def __init__(self, resource, environ, parameters): + self.resource = resource + self.environ = environ + self.max_request_size = getattr( + resource, 'max_request_size', parameters.max_request_size) + self.max_entry_size = getattr( + resource, 'max_entry_size', parameters.max_entry_size) + + def _lookup(self, method): + try: + return getattr(self.resource, method) + except AttributeError: + raise BadRequest() + + def __call__(self): + args = urlparse.parse_qsl(self.environ['QUERY_STRING'], + strict_parsing=False) + try: + args = dict( + (k.decode('utf-8'), v.decode('utf-8')) for k, v in args) + except ValueError: + raise BadRequest() + method = self.environ['REQUEST_METHOD'].lower() + if method in ('get', 'delete'): + meth = self._lookup(method) + return meth(args, None) + else: + # we expect content-length > 0, reconsider if we move + # to support chunked enconding + try: + content_length = int(self.environ['CONTENT_LENGTH']) + except (ValueError, KeyError): + raise BadRequest + if content_length <= 0: + raise BadRequest + if content_length > self.max_request_size: + raise BadRequest + reader = _FencedReader(self.environ['wsgi.input'], content_length, + self.max_entry_size) + content_type = self.environ.get('CONTENT_TYPE') + if content_type == 'application/json': + meth = self._lookup(method) + body = reader.read_chunk(sys.maxint) + return meth(args, body) + elif content_type == 'application/x-u1db-sync-stream': + meth_args = self._lookup('%s_args' % method) + meth_entry = self._lookup('%s_stream_entry' % method) + meth_end = self._lookup('%s_end' % method) + body_getline = reader.getline + if body_getline().strip() != '[': + raise BadRequest() + line = body_getline() + line, comma = utils.check_and_strip_comma(line.strip()) + meth_args(args, line) + while True: + line = body_getline() + entry = line.strip() + if entry == ']': + break + if not entry or not comma: # empty or no prec comma + raise BadRequest + entry, comma = utils.check_and_strip_comma(entry) + meth_entry({}, entry) + if comma or body_getline(): # extra comma or data + raise BadRequest + return meth_end() + else: + raise BadRequest() + + +class HTTPApp(object): + + # maximum allowed request body size + max_request_size = 15 * 1024 * 1024 # 15Mb + # maximum allowed entry/line size in request body + max_entry_size = 10 * 1024 * 1024 # 10Mb + + def __init__(self, state): + self.state = state + + def _lookup_resource(self, environ, responder): + resource_cls, params = url_to_resource.match(environ['PATH_INFO']) + if resource_cls is None: + raise BadRequest # 404 instead? + resource = resource_cls( + state=self.state, responder=responder, **params) + return resource + + def __call__(self, environ, start_response): + responder = HTTPResponder(start_response) + self.request_begin(environ) + try: + resource = self._lookup_resource(environ, responder) + HTTPInvocationByMethodWithBody(resource, environ, self)() + except errors.U1DBError, e: + self.request_u1db_error(environ, e) + status = http_errors.wire_description_to_status.get( + e.wire_description, 500) + responder.send_response_json(status, error=e.wire_description) + except BadRequest: + self.request_bad_request(environ) + responder.send_response_json(400, error="bad request") + except KeyboardInterrupt: + raise + except: + self.request_failed(environ) + raise + else: + self.request_done(environ) + return responder.content + + # hooks for tracing requests + + def request_begin(self, environ): + """Hook called at the beginning of processing a request.""" + pass + + def request_done(self, environ): + """Hook called when done processing a request.""" + pass + + def request_u1db_error(self, environ, exc): + """Hook called when processing a request resulted in a U1DBError. + + U1DBError passed as exc. + """ + pass + + def request_bad_request(self, environ): + """Hook called when processing a bad request. + + No actual processing was done. + """ + pass + + def request_failed(self, environ): + """Hook called when processing a request failed unexpectedly. + + Invoked from an except block, so there's interpreter exception + information available. + """ + pass diff --git a/src/leap/soledad/u1db/remote/http_client.py b/src/leap/soledad/u1db/remote/http_client.py new file mode 100644 index 00000000..decddda3 --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_client.py @@ -0,0 +1,218 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Base class to make requests to a remote HTTP server.""" + +import httplib +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa +import socket +import ssl +import sys +import urlparse +import urllib + +from time import sleep +from u1db import ( + errors, + ) +from u1db.remote import ( + http_errors, + ) + +from u1db.remote.ssl_match_hostname import ( # noqa + CertificateError, + match_hostname, + ) + +# Ubuntu/debian +# XXX other... +CA_CERTS = "/etc/ssl/certs/ca-certificates.crt" + + +def _encode_query_parameter(value): + """Encode query parameter.""" + if isinstance(value, bool): + if value: + value = 'true' + else: + value = 'false' + return unicode(value).encode('utf-8') + + +class _VerifiedHTTPSConnection(httplib.HTTPSConnection): + """HTTPSConnection verifying server side certificates.""" + # derived from httplib.py + + def connect(self): + "Connect to a host on a given (SSL) port." + + sock = socket.create_connection((self.host, self.port), + self.timeout, self.source_address) + if self._tunnel_host: + self.sock = sock + self._tunnel() + if sys.platform.startswith('linux'): + cert_opts = { + 'cert_reqs': ssl.CERT_REQUIRED, + 'ca_certs': CA_CERTS + } + else: + # XXX no cert verification implemented elsewhere for now + cert_opts = {} + self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, + ssl_version=ssl.PROTOCOL_SSLv3, + **cert_opts + ) + if cert_opts: + match_hostname(self.sock.getpeercert(), self.host) + + +class HTTPClientBase(object): + """Base class to make requests to a remote HTTP server.""" + + # by default use HMAC-SHA1 OAuth signature method to not disclose + # tokens + # NB: given that the content bodies are not covered by the + # signatures though, to achieve security (against man-in-the-middle + # attacks for example) one would need HTTPS + oauth_signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1() + + # Will use these delays to retry on 503 befor finally giving up. The final + # 0 is there to not wait after the final try fails. + _delays = (1, 1, 2, 4, 0) + + def __init__(self, url, creds=None): + self._url = urlparse.urlsplit(url) + self._conn = None + self._creds = {} + if creds is not None: + if len(creds) != 1: + raise errors.UnknownAuthMethod() + auth_meth, credentials = creds.items()[0] + try: + set_creds = getattr(self, 'set_%s_credentials' % auth_meth) + except AttributeError: + raise errors.UnknownAuthMethod(auth_meth) + set_creds(**credentials) + + def set_oauth_credentials(self, consumer_key, consumer_secret, + token_key, token_secret): + self._creds = {'oauth': ( + oauth.OAuthConsumer(consumer_key, consumer_secret), + oauth.OAuthToken(token_key, token_secret))} + + def _ensure_connection(self): + if self._conn is not None: + return + if self._url.scheme == 'https': + connClass = _VerifiedHTTPSConnection + else: + connClass = httplib.HTTPConnection + self._conn = connClass(self._url.hostname, self._url.port) + + def close(self): + if self._conn: + self._conn.close() + self._conn = None + + # xxx retry mechanism? + + def _error(self, respdic): + descr = respdic.get("error") + exc_cls = errors.wire_description_to_exc.get(descr) + if exc_cls is not None: + message = respdic.get("message") + raise exc_cls(message) + + def _response(self): + resp = self._conn.getresponse() + body = resp.read() + headers = dict(resp.getheaders()) + if resp.status in (200, 201): + return body, headers + elif resp.status in http_errors.ERROR_STATUSES: + try: + respdic = json.loads(body) + except ValueError: + pass + else: + self._error(respdic) + # special case + if resp.status == 503: + raise errors.Unavailable(body, headers) + raise errors.HTTPError(resp.status, body, headers) + + def _sign_request(self, method, url_query, params): + if 'oauth' in self._creds: + consumer, token = self._creds['oauth'] + full_url = "%s://%s%s" % (self._url.scheme, self._url.netloc, + url_query) + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + consumer, token, + http_method=method, + parameters=params, + http_url=full_url + ) + oauth_req.sign_request( + self.oauth_signature_method, consumer, token) + # Authorization: OAuth ... + return oauth_req.to_header().items() + else: + return [] + + def _request(self, method, url_parts, params=None, body=None, + content_type=None): + self._ensure_connection() + unquoted_url = url_query = self._url.path + if url_parts: + if not url_query.endswith('/'): + url_query += '/' + unquoted_url = url_query + url_query += '/'.join(urllib.quote(part, safe='') + for part in url_parts) + # oauth performs its own quoting + unquoted_url += '/'.join(url_parts) + encoded_params = {} + if params: + for key, value in params.items(): + key = unicode(key).encode('utf-8') + encoded_params[key] = _encode_query_parameter(value) + url_query += ('?' + urllib.urlencode(encoded_params)) + if body is not None and not isinstance(body, basestring): + body = json.dumps(body) + content_type = 'application/json' + headers = {} + if content_type: + headers['content-type'] = content_type + headers.update( + self._sign_request(method, unquoted_url, encoded_params)) + for delay in self._delays: + try: + self._conn.request(method, url_query, body, headers) + return self._response() + except errors.Unavailable, e: + sleep(delay) + raise e + + def _request_json(self, method, url_parts, params=None, body=None, + content_type=None): + res, headers = self._request(method, url_parts, params, body, + content_type) + return json.loads(res), headers diff --git a/src/leap/soledad/u1db/remote/http_database.py b/src/leap/soledad/u1db/remote/http_database.py new file mode 100644 index 00000000..6901baad --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_database.py @@ -0,0 +1,143 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""HTTPDatabase to access a remote db over the HTTP API.""" + +try: + import simplejson as json +except ImportError: + import json # noqa +import uuid + +from u1db import ( + Database, + Document, + errors, + ) +from u1db.remote import ( + http_client, + http_errors, + http_target, + ) + + +DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[ + errors.DOCUMENT_DELETED] + + +class HTTPDatabase(http_client.HTTPClientBase, Database): + """Implement the Database API to a remote HTTP server.""" + + def __init__(self, url, document_factory=None, creds=None): + super(HTTPDatabase, self).__init__(url, creds=creds) + self._factory = document_factory or Document + + def set_document_factory(self, factory): + self._factory = factory + + @staticmethod + def open_database(url, create): + db = HTTPDatabase(url) + db.open(create) + return db + + @staticmethod + def delete_database(url): + db = HTTPDatabase(url) + db._delete() + db.close() + + def open(self, create): + if create: + self._ensure() + else: + self._check() + + def _check(self): + return self._request_json('GET', [])[0] + + def _ensure(self): + self._request_json('PUT', [], {}, {}) + + def _delete(self): + self._request_json('DELETE', [], {}, {}) + + def put_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {} + if doc.rev is not None: + params['old_rev'] = doc.rev + res, headers = self._request_json('PUT', ['doc', doc.doc_id], params, + doc.get_json(), 'application/json') + doc.rev = res['rev'] + return res['rev'] + + def get_doc(self, doc_id, include_deleted=False): + try: + res, headers = self._request( + 'GET', ['doc', doc_id], {"include_deleted": include_deleted}) + except errors.DocumentDoesNotExist: + return None + except errors.HTTPError, e: + if (e.status == DOCUMENT_DELETED_STATUS and + 'x-u1db-rev' in e.headers): + res = None + headers = e.headers + else: + raise + doc_rev = headers['x-u1db-rev'] + has_conflicts = json.loads(headers['x-u1db-has-conflicts']) + doc = self._factory(doc_id, doc_rev, res) + doc.has_conflicts = has_conflicts + return doc + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + if not doc_ids: + return + doc_ids = ','.join(doc_ids) + res, headers = self._request( + 'GET', ['docs'], { + "doc_ids": doc_ids, "include_deleted": include_deleted, + "check_for_conflicts": check_for_conflicts}) + for doc_dict in json.loads(res): + doc = self._factory( + doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content']) + doc.has_conflicts = doc_dict['has_conflicts'] + yield doc + + def create_doc_from_json(self, content, doc_id=None): + if doc_id is None: + doc_id = 'D-%s' % (uuid.uuid4().hex,) + res, headers = self._request_json('PUT', ['doc', doc_id], {}, + content, 'application/json') + new_doc = self._factory(doc_id, res['rev'], content) + return new_doc + + def delete_doc(self, doc): + if doc.doc_id is None: + raise errors.InvalidDocId() + params = {'old_rev': doc.rev} + res, headers = self._request_json('DELETE', + ['doc', doc.doc_id], params) + doc.make_tombstone() + doc.rev = res['rev'] + + def get_sync_target(self): + st = http_target.HTTPSyncTarget(self._url.geturl()) + st._creds = self._creds + return st diff --git a/src/leap/soledad/u1db/remote/http_errors.py b/src/leap/soledad/u1db/remote/http_errors.py new file mode 100644 index 00000000..2039c5b2 --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_errors.py @@ -0,0 +1,46 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Information about the encoding of errors over HTTP.""" + +from u1db import ( + errors, + ) + + +# error wire descriptions mapping to HTTP status codes +wire_description_to_status = dict([ + (errors.InvalidDocId.wire_description, 400), + (errors.MissingDocIds.wire_description, 400), + (errors.Unauthorized.wire_description, 401), + (errors.DocumentTooBig.wire_description, 403), + (errors.UserQuotaExceeded.wire_description, 403), + (errors.SubscriptionNeeded.wire_description, 403), + (errors.DatabaseDoesNotExist.wire_description, 404), + (errors.DocumentDoesNotExist.wire_description, 404), + (errors.DocumentAlreadyDeleted.wire_description, 404), + (errors.RevisionConflict.wire_description, 409), + (errors.InvalidGeneration.wire_description, 409), + (errors.InvalidTransactionId.wire_description, 409), + (errors.Unavailable.wire_description, 503), +# without matching exception + (errors.DOCUMENT_DELETED, 404) +]) + + +ERROR_STATUSES = set(wire_description_to_status.values()) +# 400 included explicitly for tests +ERROR_STATUSES.add(400) diff --git a/src/leap/soledad/u1db/remote/http_target.py b/src/leap/soledad/u1db/remote/http_target.py new file mode 100644 index 00000000..1028963e --- /dev/null +++ b/src/leap/soledad/u1db/remote/http_target.py @@ -0,0 +1,135 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""SyncTarget API implementation to a remote HTTP server.""" + +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + Document, + SyncTarget, + ) +from u1db.errors import ( + BrokenSyncStream, + ) +from u1db.remote import ( + http_client, + utils, + ) + + +class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget): + """Implement the SyncTarget api to a remote HTTP server.""" + + @staticmethod + def connect(url): + return HTTPSyncTarget(url) + + def get_sync_info(self, source_replica_uid): + self._ensure_connection() + res, _ = self._request_json('GET', ['sync-from', source_replica_uid]) + return (res['target_replica_uid'], res['target_replica_generation'], + res['target_replica_transaction_id'], + res['source_replica_generation'], res['source_transaction_id']) + + def record_sync_info(self, source_replica_uid, source_replica_generation, + source_transaction_id): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('record_sync_info') + self._request_json('PUT', ['sync-from', source_replica_uid], {}, + {'generation': source_replica_generation, + 'transaction_id': source_transaction_id}) + + def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None): + parts = data.splitlines() # one at a time + if not parts or parts[0] != '[': + raise BrokenSyncStream + data = parts[1:-1] + comma = False + if data: + line, comma = utils.check_and_strip_comma(data[0]) + res = json.loads(line) + if ensure_callback and 'replica_uid' in res: + ensure_callback(res['replica_uid']) + for entry in data[1:]: + if not comma: # missing in between comma + raise BrokenSyncStream + line, comma = utils.check_and_strip_comma(entry) + entry = json.loads(line) + doc = Document(entry['id'], entry['rev'], entry['content']) + return_doc_cb(doc, entry['gen'], entry['trans_id']) + if parts[-1] != ']': + try: + partdic = json.loads(parts[-1]) + except ValueError: + pass + else: + if isinstance(partdic, dict): + self._error(partdic) + raise BrokenSyncStream + if not data or comma: # no entries or bad extra comma + raise BrokenSyncStream + return res + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._ensure_connection() + if self._trace_hook: # for tests + self._trace_hook('sync_exchange') + url = '%s/sync-from/%s' % (self._url.path, source_replica_uid) + self._conn.putrequest('POST', url) + self._conn.putheader('content-type', 'application/x-u1db-sync-stream') + for header_name, header_value in self._sign_request('POST', url, {}): + self._conn.putheader(header_name, header_value) + entries = ['['] + size = 1 + + def prepare(**dic): + entry = comma + '\r\n' + json.dumps(dic) + entries.append(entry) + return len(entry) + + comma = '' + size += prepare( + last_known_generation=last_known_generation, + last_known_trans_id=last_known_trans_id, + ensure=ensure_callback is not None) + comma = ',' + for doc, gen, trans_id in docs_by_generations: + size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(), + gen=gen, trans_id=trans_id) + entries.append('\r\n]') + size += len(entries[-1]) + self._conn.putheader('content-length', str(size)) + self._conn.endheaders() + for entry in entries: + self._conn.send(entry) + entries = None + data, _ = self._response() + res = self._parse_sync_stream(data, return_doc_cb, ensure_callback) + data = None + return res['new_generation'], res['new_transaction_id'] + + # for tests + _trace_hook = None + + def _set_trace_hook_shallow(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/u1db/remote/oauth_middleware.py b/src/leap/soledad/u1db/remote/oauth_middleware.py new file mode 100644 index 00000000..5772580a --- /dev/null +++ b/src/leap/soledad/u1db/remote/oauth_middleware.py @@ -0,0 +1,89 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. +"""U1DB OAuth authorisation WSGI middleware.""" +import httplib +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa +from urllib import quote +from wsgiref.util import shift_path_info + + +sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() +sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() + + +class OAuthMiddleware(object): + """U1DB OAuth Authorisation WSGI middleware.""" + + # max seconds the request timestamp is allowed to be shifted + # from arrival time + timestamp_threshold = 300 + + def __init__(self, app, base_url, prefix='/~/'): + self.app = app + self.base_url = base_url + self.prefix = prefix + + def get_oauth_data_store(self): + """Provide a oauth.OAuthDataStore.""" + raise NotImplementedError(self.get_oauth_data_store) + + def _error(self, start_response, status, description, message=None): + start_response("%d %s" % (status, httplib.responses[status]), + [('content-type', 'application/json')]) + err = {"error": description} + if message: + err['message'] = message + return [json.dumps(err)] + + def __call__(self, environ, start_response): + if self.prefix and not environ['PATH_INFO'].startswith(self.prefix): + return self._error(start_response, 400, "bad request") + headers = {} + if 'HTTP_AUTHORIZATION' in environ: + headers['Authorization'] = environ['HTTP_AUTHORIZATION'] + oauth_req = oauth.OAuthRequest.from_request( + http_method=environ['REQUEST_METHOD'], + http_url=self.base_url + environ['PATH_INFO'], + headers=headers, + query_string=environ['QUERY_STRING'] + ) + if oauth_req is None: + return self._error(start_response, 401, "unauthorized", + "Missing OAuth.") + try: + self.verify(environ, oauth_req) + except oauth.OAuthError, e: + return self._error(start_response, 401, "unauthorized", + e.message) + shift_path_info(environ) + return self.app(environ, start_response) + + def verify(self, environ, oauth_req): + """Verify OAuth request, put user_id in the environ.""" + oauth_server = oauth.OAuthServer(self.get_oauth_data_store()) + oauth_server.timestamp_threshold = self.timestamp_threshold + oauth_server.add_signature_method(sign_meth_HMAC_SHA1) + oauth_server.add_signature_method(sign_meth_PLAINTEXT) + consumer, token, parameters = oauth_server.verify_request(oauth_req) + # filter out oauth bits + environ['QUERY_STRING'] = '&'.join("%s=%s" % (quote(k, safe=''), + quote(v, safe='')) + for k, v in parameters.iteritems()) + return consumer, token diff --git a/src/leap/soledad/u1db/remote/server_state.py b/src/leap/soledad/u1db/remote/server_state.py new file mode 100644 index 00000000..96581359 --- /dev/null +++ b/src/leap/soledad/u1db/remote/server_state.py @@ -0,0 +1,67 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""State for servers exposing a set of U1DB databases.""" +import os +import errno + +class ServerState(object): + """Passed to a Request when it is instantiated. + + This is used to track server-side state, such as working-directory, open + databases, etc. + """ + + def __init__(self): + self._workingdir = None + + def set_workingdir(self, path): + self._workingdir = path + + def _relpath(self, relpath): + # Note: We don't want to allow absolute paths here, because we + # don't want to expose the filesystem. We should also check that + # relpath doesn't have '..' in it, etc. + return self._workingdir + '/' + relpath + + def open_database(self, path): + """Open a database at the given location.""" + from u1db.backends import sqlite_backend + full_path = self._relpath(path) + return sqlite_backend.SQLiteDatabase.open_database(full_path, + create=False) + + def check_database(self, path): + """Check if the database at the given location exists. + + Simply returns if it does or raises DatabaseDoesNotExist. + """ + db = self.open_database(path) + db.close() + + def ensure_database(self, path): + """Ensure database at the given location.""" + from u1db.backends import sqlite_backend + full_path = self._relpath(path) + db = sqlite_backend.SQLiteDatabase.open_database(full_path, + create=True) + return db, db._replica_uid + + def delete_database(self, path): + """Delete database at the given location.""" + from u1db.backends import sqlite_backend + full_path = self._relpath(path) + sqlite_backend.SQLiteDatabase.delete_database(full_path) diff --git a/src/leap/soledad/u1db/remote/ssl_match_hostname.py b/src/leap/soledad/u1db/remote/ssl_match_hostname.py new file mode 100644 index 00000000..fbabc177 --- /dev/null +++ b/src/leap/soledad/u1db/remote/ssl_match_hostname.py @@ -0,0 +1,64 @@ +"""The match_hostname() function from Python 3.2, essential when using SSL.""" +# XXX put it here until it's packaged + +import re + +__version__ = '3.2a3' + + +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn): + pats = [] + for frag in dn.split(r'.'): + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not san: + # The subject is only checked when subjectAltName is empty + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") diff --git a/src/leap/soledad/u1db/remote/utils.py b/src/leap/soledad/u1db/remote/utils.py new file mode 100644 index 00000000..14cedea9 --- /dev/null +++ b/src/leap/soledad/u1db/remote/utils.py @@ -0,0 +1,23 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Utilities for details of the procotol.""" + + +def check_and_strip_comma(line): + if line and line[-1] == ',': + return line[:-1], True + return line, False diff --git a/src/leap/soledad/u1db/sync.py b/src/leap/soledad/u1db/sync.py new file mode 100644 index 00000000..3375d097 --- /dev/null +++ b/src/leap/soledad/u1db/sync.py @@ -0,0 +1,304 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The synchronization utilities for U1DB.""" +from itertools import izip + +import u1db +from u1db import errors + + +class Synchronizer(object): + """Collect the state around synchronizing 2 U1DB replicas. + + Synchronization is bi-directional, in that new items in the source are sent + to the target, and new items in the target are returned to the source. + However, it still recognizes that one side is initiating the request. Also, + at the moment, conflicts are only created in the source. + """ + + def __init__(self, source, sync_target): + """Create a new Synchronization object. + + :param source: A Database + :param sync_target: A SyncTarget + """ + self.source = source + self.sync_target = sync_target + self.target_replica_uid = None + self.num_inserted = 0 + + def _insert_doc_from_target(self, doc, replica_gen, trans_id): + """Try to insert synced document from target. + + Implements TAKE OTHER semantics: any document from the target + that is in conflict will be taken as the new official value, + while the current conflicting value will be stored alongside + as a conflict. In the process indexes will be updated etc. + + :return: None + """ + # Increases self.num_inserted depending whether the document + # was effectively inserted. + state, _ = self.source._put_doc_if_newer(doc, save_conflict=True, + replica_uid=self.target_replica_uid, replica_gen=replica_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.num_inserted += 1 + elif state == 'converged': + # magical convergence + pass + elif state == 'superseded': + # we have something newer, will be taken care of at the next sync + pass + else: + assert state == 'conflicted' + # The doc was saved as a conflict, so the database was updated + self.num_inserted += 1 + + def _record_sync_info_with_the_target(self, start_generation): + """Record our new after sync generation with the target if gapless. + + Any documents received from the target will cause the local + database to increment its generation. We do not want to send + them back to the target in a future sync. However, there could + also be concurrent updates from another process doing eg + 'put_doc' while the sync was running. And we do want to + synchronize those documents. We can tell if there was a + concurrent update by comparing our new generation number + versus the generation we started, and how many documents we + inserted from the target. If it matches exactly, then we can + record with the target that they are fully up to date with our + new generation. + """ + cur_gen, trans_id = self.source._get_generation_info() + if (cur_gen == start_generation + self.num_inserted + and self.num_inserted > 0): + self.sync_target.record_sync_info( + self.source._replica_uid, cur_gen, trans_id) + + def sync(self, callback=None, autocreate=False): + """Synchronize documents between source and target.""" + sync_target = self.sync_target + # get target identifier, its current generation, + # and its last-seen database generation for this source + try: + (self.target_replica_uid, target_gen, target_trans_id, + target_my_gen, target_my_trans_id) = sync_target.get_sync_info( + self.source._replica_uid) + except errors.DatabaseDoesNotExist: + if not autocreate: + raise + # will try to ask sync_exchange() to create the db + self.target_replica_uid = None + target_gen, target_trans_id = 0, '' + target_my_gen, target_my_trans_id = 0, '' + def ensure_callback(replica_uid): + self.target_replica_uid = replica_uid + else: + ensure_callback = None + # validate the generation and transaction id the target knows about us + self.source.validate_gen_and_trans_id( + target_my_gen, target_my_trans_id) + # what's changed since that generation and this current gen + my_gen, _, changes = self.source.whats_changed(target_my_gen) + + # this source last-seen database generation for the target + if self.target_replica_uid is None: + target_last_known_gen, target_last_known_trans_id = 0, '' + else: + target_last_known_gen, target_last_known_trans_id = \ + self.source._get_replica_gen_and_trans_id(self.target_replica_uid) + if not changes and target_last_known_gen == target_gen: + if target_trans_id != target_last_known_trans_id: + raise errors.InvalidTransactionId + return my_gen + changed_doc_ids = [doc_id for doc_id, _, _ in changes] + # prepare to send all the changed docs + docs_to_send = self.source.get_docs(changed_doc_ids, + check_for_conflicts=False, include_deleted=True) + # TODO: there must be a way to not iterate twice + docs_by_generation = zip( + docs_to_send, (gen for _, gen, _ in changes), + (trans for _, _, trans in changes)) + + # exchange documents and try to insert the returned ones with + # the target, return target synced-up-to gen + new_gen, new_trans_id = sync_target.sync_exchange( + docs_by_generation, self.source._replica_uid, + target_last_known_gen, target_last_known_trans_id, + self._insert_doc_from_target, ensure_callback=ensure_callback) + # record target synced-up-to generation including applying what we sent + self.source._set_replica_gen_and_trans_id( + self.target_replica_uid, new_gen, new_trans_id) + + # if gapless record current reached generation with target + self._record_sync_info_with_the_target(my_gen) + + return my_gen + + +class SyncExchange(object): + """Steps and state for carrying through a sync exchange on a target.""" + + def __init__(self, db, source_replica_uid, last_known_generation): + self._db = db + self.source_replica_uid = source_replica_uid + self.source_last_known_generation = last_known_generation + self.seen_ids = {} # incoming ids not superseded + self.changes_to_return = None + self.new_gen = None + self.new_trans_id = None + # for tests + self._incoming_trace = [] + self._trace_hook = None + self._db._last_exchange_log = { + 'receive': {'docs': self._incoming_trace}, + 'return': None + } + + def _set_trace_hook(self, cb): + self._trace_hook = cb + + def _trace(self, state): + if not self._trace_hook: + return + self._trace_hook(state) + + def insert_doc_from_source(self, doc, source_gen, trans_id): + """Try to insert synced document from source. + + Conflicting documents are not inserted but will be sent over + to the sync source. + + It keeps track of progress by storing the document source + generation as well. + + The 1st step of a sync exchange is to call this repeatedly to + try insert all incoming documents from the source. + + :param doc: A Document object. + :param source_gen: The source generation of doc. + :return: None + """ + state, at_gen = self._db._put_doc_if_newer(doc, save_conflict=False, + replica_uid=self.source_replica_uid, replica_gen=source_gen, + replica_trans_id=trans_id) + if state == 'inserted': + self.seen_ids[doc.doc_id] = at_gen + elif state == 'converged': + # magical convergence + self.seen_ids[doc.doc_id] = at_gen + elif state == 'superseded': + # we have something newer that we will return + pass + else: + # conflict that we will returne + assert state == 'conflicted' + # for tests + self._incoming_trace.append((doc.doc_id, doc.rev)) + self._db._last_exchange_log['receive'].update({ + 'source_uid': self.source_replica_uid, + 'source_gen': source_gen + }) + + def find_changes_to_return(self): + """Find changes to return. + + Find changes since last_known_generation in db generation + order using whats_changed. It excludes documents ids that have + already been considered (superseded by the sender, etc). + + :return: new_generation - the generation of this database + which the caller can consider themselves to be synchronized after + processing the returned documents. + """ + self._db._last_exchange_log['receive'].update({ # for tests + 'last_known_gen': self.source_last_known_generation + }) + self._trace('before whats_changed') + gen, trans_id, changes = self._db.whats_changed( + self.source_last_known_generation) + self._trace('after whats_changed') + self.new_gen = gen + self.new_trans_id = trans_id + seen_ids = self.seen_ids + # changed docs that weren't superseded by or converged with + self.changes_to_return = [ + (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes + # there was a subsequent update + if doc_id not in seen_ids or seen_ids.get(doc_id) < gen] + return self.new_gen + + def return_docs(self, return_doc_cb): + """Return the changed documents and their last change generation + repeatedly invoking the callback return_doc_cb. + + The final step of a sync exchange. + + :param: return_doc_cb(doc, gen, trans_id): is a callback + used to return the documents with their last change generation + to the target replica. + :return: None + """ + changes_to_return = self.changes_to_return + # return docs, including conflicts + changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return] + self._trace('before get_docs') + docs = self._db.get_docs( + changed_doc_ids, check_for_conflicts=False, include_deleted=True) + + docs_by_gen = izip( + docs, (gen for _, gen, _ in changes_to_return), + (trans_id for _, _, trans_id in changes_to_return)) + _outgoing_trace = [] # for tests + for doc, gen, trans_id in docs_by_gen: + return_doc_cb(doc, gen, trans_id) + _outgoing_trace.append((doc.doc_id, doc.rev)) + # for tests + self._db._last_exchange_log['return'] = { + 'docs': _outgoing_trace, + 'last_gen': self.new_gen + } + + +class LocalSyncTarget(u1db.SyncTarget): + """Common sync target implementation logic for all local sync targets.""" + + def __init__(self, db): + self._db = db + self._trace_hook = None + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + self._db.validate_gen_and_trans_id( + last_known_generation, last_known_trans_id) + sync_exch = SyncExchange( + self._db, source_replica_uid, last_known_generation) + if self._trace_hook: + sync_exch._set_trace_hook(self._trace_hook) + # 1st step: try to insert incoming docs and record progress + for doc, doc_gen, trans_id in docs_by_generations: + sync_exch.insert_doc_from_source(doc, doc_gen, trans_id) + # 2nd step: find changed documents (including conflicts) to return + new_gen = sync_exch.find_changes_to_return() + # final step: return docs and record source replica sync point + sync_exch.return_docs(return_doc_cb) + return new_gen, sync_exch.new_trans_id + + def _set_trace_hook(self, cb): + self._trace_hook = cb diff --git a/src/leap/soledad/u1db/tests/__init__.py b/src/leap/soledad/u1db/tests/__init__.py new file mode 100644 index 00000000..b8e16b15 --- /dev/null +++ b/src/leap/soledad/u1db/tests/__init__.py @@ -0,0 +1,463 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test infrastructure for U1DB""" + +import copy +import shutil +import socket +import tempfile +import threading + +try: + import simplejson as json +except ImportError: + import json # noqa + +from wsgiref import simple_server + +from oauth import oauth +from sqlite3 import dbapi2 +from StringIO import StringIO + +import testscenarios +import testtools + +from u1db import ( + errors, + Document, + ) +from u1db.backends import ( + inmemory, + sqlite_backend, + ) +from u1db.remote import ( + server_state, + ) + +try: + from u1db.tests import c_backend_wrapper + c_backend_error = None +except ImportError, e: + c_backend_wrapper = None # noqa + c_backend_error = e + +# Setting this means that failing assertions will not include this module in +# their traceback. However testtools doesn't seem to set it, and we don't want +# this level to be omitted, but the lower levels to be shown. +# __unittest = 1 + + +class TestCase(testtools.TestCase): + + def createTempDir(self, prefix='u1db-tmp-'): + """Create a temporary directory to do some work in. + + This directory will be scheduled for cleanup when the test ends. + """ + tempdir = tempfile.mkdtemp(prefix=prefix) + self.addCleanup(shutil.rmtree, tempdir) + return tempdir + + def make_document(self, doc_id, doc_rev, content, has_conflicts=False): + return self.make_document_for_test( + self, doc_id, doc_rev, content, has_conflicts) + + def make_document_for_test(self, test, doc_id, doc_rev, content, + has_conflicts): + return make_document_for_test( + test, doc_id, doc_rev, content, has_conflicts) + + def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id)) + + def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content, + has_conflicts): + """Assert that the document in the database looks correct.""" + exp_doc = self.make_document(doc_id, doc_rev, content, + has_conflicts=has_conflicts) + self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True)) + + def assertGetDocConflicts(self, db, doc_id, conflicts): + """Assert what conflicts are stored for a given doc_id. + + :param conflicts: A list of (doc_rev, content) pairs. + The first item must match the first item returned from the + database, however the rest can be returned in any order. + """ + if conflicts: + conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring) + else cont)) for (rev, cont) in conflicts] + conflicts = conflicts[:1] + sorted(conflicts[1:]) + actual = db.get_doc_conflicts(doc_id) + if actual: + actual = [(doc.rev, (json.loads(doc.get_json()) + if doc.get_json() is not None else None)) for doc in actual] + actual = actual[:1] + sorted(actual[1:]) + self.assertEqual(conflicts, actual) + + +def multiply_scenarios(a_scenarios, b_scenarios): + """Create the cross-product of scenarios.""" + + all_scenarios = [] + for a_name, a_attrs in a_scenarios: + for b_name, b_attrs in b_scenarios: + name = '%s,%s' % (a_name, b_name) + attrs = dict(a_attrs) + attrs.update(b_attrs) + all_scenarios.append((name, attrs)) + return all_scenarios + + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + + +def make_memory_database_for_test(test, replica_uid): + return inmemory.InMemoryDatabase(replica_uid) + + +def copy_memory_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + new_db = inmemory.InMemoryDatabase(db._replica_uid) + new_db._transaction_log = db._transaction_log[:] + new_db._docs = copy.deepcopy(db._docs) + new_db._conflicts = copy.deepcopy(db._conflicts) + new_db._indexes = copy.deepcopy(db._indexes) + new_db._factory = db._factory + return new_db + + +def make_sqlite_partial_expanded_for_test(test, replica_uid): + db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + db._set_replica_uid(replica_uid) + return db + + +def copy_sqlite_partial_expanded_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + tmpfile = StringIO() + for line in db._db_handle.iterdump(): + if not 'sqlite_sequence' in line: # work around bug in iterdump + tmpfile.write('%s\n' % line) + tmpfile.seek(0) + new_db._db_handle = dbapi2.connect(':memory:') + new_db._db_handle.cursor().executescript(tmpfile.read()) + new_db._db_handle.commit() + new_db._set_replica_uid(db._replica_uid) + new_db._factory = db._factory + return new_db + + +def make_document_for_test(test, doc_id, rev, content, has_conflicts=False): + return Document(doc_id, rev, content, has_conflicts=has_conflicts) + + +def make_c_database_for_test(test, replica_uid): + if c_backend_wrapper is None: + test.skipTest('c_backend_wrapper is not available') + db = c_backend_wrapper.CDatabase(':memory:') + db._set_replica_uid(replica_uid) + return db + + +def copy_c_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + if c_backend_wrapper is None: + test.skipTest('c_backend_wrapper is not available') + new_db = db._copy(db) + return new_db + + +def make_c_document_for_test(test, doc_id, rev, content, has_conflicts=False): + if c_backend_wrapper is None: + test.skipTest('c_backend_wrapper is not available') + return c_backend_wrapper.make_document( + doc_id, rev, content, has_conflicts=has_conflicts) + + +LOCAL_DATABASES_SCENARIOS = [ + ('mem', {'make_database_for_test': make_memory_database_for_test, + 'copy_database_for_test': copy_memory_database_for_test, + 'make_document_for_test': make_document_for_test}), + ('sql', {'make_database_for_test': + make_sqlite_partial_expanded_for_test, + 'copy_database_for_test': + copy_sqlite_partial_expanded_for_test, + 'make_document_for_test': make_document_for_test}), + ] + + +C_DATABASE_SCENARIOS = [ + ('c', {'make_database_for_test': make_c_database_for_test, + 'copy_database_for_test': copy_c_database_for_test, + 'make_document_for_test': make_c_document_for_test})] + + +class DatabaseBaseTests(TestCase): + + accept_fixed_trans_id = False # set to True assertTransactionLog + # is happy with all trans ids = '' + + scenarios = LOCAL_DATABASES_SCENARIOS + + def create_database(self, replica_uid): + return self.make_database_for_test(self, replica_uid) + + def copy_database(self, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES + # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST + # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS + # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND + # NINJA TO YOUR HOUSE. + return self.copy_database_for_test(self, db) + + def setUp(self): + super(DatabaseBaseTests, self).setUp() + self.db = self.create_database('test') + + def tearDown(self): + # TODO: Add close_database parameterization + # self.close_database(self.db) + super(DatabaseBaseTests, self).tearDown() + + def assertTransactionLog(self, doc_ids, db): + """Assert that the given docs are in the transaction log.""" + log = db._get_transaction_log() + just_ids = [] + seen_transactions = set() + for doc_id, transaction_id in log: + just_ids.append(doc_id) + self.assertIsNot(None, transaction_id, + "Transaction id should not be None") + if transaction_id == '' and self.accept_fixed_trans_id: + continue + self.assertNotEqual('', transaction_id, + "Transaction id should be a unique string") + self.assertTrue(transaction_id.startswith('T-')) + self.assertNotIn(transaction_id, seen_transactions) + seen_transactions.add(transaction_id) + self.assertEqual(doc_ids, just_ids) + + def getLastTransId(self, db): + """Return the transaction id for the last database update.""" + return self.db._get_transaction_log()[-1][-1] + + +class ServerStateForTests(server_state.ServerState): + """Used in the test suite, so we don't have to touch disk, etc.""" + + def __init__(self): + super(ServerStateForTests, self).__init__() + self._dbs = {} + + def open_database(self, path): + try: + return self._dbs[path] + except KeyError: + raise errors.DatabaseDoesNotExist + + def check_database(self, path): + # cares only about the possible exception + self.open_database(path) + + def ensure_database(self, path): + try: + db = self.open_database(path) + except errors.DatabaseDoesNotExist: + db = self._create_database(path) + return db, db._replica_uid + + def _copy_database(self, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES + # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST + # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS + # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND + # NINJA TO YOUR HOUSE. + new_db = copy_memory_database_for_test(None, db) + path = db._replica_uid + while path in self._dbs: + path += 'copy' + self._dbs[path] = new_db + return new_db + + def _create_database(self, path): + db = inmemory.InMemoryDatabase(path) + self._dbs[path] = db + return db + + def delete_database(self, path): + del self._dbs[path] + + +class ResponderForTests(object): + """Responder for tests.""" + _started = False + sent_response = False + status = None + + def start_response(self, status='success', **kwargs): + self._started = True + self.status = status + self.kwargs = kwargs + + def send_response(self, status='success', **kwargs): + self.start_response(status, **kwargs) + self.finish_response() + + def finish_response(self): + self.sent_response = True + + +class TestCaseWithServer(TestCase): + + @staticmethod + def server_def(): + # hook point + # should return (ServerClass, "shutdown method name", "url_scheme") + class _RequestHandler(simple_server.WSGIRequestHandler): + def log_request(*args): + pass # suppress + + def make_server(host_port, application): + assert application, "forgot to override make_app(_with_state)?" + srv = simple_server.WSGIServer(host_port, _RequestHandler) + # patch the value in if it's None + if getattr(application, 'base_url', 1) is None: + application.base_url = "http://%s:%s" % srv.server_address + srv.set_app(application) + return srv + + return make_server, "shutdown", "http" + + @staticmethod + def make_app_with_state(state): + # hook point + return None + + def make_app(self): + # potential hook point + self.request_state = ServerStateForTests() + return self.make_app_with_state(self.request_state) + + def setUp(self): + super(TestCaseWithServer, self).setUp() + self.server = self.server_thread = None + + @property + def url_scheme(self): + return self.server_def()[-1] + + def startServer(self): + server_def = self.server_def() + server_class, shutdown_meth, _ = server_def + application = self.make_app() + self.server = server_class(('127.0.0.1', 0), application) + self.server_thread = threading.Thread(target=self.server.serve_forever, + kwargs=dict(poll_interval=0.01)) + self.server_thread.start() + self.addCleanup(self.server_thread.join) + self.addCleanup(getattr(self.server, shutdown_meth)) + + def getURL(self, path=None): + host, port = self.server.server_address + if path is None: + path = '' + return '%s://%s:%s/%s' % (self.url_scheme, host, port, path) + + +def socket_pair(): + """Return a pair of TCP sockets connected to each other. + + Unlike socket.socketpair, this should work on Windows. + """ + sock_pair = getattr(socket, 'socket_pair', None) + if sock_pair: + return sock_pair(socket.AF_INET, socket.SOCK_STREAM) + listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(1) + client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_sock.connect(listen_sock.getsockname()) + server_sock, addr = listen_sock.accept() + listen_sock.close() + return server_sock, client_sock + + +# OAuth related testing + +consumer1 = oauth.OAuthConsumer('K1', 'S1') +token1 = oauth.OAuthToken('kkkk1', 'XYZ') +consumer2 = oauth.OAuthConsumer('K2', 'S2') +token2 = oauth.OAuthToken('kkkk2', 'ZYX') +token3 = oauth.OAuthToken('kkkk3', 'ZYX') + + +class TestingOAuthDataStore(oauth.OAuthDataStore): + """In memory predefined OAuthDataStore for testing.""" + + consumers = { + consumer1.key: consumer1, + consumer2.key: consumer2, + } + + tokens = { + token1.key: token1, + token2.key: token2 + } + + def lookup_consumer(self, key): + return self.consumers.get(key) + + def lookup_token(self, token_type, token_token): + return self.tokens.get(token_token) + + def lookup_nonce(self, oauth_consumer, oauth_token, nonce): + return None + +testingOAuthStore = TestingOAuthDataStore() + +sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1() +sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT() + + +def load_with_scenarios(loader, standard_tests, pattern): + """Load the tests in a given module. + + This just applies testscenarios.generate_scenarios to all the tests that + are present. We do it at load time rather than at run time, because it + plays nicer with various tools. + """ + suite = loader.suiteClass() + suite.addTests(testscenarios.generate_scenarios(standard_tests)) + return suite diff --git a/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx new file mode 100644 index 00000000..8a4b600d --- /dev/null +++ b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx @@ -0,0 +1,1541 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. +# +"""A Cython wrapper around the C implementation of U1DB Database backend.""" + +cdef extern from "Python.h": + object PyString_FromStringAndSize(char *s, Py_ssize_t n) + int PyString_AsStringAndSize(object o, char **buf, Py_ssize_t *length + ) except -1 + char *PyString_AsString(object) except NULL + char *PyString_AS_STRING(object) + char *strdup(char *) + void *calloc(size_t, size_t) + void free(void *) + ctypedef struct FILE: + pass + fprintf(FILE *, char *, ...) + FILE *stderr + size_t strlen(char *) + +cdef extern from "stdarg.h": + ctypedef struct va_list: + pass + void va_start(va_list, void*) + void va_start_int "va_start" (va_list, int) + void va_end(va_list) + +cdef extern from "u1db/u1db.h": + ctypedef struct u1database: + pass + ctypedef struct u1db_document: + char *doc_id + size_t doc_id_len + char *doc_rev + size_t doc_rev_len + char *json + size_t json_len + int has_conflicts + # Note: u1query is actually defined in u1db_internal.h, and in u1db.h it is + # just an opaque pointer. However, older versions of Cython don't let + # you have a forward declaration and a full declaration, so we just + # expose the whole thing here. + ctypedef struct u1query: + char *index_name + int num_fields + char **fields + cdef struct u1db_oauth_creds: + int auth_kind + char *consumer_key + char *consumer_secret + char *token_key + char *token_secret + ctypedef union u1db_creds + ctypedef u1db_creds* const_u1db_creds_ptr "const u1db_creds *" + + ctypedef char* const_char_ptr "const char*" + ctypedef int (*u1db_doc_callback)(void *context, u1db_document *doc) + ctypedef int (*u1db_key_callback)(void *context, int num_fields, + const_char_ptr *key) + ctypedef int (*u1db_doc_gen_callback)(void *context, + u1db_document *doc, int gen, const_char_ptr trans_id) + ctypedef int (*u1db_trans_info_callback)(void *context, + const_char_ptr doc_id, int gen, const_char_ptr trans_id) + + u1database * u1db_open(char *fname) + void u1db_free(u1database **) + int u1db_set_replica_uid(u1database *, char *replica_uid) + int u1db_set_document_size_limit(u1database *, int limit) + int u1db_get_replica_uid(u1database *, const_char_ptr *replica_uid) + int u1db_create_doc_from_json(u1database *db, char *json, char *doc_id, + u1db_document **doc) + int u1db_delete_doc(u1database *db, u1db_document *doc) + int u1db_get_doc(u1database *db, char *doc_id, int include_deleted, + u1db_document **doc) + int u1db_get_docs(u1database *db, int n_doc_ids, const_char_ptr *doc_ids, + int check_for_conflicts, int include_deleted, + void *context, u1db_doc_callback cb) + int u1db_get_all_docs(u1database *db, int include_deleted, int *generation, + void *context, u1db_doc_callback cb) + int u1db_put_doc(u1database *db, u1db_document *doc) + int u1db__validate_source(u1database *db, const_char_ptr replica_uid, + int replica_gen, const_char_ptr replica_trans_id) + int u1db__put_doc_if_newer(u1database *db, u1db_document *doc, + int save_conflict, char *replica_uid, + int replica_gen, char *replica_trans_id, + int *state, int *at_gen) + int u1db_resolve_doc(u1database *db, u1db_document *doc, + int n_revs, const_char_ptr *revs) + int u1db_delete_doc(u1database *db, u1db_document *doc) + int u1db_whats_changed(u1database *db, int *gen, char **trans_id, + void *context, u1db_trans_info_callback cb) + int u1db__get_transaction_log(u1database *db, void *context, + u1db_trans_info_callback cb) + int u1db_get_doc_conflicts(u1database *db, char *doc_id, void *context, + u1db_doc_callback cb) + int u1db_sync(u1database *db, const_char_ptr url, + const_u1db_creds_ptr creds, int *local_gen) nogil + int u1db_create_index_list(u1database *db, char *index_name, + int n_expressions, const_char_ptr *expressions) + int u1db_create_index(u1database *db, char *index_name, int n_expressions, + ...) + int u1db_get_from_index_list(u1database *db, u1query *query, void *context, + u1db_doc_callback cb, int n_values, + const_char_ptr *values) + int u1db_get_from_index(u1database *db, u1query *query, void *context, + u1db_doc_callback cb, int n_values, char *val0, + ...) + int u1db_get_range_from_index(u1database *db, u1query *query, + void *context, u1db_doc_callback cb, + int n_values, const_char_ptr *start_values, + const_char_ptr *end_values) + int u1db_delete_index(u1database *db, char *index_name) + int u1db_list_indexes(u1database *db, void *context, + int (*cb)(void *context, const_char_ptr index_name, + int n_expressions, const_char_ptr *expressions)) + int u1db_get_index_keys(u1database *db, char *index_name, void *context, + u1db_key_callback cb) + int u1db_simple_lookup1(u1database *db, char *index_name, char *val1, + void *context, u1db_doc_callback cb) + int u1db_query_init(u1database *db, char *index_name, u1query **query) + void u1db_free_query(u1query **query) + + int U1DB_OK + int U1DB_INVALID_PARAMETER + int U1DB_REVISION_CONFLICT + int U1DB_INVALID_DOC_ID + int U1DB_DOCUMENT_ALREADY_DELETED + int U1DB_DOCUMENT_DOES_NOT_EXIST + int U1DB_NOT_IMPLEMENTED + int U1DB_INVALID_JSON + int U1DB_DOCUMENT_TOO_BIG + int U1DB_USER_QUOTA_EXCEEDED + int U1DB_INVALID_VALUE_FOR_INDEX + int U1DB_INVALID_FIELD_SPECIFIER + int U1DB_INVALID_GLOBBING + int U1DB_BROKEN_SYNC_STREAM + int U1DB_DUPLICATE_INDEX_NAME + int U1DB_INDEX_DOES_NOT_EXIST + int U1DB_INVALID_GENERATION + int U1DB_INVALID_TRANSACTION_ID + int U1DB_INVALID_TRANSFORMATION_FUNCTION + int U1DB_UNKNOWN_OPERATION + int U1DB_INTERNAL_ERROR + int U1DB_TARGET_UNAVAILABLE + + int U1DB_INSERTED + int U1DB_SUPERSEDED + int U1DB_CONVERGED + int U1DB_CONFLICTED + + int U1DB_OAUTH_AUTH + + void u1db_free_doc(u1db_document **doc) + int u1db_doc_set_json(u1db_document *doc, char *json) + int u1db_doc_get_size(u1db_document *doc) + + +cdef extern from "u1db/u1db_internal.h": + ctypedef struct u1db_row: + u1db_row *next + int num_columns + int *column_sizes + unsigned char **columns + + ctypedef struct u1db_table: + int status + u1db_row *first_row + + ctypedef struct u1db_record: + u1db_record *next + char *doc_id + char *doc_rev + char *doc + + ctypedef struct u1db_sync_exchange: + int target_gen + int num_doc_ids + char **doc_ids_to_return + int *gen_for_doc_ids + const_char_ptr *trans_ids_for_doc_ids + + ctypedef int (*u1db__trace_callback)(void *context, const_char_ptr state) + ctypedef struct u1db_sync_target: + int (*get_sync_info)(u1db_sync_target *st, char *source_replica_uid, + const_char_ptr *st_replica_uid, int *st_gen, + char **st_trans_id, int *source_gen, + char **source_trans_id) nogil + int (*record_sync_info)(u1db_sync_target *st, + char *source_replica_uid, int source_gen, char *trans_id) nogil + int (*sync_exchange)(u1db_sync_target *st, + char *source_replica_uid, int n_docs, + u1db_document **docs, int *generations, + const_char_ptr *trans_ids, + int *target_gen, char **target_trans_id, + void *context, u1db_doc_gen_callback cb, + void *ensure_callback) nogil + int (*sync_exchange_doc_ids)(u1db_sync_target *st, + u1database *source_db, int n_doc_ids, + const_char_ptr *doc_ids, int *generations, + const_char_ptr *trans_ids, + int *target_gen, char **target_trans_id, + void *context, + u1db_doc_gen_callback cb, + void *ensure_callback) nogil + int (*get_sync_exchange)(u1db_sync_target *st, + char *source_replica_uid, + int last_known_source_gen, + u1db_sync_exchange **exchange) nogil + void (*finalize_sync_exchange)(u1db_sync_target *st, + u1db_sync_exchange **exchange) nogil + int (*_set_trace_hook)(u1db_sync_target *st, + void *context, u1db__trace_callback cb) nogil + + + void u1db__set_zero_delays() + int u1db__get_generation(u1database *, int *db_rev) + int u1db__get_document_size_limit(u1database *, int *limit) + int u1db__get_generation_info(u1database *, int *db_rev, char **trans_id) + int u1db__get_trans_id_for_gen(u1database *, int db_rev, char **trans_id) + int u1db_validate_gen_and_trans_id(u1database *, int db_rev, + const_char_ptr trans_id) + char *u1db__allocate_doc_id(u1database *) + int u1db__sql_close(u1database *) + u1database *u1db__copy(u1database *) + int u1db__sql_is_open(u1database *) + u1db_table *u1db__sql_run(u1database *, char *sql, size_t n) + void u1db__free_table(u1db_table **table) + u1db_record *u1db__create_record(char *doc_id, char *doc_rev, char *doc) + void u1db__free_records(u1db_record **) + + int u1db__allocate_document(char *doc_id, char *revision, char *content, + int has_conflicts, u1db_document **result) + int u1db__generate_hex_uuid(char *) + + int u1db__get_replica_gen_and_trans_id(u1database *db, char *replica_uid, + int *generation, char **trans_id) + int u1db__set_replica_gen_and_trans_id(u1database *db, char *replica_uid, + int generation, char *trans_id) + int u1db__sync_get_machine_info(u1database *db, char *other_replica_uid, + int *other_db_rev, char **my_replica_uid, + int *my_db_rev) + int u1db__sync_record_machine_info(u1database *db, char *replica_uid, + int db_rev) + int u1db__sync_exchange_seen_ids(u1db_sync_exchange *se, int *n_ids, + const_char_ptr **doc_ids) + int u1db__format_query(int n_fields, const_char_ptr *values, char **buf, + int *wildcard) + int u1db__get_sync_target(u1database *db, u1db_sync_target **sync_target) + int u1db__free_sync_target(u1db_sync_target **sync_target) + int u1db__sync_db_to_target(u1database *db, u1db_sync_target *target, + int *local_gen_before_sync) nogil + + int u1db__sync_exchange_insert_doc_from_source(u1db_sync_exchange *se, + u1db_document *doc, int source_gen, const_char_ptr trans_id) + int u1db__sync_exchange_find_doc_ids_to_return(u1db_sync_exchange *se) + int u1db__sync_exchange_return_docs(u1db_sync_exchange *se, void *context, + int (*cb)(void *context, + u1db_document *doc, int gen, + const_char_ptr trans_id)) + int u1db__create_http_sync_target(char *url, u1db_sync_target **target) + int u1db__create_oauth_http_sync_target(char *url, + char *consumer_key, char *consumer_secret, + char *token_key, char *token_secret, + u1db_sync_target **target) + +cdef extern from "u1db/u1db_http_internal.h": + int u1db__format_sync_url(u1db_sync_target *st, + const_char_ptr source_replica_uid, char **sync_url) + int u1db__get_oauth_authorization(u1db_sync_target *st, + char *http_method, char *url, + char **oauth_authorization) + + +cdef extern from "u1db/u1db_vectorclock.h": + ctypedef struct u1db_vectorclock_item: + char *replica_uid + int generation + + ctypedef struct u1db_vectorclock: + int num_items + u1db_vectorclock_item *items + + u1db_vectorclock *u1db__vectorclock_from_str(char *s) + void u1db__free_vectorclock(u1db_vectorclock **clock) + int u1db__vectorclock_increment(u1db_vectorclock *clock, char *replica_uid) + int u1db__vectorclock_maximize(u1db_vectorclock *clock, + u1db_vectorclock *other) + int u1db__vectorclock_as_str(u1db_vectorclock *clock, char **result) + int u1db__vectorclock_is_newer(u1db_vectorclock *maybe_newer, + u1db_vectorclock *older) + +from u1db import errors +from sqlite3 import dbapi2 + + +cdef int _append_trans_info_to_list(void *context, const_char_ptr doc_id, + int generation, + const_char_ptr trans_id) with gil: + a_list = <object>(context) + doc = doc_id + a_list.append((doc, generation, trans_id)) + return 0 + + +cdef int _append_doc_to_list(void *context, u1db_document *doc) with gil: + a_list = <object>context + pydoc = CDocument() + pydoc._doc = doc + a_list.append(pydoc) + return 0 + +cdef int _append_key_to_list(void *context, int num_fields, + const_char_ptr *key) with gil: + a_list = <object>(context) + field_list = [] + for i from 0 <= i < num_fields: + field = key[i] + field_list.append(field.decode('utf-8')) + a_list.append(tuple(field_list)) + return 0 + +cdef _list_to_array(lst, const_char_ptr **res, int *count): + cdef const_char_ptr *tmp + count[0] = len(lst) + tmp = <const_char_ptr*>calloc(sizeof(char*), count[0]) + for idx, x in enumerate(lst): + tmp[idx] = x + res[0] = tmp + +cdef _list_to_str_array(lst, const_char_ptr **res, int *count): + cdef const_char_ptr *tmp + count[0] = len(lst) + tmp = <const_char_ptr*>calloc(sizeof(char*), count[0]) + new_objs = [] + for idx, x in enumerate(lst): + if isinstance(x, unicode): + x = x.encode('utf-8') + new_objs.append(x) + tmp[idx] = x + res[0] = tmp + return new_objs + + +cdef int _append_index_definition_to_list(void *context, + const_char_ptr index_name, int n_expressions, + const_char_ptr *expressions) with gil: + cdef int i + + a_list = <object>(context) + exp_list = [] + for i from 0 <= i < n_expressions: + s = expressions[i] + exp_list.append(s.decode('utf-8')) + a_list.append((index_name, exp_list)) + return 0 + + +cdef int return_doc_cb_wrapper(void *context, u1db_document *doc, + int gen, const_char_ptr trans_id) with gil: + cdef CDocument pydoc + user_cb = <object>context + pydoc = CDocument() + pydoc._doc = doc + try: + user_cb(pydoc, gen, trans_id) + except Exception, e: + # We suppress the exception here, because intermediating through the C + # layer gets a bit crazy + return U1DB_INVALID_PARAMETER + return U1DB_OK + + +cdef int _trace_hook(void *context, const_char_ptr state) with gil: + if context == NULL: + return U1DB_INVALID_PARAMETER + ctx = <object>context + try: + ctx(state) + except: + # Note: It would be nice if we could map the Python exception into + # something in C + return U1DB_INTERNAL_ERROR + return U1DB_OK + + +cdef char *_ensure_str(object obj, object extra_objs) except NULL: + """Ensure that we have the UTF-8 representation of a parameter. + + :param obj: A Unicode or String object. + :param extra_objs: This should be a Python list. If we have to convert obj + from being a Unicode object, this will hold the PyString object so that + we know the char* lifetime will be correct. + :return: A C pointer to the UTF-8 representation. + """ + if isinstance(obj, unicode): + obj = obj.encode('utf-8') + extra_objs.append(obj) + return PyString_AsString(obj) + + +def _format_query(fields): + """Wrapper around u1db__format_query for testing.""" + cdef int status + cdef char *buf + cdef int wildcard[10] + cdef const_char_ptr *values + cdef int n_values + + # keep a reference to new_objs so that the pointers in expressions + # remain valid. + new_objs = _list_to_str_array(fields, &values, &n_values) + try: + status = u1db__format_query(n_values, values, &buf, wildcard) + finally: + free(<void*>values) + handle_status("format_query", status) + if buf == NULL: + res = None + else: + res = buf + free(buf) + w = [] + for i in range(len(fields)): + w.append(wildcard[i]) + return res, w + + +def make_document(doc_id, rev, content, has_conflicts=False): + cdef u1db_document *doc + cdef char *c_content = NULL, *c_rev = NULL, *c_doc_id = NULL + cdef int conflict + + if has_conflicts: + conflict = 1 + else: + conflict = 0 + if doc_id is None: + c_doc_id = NULL + else: + c_doc_id = doc_id + if content is None: + c_content = NULL + else: + c_content = content + if rev is None: + c_rev = NULL + else: + c_rev = rev + handle_status( + "make_document", + u1db__allocate_document(c_doc_id, c_rev, c_content, conflict, &doc)) + pydoc = CDocument() + pydoc._doc = doc + return pydoc + + +def generate_hex_uuid(): + uuid = PyString_FromStringAndSize(NULL, 32) + handle_status( + "Failed to generate uuid", + u1db__generate_hex_uuid(PyString_AS_STRING(uuid))) + return uuid + + +cdef class CDocument(object): + """A thin wrapper around the C Document struct.""" + + cdef u1db_document *_doc + + def __init__(self): + self._doc = NULL + + def __dealloc__(self): + u1db_free_doc(&self._doc) + + property doc_id: + def __get__(self): + if self._doc.doc_id == NULL: + return None + return PyString_FromStringAndSize( + self._doc.doc_id, self._doc.doc_id_len) + + property rev: + def __get__(self): + if self._doc.doc_rev == NULL: + return None + return PyString_FromStringAndSize( + self._doc.doc_rev, self._doc.doc_rev_len) + + def get_json(self): + if self._doc.json == NULL: + return None + return PyString_FromStringAndSize( + self._doc.json, self._doc.json_len) + + def set_json(self, val): + u1db_doc_set_json(self._doc, val) + + def get_size(self): + return u1db_doc_get_size(self._doc) + + property has_conflicts: + def __get__(self): + if self._doc.has_conflicts: + return True + return False + + def __repr__(self): + if self._doc.has_conflicts: + extra = ', conflicted' + else: + extra = '' + return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id, + self.rev, extra, self.get_json()) + + def __hash__(self): + raise NotImplementedError(self.__hash__) + + def __richcmp__(self, other, int t): + try: + if t == 0: # Py_LT < + return ((self.doc_id, self.rev, self.get_json()) + < (other.doc_id, other.rev, other.get_json())) + elif t == 2: # Py_EQ == + return (self.doc_id == other.doc_id + and self.rev == other.rev + and self.get_json() == other.get_json() + and self.has_conflicts == other.has_conflicts) + except AttributeError: + # Fall through to NotImplemented + pass + + return NotImplemented + + +cdef object safe_str(const_char_ptr s): + if s == NULL: + return None + return s + + +cdef class CQuery: + + cdef u1query *_query + + def __init__(self): + self._query = NULL + + def __dealloc__(self): + u1db_free_query(&self._query) + + def _check(self): + if self._query == NULL: + raise RuntimeError("No valid _query.") + + property index_name: + def __get__(self): + self._check() + return safe_str(self._query.index_name) + + property num_fields: + def __get__(self): + self._check() + return self._query.num_fields + + property fields: + def __get__(self): + cdef int i + self._check() + fields = [] + for i from 0 <= i < self._query.num_fields: + fields.append(safe_str(self._query.fields[i])) + return fields + + +cdef handle_status(context, int status): + if status == U1DB_OK: + return + if status == U1DB_REVISION_CONFLICT: + raise errors.RevisionConflict() + if status == U1DB_INVALID_DOC_ID: + raise errors.InvalidDocId() + if status == U1DB_DOCUMENT_ALREADY_DELETED: + raise errors.DocumentAlreadyDeleted() + if status == U1DB_DOCUMENT_DOES_NOT_EXIST: + raise errors.DocumentDoesNotExist() + if status == U1DB_INVALID_PARAMETER: + raise RuntimeError('Bad parameters supplied') + if status == U1DB_NOT_IMPLEMENTED: + raise NotImplementedError("Functionality not implemented yet: %s" + % (context,)) + if status == U1DB_INVALID_VALUE_FOR_INDEX: + raise errors.InvalidValueForIndex() + if status == U1DB_INVALID_GLOBBING: + raise errors.InvalidGlobbing() + if status == U1DB_INTERNAL_ERROR: + raise errors.U1DBError("internal error") + if status == U1DB_BROKEN_SYNC_STREAM: + raise errors.BrokenSyncStream() + if status == U1DB_CONFLICTED: + raise errors.ConflictedDoc() + if status == U1DB_DUPLICATE_INDEX_NAME: + raise errors.IndexNameTakenError() + if status == U1DB_INDEX_DOES_NOT_EXIST: + raise errors.IndexDoesNotExist + if status == U1DB_INVALID_GENERATION: + raise errors.InvalidGeneration + if status == U1DB_INVALID_TRANSACTION_ID: + raise errors.InvalidTransactionId + if status == U1DB_TARGET_UNAVAILABLE: + raise errors.Unavailable + if status == U1DB_INVALID_JSON: + raise errors.InvalidJSON + if status == U1DB_DOCUMENT_TOO_BIG: + raise errors.DocumentTooBig + if status == U1DB_USER_QUOTA_EXCEEDED: + raise errors.UserQuotaExceeded + if status == U1DB_INVALID_TRANSFORMATION_FUNCTION: + raise errors.IndexDefinitionParseError + if status == U1DB_UNKNOWN_OPERATION: + raise errors.IndexDefinitionParseError + if status == U1DB_INVALID_FIELD_SPECIFIER: + raise errors.IndexDefinitionParseError() + raise RuntimeError('%s (status: %s)' % (context, status)) + + +cdef class CDatabase +cdef class CSyncTarget + +cdef class CSyncExchange(object): + + cdef u1db_sync_exchange *_exchange + cdef CSyncTarget _target + + def __init__(self, CSyncTarget target, source_replica_uid, source_gen): + self._target = target + assert self._target._st.get_sync_exchange != NULL, \ + "get_sync_exchange is NULL?" + handle_status("get_sync_exchange", + self._target._st.get_sync_exchange(self._target._st, + source_replica_uid, source_gen, &self._exchange)) + + def __dealloc__(self): + if self._target is not None and self._target._st != NULL: + self._target._st.finalize_sync_exchange(self._target._st, + &self._exchange) + + def _check(self): + if self._exchange == NULL: + raise RuntimeError("self._exchange is NULL") + + property target_gen: + def __get__(self): + self._check() + return self._exchange.target_gen + + def insert_doc_from_source(self, CDocument doc, source_gen, + source_trans_id): + self._check() + handle_status("insert_doc_from_source", + u1db__sync_exchange_insert_doc_from_source(self._exchange, + doc._doc, source_gen, source_trans_id)) + + def find_doc_ids_to_return(self): + self._check() + handle_status("find_doc_ids_to_return", + u1db__sync_exchange_find_doc_ids_to_return(self._exchange)) + + def return_docs(self, return_doc_cb): + self._check() + handle_status("return_docs", + u1db__sync_exchange_return_docs(self._exchange, + <void *>return_doc_cb, &return_doc_cb_wrapper)) + + def get_seen_ids(self): + cdef const_char_ptr *seen_ids + cdef int i, n_ids + self._check() + handle_status("sync_exchange_seen_ids", + u1db__sync_exchange_seen_ids(self._exchange, &n_ids, &seen_ids)) + res = [] + for i from 0 <= i < n_ids: + res.append(seen_ids[i]) + if (seen_ids != NULL): + free(<void*>seen_ids) + return res + + def get_doc_ids_to_return(self): + self._check() + res = [] + if (self._exchange.num_doc_ids > 0 + and self._exchange.doc_ids_to_return != NULL): + for i from 0 <= i < self._exchange.num_doc_ids: + res.append( + (self._exchange.doc_ids_to_return[i], + self._exchange.gen_for_doc_ids[i], + self._exchange.trans_ids_for_doc_ids[i])) + return res + + +cdef class CSyncTarget(object): + + cdef u1db_sync_target *_st + cdef CDatabase _db + + def __init__(self): + self._db = None + self._st = NULL + u1db__set_zero_delays() + + def __dealloc__(self): + u1db__free_sync_target(&self._st) + + def _check(self): + if self._st == NULL: + raise RuntimeError("self._st is NULL") + + def get_sync_info(self, source_replica_uid): + cdef const_char_ptr st_replica_uid = NULL + cdef int st_gen = 0, source_gen = 0, status + cdef char *trans_id = NULL + cdef char *st_trans_id = NULL + cdef char *c_source_replica_uid = NULL + + self._check() + assert self._st.get_sync_info != NULL, "get_sync_info is NULL?" + c_source_replica_uid = source_replica_uid + with nogil: + status = self._st.get_sync_info(self._st, c_source_replica_uid, + &st_replica_uid, &st_gen, &st_trans_id, &source_gen, &trans_id) + handle_status("get_sync_info", status) + res_trans_id = None + res_st_trans_id = None + if trans_id != NULL: + res_trans_id = trans_id + free(trans_id) + if st_trans_id != NULL: + res_st_trans_id = st_trans_id + free(st_trans_id) + return ( + safe_str(st_replica_uid), st_gen, res_st_trans_id, source_gen, + res_trans_id) + + def record_sync_info(self, source_replica_uid, source_gen, source_trans_id): + cdef int status + cdef int c_source_gen + cdef char *c_source_replica_uid = NULL + cdef char *c_source_trans_id = NULL + + self._check() + assert self._st.record_sync_info != NULL, "record_sync_info is NULL?" + c_source_replica_uid = source_replica_uid + c_source_gen = source_gen + c_source_trans_id = source_trans_id + with nogil: + status = self._st.record_sync_info( + self._st, c_source_replica_uid, c_source_gen, + c_source_trans_id) + handle_status("record_sync_info", status) + + def _get_sync_exchange(self, source_replica_uid, source_gen): + self._check() + return CSyncExchange(self, source_replica_uid, source_gen) + + def sync_exchange_doc_ids(self, source_db, doc_id_generations, + last_known_generation, last_known_trans_id, + return_doc_cb): + cdef const_char_ptr *doc_ids + cdef int *generations + cdef int num_doc_ids + cdef int target_gen + cdef char *target_trans_id = NULL + cdef int status + cdef CDatabase sdb + + self._check() + assert self._st.sync_exchange_doc_ids != NULL, "sync_exchange_doc_ids is NULL?" + sdb = source_db + num_doc_ids = len(doc_id_generations) + doc_ids = <const_char_ptr *>calloc(num_doc_ids, sizeof(char *)) + if doc_ids == NULL: + raise MemoryError + generations = <int *>calloc(num_doc_ids, sizeof(int)) + if generations == NULL: + free(<void *>doc_ids) + raise MemoryError + trans_ids = <const_char_ptr*>calloc(num_doc_ids, sizeof(char *)) + if trans_ids == NULL: + raise MemoryError + res_trans_id = '' + try: + for i, (doc_id, gen, trans_id) in enumerate(doc_id_generations): + doc_ids[i] = PyString_AsString(doc_id) + generations[i] = gen + trans_ids[i] = trans_id + target_gen = last_known_generation + if last_known_trans_id is not None: + target_trans_id = last_known_trans_id + with nogil: + status = self._st.sync_exchange_doc_ids(self._st, sdb._db, + num_doc_ids, doc_ids, generations, trans_ids, + &target_gen, &target_trans_id, + <void*>return_doc_cb, return_doc_cb_wrapper, NULL) + handle_status("sync_exchange_doc_ids", status) + if target_trans_id != NULL: + res_trans_id = target_trans_id + finally: + if target_trans_id != NULL: + free(target_trans_id) + if doc_ids != NULL: + free(<void *>doc_ids) + if generations != NULL: + free(generations) + if trans_ids != NULL: + free(trans_ids) + return target_gen, res_trans_id + + def sync_exchange(self, docs_by_generations, source_replica_uid, + last_known_generation, last_known_trans_id, + return_doc_cb, ensure_callback=None): + cdef CDocument cur_doc + cdef u1db_document **docs = NULL + cdef int *generations = NULL + cdef const_char_ptr *trans_ids = NULL + cdef char *target_trans_id = NULL + cdef char *c_source_replica_uid = NULL + cdef int i, count, status, target_gen + assert ensure_callback is None # interface difference + + self._check() + assert self._st.sync_exchange != NULL, "sync_exchange is NULL?" + count = len(docs_by_generations) + res_trans_id = '' + try: + docs = <u1db_document **>calloc(count, sizeof(u1db_document*)) + if docs == NULL: + raise MemoryError + generations = <int*>calloc(count, sizeof(int)) + if generations == NULL: + raise MemoryError + trans_ids = <const_char_ptr*>calloc(count, sizeof(char*)) + if trans_ids == NULL: + raise MemoryError + for i from 0 <= i < count: + cur_doc = docs_by_generations[i][0] + generations[i] = docs_by_generations[i][1] + trans_ids[i] = docs_by_generations[i][2] + docs[i] = cur_doc._doc + target_gen = last_known_generation + if last_known_trans_id is not None: + target_trans_id = last_known_trans_id + c_source_replica_uid = source_replica_uid + with nogil: + status = self._st.sync_exchange( + self._st, c_source_replica_uid, count, docs, generations, + trans_ids, &target_gen, &target_trans_id, + <void *>return_doc_cb, return_doc_cb_wrapper, NULL) + handle_status("sync_exchange", status) + finally: + if docs != NULL: + free(docs) + if generations != NULL: + free(generations) + if trans_ids != NULL: + free(trans_ids) + if target_trans_id != NULL: + res_trans_id = target_trans_id + free(target_trans_id) + return target_gen, res_trans_id + + def _set_trace_hook(self, cb): + self._check() + assert self._st._set_trace_hook != NULL, "_set_trace_hook is NULL?" + handle_status("_set_trace_hook", + self._st._set_trace_hook(self._st, <void*>cb, _trace_hook)) + + _set_trace_hook_shallow = _set_trace_hook + + +cdef class CDatabase(object): + """A thin wrapper/shim to interact with the C implementation. + + Functionality should not be written here. It is only provided as a way to + expose the C API to the python test suite. + """ + + cdef public object _filename + cdef u1database *_db + cdef public object _supports_indexes + + def __init__(self, filename): + self._supports_indexes = False + self._filename = filename + self._db = u1db_open(self._filename) + + def __dealloc__(self): + u1db_free(&self._db) + + def close(self): + return u1db__sql_close(self._db) + + def _copy(self, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + new_db = CDatabase(':memory:') + u1db_free(&new_db._db) + new_db._db = u1db__copy(self._db) + return new_db + + def _sql_is_open(self): + if self._db == NULL: + return True + return u1db__sql_is_open(self._db) + + property _replica_uid: + def __get__(self): + cdef const_char_ptr val + cdef int status + status = u1db_get_replica_uid(self._db, &val) + if status != 0: + if val != NULL: + err = str(val) + else: + err = "<unknown>" + raise RuntimeError("Failed to get_replica_uid: %d %s" + % (status, err)) + if val == NULL: + return None + return str(val) + + def _set_replica_uid(self, replica_uid): + cdef int status + status = u1db_set_replica_uid(self._db, replica_uid) + if status != 0: + raise RuntimeError('replica_uid could not be set to %s, error: %d' + % (replica_uid, status)) + + property document_size_limit: + def __get__(self): + cdef int limit + handle_status("document_size_limit", + u1db__get_document_size_limit(self._db, &limit)) + return limit + + def set_document_size_limit(self, limit): + cdef int status + status = u1db_set_document_size_limit(self._db, limit) + if status != 0: + raise RuntimeError( + "document_size_limit could not be set to %d, error: %d", + (limit, status)) + + def _allocate_doc_id(self): + cdef char *val + val = u1db__allocate_doc_id(self._db) + if val == NULL: + raise RuntimeError("Failed to allocate document id") + s = str(val) + free(val) + return s + + def _run_sql(self, sql): + cdef u1db_table *tbl + cdef u1db_row *cur_row + cdef size_t n + cdef int i + + if self._db == NULL: + raise RuntimeError("called _run_sql with a NULL pointer.") + tbl = u1db__sql_run(self._db, sql, len(sql)) + if tbl == NULL: + raise MemoryError("Failed to allocate table memory.") + try: + if tbl.status != 0: + raise RuntimeError("Status was not 0: %d" % (tbl.status,)) + # Now convert the table into python + res = [] + cur_row = tbl.first_row + while cur_row != NULL: + row = [] + for i from 0 <= i < cur_row.num_columns: + row.append(PyString_FromStringAndSize( + <char*>(cur_row.columns[i]), cur_row.column_sizes[i])) + res.append(tuple(row)) + cur_row = cur_row.next + return res + finally: + u1db__free_table(&tbl) + + def create_doc_from_json(self, json, doc_id=None): + cdef u1db_document *doc = NULL + cdef char *c_doc_id + + if doc_id is None: + c_doc_id = NULL + else: + c_doc_id = doc_id + handle_status('Failed to create_doc', + u1db_create_doc_from_json(self._db, json, c_doc_id, &doc)) + pydoc = CDocument() + pydoc._doc = doc + return pydoc + + def put_doc(self, CDocument doc): + handle_status("Failed to put_doc", + u1db_put_doc(self._db, doc._doc)) + return doc.rev + + def _validate_source(self, replica_uid, replica_gen, replica_trans_id): + cdef const_char_ptr c_uid, c_trans_id + cdef int c_gen = 0 + + c_uid = replica_uid + c_trans_id = replica_trans_id + c_gen = replica_gen + handle_status( + "invalid generation or transaction id", + u1db__validate_source(self._db, c_uid, c_gen, c_trans_id)) + + def _put_doc_if_newer(self, CDocument doc, save_conflict, replica_uid=None, + replica_gen=None, replica_trans_id=None): + cdef char *c_uid, *c_trans_id + cdef int gen, state = 0, at_gen = -1 + + if replica_uid is None: + c_uid = NULL + else: + c_uid = replica_uid + if replica_trans_id is None: + c_trans_id = NULL + else: + c_trans_id = replica_trans_id + if replica_gen is None: + gen = 0 + else: + gen = replica_gen + handle_status("Failed to _put_doc_if_newer", + u1db__put_doc_if_newer(self._db, doc._doc, save_conflict, + c_uid, gen, c_trans_id, &state, &at_gen)) + if state == U1DB_INSERTED: + return 'inserted', at_gen + elif state == U1DB_SUPERSEDED: + return 'superseded', at_gen + elif state == U1DB_CONVERGED: + return 'converged', at_gen + elif state == U1DB_CONFLICTED: + return 'conflicted', at_gen + else: + raise RuntimeError("Unknown _put_doc_if_newer state: %d" % (state,)) + + def get_doc(self, doc_id, include_deleted=False): + cdef u1db_document *doc = NULL + deleted = 1 if include_deleted else 0 + handle_status("get_doc failed", + u1db_get_doc(self._db, doc_id, deleted, &doc)) + if doc == NULL: + return None + pydoc = CDocument() + pydoc._doc = doc + return pydoc + + def get_docs(self, doc_ids, check_for_conflicts=True, + include_deleted=False): + cdef int n_doc_ids, conflicts + cdef const_char_ptr *c_doc_ids + + _list_to_array(doc_ids, &c_doc_ids, &n_doc_ids) + deleted = 1 if include_deleted else 0 + conflicts = 1 if check_for_conflicts else 0 + a_list = [] + handle_status("get_docs", + u1db_get_docs(self._db, n_doc_ids, c_doc_ids, + conflicts, deleted, <void*>a_list, _append_doc_to_list)) + free(<void*>c_doc_ids) + return a_list + + def get_all_docs(self, include_deleted=False): + cdef int c_generation + + a_list = [] + deleted = 1 if include_deleted else 0 + generation = 0 + c_generation = generation + handle_status( + "get_all_docs", u1db_get_all_docs( + self._db, deleted, &c_generation, <void*>a_list, + _append_doc_to_list)) + return (c_generation, a_list) + + def resolve_doc(self, CDocument doc, conflicted_doc_revs): + cdef const_char_ptr *revs + cdef int n_revs + + _list_to_array(conflicted_doc_revs, &revs, &n_revs) + handle_status("resolve_doc", + u1db_resolve_doc(self._db, doc._doc, n_revs, revs)) + free(<void*>revs) + + def get_doc_conflicts(self, doc_id): + conflict_docs = [] + handle_status("get_doc_conflicts", + u1db_get_doc_conflicts(self._db, doc_id, <void*>conflict_docs, + _append_doc_to_list)) + return conflict_docs + + def delete_doc(self, CDocument doc): + handle_status( + "Failed to delete %s" % (doc,), + u1db_delete_doc(self._db, doc._doc)) + + def whats_changed(self, generation=0): + cdef int c_generation + cdef int status + cdef char *trans_id = NULL + + a_list = [] + c_generation = generation + res_trans_id = '' + status = u1db_whats_changed(self._db, &c_generation, &trans_id, + <void*>a_list, _append_trans_info_to_list) + try: + handle_status("whats_changed", status) + finally: + if trans_id != NULL: + res_trans_id = trans_id + free(trans_id) + return c_generation, res_trans_id, a_list + + def _get_transaction_log(self): + a_list = [] + handle_status("_get_transaction_log", + u1db__get_transaction_log(self._db, <void*>a_list, + _append_trans_info_to_list)) + return [(doc_id, trans_id) for doc_id, gen, trans_id in a_list] + + def _get_generation(self): + cdef int generation + handle_status("get_generation", + u1db__get_generation(self._db, &generation)) + return generation + + def _get_generation_info(self): + cdef int generation + cdef char *trans_id + handle_status("get_generation_info", + u1db__get_generation_info(self._db, &generation, &trans_id)) + raw_trans_id = None + if trans_id != NULL: + raw_trans_id = trans_id + free(trans_id) + return generation, raw_trans_id + + def validate_gen_and_trans_id(self, generation, trans_id): + handle_status( + "validate_gen_and_trans_id", + u1db_validate_gen_and_trans_id(self._db, generation, trans_id)) + + def _get_trans_id_for_gen(self, generation): + cdef char *trans_id = NULL + + handle_status( + "_get_trans_id_for_gen", + u1db__get_trans_id_for_gen(self._db, generation, &trans_id)) + raw_trans_id = None + if trans_id != NULL: + raw_trans_id = trans_id + free(trans_id) + return raw_trans_id + + def _get_replica_gen_and_trans_id(self, replica_uid): + cdef int generation, status + cdef char *trans_id = NULL + + status = u1db__get_replica_gen_and_trans_id( + self._db, replica_uid, &generation, &trans_id) + handle_status("_get_replica_gen_and_trans_id", status) + raw_trans_id = None + if trans_id != NULL: + raw_trans_id = trans_id + free(trans_id) + return generation, raw_trans_id + + def _set_replica_gen_and_trans_id(self, replica_uid, generation, trans_id): + handle_status("_set_replica_gen_and_trans_id", + u1db__set_replica_gen_and_trans_id( + self._db, replica_uid, generation, trans_id)) + + def create_index_list(self, index_name, index_expressions): + cdef const_char_ptr *expressions + cdef int n_expressions + + # keep a reference to new_objs so that the pointers in expressions + # remain valid. + new_objs = _list_to_str_array( + index_expressions, &expressions, &n_expressions) + try: + status = u1db_create_index_list( + self._db, index_name, n_expressions, expressions) + finally: + free(<void*>expressions) + handle_status("create_index", status) + + def create_index(self, index_name, *index_expressions): + extra = [] + if len(index_expressions) == 0: + status = u1db_create_index(self._db, index_name, 0, NULL) + elif len(index_expressions) == 1: + status = u1db_create_index( + self._db, index_name, 1, + _ensure_str(index_expressions[0], extra)) + elif len(index_expressions) == 2: + status = u1db_create_index( + self._db, index_name, 2, + _ensure_str(index_expressions[0], extra), + _ensure_str(index_expressions[1], extra)) + elif len(index_expressions) == 3: + status = u1db_create_index( + self._db, index_name, 3, + _ensure_str(index_expressions[0], extra), + _ensure_str(index_expressions[1], extra), + _ensure_str(index_expressions[2], extra)) + elif len(index_expressions) == 4: + status = u1db_create_index( + self._db, index_name, 4, + _ensure_str(index_expressions[0], extra), + _ensure_str(index_expressions[1], extra), + _ensure_str(index_expressions[2], extra), + _ensure_str(index_expressions[3], extra)) + else: + status = U1DB_NOT_IMPLEMENTED + handle_status("create_index", status) + + def sync(self, url, creds=None): + cdef const_char_ptr c_url + cdef int local_gen = 0 + cdef u1db_oauth_creds _oauth_creds + cdef u1db_creds *_creds = NULL + c_url = url + if creds is not None: + _oauth_creds.auth_kind = U1DB_OAUTH_AUTH + _oauth_creds.consumer_key = creds['oauth']['consumer_key'] + _oauth_creds.consumer_secret = creds['oauth']['consumer_secret'] + _oauth_creds.token_key = creds['oauth']['token_key'] + _oauth_creds.token_secret = creds['oauth']['token_secret'] + _creds = <u1db_creds *>&_oauth_creds + with nogil: + status = u1db_sync(self._db, c_url, _creds, &local_gen) + handle_status("sync", status) + return local_gen + + def list_indexes(self): + a_list = [] + handle_status("list_indexes", + u1db_list_indexes(self._db, <void *>a_list, + _append_index_definition_to_list)) + return a_list + + def delete_index(self, index_name): + handle_status("delete_index", + u1db_delete_index(self._db, index_name)) + + def get_from_index_list(self, index_name, key_values): + cdef const_char_ptr *values + cdef int n_values + cdef CQuery query + + query = self._query_init(index_name) + res = [] + # keep a reference to new_objs so that the pointers in expressions + # remain valid. + new_objs = _list_to_str_array(key_values, &values, &n_values) + try: + handle_status( + "get_from_index", u1db_get_from_index_list( + self._db, query._query, <void*>res, _append_doc_to_list, + n_values, values)) + finally: + free(<void*>values) + return res + + def get_from_index(self, index_name, *key_values): + cdef CQuery query + cdef int status + + extra = [] + query = self._query_init(index_name) + res = [] + status = U1DB_OK + if len(key_values) == 0: + status = u1db_get_from_index(self._db, query._query, + <void*>res, _append_doc_to_list, 0, NULL) + elif len(key_values) == 1: + status = u1db_get_from_index(self._db, query._query, + <void*>res, _append_doc_to_list, 1, + _ensure_str(key_values[0], extra)) + elif len(key_values) == 2: + status = u1db_get_from_index(self._db, query._query, + <void*>res, _append_doc_to_list, 2, + _ensure_str(key_values[0], extra), + _ensure_str(key_values[1], extra)) + elif len(key_values) == 3: + status = u1db_get_from_index(self._db, query._query, + <void*>res, _append_doc_to_list, 3, + _ensure_str(key_values[0], extra), + _ensure_str(key_values[1], extra), + _ensure_str(key_values[2], extra)) + elif len(key_values) == 4: + status = u1db_get_from_index(self._db, query._query, + <void*>res, _append_doc_to_list, 4, + _ensure_str(key_values[0], extra), + _ensure_str(key_values[1], extra), + _ensure_str(key_values[2], extra), + _ensure_str(key_values[3], extra)) + else: + status = U1DB_NOT_IMPLEMENTED + handle_status("get_from_index", status) + return res + + def get_range_from_index(self, index_name, start_value=None, + end_value=None): + cdef CQuery query + cdef const_char_ptr *start_values + cdef int n_values + cdef const_char_ptr *end_values + + if start_value is not None: + if isinstance(start_value, basestring): + start_value = (start_value,) + new_objs_1 = _list_to_str_array( + start_value, &start_values, &n_values) + else: + n_values = 0 + start_values = NULL + if end_value is not None: + if isinstance(end_value, basestring): + end_value = (end_value,) + new_objs_2 = _list_to_str_array( + end_value, &end_values, &n_values) + else: + end_values = NULL + query = self._query_init(index_name) + res = [] + try: + handle_status("get_range_from_index", + u1db_get_range_from_index( + self._db, query._query, <void*>res, _append_doc_to_list, + n_values, start_values, end_values)) + finally: + if start_values != NULL: + free(<void*>start_values) + if end_values != NULL: + free(<void*>end_values) + return res + + def get_index_keys(self, index_name): + cdef int status + keys = [] + status = U1DB_OK + status = u1db_get_index_keys( + self._db, index_name, <void*>keys, _append_key_to_list) + handle_status("get_index_keys", status) + return keys + + def _query_init(self, index_name): + cdef CQuery query + query = CQuery() + handle_status("query_init", + u1db_query_init(self._db, index_name, &query._query)) + return query + + def get_sync_target(self): + cdef CSyncTarget target + target = CSyncTarget() + target._db = self + handle_status("get_sync_target", + u1db__get_sync_target(target._db._db, &target._st)) + return target + + +cdef class VectorClockRev: + + cdef u1db_vectorclock *_clock + + def __init__(self, s): + if s is None: + self._clock = u1db__vectorclock_from_str(NULL) + else: + self._clock = u1db__vectorclock_from_str(s) + + def __dealloc__(self): + u1db__free_vectorclock(&self._clock) + + def __repr__(self): + cdef int status + cdef char *res + if self._clock == NULL: + return '%s(None)' % (self.__class__.__name__,) + status = u1db__vectorclock_as_str(self._clock, &res) + if status != U1DB_OK: + return '%s(<failure: %d>)' % (status,) + if res == NULL: + val = '%s(NULL)' % (self.__class__.__name__,) + else: + val = '%s(%s)' % (self.__class__.__name__, res) + free(res) + return val + + def as_dict(self): + cdef u1db_vectorclock *cur + cdef int i + cdef int gen + if self._clock == NULL: + return None + res = {} + for i from 0 <= i < self._clock.num_items: + gen = self._clock.items[i].generation + res[self._clock.items[i].replica_uid] = gen + return res + + def as_str(self): + cdef int status + cdef char *res + + status = u1db__vectorclock_as_str(self._clock, &res) + if status != U1DB_OK: + raise RuntimeError("Failed to VectorClockRev.as_str(): %d" % (status,)) + if res == NULL: + s = None + else: + s = res + free(res) + return s + + def increment(self, replica_uid): + cdef int status + + status = u1db__vectorclock_increment(self._clock, replica_uid) + if status != U1DB_OK: + raise RuntimeError("Failed to increment: %d" % (status,)) + + def maximize(self, vcr): + cdef int status + cdef VectorClockRev other + + other = vcr + status = u1db__vectorclock_maximize(self._clock, other._clock) + if status != U1DB_OK: + raise RuntimeError("Failed to maximize: %d" % (status,)) + + def is_newer(self, vcr): + cdef int is_newer + cdef VectorClockRev other + + other = vcr + is_newer = u1db__vectorclock_is_newer(self._clock, other._clock) + if is_newer == 0: + return False + elif is_newer == 1: + return True + else: + raise RuntimeError("Failed to is_newer: %d" % (is_newer,)) + + +def sync_db_to_target(db, target): + """Sync the data between a CDatabase and a CSyncTarget""" + cdef CDatabase cdb + cdef CSyncTarget ctarget + cdef int local_gen = 0, status + + cdb = db + ctarget = target + with nogil: + status = u1db__sync_db_to_target(cdb._db, ctarget._st, &local_gen) + handle_status("sync_db_to_target", status) + return local_gen + + +def create_http_sync_target(url): + cdef CSyncTarget target + + target = CSyncTarget() + handle_status("create_http_sync_target", + u1db__create_http_sync_target(url, &target._st)) + return target + + +def create_oauth_http_sync_target(url, consumer_key, consumer_secret, + token_key, token_secret): + cdef CSyncTarget target + + target = CSyncTarget() + handle_status("create_http_sync_target", + u1db__create_oauth_http_sync_target(url, consumer_key, consumer_secret, + token_key, token_secret, + &target._st)) + return target + + +def _format_sync_url(target, source_replica_uid): + cdef CSyncTarget st + cdef char *sync_url = NULL + cdef object res + st = target + handle_status("format_sync_url", + u1db__format_sync_url(st._st, source_replica_uid, &sync_url)) + if sync_url == NULL: + res = None + else: + res = sync_url + free(sync_url) + return res + + +def _get_oauth_authorization(target, method, url): + cdef CSyncTarget st + cdef char *auth = NULL + + st = target + handle_status("get_oauth_authorization", + u1db__get_oauth_authorization(st._st, method, url, &auth)) + res = None + if auth != NULL: + res = auth + free(auth) + return res diff --git a/src/leap/soledad/u1db/tests/commandline/__init__.py b/src/leap/soledad/u1db/tests/commandline/__init__.py new file mode 100644 index 00000000..007cecd3 --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +import errno +import time + + +def safe_close(process, timeout=0.1): + """Shutdown the process in the nicest fashion you can manage. + + :param process: A subprocess.Popen object. + :param timeout: We'll try to send 'SIGTERM' but if the process is alive + longer that 'timeout', we'll send SIGKILL. + """ + if process.poll() is not None: + return + try: + process.terminate() + except OSError, e: + if e.errno in (errno.ESRCH,): + # Process has exited + return + tend = time.time() + timeout + while time.time() < tend: + if process.poll() is not None: + return + time.sleep(0.01) + try: + process.kill() + except OSError, e: + if e.errno in (errno.ESRCH,): + # Process has exited + return + process.wait() diff --git a/src/leap/soledad/u1db/tests/commandline/test_client.py b/src/leap/soledad/u1db/tests/commandline/test_client.py new file mode 100644 index 00000000..78ca21eb --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/test_client.py @@ -0,0 +1,916 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +import cStringIO +import os +import sys +try: + import simplejson as json +except ImportError: + import json # noqa +import subprocess + +from u1db import ( + errors, + open as u1db_open, + tests, + vectorclock, + ) +from u1db.commandline import ( + client, + serve, + ) +from u1db.tests.commandline import safe_close +from u1db.tests import test_remote_sync_target + + +class TestArgs(tests.TestCase): + """These tests are meant to test just the argument parsing. + + Each Command should have at least one test, possibly more if it allows + optional arguments, etc. + """ + + def setUp(self): + super(TestArgs, self).setUp() + self.parser = client.client_commands.make_argparser() + + def parse_args(self, args): + # ArgumentParser.parse_args doesn't play very nicely with a test suite, + # so we trap SystemExit in case something is wrong with the args we're + # parsing. + try: + return self.parser.parse_args(args) + except SystemExit: + raise AssertionError('got SystemExit') + + def test_create(self): + args = self.parse_args(['create', 'test.db']) + self.assertEqual(client.CmdCreate, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual(None, args.doc_id) + self.assertEqual(None, args.infile) + + def test_create_custom_doc_id(self): + args = self.parse_args(['create', '--id', 'xyz', 'test.db']) + self.assertEqual(client.CmdCreate, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('xyz', args.doc_id) + self.assertEqual(None, args.infile) + + def test_delete(self): + args = self.parse_args(['delete', 'test.db', 'doc-id', 'doc-rev']) + self.assertEqual(client.CmdDelete, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual('doc-rev', args.doc_rev) + + def test_get(self): + args = self.parse_args(['get', 'test.db', 'doc-id']) + self.assertEqual(client.CmdGet, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual(None, args.outfile) + + def test_get_dash(self): + args = self.parse_args(['get', 'test.db', 'doc-id', '-']) + self.assertEqual(client.CmdGet, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual(sys.stdout, args.outfile) + + def test_init_db(self): + args = self.parse_args( + ['init-db', 'test.db', '--replica-uid=replica-uid']) + self.assertEqual(client.CmdInitDB, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('replica-uid', args.replica_uid) + + def test_init_db_no_replica(self): + args = self.parse_args(['init-db', 'test.db']) + self.assertEqual(client.CmdInitDB, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertIs(None, args.replica_uid) + + def test_put(self): + args = self.parse_args(['put', 'test.db', 'doc-id', 'old-doc-rev']) + self.assertEqual(client.CmdPut, args.subcommand) + self.assertEqual('test.db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual('old-doc-rev', args.doc_rev) + self.assertEqual(None, args.infile) + + def test_sync(self): + args = self.parse_args(['sync', 'source', 'target']) + self.assertEqual(client.CmdSync, args.subcommand) + self.assertEqual('source', args.source) + self.assertEqual('target', args.target) + + def test_create_index(self): + args = self.parse_args(['create-index', 'db', 'index', 'expression']) + self.assertEqual(client.CmdCreateIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + self.assertEqual(['expression'], args.expression) + + def test_create_index_multi_expression(self): + args = self.parse_args(['create-index', 'db', 'index', 'e1', 'e2']) + self.assertEqual(client.CmdCreateIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + self.assertEqual(['e1', 'e2'], args.expression) + + def test_list_indexes(self): + args = self.parse_args(['list-indexes', 'db']) + self.assertEqual(client.CmdListIndexes, args.subcommand) + self.assertEqual('db', args.database) + + def test_delete_index(self): + args = self.parse_args(['delete-index', 'db', 'index']) + self.assertEqual(client.CmdDeleteIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + + def test_get_index_keys(self): + args = self.parse_args(['get-index-keys', 'db', 'index']) + self.assertEqual(client.CmdGetIndexKeys, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + + def test_get_from_index(self): + args = self.parse_args(['get-from-index', 'db', 'index', 'foo']) + self.assertEqual(client.CmdGetFromIndex, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('index', args.index) + self.assertEqual(['foo'], args.values) + + def test_get_doc_conflicts(self): + args = self.parse_args(['get-doc-conflicts', 'db', 'doc-id']) + self.assertEqual(client.CmdGetDocConflicts, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('doc-id', args.doc_id) + + def test_resolve(self): + args = self.parse_args( + ['resolve-doc', 'db', 'doc-id', 'rev:1', 'other:1']) + self.assertEqual(client.CmdResolve, args.subcommand) + self.assertEqual('db', args.database) + self.assertEqual('doc-id', args.doc_id) + self.assertEqual(['rev:1', 'other:1'], args.doc_revs) + self.assertEqual(None, args.infile) + + +class TestCaseWithDB(tests.TestCase): + """These next tests are meant to have one class per Command. + + It is meant to test the inner workings of each command. The detailed + testing should happen in these classes. Stuff like how it handles errors, + etc. should be done here. + """ + + def setUp(self): + super(TestCaseWithDB, self).setUp() + self.working_dir = self.createTempDir() + self.db_path = self.working_dir + '/test.db' + self.db = u1db_open(self.db_path, create=True) + self.db._set_replica_uid('test') + self.addCleanup(self.db.close) + + def make_command(self, cls, stdin_content=''): + inf = cStringIO.StringIO(stdin_content) + out = cStringIO.StringIO() + err = cStringIO.StringIO() + return cls(inf, out, err) + + +class TestCmdCreate(TestCaseWithDB): + + def test_create(self): + cmd = self.make_command(client.CmdCreate) + inf = cStringIO.StringIO(tests.simple_doc) + cmd.run(self.db_path, inf, 'test-id') + doc = self.db.get_doc('test-id') + self.assertEqual(tests.simple_doc, doc.get_json()) + self.assertFalse(doc.has_conflicts) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('id: test-id\nrev: %s\n' % (doc.rev,), + cmd.stderr.getvalue()) + + +class TestCmdDelete(TestCaseWithDB): + + def test_delete(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + cmd = self.make_command(client.CmdDelete) + cmd.run(self.db_path, doc.doc_id, doc.rev) + doc2 = self.db.get_doc(doc.doc_id, include_deleted=True) + self.assertEqual(doc.doc_id, doc2.doc_id) + self.assertNotEqual(doc.rev, doc2.rev) + self.assertIs(None, doc2.get_json()) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (doc2.rev,), cmd.stderr.getvalue()) + + def test_delete_fails_if_nonexistent(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + db2_path = self.db_path + '.typo' + cmd = self.make_command(client.CmdDelete) + # TODO: We should really not be showing a traceback here. But we need + # to teach the commandline infrastructure how to handle + # exceptions. + # However, we *do* want to test that the db doesn't get created + # by accident. + self.assertRaises(errors.DatabaseDoesNotExist, + cmd.run, db2_path, doc.doc_id, doc.rev) + self.assertFalse(os.path.exists(db2_path)) + + def test_delete_no_such_doc(self): + cmd = self.make_command(client.CmdDelete) + # TODO: We should really not be showing a traceback here. But we need + # to teach the commandline infrastructure how to handle + # exceptions. + self.assertRaises(errors.DocumentDoesNotExist, + cmd.run, self.db_path, 'no-doc-id', 'no-rev') + + def test_delete_bad_rev(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + cmd = self.make_command(client.CmdDelete) + self.assertRaises(errors.RevisionConflict, + cmd.run, self.db_path, doc.doc_id, 'not-the-actual-doc-rev:1') + # TODO: Test that we get a pretty output. + + +class TestCmdGet(TestCaseWithDB): + + def setUp(self): + super(TestCmdGet, self).setUp() + self.doc = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-test-doc') + + def test_get_simple(self): + cmd = self.make_command(client.CmdGet) + cmd.run(self.db_path, 'my-test-doc', None) + self.assertEqual(tests.simple_doc + "\n", cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (self.doc.rev,), + cmd.stderr.getvalue()) + + def test_get_conflict(self): + doc = self.make_document('my-test-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + cmd = self.make_command(client.CmdGet) + cmd.run(self.db_path, 'my-test-doc', None) + self.assertEqual('{}\n', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\nDocument has conflicts.\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_get_fail(self): + cmd = self.make_command(client.CmdGet) + result = cmd.run(self.db_path, 'doc-not-there', None) + self.assertEqual(1, result) + self.assertEqual("", cmd.stdout.getvalue()) + self.assertTrue("not found" in cmd.stderr.getvalue()) + + def test_get_no_database(self): + cmd = self.make_command(client.CmdGet) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", None) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + +class TestCmdGetDocConflicts(TestCaseWithDB): + + def setUp(self): + super(TestCmdGetDocConflicts, self).setUp() + self.doc1 = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-doc') + self.doc2 = self.make_document('my-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + self.doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + + def test_get_doc_conflicts_none(self): + self.db.create_doc_from_json(tests.simple_doc, doc_id='a-doc') + cmd = self.make_command(client.CmdGetDocConflicts) + cmd.run(self.db_path, 'a-doc') + self.assertEqual([], json.loads(cmd.stdout.getvalue())) + self.assertEqual('', cmd.stderr.getvalue()) + + def test_get_doc_conflicts_simple(self): + cmd = self.make_command(client.CmdGetDocConflicts) + cmd.run(self.db_path, 'my-doc') + self.assertEqual( + [dict(rev=self.doc2.rev, content=self.doc2.content), + dict(rev=self.doc1.rev, content=self.doc1.content)], + json.loads(cmd.stdout.getvalue())) + self.assertEqual('', cmd.stderr.getvalue()) + + def test_get_doc_conflicts_no_db(self): + cmd = self.make_command(client.CmdGetDocConflicts) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_get_doc_conflicts_no_doc(self): + cmd = self.make_command(client.CmdGetDocConflicts) + retval = cmd.run(self.db_path, "some-doc") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n') + + +class TestCmdInit(TestCaseWithDB): + + def test_init_new(self): + path = self.working_dir + '/test2.db' + self.assertFalse(os.path.exists(path)) + cmd = self.make_command(client.CmdInitDB) + cmd.run(path, 'test-uid') + self.assertTrue(os.path.exists(path)) + db = u1db_open(path, create=False) + self.assertEqual('test-uid', db._replica_uid) + + def test_init_no_uid(self): + path = self.working_dir + '/test2.db' + cmd = self.make_command(client.CmdInitDB) + cmd.run(path, None) + self.assertTrue(os.path.exists(path)) + db = u1db_open(path, create=False) + self.assertIsNot(None, db._replica_uid) + + +class TestCmdPut(TestCaseWithDB): + + def setUp(self): + super(TestCmdPut, self).setUp() + self.doc = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-test-doc') + + def test_put_simple(self): + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + cmd.run(self.db_path, 'my-test-doc', self.doc.rev, inf) + doc = self.db.get_doc('my-test-doc') + self.assertNotEqual(self.doc.rev, doc.rev) + self.assertGetDoc(self.db, 'my-test-doc', doc.rev, + tests.nested_doc, False) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_put_no_db(self): + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", + 'my-test-doc', self.doc.rev, inf) + self.assertEqual(retval, 1) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Database does not exist.\n', cmd.stderr.getvalue()) + + def test_put_no_doc(self): + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path, 'no-such-doc', 'wut:1', inf) + self.assertEqual(1, retval) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Document does not exist.\n', cmd.stderr.getvalue()) + + def test_put_doc_old_rev(self): + rev = self.doc.rev + doc = self.make_document('my-test-doc', rev, '{}', False) + self.db.put_doc(doc) + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path, 'my-test-doc', rev, inf) + self.assertEqual(1, retval) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Given revision is not current.\n', + cmd.stderr.getvalue()) + + def test_put_doc_w_conflicts(self): + doc = self.make_document('my-test-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + cmd = self.make_command(client.CmdPut) + inf = cStringIO.StringIO(tests.nested_doc) + retval = cmd.run(self.db_path, 'my-test-doc', 'other:1', inf) + self.assertEqual(1, retval) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('Document has conflicts.\n' + 'Inspect with get-doc-conflicts, then resolve.\n', + cmd.stderr.getvalue()) + + +class TestCmdResolve(TestCaseWithDB): + + def setUp(self): + super(TestCmdResolve, self).setUp() + self.doc1 = self.db.create_doc_from_json( + tests.simple_doc, doc_id='my-doc') + self.doc2 = self.make_document('my-doc', 'other:1', '{}', False) + self.db._put_doc_if_newer( + self.doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + + def test_resolve_simple(self): + self.assertTrue(self.db.get_doc('my-doc').has_conflicts) + cmd = self.make_command(client.CmdResolve) + inf = cStringIO.StringIO(tests.nested_doc) + cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf) + doc = self.db.get_doc('my-doc') + vec = vectorclock.VectorClockRev(doc.rev) + self.assertTrue( + vec.is_newer(vectorclock.VectorClockRev(self.doc1.rev))) + self.assertTrue( + vec.is_newer(vectorclock.VectorClockRev(self.doc2.rev))) + self.assertGetDoc(self.db, 'my-doc', doc.rev, tests.nested_doc, False) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual('rev: %s\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_resolve_double(self): + moar = '{"x": 42}' + doc3 = self.make_document('my-doc', 'third:1', moar, False) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + cmd = self.make_command(client.CmdResolve) + inf = cStringIO.StringIO(tests.nested_doc) + cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf) + doc = self.db.get_doc('my-doc') + self.assertGetDoc(self.db, 'my-doc', doc.rev, moar, True) + self.assertEqual('', cmd.stdout.getvalue()) + self.assertEqual( + 'rev: %s\nDocument still has conflicts.\n' % (doc.rev,), + cmd.stderr.getvalue()) + + def test_resolve_no_db(self): + cmd = self.make_command(client.CmdResolve) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", [], None) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_resolve_no_doc(self): + cmd = self.make_command(client.CmdResolve) + retval = cmd.run(self.db_path, "foo", [], None) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n') + + +class TestCmdSync(TestCaseWithDB): + + def setUp(self): + super(TestCmdSync, self).setUp() + self.db2_path = self.working_dir + '/test2.db' + self.db2 = u1db_open(self.db2_path, create=True) + self.addCleanup(self.db2.close) + self.db2._set_replica_uid('test2') + self.doc = self.db.create_doc_from_json( + tests.simple_doc, doc_id='test-id') + self.doc2 = self.db2.create_doc_from_json( + tests.nested_doc, doc_id='my-test-id') + + def test_sync(self): + cmd = self.make_command(client.CmdSync) + cmd.run(self.db_path, self.db2_path) + self.assertGetDoc(self.db2, 'test-id', self.doc.rev, tests.simple_doc, + False) + self.assertGetDoc(self.db, 'my-test-id', self.doc2.rev, + tests.nested_doc, False) + + +class TestCmdSyncRemote(tests.TestCaseWithServer, TestCaseWithDB): + + make_app_with_state = \ + staticmethod(test_remote_sync_target.make_http_app) + + def setUp(self): + super(TestCmdSyncRemote, self).setUp() + self.startServer() + self.db2 = self.request_state._create_database('test2.db') + + def test_sync_remote(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db2.create_doc_from_json(tests.nested_doc) + db2_url = self.getURL('test2.db') + self.assertTrue(db2_url.startswith('http://')) + self.assertTrue(db2_url.endswith('/test2.db')) + cmd = self.make_command(client.CmdSync) + cmd.run(self.db_path, db2_url) + self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc, + False) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, + False) + + +class TestCmdCreateIndex(TestCaseWithDB): + + def test_create_index(self): + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["bar", "baz"]) + self.assertEqual(self.db.list_indexes(), [('foo', ['bar', "baz"])]) + self.assertEqual(retval, None) # conveniently mapped to 0 + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_create_index_no_db(self): + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", ["bar"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_create_dupe_index(self): + self.db.create_index("foo", "bar") + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["bar"]) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_create_dupe_index_different_expression(self): + self.db.create_index("foo", "bar") + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["baz"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), + "There is already a different index named 'foo'.\n") + + def test_create_index_bad_expression(self): + cmd = self.make_command(client.CmdCreateIndex) + retval = cmd.run(self.db_path, "foo", ["WAT()"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), + 'Bad index expression.\n') + + +class TestCmdListIndexes(TestCaseWithDB): + + def test_list_no_indexes(self): + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_list_indexes(self): + self.db.create_index("foo", "bar", "baz") + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), 'foo: bar, baz\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_list_several_indexes(self): + self.db.create_index("foo", "bar", "baz") + self.db.create_index("bar", "baz", "foo") + self.db.create_index("baz", "foo", "bar") + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), + 'bar: baz, foo\n' + 'baz: foo, bar\n' + 'foo: bar, baz\n' + ) + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_list_indexes_no_db(self): + cmd = self.make_command(client.CmdListIndexes) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + +class TestCmdDeleteIndex(TestCaseWithDB): + + def test_delete_index(self): + self.db.create_index("foo", "bar", "baz") + cmd = self.make_command(client.CmdDeleteIndex) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + self.assertEqual([], self.db.list_indexes()) + + def test_delete_index_no_db(self): + cmd = self.make_command(client.CmdDeleteIndex) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_delete_index_no_index(self): + cmd = self.make_command(client.CmdDeleteIndex) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + +class TestCmdGetIndexKeys(TestCaseWithDB): + + def test_get_index_keys(self): + self.db.create_index("foo", "bar") + self.db.create_doc_from_json('{"bar": 42}') + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '42\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_index_keys_nonascii(self): + self.db.create_index("foo", "bar") + self.db.create_doc_from_json('{"bar": "\u00a4"}') + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '\xc2\xa4\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_index_keys_empty(self): + self.db.create_index("foo", "bar") + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_index_keys_no_db(self): + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_get_index_keys_no_index(self): + cmd = self.make_command(client.CmdGetIndexKeys) + retval = cmd.run(self.db_path, "foo") + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n') + + +class TestCmdGetFromIndex(TestCaseWithDB): + + def test_get_from_index(self): + self.db.create_index("index", "key") + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db.create_doc_from_json(tests.nested_doc) + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path, "index", ["value"]) + self.assertEqual(retval, None) + self.assertEqual(sorted(json.loads(cmd.stdout.getvalue())), + sorted([dict(id=doc1.doc_id, + rev=doc1.rev, + content=doc1.content), + dict(id=doc2.doc_id, + rev=doc2.rev, + content=doc2.content), + ])) + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_from_index_empty(self): + self.db.create_index("index", "key") + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path, "index", ["value"]) + self.assertEqual(retval, None) + self.assertEqual(cmd.stdout.getvalue(), '[]\n') + self.assertEqual(cmd.stderr.getvalue(), '') + + def test_get_from_index_no_db(self): + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", []) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n') + + def test_get_from_index_no_index(self): + cmd = self.make_command(client.CmdGetFromIndex) + retval = cmd.run(self.db_path, "foo", []) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n') + + def test_get_from_index_two_expr_instead_of_one(self): + self.db.create_index("index", "key1") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1", "value2"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query: index 'index' requires" + " 1 query expression, not 2.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1'\n" + % self.db_path, cmd.stderr.getvalue()) + + def test_get_from_index_three_expr_instead_of_two(self): + self.db.create_index("index", "key1", "key2") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1", "value2", "value3"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query: index 'index' requires" + " 2 query expressions, not 3.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1' 'value2'\n" + % self.db_path, cmd.stderr.getvalue()) + + def test_get_from_index_one_expr_instead_of_two(self): + self.db.create_index("index", "key1", "key2") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query: index 'index' requires" + " 2 query expressions, not 1.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1' '*'\n" + % self.db_path, cmd.stderr.getvalue()) + + def test_get_from_index_cant_bad_glob(self): + self.db.create_index("index", "key1", "key2") + cmd = self.make_command(client.CmdGetFromIndex) + cmd.argv = ["XX", "YY"] + retval = cmd.run(self.db_path, "index", ["value1*", "value2"]) + self.assertEqual(retval, 1) + self.assertEqual(cmd.stdout.getvalue(), '') + self.assertEqual("Invalid query:" + " a star can only be followed by stars.\n" + "For example, the following would be valid:\n" + " XX YY %r 'index' 'value1*' '*'\n" + % self.db_path, cmd.stderr.getvalue()) + + +class RunMainHelper(object): + + def run_main(self, args, stdin=None): + if stdin is not None: + self.patch(sys, 'stdin', cStringIO.StringIO(stdin)) + stdout = cStringIO.StringIO() + stderr = cStringIO.StringIO() + self.patch(sys, 'stdout', stdout) + self.patch(sys, 'stderr', stderr) + try: + ret = client.main(args) + except SystemExit, e: + self.fail("Intercepted SystemExit: %s" % (e,)) + if ret is None: + ret = 0 + return ret, stdout.getvalue(), stderr.getvalue() + + +class TestCommandLine(TestCaseWithDB, RunMainHelper): + """These are meant to test that the infrastructure is fully connected. + + Each command is likely to only have one test here. Something that ensures + 'main()' knows about and can run the command correctly. Most logic-level + testing of the Command should go into its own test class above. + """ + + def _get_u1db_client_path(self): + from u1db import __path__ as u1db_path + u1db_parent_dir = os.path.dirname(u1db_path[0]) + return os.path.join(u1db_parent_dir, 'u1db-client') + + def runU1DBClient(self, args): + command = [sys.executable, self._get_u1db_client_path()] + command.extend(args) + p = subprocess.Popen(command, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.addCleanup(safe_close, p) + return p + + def test_create_subprocess(self): + p = self.runU1DBClient(['create', '--id', 'test-id', self.db_path]) + stdout, stderr = p.communicate(tests.simple_doc) + self.assertEqual(0, p.returncode) + self.assertEqual('', stdout) + doc = self.db.get_doc('test-id') + self.assertEqual(tests.simple_doc, doc.get_json()) + self.assertFalse(doc.has_conflicts) + expected = 'id: test-id\nrev: %s\n' % (doc.rev,) + stripped = stderr.replace('\r\n', '\n') + if expected != stripped: + # When run under python-dbg, it prints out the refs after the + # actual content, so match it if we need to. + expected_re = expected + '\[\d+ refs\]\n' + self.assertRegexpMatches(stripped, expected_re) + + def test_get(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + ret, stdout, stderr = self.run_main(['get', self.db_path, 'test-id']) + self.assertEqual(0, ret) + self.assertEqual(tests.simple_doc + "\n", stdout) + self.assertEqual('rev: %s\n' % (doc.rev,), stderr) + ret, stdout, stderr = self.run_main(['get', self.db_path, 'not-there']) + self.assertEqual(1, ret) + + def test_delete(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + ret, stdout, stderr = self.run_main( + ['delete', self.db_path, 'test-id', doc.rev]) + doc = self.db.get_doc('test-id', include_deleted=True) + self.assertEqual(0, ret) + self.assertEqual('', stdout) + self.assertEqual('rev: %s\n' % (doc.rev,), stderr) + + def test_init_db(self): + path = self.working_dir + '/test2.db' + ret, stdout, stderr = self.run_main(['init-db', path]) + u1db_open(path, create=False) + + def test_put(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + ret, stdout, stderr = self.run_main( + ['put', self.db_path, 'test-id', doc.rev], + stdin=tests.nested_doc) + doc = self.db.get_doc('test-id') + self.assertFalse(doc.has_conflicts) + self.assertEqual(tests.nested_doc, doc.get_json()) + self.assertEqual(0, ret) + self.assertEqual('', stdout) + self.assertEqual('rev: %s\n' % (doc.rev,), stderr) + + def test_sync(self): + doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id') + self.db2_path = self.working_dir + '/test2.db' + self.db2 = u1db_open(self.db2_path, create=True) + self.addCleanup(self.db2.close) + ret, stdout, stderr = self.run_main( + ['sync', self.db_path, self.db2_path]) + self.assertEqual(0, ret) + self.assertEqual('', stdout) + self.assertEqual('', stderr) + self.assertGetDoc( + self.db2, 'test-id', doc.rev, tests.simple_doc, False) + + +class TestHTTPIntegration(tests.TestCaseWithServer, RunMainHelper): + """Meant to test the cases where commands operate over http.""" + + def server_def(self): + def make_server(host_port, _application): + return serve.make_server(host_port[0], host_port[1], + self.working_dir) + return make_server, "shutdown", "http" + + def setUp(self): + super(TestHTTPIntegration, self).setUp() + self.working_dir = self.createTempDir(prefix='u1db-http-server-') + self.startServer() + + def getPath(self, dbname): + return os.path.join(self.working_dir, dbname) + + def test_init_db(self): + url = self.getURL('new.db') + ret, stdout, stderr = self.run_main(['init-db', url]) + u1db_open(self.getPath('new.db'), create=False) + + def test_create_get_put_delete(self): + db = u1db_open(self.getPath('test.db'), create=True) + url = self.getURL('test.db') + doc_id = '%abcd' + ret, stdout, stderr = self.run_main(['create', url, '--id', doc_id], + stdin=tests.simple_doc) + self.assertEqual(0, ret) + ret, stdout, stderr = self.run_main(['get', url, doc_id]) + self.assertEqual(0, ret) + self.assertTrue(stderr.startswith('rev: ')) + doc_rev = stderr[len('rev: '):].rstrip() + ret, stdout, stderr = self.run_main(['put', url, doc_id, doc_rev], + stdin=tests.nested_doc) + self.assertEqual(0, ret) + self.assertTrue(stderr.startswith('rev: ')) + doc_rev1 = stderr[len('rev: '):].rstrip() + self.assertGetDoc(db, doc_id, doc_rev1, tests.nested_doc, False) + ret, stdout, stderr = self.run_main(['delete', url, doc_id, doc_rev1]) + self.assertEqual(0, ret) + self.assertTrue(stderr.startswith('rev: ')) + doc_rev2 = stderr[len('rev: '):].rstrip() + self.assertGetDocIncludeDeleted(db, doc_id, doc_rev2, None, False) diff --git a/src/leap/soledad/u1db/tests/commandline/test_command.py b/src/leap/soledad/u1db/tests/commandline/test_command.py new file mode 100644 index 00000000..43580f23 --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/test_command.py @@ -0,0 +1,105 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +import cStringIO +import argparse + +from u1db import ( + tests, + ) +from u1db.commandline import ( + command, + ) + + +class MyTestCommand(command.Command): + """Help String""" + + name = 'mycmd' + + @classmethod + def _populate_subparser(cls, parser): + parser.add_argument('foo') + parser.add_argument('--bar', dest='nbar', type=int) + + def run(self, foo, nbar): + self.stdout.write('foo: %s nbar: %d' % (foo, nbar)) + return 0 + + +def make_stdin_out_err(): + return cStringIO.StringIO(), cStringIO.StringIO(), cStringIO.StringIO() + + +class TestCommandGroup(tests.TestCase): + + def trap_system_exit(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except SystemExit, e: + self.fail('Got SystemExit trying to run: %s' % (func,)) + + def parse_args(self, parser, args): + return self.trap_system_exit(parser.parse_args, args) + + def test_register(self): + group = command.CommandGroup() + self.assertEqual({}, group.commands) + group.register(MyTestCommand) + self.assertEqual({'mycmd': MyTestCommand}, + group.commands) + + def test_make_argparser(self): + group = command.CommandGroup(description='test-foo') + parser = group.make_argparser() + self.assertIsInstance(parser, argparse.ArgumentParser) + + def test_make_argparser_with_command(self): + group = command.CommandGroup(description='test-foo') + group.register(MyTestCommand) + parser = group.make_argparser() + args = self.parse_args(parser, ['mycmd', 'foozizle', '--bar=10']) + self.assertEqual('foozizle', args.foo) + self.assertEqual(10, args.nbar) + self.assertEqual(MyTestCommand, args.subcommand) + + def test_run_argv(self): + group = command.CommandGroup() + group.register(MyTestCommand) + stdin, stdout, stderr = make_stdin_out_err() + ret = self.trap_system_exit(group.run_argv, + ['mycmd', 'foozizle', '--bar=10'], + stdin, stdout, stderr) + self.assertEqual(0, ret) + + +class TestCommand(tests.TestCase): + + def make_command(self): + stdin, stdout, stderr = make_stdin_out_err() + return command.Command(stdin, stdout, stderr) + + def test__init__(self): + cmd = self.make_command() + self.assertIsNot(None, cmd.stdin) + self.assertIsNot(None, cmd.stdout) + self.assertIsNot(None, cmd.stderr) + + def test_run_args(self): + stdin, stdout, stderr = make_stdin_out_err() + cmd = MyTestCommand(stdin, stdout, stderr) + res = cmd.run(foo='foozizle', nbar=10) + self.assertEqual('foo: foozizle nbar: 10', stdout.getvalue()) diff --git a/src/leap/soledad/u1db/tests/commandline/test_serve.py b/src/leap/soledad/u1db/tests/commandline/test_serve.py new file mode 100644 index 00000000..6397eabe --- /dev/null +++ b/src/leap/soledad/u1db/tests/commandline/test_serve.py @@ -0,0 +1,101 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +import os +import socket +import subprocess +import sys + +from u1db import ( + __version__ as _u1db_version, + open as u1db_open, + tests, + ) +from u1db.remote import http_client +from u1db.tests.commandline import safe_close + + +class TestU1DBServe(tests.TestCase): + + def _get_u1db_serve_path(self): + from u1db import __path__ as u1db_path + u1db_parent_dir = os.path.dirname(u1db_path[0]) + return os.path.join(u1db_parent_dir, 'u1db-serve') + + def startU1DBServe(self, args): + command = [sys.executable, self._get_u1db_serve_path()] + command.extend(args) + p = subprocess.Popen(command, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.addCleanup(safe_close, p) + return p + + def test_help(self): + p = self.startU1DBServe(['--help']) + stdout, stderr = p.communicate() + if stderr != '': + # stderr should normally be empty, but if we are running under + # python-dbg, it contains the following string + self.assertRegexpMatches(stderr, r'\[\d+ refs\]') + self.assertEqual(0, p.returncode) + self.assertIn('Run the U1DB server', stdout) + + def test_bind_to_port(self): + p = self.startU1DBServe([]) + starts = 'listening on:' + x = p.stdout.readline() + self.assertTrue(x.startswith(starts)) + port = int(x[len(starts):].split(":")[1]) + url = "http://127.0.0.1:%s/" % port + c = http_client.HTTPClientBase(url) + self.addCleanup(c.close) + res, _ = c._request_json('GET', []) + self.assertEqual({'version': _u1db_version}, res) + + def test_supply_port(self): + s = socket.socket() + s.bind(('127.0.0.1', 0)) + host, port = s.getsockname() + s.close() + p = self.startU1DBServe(['--port', str(port)]) + x = p.stdout.readline().strip() + self.assertEqual('listening on: 127.0.0.1:%s' % (port,), x) + url = "http://127.0.0.1:%s/" % port + c = http_client.HTTPClientBase(url) + self.addCleanup(c.close) + res, _ = c._request_json('GET', []) + self.assertEqual({'version': _u1db_version}, res) + + def test_bind_to_host(self): + p = self.startU1DBServe(["--host", "localhost"]) + starts = 'listening on: 127.0.0.1:' + x = p.stdout.readline() + self.assertTrue(x.startswith(starts)) + + def test_supply_working_dir(self): + tmp_dir = self.createTempDir('u1db-serve-test') + db = u1db_open(os.path.join(tmp_dir, 'landmark.db'), create=True) + db.close() + p = self.startU1DBServe(['--working-dir', tmp_dir]) + starts = 'listening on:' + x = p.stdout.readline() + self.assertTrue(x.startswith(starts)) + port = int(x[len(starts):].split(":")[1]) + url = "http://127.0.0.1:%s/landmark.db" % port + c = http_client.HTTPClientBase(url) + self.addCleanup(c.close) + res, _ = c._request_json('GET', []) + self.assertEqual({}, res) diff --git a/src/leap/soledad/u1db/tests/test_auth_middleware.py b/src/leap/soledad/u1db/tests/test_auth_middleware.py new file mode 100644 index 00000000..e765f8a7 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_auth_middleware.py @@ -0,0 +1,309 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test OAuth wsgi middleware""" +import paste.fixture +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa +import time + +from u1db import tests + +from u1db.remote.oauth_middleware import OAuthMiddleware +from u1db.remote.basic_auth_middleware import BasicAuthMiddleware, Unauthorized + + +BASE_URL = 'https://example.net' + + +class TestBasicAuthMiddleware(tests.TestCase): + + def setUp(self): + super(TestBasicAuthMiddleware, self).setUp() + self.got = [] + + def witness_app(environ, start_response): + start_response("200 OK", [("content-type", "text/plain")]) + self.got.append(( + environ['user_id'], environ['PATH_INFO'], + environ['QUERY_STRING'])) + return ["ok"] + + class MyAuthMiddleware(BasicAuthMiddleware): + + def verify_user(self, environ, user, password): + if user != "correct_user": + raise Unauthorized + if password != "correct_password": + raise Unauthorized + environ['user_id'] = user + + self.auth_midw = MyAuthMiddleware(witness_app, prefix="/pfx/") + self.app = paste.fixture.TestApp(self.auth_midw) + + def test_expect_prefix(self): + url = BASE_URL + '/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"error": "bad request"}', resp.body) + + def test_missing_auth(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "unauthorized", + "message": "Missing Basic Authentication."}, + json.loads(resp.body)) + + def test_correct_auth(self): + user = "correct_user" + password = "correct_password" + params = {'old_rev': 'old-rev'} + url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + auth = '%s:%s' % (user, password) + headers = { + 'Authorization': 'Basic %s' % (auth.encode('base64'),)} + resp = self.app.delete(url, headers=headers) + self.assertEqual(200, resp.status) + self.assertEqual( + [('correct_user', '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_incorrect_auth(self): + user = "correct_user" + password = "incorrect_password" + params = {'old_rev': 'old-rev'} + url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + auth = '%s:%s' % (user, password) + headers = { + 'Authorization': 'Basic %s' % (auth.encode('base64'),)} + resp = self.app.delete(url, headers=headers, expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "unauthorized", + "message": "Incorrect password or login."}, + json.loads(resp.body)) + + +class TestOAuthMiddlewareDefaultPrefix(tests.TestCase): + def setUp(self): + + super(TestOAuthMiddlewareDefaultPrefix, self).setUp() + self.got = [] + + def witness_app(environ, start_response): + start_response("200 OK", [("content-type", "text/plain")]) + self.got.append((environ['token_key'], environ['PATH_INFO'], + environ['QUERY_STRING'])) + return ["ok"] + + class MyOAuthMiddleware(OAuthMiddleware): + get_oauth_data_store = lambda self: tests.testingOAuthStore + + def verify(self, environ, oauth_req): + consumer, token = super(MyOAuthMiddleware, self).verify( + environ, oauth_req) + environ['token_key'] = token.key + + self.oauth_midw = MyOAuthMiddleware(witness_app, BASE_URL) + self.app = paste.fixture.TestApp(self.oauth_midw) + + def test_expect_tilde(self): + url = BASE_URL + '/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"error": "bad request"}', resp.body) + + def test_oauth_in_header(self): + url = BASE_URL + '/~/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer2, + tests.token2, + parameters=params, + http_url=url, + http_method='DELETE' + ) + url = oauth_req.get_normalized_http_url() + '?' + ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer2, tests.token2) + resp = self.app.delete(url, headers=oauth_req.to_header()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token2.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_in_query_string(self): + url = BASE_URL + '/~/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer1, tests.token1) + resp = self.app.delete(oauth_req.to_url()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token1.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + +class TestOAuthMiddleware(tests.TestCase): + + def setUp(self): + super(TestOAuthMiddleware, self).setUp() + self.got = [] + + def witness_app(environ, start_response): + start_response("200 OK", [("content-type", "text/plain")]) + self.got.append((environ['token_key'], environ['PATH_INFO'], + environ['QUERY_STRING'])) + return ["ok"] + + class MyOAuthMiddleware(OAuthMiddleware): + get_oauth_data_store = lambda self: tests.testingOAuthStore + + def verify(self, environ, oauth_req): + consumer, token = super(MyOAuthMiddleware, self).verify( + environ, oauth_req) + environ['token_key'] = token.key + + self.oauth_midw = MyOAuthMiddleware( + witness_app, BASE_URL, prefix='/pfx/') + self.app = paste.fixture.TestApp(self.oauth_midw) + + def test_expect_prefix(self): + url = BASE_URL + '/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"error": "bad request"}', resp.body) + + def test_missing_oauth(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + resp = self.app.delete(url, expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "unauthorized", "message": "Missing OAuth."}, + json.loads(resp.body)) + + def test_oauth_in_query_string(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer1, tests.token1) + resp = self.app.delete(oauth_req.to_url()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token1.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_invalid(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token3, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer1, tests.token3) + resp = self.app.delete(oauth_req.to_url(), + expect_errors=True) + self.assertEqual(401, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + err = json.loads(resp.body) + self.assertEqual({"error": "unauthorized", + "message": err['message']}, + err) + + def test_oauth_in_header(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer2, + tests.token2, + parameters=params, + http_url=url, + http_method='DELETE' + ) + url = oauth_req.get_normalized_http_url() + '?' + ( + '&'.join("%s=%s" % (k, v) for k, v in params.items())) + oauth_req.sign_request(tests.sign_meth_HMAC_SHA1, + tests.consumer2, tests.token2) + resp = self.app.delete(url, headers=oauth_req.to_header()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token2.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_plain_text(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.sign_request(tests.sign_meth_PLAINTEXT, + tests.consumer1, tests.token1) + resp = self.app.delete(oauth_req.to_url()) + self.assertEqual(200, resp.status) + self.assertEqual([(tests.token1.key, + '/foo/doc/doc-id', 'old_rev=old-rev')], self.got) + + def test_oauth_timestamp_threshold(self): + url = BASE_URL + '/pfx/foo/doc/doc-id' + params = {'old_rev': 'old-rev'} + oauth_req = oauth.OAuthRequest.from_consumer_and_token( + tests.consumer1, + tests.token1, + parameters=params, + http_url=url, + http_method='DELETE' + ) + oauth_req.set_parameter('oauth_timestamp', int(time.time()) - 5) + oauth_req.sign_request(tests.sign_meth_PLAINTEXT, + tests.consumer1, tests.token1) + # tweak threshold + self.oauth_midw.timestamp_threshold = 1 + resp = self.app.delete(oauth_req.to_url(), expect_errors=True) + self.assertEqual(401, resp.status) + err = json.loads(resp.body) + self.assertIn('Expired timestamp', err['message']) + self.assertIn('threshold 1', err['message']) diff --git a/src/leap/soledad/u1db/tests/test_backends.py b/src/leap/soledad/u1db/tests/test_backends.py new file mode 100644 index 00000000..7a3c9e5c --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_backends.py @@ -0,0 +1,1895 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The backend class for U1DB. This deals with hiding storage details.""" + +try: + import simplejson as json +except ImportError: + import json # noqa +from u1db import ( + DocumentBase, + errors, + tests, + vectorclock, + ) + +simple_doc = tests.simple_doc +nested_doc = tests.nested_doc + +from u1db.tests.test_remote_sync_target import ( + make_http_app, + make_oauth_http_app, +) + +from u1db.remote import ( + http_database, + ) + +try: + from u1db.tests import c_backend_wrapper +except ImportError: + c_backend_wrapper = None # noqa + + +def make_http_database_for_test(test, replica_uid, path='test'): + test.startServer() + test.request_state._create_database(replica_uid) + return http_database.HTTPDatabase(test.getURL(path)) + + +def copy_http_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + return test.request_state._copy_database(db) + + +def make_oauth_http_database_for_test(test, replica_uid): + http_db = make_http_database_for_test(test, replica_uid, '~/test') + http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return http_db + + +def copy_oauth_http_database_for_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR + # HOUSE. + http_db = test.request_state._copy_database(db) + http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return http_db + + +class TestAlternativeDocument(DocumentBase): + """A (not very) alternative implementation of Document.""" + + +class AllDatabaseTests(tests.DatabaseBaseTests, tests.TestCaseWithServer): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + [ + ('http', {'make_database_for_test': make_http_database_for_test, + 'copy_database_for_test': copy_http_database_for_test, + 'make_document_for_test': tests.make_document_for_test, + 'make_app_with_state': make_http_app}), + ('oauth_http', {'make_database_for_test': + make_oauth_http_database_for_test, + 'copy_database_for_test': + copy_oauth_http_database_for_test, + 'make_document_for_test': tests.make_document_for_test, + 'make_app_with_state': make_oauth_http_app}) + ] + tests.C_DATABASE_SCENARIOS + + def test_close(self): + self.db.close() + + def test_create_doc_allocating_doc_id(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertNotEqual(None, doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_different_ids_same_db(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertNotEqual(doc1.doc_id, doc2.doc_id) + + def test_create_doc_with_id(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id') + self.assertEqual('my-id', doc.doc_id) + self.assertNotEqual(None, doc.rev) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_create_doc_existing_id(self): + doc = self.db.create_doc_from_json(simple_doc) + new_content = '{"something": "else"}' + self.assertRaises( + errors.RevisionConflict, self.db.create_doc_from_json, + new_content, doc.doc_id) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + + def test_put_doc_creating_initial(self): + doc = self.make_document('my_doc_id', None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertIsNot(None, new_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False) + + def test_put_doc_space_in_id(self): + doc = self.make_document('my doc id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_update(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + orig_rev = doc.rev + doc.set_json('{"updated": "stuff"}') + new_rev = self.db.put_doc(doc) + self.assertNotEqual(new_rev, orig_rev) + self.assertGetDoc(self.db, 'my_doc_id', new_rev, + '{"updated": "stuff"}', False) + self.assertEqual(doc.rev, new_rev) + + def test_put_non_ascii_key(self): + content = json.dumps({u'key\xe5': u'val'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_non_ascii_value(self): + content = json.dumps({'key': u'\xe5'}) + doc = self.db.create_doc_from_json(content, doc_id='my_doc') + self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False) + + def test_put_doc_refuses_no_id(self): + doc = self.make_document(None, None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document("", None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_refuses_slashes(self): + doc = self.make_document('a/b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + doc = self.make_document(r'\b', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_doc_url_quoting_is_fine(self): + doc_id = "%2F%2Ffoo%2Fbar" + doc = self.make_document(doc_id, None, simple_doc) + new_rev = self.db.put_doc(doc) + self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False) + + def test_put_doc_refuses_non_existing_old_rev(self): + doc = self.make_document('doc-id', 'test:4', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc) + + def test_put_doc_refuses_non_ascii_doc_id(self): + doc = self.make_document('d\xc3\xa5c-id', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_put_fails_with_bad_old_rev(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + old_rev = doc.rev + bad_doc = self.make_document(doc.doc_id, 'other:1', + '{"something": "else"}') + self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc) + self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False) + + def test_create_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(new_doc.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev)) + + def test_put_succeeds_after_delete(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True) + deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev) + doc2 = self.make_document('my_doc_id', None, simple_doc) + self.db.put_doc(doc2) + self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False) + new_vc = vectorclock.VectorClockRev(doc2.rev) + self.assertTrue( + new_vc.is_newer(deleted_vc), + "%s does not supersede %s" % (doc2.rev, deleted_doc.rev)) + + def test_get_doc_after_put(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False) + + def test_get_doc_nonexisting(self): + self.assertIs(None, self.db.get_doc('non-existing')) + + def test_get_doc_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertIs(None, self.db.get_doc('my_doc_id')) + + def test_get_doc_include_deleted(self): + doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id') + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_get_docs(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual([doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + + def test_get_docs_include_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc1) + self.assertEqual( + [doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id], + include_deleted=True))) + + def test_get_docs_request_ordered(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual([doc1, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id]))) + self.assertEqual([doc2, doc1], + list(self.db.get_docs([doc2.doc_id, doc1.doc_id]))) + + def test_get_docs_empty_list(self): + self.assertEqual([], list(self.db.get_docs([]))) + + def test_handles_nested_content(self): + doc = self.db.create_doc_from_json(nested_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + def test_handles_doc_with_null(self): + doc = self.db.create_doc_from_json('{"key": null}') + self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False) + + def test_delete_doc(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + orig_rev = doc.rev + self.db.delete_doc(doc) + self.assertNotEqual(orig_rev, doc.rev) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + self.assertIs(None, self.db.get_doc(doc.doc_id)) + + def test_delete_doc_non_existent(self): + doc = self.make_document('non-existing', 'other:1', simple_doc) + self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc) + + def test_delete_doc_already_deleted(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertRaises(errors.DocumentAlreadyDeleted, + self.db.delete_doc, doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + + def test_delete_doc_bad_rev(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc) + self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + + def test_delete_doc_sets_content_to_None(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertIs(None, doc.get_json()) + + def test_delete_doc_rev_supersedes(self): + doc = self.db.create_doc_from_json(simple_doc) + doc.set_json(nested_doc) + self.db.put_doc(doc) + doc.set_json('{"fishy": "content"}') + self.db.put_doc(doc) + old_rev = doc.rev + self.db.delete_doc(doc) + cur_vc = vectorclock.VectorClockRev(old_rev) + deleted_vc = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(deleted_vc.is_newer(cur_vc), + "%s does not supersede %s" % (doc.rev, old_rev)) + + def test_delete_then_put(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, doc.rev, None, False) + doc.set_json(nested_doc) + self.db.put_doc(doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False) + + +class DocumentSizeTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_put_doc_refuses_oversized_documents(self): + self.db.set_document_size_limit(1) + doc = self.make_document('doc-id', None, simple_doc) + self.assertRaises(errors.DocumentTooBig, self.db.put_doc, doc) + + def test_create_doc_refuses_oversized_documents(self): + self.db.set_document_size_limit(1) + self.assertRaises( + errors.DocumentTooBig, self.db.create_doc_from_json, simple_doc, + doc_id='my_doc_id') + + def test_set_document_size_limit_zero(self): + self.db.set_document_size_limit(0) + self.assertEqual(0, self.db.document_size_limit) + + def test_set_document_size_limit(self): + self.db.set_document_size_limit(1000000) + self.assertEqual(1000000, self.db.document_size_limit) + + +class LocalDatabaseTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_create_doc_different_ids_diff_db(self): + doc1 = self.db.create_doc_from_json(simple_doc) + db2 = self.create_database('other-uid') + doc2 = db2.create_doc_from_json(simple_doc) + self.assertNotEqual(doc1.doc_id, doc2.doc_id) + + def test_put_doc_refuses_slashes_picky(self): + doc = self.make_document('/a', None, simple_doc) + self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc) + + def test_get_all_docs_empty(self): + self.assertEqual([], list(self.db.get_all_docs()[1])) + + def test_get_all_docs(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual( + sorted([doc1, doc2]), sorted(list(self.db.get_all_docs()[1]))) + + def test_get_all_docs_exclude_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc2) + self.assertEqual([doc1], list(self.db.get_all_docs()[1])) + + def test_get_all_docs_include_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.db.delete_doc(doc2) + self.assertEqual( + sorted([doc1, doc2]), + sorted(list(self.db.get_all_docs(include_deleted=True)[1]))) + + def test_get_all_docs_generation(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_doc_from_json(nested_doc) + self.assertEqual(2, self.db.get_all_docs()[0]) + + def test_simple_put_doc_if_newer(self): + doc = self.make_document('my-doc-id', 'test:1', simple_doc) + state_at_gen = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(('inserted', 1), state_at_gen) + self.assertGetDoc(self.db, 'my-doc-id', 'test:1', simple_doc, False) + + def test_simple_put_doc_if_newer_deleted(self): + self.db.create_doc_from_json('{}', doc_id='my-doc-id') + doc = self.make_document('my-doc-id', 'test:2', None) + state_at_gen = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(('inserted', 2), state_at_gen) + self.assertGetDocIncludeDeleted( + self.db, 'my-doc-id', 'test:2', None, False) + + def test_put_doc_if_newer_already_superseded(self): + orig_doc = '{"new": "doc"}' + doc1 = self.db.create_doc_from_json(orig_doc) + doc1_rev1 = doc1.rev + doc1.set_json(simple_doc) + self.db.put_doc(doc1) + doc1_rev2 = doc1.rev + # Nothing is inserted, because the document is already superseded + doc = self.make_document(doc1.doc_id, doc1_rev1, orig_doc) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('superseded', state) + self.assertGetDoc(self.db, doc1.doc_id, doc1_rev2, simple_doc, False) + + def test_put_doc_if_newer_autoresolve(self): + doc1 = self.db.create_doc_from_json(simple_doc) + rev = doc1.rev + doc = self.make_document(doc1.doc_id, "whatever:1", doc1.get_json()) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('superseded', state) + doc2 = self.db.get_doc(doc1.doc_id) + v2 = vectorclock.VectorClockRev(doc2.rev) + self.assertTrue(v2.is_newer(vectorclock.VectorClockRev("whatever:1"))) + self.assertTrue(v2.is_newer(vectorclock.VectorClockRev(rev))) + # strictly newer locally + self.assertTrue(rev not in doc2.rev) + + def test_put_doc_if_newer_already_converged(self): + orig_doc = '{"new": "doc"}' + doc1 = self.db.create_doc_from_json(orig_doc) + state_at_gen = self.db._put_doc_if_newer( + doc1, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(('converged', 1), state_at_gen) + + def test_put_doc_if_newer_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + # Nothing is inserted, the document id is returned as would-conflict + alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + state, _ = self.db._put_doc_if_newer( + alt_doc, save_conflict=False, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('conflicted', state) + # The database wasn't altered + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + + def test_put_doc_if_newer_newer_generation(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc = self.make_document('doc_id', 'other:2', simple_doc) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='other', replica_gen=2, + replica_trans_id='T-irrelevant') + self.assertEqual('inserted', state) + + def test_put_doc_if_newer_same_generation_same_txid(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc = self.db.create_doc_from_json(simple_doc) + self.make_document(doc.doc_id, 'other:1', simple_doc) + state, _ = self.db._put_doc_if_newer( + doc, save_conflict=False, replica_uid='other', replica_gen=1, + replica_trans_id='T-sid') + self.assertEqual('converged', state) + + def test_put_doc_if_newer_wrong_transaction_id(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc = self.make_document('doc_id', 'other:1', simple_doc) + self.assertRaises( + errors.InvalidTransactionId, + self.db._put_doc_if_newer, doc, save_conflict=False, + replica_uid='other', replica_gen=1, replica_trans_id='T-sad') + + def test_put_doc_if_newer_old_generation_older_doc(self): + orig_doc = '{"new": "doc"}' + doc = self.db.create_doc_from_json(orig_doc) + doc_rev1 = doc.rev + doc.set_json(simple_doc) + self.db.put_doc(doc) + self.db._set_replica_gen_and_trans_id('other', 3, 'T-sid') + older_doc = self.make_document(doc.doc_id, doc_rev1, simple_doc) + state, _ = self.db._put_doc_if_newer( + older_doc, save_conflict=False, replica_uid='other', replica_gen=8, + replica_trans_id='T-irrelevant') + self.assertEqual('superseded', state) + + def test_put_doc_if_newer_old_generation_newer_doc(self): + self.db._set_replica_gen_and_trans_id('other', 5, 'T-sid') + doc = self.make_document('doc_id', 'other:1', simple_doc) + self.assertRaises( + errors.InvalidGeneration, + self.db._put_doc_if_newer, doc, save_conflict=False, + replica_uid='other', replica_gen=1, replica_trans_id='T-sad') + + def test_put_doc_if_newer_replica_uid(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1', + nested_doc) + self.assertEqual('inserted', + self.db._put_doc_if_newer(doc2, save_conflict=False, + replica_uid='other', replica_gen=2, + replica_trans_id='T-id2')[0]) + self.assertEqual((2, 'T-id2'), self.db._get_replica_gen_and_trans_id( + 'other')) + # Compare to the old rev, should be superseded + doc2 = self.make_document(doc1.doc_id, doc1.rev, nested_doc) + self.assertEqual('superseded', + self.db._put_doc_if_newer(doc2, save_conflict=False, + replica_uid='other', replica_gen=3, + replica_trans_id='T-id3')[0]) + self.assertEqual( + (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other')) + # A conflict that isn't saved still records the sync gen, because we + # don't need to see it again + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|fourth:1', + '{}') + self.assertEqual('conflicted', + self.db._put_doc_if_newer(doc2, save_conflict=False, + replica_uid='other', replica_gen=4, + replica_trans_id='T-id4')[0]) + self.assertEqual( + (4, 'T-id4'), self.db._get_replica_gen_and_trans_id('other')) + + def test__get_replica_gen_and_trans_id(self): + self.assertEqual( + (0, ''), self.db._get_replica_gen_and_trans_id('other-db')) + self.db._set_replica_gen_and_trans_id('other-db', 2, 'T-transaction') + self.assertEqual( + (2, 'T-transaction'), + self.db._get_replica_gen_and_trans_id('other-db')) + + def test_put_updates_transaction_log(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + doc.set_json('{"something": "else"}') + self.db.put_doc(doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), + self.db.whats_changed()) + + def test_delete_updates_transaction_log(self): + doc = self.db.create_doc_from_json(simple_doc) + db_gen, _, _ = self.db.whats_changed() + self.db.delete_doc(doc) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), + self.db.whats_changed(db_gen)) + + def test_whats_changed_initial_database(self): + self.assertEqual((0, '', []), self.db.whats_changed()) + + def test_whats_changed_returns_one_id_for_multiple_changes(self): + doc = self.db.create_doc_from_json(simple_doc) + doc.set_json('{"new": "contents"}') + self.db.put_doc(doc) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]), + self.db.whats_changed()) + self.assertEqual((2, last_trans_id, []), self.db.whats_changed(2)) + + def test_whats_changed_returns_last_edits_ascending(self): + doc = self.db.create_doc_from_json(simple_doc) + doc1 = self.db.create_doc_from_json(simple_doc) + doc.set_json('{"new": "contents"}') + self.db.delete_doc(doc1) + delete_trans_id = self.getLastTransId(self.db) + self.db.put_doc(doc) + put_trans_id = self.getLastTransId(self.db) + self.assertEqual((4, put_trans_id, + [(doc1.doc_id, 3, delete_trans_id), + (doc.doc_id, 4, put_trans_id)]), + self.db.whats_changed()) + + def test_whats_changed_doesnt_include_old_gen(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual((3, last_trans_id, [(doc2.doc_id, 3, last_trans_id)]), + self.db.whats_changed(2)) + + +class LocalDatabaseValidateGenNTransIdTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_validate_gen_and_trans_id(self): + self.db.create_doc_from_json(simple_doc) + gen, trans_id = self.db._get_generation_info() + self.db.validate_gen_and_trans_id(gen, trans_id) + + def test_validate_gen_and_trans_id_invalid_txid(self): + self.db.create_doc_from_json(simple_doc) + gen, _ = self.db._get_generation_info() + self.assertRaises( + errors.InvalidTransactionId, + self.db.validate_gen_and_trans_id, gen, 'wrong') + + def test_validate_gen_and_trans_id_invalid_gen(self): + self.db.create_doc_from_json(simple_doc) + gen, trans_id = self.db._get_generation_info() + self.assertRaises( + errors.InvalidGeneration, + self.db.validate_gen_and_trans_id, gen + 1, trans_id) + + +class LocalDatabaseValidateSourceGenTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_validate_source_gen_and_trans_id_same(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + self.db._validate_source('other', 1, 'T-sid') + + def test_validate_source_gen_newer(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + self.db._validate_source('other', 2, 'T-whatevs') + + def test_validate_source_wrong_txid(self): + self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid') + self.assertRaises( + errors.InvalidTransactionId, + self.db._validate_source, 'other', 1, 'T-sad') + + +class LocalDatabaseWithConflictsTests(tests.DatabaseBaseTests): + # test supporting/functionality around storing conflicts + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def test_get_docs_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual([doc2], list(self.db.get_docs([doc1.doc_id]))) + + def test_get_docs_conflicts_ignored(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + no_conflict_doc = self.make_document(doc1.doc_id, 'alternate:1', + nested_doc) + self.assertEqual([no_conflict_doc, doc2], + list(self.db.get_docs([doc1.doc_id, doc2.doc_id], + check_for_conflicts=False))) + + def test_get_doc_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual([alt_doc, doc], + self.db.get_doc_conflicts(doc.doc_id)) + + def test_get_all_docs_sees_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + _, docs = self.db.get_all_docs() + self.assertTrue(list(docs)[0].has_conflicts) + + def test_get_doc_conflicts_unconflicted(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertEqual([], self.db.get_doc_conflicts(doc.doc_id)) + + def test_get_doc_conflicts_no_such_id(self): + self.assertEqual([], self.db.get_doc_conflicts('doc-id')) + + def test_resolve_doc(self): + doc = self.db.create_doc_from_json(simple_doc) + alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc.doc_id, + [('alternate:1', nested_doc), (doc.rev, simple_doc)]) + orig_rev = doc.rev + self.db.resolve_doc(doc, [alt_doc.rev, doc.rev]) + self.assertNotEqual(orig_rev, doc.rev) + self.assertFalse(doc.has_conflicts) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + self.assertGetDocConflicts(self.db, doc.doc_id, []) + + def test_resolve_doc_picks_biggest_vcr(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, simple_doc)]) + orig_doc1_rev = doc1.rev + self.db.resolve_doc(doc1, [doc2.rev, doc1.rev]) + self.assertFalse(doc1.has_conflicts) + self.assertNotEqual(orig_doc1_rev, doc1.rev) + self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + vcr_1 = vectorclock.VectorClockRev(orig_doc1_rev) + vcr_2 = vectorclock.VectorClockRev(doc2.rev) + vcr_new = vectorclock.VectorClockRev(doc1.rev) + self.assertTrue(vcr_new.is_newer(vcr_1)) + self.assertTrue(vcr_new.is_newer(vcr_2)) + + def test_resolve_doc_partial_not_winning(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, simple_doc)]) + content3 = '{"key": "valin3"}' + doc3 = self.make_document(doc1.doc_id, 'third:1', content3) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='bar') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc3.rev, content3), + (doc1.rev, simple_doc), + (doc2.rev, nested_doc)]) + self.db.resolve_doc(doc1, [doc2.rev, doc1.rev]) + self.assertTrue(doc1.has_conflicts) + self.assertGetDoc(self.db, doc1.doc_id, doc3.rev, content3, True) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc3.rev, content3), + (doc1.rev, simple_doc)]) + + def test_resolve_doc_partial_winning(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + content3 = '{"key": "valin3"}' + doc3 = self.make_document(doc1.doc_id, 'third:1', content3) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='bar') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc3.rev, content3), + (doc1.rev, simple_doc), + (doc2.rev, nested_doc)]) + self.db.resolve_doc(doc1, [doc3.rev, doc1.rev]) + self.assertTrue(doc1.has_conflicts) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc1.rev, simple_doc), + (doc2.rev, nested_doc)]) + + def test_resolve_doc_with_delete_conflict(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc1) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, None)]) + self.db.resolve_doc(doc2, [doc1.rev, doc2.rev]) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, False) + + def test_resolve_doc_with_delete_to_delete(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc1) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [(doc2.rev, nested_doc), + (doc1.rev, None)]) + self.db.resolve_doc(doc1, [doc1.rev, doc2.rev]) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + self.assertGetDocIncludeDeleted( + self.db, doc1.doc_id, doc1.rev, None, False) + + def test_put_doc_if_newer_save_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + # Document is inserted as a conflict + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + state, _ = self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual('conflicted', state) + # The database was updated + self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, nested_doc, True) + + def test_force_doc_conflict_supersedes_properly(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', '{"b": 1}') + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + doc3 = self.make_document(doc1.doc_id, 'altalt:1', '{"c": 1}') + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='bar') + doc22 = self.make_document(doc1.doc_id, 'alternate:2', '{"b": 2}') + self.db._put_doc_if_newer( + doc22, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='zed') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:2', doc22.get_json()), + ('altalt:1', doc3.get_json()), + (doc1.rev, simple_doc)]) + + def test_put_doc_if_newer_save_conflict_was_deleted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc1) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertTrue(doc2.has_conflicts) + self.assertGetDoc( + self.db, doc1.doc_id, 'alternate:1', nested_doc, True) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:1', nested_doc), (doc1.rev, None)]) + + def test_put_doc_if_newer_propagates_full_resolution(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + resolved_vcr = vectorclock.VectorClockRev(doc1.rev) + vcr_2 = vectorclock.VectorClockRev(doc2.rev) + resolved_vcr.maximize(vcr_2) + resolved_vcr.increment('alternate') + doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(), + '{"good": 1}') + state, _ = self.db._put_doc_if_newer( + doc_resolved, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual('inserted', state) + self.assertFalse(doc_resolved.has_conflicts) + self.assertGetDocConflicts(self.db, doc1.doc_id, []) + doc3 = self.db.get_doc(doc1.doc_id) + self.assertFalse(doc3.has_conflicts) + + def test_put_doc_if_newer_propagates_partial_resolution(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'altalt:1', '{}') + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + doc3 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc3, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:1', nested_doc), ('test:1', simple_doc), + ('altalt:1', '{}')]) + resolved_vcr = vectorclock.VectorClockRev(doc1.rev) + vcr_3 = vectorclock.VectorClockRev(doc3.rev) + resolved_vcr.maximize(vcr_3) + resolved_vcr.increment('alternate') + doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(), + '{"good": 1}') + state, _ = self.db._put_doc_if_newer( + doc_resolved, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual('inserted', state) + self.assertTrue(doc_resolved.has_conflicts) + doc4 = self.db.get_doc(doc1.doc_id) + self.assertTrue(doc4.has_conflicts) + self.assertGetDocConflicts(self.db, doc1.doc_id, + [('alternate:2|test:1', '{"good": 1}'), ('altalt:1', '{}')]) + + def test_put_doc_if_newer_replica_uid(self): + doc1 = self.db.create_doc_from_json(simple_doc) + self.db._set_replica_gen_and_trans_id('other', 1, 'T-id') + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1', + nested_doc) + self.db._put_doc_if_newer(doc2, save_conflict=True, + replica_uid='other', replica_gen=2, + replica_trans_id='T-id2') + # Conflict vs the current update + doc2 = self.make_document(doc1.doc_id, doc1.rev + '|third:3', + '{}') + self.assertEqual('conflicted', + self.db._put_doc_if_newer(doc2, save_conflict=True, + replica_uid='other', replica_gen=3, + replica_trans_id='T-id3')[0]) + self.assertEqual( + (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other')) + + def test_put_doc_if_newer_autoresolve_2(self): + # this is an ordering variant of _3, but that already works + # adding the test explicitly to catch the regression easily + doc_a1 = self.db.create_doc_from_json(simple_doc) + doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', "{}") + doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', + '{"a":"42"}') + doc_a3 = self.make_document(doc_a1.doc_id, 'test:2|other:1', "{}") + state, _ = self.db._put_doc_if_newer( + doc_a2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(state, 'inserted') + state, _ = self.db._put_doc_if_newer( + doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual(state, 'conflicted') + state, _ = self.db._put_doc_if_newer( + doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual(state, 'inserted') + self.assertFalse(self.db.get_doc(doc_a1.doc_id).has_conflicts) + + def test_put_doc_if_newer_autoresolve_3(self): + doc_a1 = self.db.create_doc_from_json(simple_doc) + doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', "{}") + doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}') + doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', "{}") + state, _ = self.db._put_doc_if_newer( + doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(state, 'inserted') + state, _ = self.db._put_doc_if_newer( + doc_a2, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual(state, 'conflicted') + state, _ = self.db._put_doc_if_newer( + doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual(state, 'superseded') + doc = self.db.get_doc(doc_a1.doc_id, True) + self.assertFalse(doc.has_conflicts) + rev = vectorclock.VectorClockRev(doc.rev) + rev_a3 = vectorclock.VectorClockRev('test:3') + rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1') + self.assertTrue(rev.is_newer(rev_a3)) + self.assertTrue('test:4' in doc.rev) # locally increased + self.assertTrue(rev.is_newer(rev_a1b1)) + + def test_put_doc_if_newer_autoresolve_4(self): + doc_a1 = self.db.create_doc_from_json(simple_doc) + doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', None) + doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}') + doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', None) + state, _ = self.db._put_doc_if_newer( + doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertEqual(state, 'inserted') + state, _ = self.db._put_doc_if_newer( + doc_a2, save_conflict=True, replica_uid='r', replica_gen=2, + replica_trans_id='foo2') + self.assertEqual(state, 'conflicted') + state, _ = self.db._put_doc_if_newer( + doc_a3, save_conflict=True, replica_uid='r', replica_gen=3, + replica_trans_id='foo3') + self.assertEqual(state, 'superseded') + doc = self.db.get_doc(doc_a1.doc_id, True) + self.assertFalse(doc.has_conflicts) + rev = vectorclock.VectorClockRev(doc.rev) + rev_a3 = vectorclock.VectorClockRev('test:3') + rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1') + self.assertTrue(rev.is_newer(rev_a3)) + self.assertTrue('test:4' in doc.rev) # locally increased + self.assertTrue(rev.is_newer(rev_a1b1)) + + def test_put_refuses_to_update_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + content2 = '{"key": "altval"}' + doc2 = self.make_document(doc1.doc_id, 'altrev:1', content2) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, content2, True) + content3 = '{"key": "local"}' + doc2.set_json(content3) + self.assertRaises(errors.ConflictedDoc, self.db.put_doc, doc2) + + def test_delete_refuses_for_conflicted(self): + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.make_document(doc1.doc_id, 'altrev:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, True) + self.assertRaises(errors.ConflictedDoc, self.db.delete_doc, doc2) + + +class DatabaseIndexTests(tests.DatabaseBaseTests): + + scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS + + def assertParseError(self, definition): + self.db.create_doc_from_json(nested_doc) + self.assertRaises( + errors.IndexDefinitionParseError, self.db.create_index, 'idx', + definition) + + def assertIndexCreatable(self, definition): + name = "idx" + self.db.create_doc_from_json(nested_doc) + self.db.create_index(name, definition) + self.assertEqual( + [(name, [definition])], self.db.list_indexes()) + + def test_create_index(self): + self.db.create_index('test-idx', 'name') + self.assertEqual([('test-idx', ['name'])], + self.db.list_indexes()) + + def test_create_index_on_non_ascii_field_name(self): + doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'})) + self.db.create_index('test-idx', u'\xe5') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_list_indexes_with_non_ascii_field_names(self): + self.db.create_index('test-idx', u'\xe5') + self.assertEqual( + [('test-idx', [u'\xe5'])], self.db.list_indexes()) + + def test_create_index_evaluates_it(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_wildcard_matches_unicode_value(self): + doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) + + def test_retrieve_unicode_value_from_index(self): + doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', u"valu\xe5")) + + def test_create_index_fails_if_name_taken(self): + self.db.create_index('test-idx', 'key') + self.assertRaises(errors.IndexNameTakenError, + self.db.create_index, + 'test-idx', 'stuff') + + def test_create_index_does_not_fail_if_name_taken_with_same_index(self): + self.db.create_index('test-idx', 'key') + self.db.create_index('test-idx', 'key') + self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) + + def test_create_index_does_not_duplicate_indexed_fields(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.db.delete_index('test-idx') + self.db.create_index('test-idx', 'key') + self.assertEqual(1, len(self.db.get_from_index('test-idx', 'value'))) + + def test_delete_index_does_not_remove_fields_from_other_indexes(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.db.create_index('test-idx2', 'key') + self.db.delete_index('test-idx') + self.assertEqual(1, len(self.db.get_from_index('test-idx2', 'value'))) + + def test_create_index_after_deleting_document(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc2) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_delete_index(self): + self.db.create_index('test-idx', 'key') + self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) + self.db.delete_index('test-idx') + self.assertEqual([], self.db.list_indexes()) + + def test_create_adds_to_index(self): + self.db.create_index('test-idx', 'key') + doc = self.db.create_doc_from_json(simple_doc) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index_unmatched(self): + self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual([], self.db.get_from_index('test-idx', 'novalue')) + + def test_create_index_multiple_exact_matches(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual( + sorted([doc, doc2]), + sorted(self.db.get_from_index('test-idx', 'value'))) + + def test_get_from_index(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index_multi(self): + content = '{"key": "value", "key2": "value2"}' + doc = self.db.create_doc_from_json(content) + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2')) + + def test_get_from_index_multi_list(self): + doc = self.db.create_doc_from_json( + '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2-1')) + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2-2')) + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'value', 'value2-3')) + self.assertEqual( + [('value', 'value2-1'), ('value', 'value2-2'), + ('value', 'value2-3')], + sorted(self.db.get_index_keys('test-idx'))) + + def test_get_from_index_sees_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key', 'key2') + alt_doc = self.make_document( + doc.doc_id, 'alternate:1', + '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}') + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + docs = self.db.get_from_index('test-idx', 'value', 'value2-1') + self.assertTrue(docs[0].has_conflicts) + + def test_get_index_keys_multi_list_list(self): + self.db.create_doc_from_json( + '{"key": "value1-1 value1-2 value1-3", ' + '"key2": ["value2-1", "value2-2", "value2-3"]}') + self.db.create_index('test-idx', 'split_words(key)', 'key2') + self.assertEqual( + [(u'value1-1', u'value2-1'), (u'value1-1', u'value2-2'), + (u'value1-1', u'value2-3'), (u'value1-2', u'value2-1'), + (u'value1-2', u'value2-2'), (u'value1-2', u'value2-3'), + (u'value1-3', u'value2-1'), (u'value1-3', u'value2-2'), + (u'value1-3', u'value2-3')], + sorted(self.db.get_index_keys('test-idx'))) + + def test_get_from_index_multi_ordered(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2, doc1], + self.db.get_from_index('test-idx', 'v*', '*')) + + def test_get_range_from_index_start_end(self): + doc1 = self.db.create_doc_from_json('{"key": "value3"}') + doc2 = self.db.create_doc_from_json('{"key": "value2"}') + self.db.create_doc_from_json('{"key": "value4"}') + self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc2, doc1], + self.db.get_range_from_index('test-idx', 'value2', 'value3')) + + def test_get_range_from_index_start(self): + doc1 = self.db.create_doc_from_json('{"key": "value3"}') + doc2 = self.db.create_doc_from_json('{"key": "value2"}') + doc3 = self.db.create_doc_from_json('{"key": "value4"}') + self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc2, doc1, doc3], + self.db.get_range_from_index('test-idx', 'value2')) + + def test_get_range_from_index_sees_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + alt_doc = self.make_document( + doc.doc_id, 'alternate:1', '{"key": "valuedepalue"}') + self.db._put_doc_if_newer( + alt_doc, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + docs = self.db.get_range_from_index('test-idx', 'a') + self.assertTrue(docs[0].has_conflicts) + + def test_get_range_from_index_end(self): + self.db.create_doc_from_json('{"key": "value3"}') + doc2 = self.db.create_doc_from_json('{"key": "value2"}') + self.db.create_doc_from_json('{"key": "value4"}') + doc4 = self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc4, doc2], + self.db.get_range_from_index('test-idx', None, 'value2')) + + def test_get_wildcard_range_from_index_start(self): + doc1 = self.db.create_doc_from_json('{"key": "value4"}') + doc2 = self.db.create_doc_from_json('{"key": "value23"}') + doc3 = self.db.create_doc_from_json('{"key": "value2"}') + doc4 = self.db.create_doc_from_json('{"key": "value22"}') + self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc3, doc4, doc2, doc1], + self.db.get_range_from_index('test-idx', 'value2*')) + + def test_get_wildcard_range_from_index_end(self): + self.db.create_doc_from_json('{"key": "value4"}') + doc2 = self.db.create_doc_from_json('{"key": "value23"}') + doc3 = self.db.create_doc_from_json('{"key": "value2"}') + doc4 = self.db.create_doc_from_json('{"key": "value22"}') + doc5 = self.db.create_doc_from_json('{"key": "value1"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc5, doc3, doc4, doc2], + self.db.get_range_from_index('test-idx', None, 'value2*')) + + def test_get_wildcard_range_from_index_start_end(self): + self.db.create_doc_from_json('{"key": "a"}') + self.db.create_doc_from_json('{"key": "boo3"}') + doc3 = self.db.create_doc_from_json('{"key": "catalyst"}') + doc4 = self.db.create_doc_from_json('{"key": "whaever"}') + self.db.create_doc_from_json('{"key": "zerg"}') + self.db.create_index('test-idx', 'key') + self.assertEqual( + [doc3, doc4], + self.db.get_range_from_index('test-idx', 'cat*', 'zap*')) + + def test_get_range_from_index_multi_column_start_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc3, doc2], + self.db.get_range_from_index( + 'test-idx', ('value2', 'value2'), ('value2', 'value3'))) + + def test_get_range_from_index_multi_column_start(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + self.db.create_doc_from_json('{"key": "value2", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc2, doc1], + self.db.get_range_from_index('test-idx', ('value2', 'value3'))) + + def test_get_range_from_index_multi_column_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2], + self.db.get_range_from_index( + 'test-idx', None, ('value2', 'value3'))) + + def test_get_wildcard_range_from_index_multi_column_start(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc3, doc2, doc1], + self.db.get_range_from_index('test-idx', ('value2', 'value2*'))) + + def test_get_wildcard_range_from_index_multi_column_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2], + self.db.get_range_from_index( + 'test-idx', None, ('value2', 'value2*'))) + + def test_get_glob_range_from_index_multi_column_start(self): + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value2"}') + self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc2, doc1], + self.db.get_range_from_index('test-idx', ('value2', '*'))) + + def test_get_glob_range_from_index_multi_column_end(self): + self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value23"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2], + self.db.get_range_from_index('test-idx', None, ('value2', '*'))) + + def test_get_range_from_index_illegal_wildcard_order(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', ('*', 'v2')) + + def test_get_range_from_index_illegal_glob_after_wildcard(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', ('*', 'v*')) + + def test_get_range_from_index_illegal_wildcard_order_end(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', None, ('*', 'v2')) + + def test_get_range_from_index_illegal_glob_after_wildcard_end(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_range_from_index, 'test-idx', None, ('*', 'v*')) + + def test_get_from_index_fails_if_no_index(self): + self.assertRaises( + errors.IndexDoesNotExist, self.db.get_from_index, 'foo') + + def test_get_index_keys_fails_if_no_index(self): + self.assertRaises(errors.IndexDoesNotExist, + self.db.get_index_keys, + 'foo') + + def test_get_index_keys_works_if_no_docs(self): + self.db.create_index('test-idx', 'key') + self.assertEqual([], self.db.get_index_keys('test-idx')) + + def test_put_updates_index(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + new_content = '{"key": "altval"}' + doc.set_json(new_content) + self.db.put_doc(doc) + self.assertEqual([], self.db.get_from_index('test-idx', 'value')) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'altval')) + + def test_delete_updates_index(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(simple_doc) + self.db.create_index('test-idx', 'key') + self.assertEqual( + sorted([doc, doc2]), + sorted(self.db.get_from_index('test-idx', 'value'))) + self.db.delete_doc(doc) + self.assertEqual([doc2], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index_illegal_number_of_entries(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidValueForIndex, self.db.get_from_index, 'test-idx') + self.assertRaises( + errors.InvalidValueForIndex, + self.db.get_from_index, 'test-idx', 'v1') + self.assertRaises( + errors.InvalidValueForIndex, + self.db.get_from_index, 'test-idx', 'v1', 'v2', 'v3') + + def test_get_from_index_illegal_wildcard_order(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', '*', 'v2') + + def test_get_from_index_illegal_glob_after_wildcard(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', '*', 'v*') + + def test_get_all_from_index(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + # This one should not be in the index + self.db.create_doc_from_json('{"no": "key"}') + diff_value_doc = '{"key": "diff value"}' + doc4 = self.db.create_doc_from_json(diff_value_doc) + # This is essentially a 'prefix' match, but we match every entry. + self.assertEqual( + sorted([doc1, doc2, doc4]), + sorted(self.db.get_from_index('test-idx', '*'))) + + def test_get_all_from_index_ordered(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json('{"key": "value x"}') + doc2 = self.db.create_doc_from_json('{"key": "value b"}') + doc3 = self.db.create_doc_from_json('{"key": "value a"}') + doc4 = self.db.create_doc_from_json('{"key": "value m"}') + # This is essentially a 'prefix' match, but we match every entry. + self.assertEqual( + [doc3, doc2, doc4, doc1], self.db.get_from_index('test-idx', '*')) + + def test_put_updates_when_adding_key(self): + doc = self.db.create_doc_from_json("{}") + self.db.create_index('test-idx', 'key') + self.assertEqual([], self.db.get_from_index('test-idx', '*')) + doc.set_json(simple_doc) + self.db.put_doc(doc) + self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) + + def test_get_from_index_empty_string(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + content2 = '{"key": ""}' + doc2 = self.db.create_doc_from_json(content2) + self.assertEqual([doc2], self.db.get_from_index('test-idx', '')) + # Empty string matches the wildcard. + self.assertEqual( + sorted([doc1, doc2]), + sorted(self.db.get_from_index('test-idx', '*'))) + + def test_get_from_index_not_null(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + self.db.create_doc_from_json('{"key": null}') + self.assertEqual([doc1], self.db.get_from_index('test-idx', '*')) + + def test_get_partial_from_index(self): + content1 = '{"k1": "v1", "k2": "v2"}' + content2 = '{"k1": "v1", "k2": "x2"}' + content3 = '{"k1": "v1", "k2": "y2"}' + # doc4 has a different k1 value, so it doesn't match the prefix. + content4 = '{"k1": "NN", "k2": "v2"}' + doc1 = self.db.create_doc_from_json(content1) + doc2 = self.db.create_doc_from_json(content2) + doc3 = self.db.create_doc_from_json(content3) + self.db.create_doc_from_json(content4) + self.db.create_index('test-idx', 'k1', 'k2') + self.assertEqual( + sorted([doc1, doc2, doc3]), + sorted(self.db.get_from_index('test-idx', "v1", "*"))) + + def test_get_glob_match(self): + # Note: the exact glob syntax is probably subject to change + content1 = '{"k1": "v1", "k2": "v1"}' + content2 = '{"k1": "v1", "k2": "v2"}' + content3 = '{"k1": "v1", "k2": "v3"}' + # doc4 has a different k2 prefix value, so it doesn't match + content4 = '{"k1": "v1", "k2": "ZZ"}' + self.db.create_index('test-idx', 'k1', 'k2') + doc1 = self.db.create_doc_from_json(content1) + doc2 = self.db.create_doc_from_json(content2) + doc3 = self.db.create_doc_from_json(content3) + self.db.create_doc_from_json(content4) + self.assertEqual( + sorted([doc1, doc2, doc3]), + sorted(self.db.get_from_index('test-idx', "v1", "v*"))) + + def test_nested_index(self): + doc = self.db.create_doc_from_json(nested_doc) + self.db.create_index('test-idx', 'sub.doc') + self.assertEqual( + [doc], self.db.get_from_index('test-idx', 'underneath')) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertEqual( + sorted([doc, doc2]), + sorted(self.db.get_from_index('test-idx', 'underneath'))) + + def test_nested_nonexistent(self): + self.db.create_doc_from_json(nested_doc) + # sub exists, but sub.foo does not: + self.db.create_index('test-idx', 'sub.foo') + self.assertEqual([], self.db.get_from_index('test-idx', '*')) + + def test_nested_nonexistent2(self): + self.db.create_doc_from_json(nested_doc) + self.db.create_index('test-idx', 'sub.foo.bar.baz.qux.fnord') + self.assertEqual([], self.db.get_from_index('test-idx', '*')) + + def test_nested_traverses_lists(self): + # subpath finds dicts in list + doc = self.db.create_doc_from_json( + '{"foo": [{"zap": "bar"}, {"zap": "baz"}]}') + # subpath only finds dicts in list + self.db.create_doc_from_json('{"foo": ["zap", "baz"]}') + self.db.create_index('test-idx', 'foo.zap') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'bar')) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'baz')) + + def test_nested_list_traversal(self): + # subpath finds dicts in list + doc = self.db.create_doc_from_json( + '{"foo": [{"zap": [{"qux": "fnord"}, {"qux": "zombo"}]},' + '{"zap": "baz"}]}') + # subpath only finds dicts in list + self.db.create_index('test-idx', 'foo.zap.qux') + self.assertEqual([doc], self.db.get_from_index('test-idx', 'fnord')) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'zombo')) + + def test_index_list1(self): + self.db.create_index("index", "name") + content = '{"name": ["foo", "bar"]}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_index_list2(self): + self.db.create_index("index", "name") + content = '{"name": ["foo", "bar"]}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_get_from_index_case_sensitive(self): + self.db.create_index('test-idx', 'key') + doc1 = self.db.create_doc_from_json(simple_doc) + self.assertEqual([], self.db.get_from_index('test-idx', 'V*')) + self.assertEqual([doc1], self.db.get_from_index('test-idx', 'v*')) + + def test_get_from_index_illegal_glob_before_value(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', 'v*', 'v2') + + def test_get_from_index_illegal_glob_after_glob(self): + self.db.create_index('test-idx', 'k1', 'k2') + self.assertRaises( + errors.InvalidGlobbing, + self.db.get_from_index, 'test-idx', 'v*', 'v*') + + def test_get_from_index_with_sql_wildcards(self): + self.db.create_index('test-idx', 'key') + content1 = '{"key": "va%lue"}' + content2 = '{"key": "value"}' + content3 = '{"key": "va_lue"}' + doc1 = self.db.create_doc_from_json(content1) + self.db.create_doc_from_json(content2) + doc3 = self.db.create_doc_from_json(content3) + # The '%' in the search should be treated literally, not as a sql + # globbing character. + self.assertEqual([doc1], self.db.get_from_index('test-idx', 'va%*')) + # Same for '_' + self.assertEqual([doc3], self.db.get_from_index('test-idx', 'va_*')) + + def test_get_from_index_with_lower(self): + self.db.create_index("index", "lower(name)") + content = '{"name": "Foo"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_get_from_index_with_lower_matches_same_case(self): + self.db.create_index("index", "lower(name)") + content = '{"name": "foo"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_lower_doesnt_match_different_case(self): + self.db.create_index("index", "lower(name)") + content = '{"name": "Foo"}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "Foo") + self.assertEqual([], rows) + + def test_index_lower_doesnt_match_other_index(self): + self.db.create_index("index", "lower(name)") + self.db.create_index("other_index", "name") + content = '{"name": "Foo"}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "Foo") + self.assertEqual(0, len(rows)) + + def test_index_split_words_match_first(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_split_words_match_second(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_index_split_words_match_both(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo foo"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_split_words_double_space(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_index_split_words_leading_space(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": " foo bar"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "foo") + self.assertEqual([doc], rows) + + def test_index_split_words_trailing_space(self): + self.db.create_index("index", "split_words(name)") + content = '{"name": "foo bar "}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "bar") + self.assertEqual([doc], rows) + + def test_get_from_index_with_number(self): + self.db.create_index("index", "number(foo, 5)") + content = '{"foo": 12}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "00012") + self.assertEqual([doc], rows) + + def test_get_from_index_with_number_bigger_than_padding(self): + self.db.create_index("index", "number(foo, 5)") + content = '{"foo": 123456}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "123456") + self.assertEqual([doc], rows) + + def test_number_mapping_ignores_non_numbers(self): + self.db.create_index("index", "number(foo, 5)") + content = '{"foo": 56}' + doc1 = self.db.create_doc_from_json(content) + content = '{"foo": "this is not a maigret painting"}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "*") + self.assertEqual([doc1], rows) + + def test_get_from_index_with_bool(self): + self.db.create_index("index", "bool(foo)") + content = '{"foo": true}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "1") + self.assertEqual([doc], rows) + + def test_get_from_index_with_bool_false(self): + self.db.create_index("index", "bool(foo)") + content = '{"foo": false}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "0") + self.assertEqual([doc], rows) + + def test_get_from_index_with_non_bool(self): + self.db.create_index("index", "bool(foo)") + content = '{"foo": 42}' + self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "*") + self.assertEqual([], rows) + + def test_get_from_index_with_combine(self): + self.db.create_index("index", "combine(foo, bar)") + content = '{"foo": "value1", "bar": "value2"}' + doc = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "value1") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "value2") + self.assertEqual([doc], rows) + + def test_get_complex_combine(self): + self.db.create_index( + "index", "combine(number(foo, 5), lower(bar), split_words(baz))") + content = '{"foo": 12, "bar": "ALLCAPS", "baz": "qux nox"}' + doc = self.db.create_doc_from_json(content) + content = '{"foo": "not a number", "bar": "something"}' + doc2 = self.db.create_doc_from_json(content) + rows = self.db.get_from_index("index", "00012") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "allcaps") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "nox") + self.assertEqual([doc], rows) + rows = self.db.get_from_index("index", "something") + self.assertEqual([doc2], rows) + + def test_get_index_keys_from_index(self): + self.db.create_index('test-idx', 'key') + content1 = '{"key": "value1"}' + content2 = '{"key": "value2"}' + content3 = '{"key": "value2"}' + self.db.create_doc_from_json(content1) + self.db.create_doc_from_json(content2) + self.db.create_doc_from_json(content3) + self.assertEqual( + [('value1',), ('value2',)], + sorted(self.db.get_index_keys('test-idx'))) + + def test_get_index_keys_from_multicolumn_index(self): + self.db.create_index('test-idx', 'key1', 'key2') + content1 = '{"key1": "value1", "key2": "val2-1"}' + content2 = '{"key1": "value2", "key2": "val2-2"}' + content3 = '{"key1": "value2", "key2": "val2-2"}' + content4 = '{"key1": "value2", "key2": "val3"}' + self.db.create_doc_from_json(content1) + self.db.create_doc_from_json(content2) + self.db.create_doc_from_json(content3) + self.db.create_doc_from_json(content4) + self.assertEqual([ + ('value1', 'val2-1'), + ('value2', 'val2-2'), + ('value2', 'val3')], + sorted(self.db.get_index_keys('test-idx'))) + + def test_empty_expr(self): + self.assertParseError('') + + def test_nested_unknown_operation(self): + self.assertParseError('unknown_operation(field1)') + + def test_parse_missing_close_paren(self): + self.assertParseError("lower(a") + + def test_parse_trailing_close_paren(self): + self.assertParseError("lower(ab))") + + def test_parse_trailing_chars(self): + self.assertParseError("lower(ab)adsf") + + def test_parse_empty_op(self): + self.assertParseError("(ab)") + + def test_parse_top_level_commas(self): + self.assertParseError("a, b") + + def test_invalid_field_name(self): + self.assertParseError("a.") + + def test_invalid_inner_field_name(self): + self.assertParseError("lower(a.)") + + def test_gobbledigook(self): + self.assertParseError("(@#@cc @#!*DFJSXV(()jccd") + + def test_leading_space(self): + self.assertIndexCreatable(" lower(a)") + + def test_trailing_space(self): + self.assertIndexCreatable("lower(a) ") + + def test_spaces_before_open_paren(self): + self.assertIndexCreatable("lower (a)") + + def test_spaces_after_open_paren(self): + self.assertIndexCreatable("lower( a)") + + def test_spaces_before_close_paren(self): + self.assertIndexCreatable("lower(a )") + + def test_spaces_before_comma(self): + self.assertIndexCreatable("combine(a , b , c)") + + def test_spaces_after_comma(self): + self.assertIndexCreatable("combine(a, b, c)") + + def test_all_together_now(self): + self.assertParseError(' (a) ') + + def test_all_together_now2(self): + self.assertParseError('combine(lower(x)x,foo)') + + +class PythonBackendTests(tests.DatabaseBaseTests): + + def setUp(self): + super(PythonBackendTests, self).setUp() + self.simple_doc = json.loads(simple_doc) + + def test_create_doc_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id') + self.assertTrue(isinstance(doc, TestAlternativeDocument)) + + def test_get_doc_after_put_with_factory(self): + doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id') + self.db.set_document_factory(TestAlternativeDocument) + result = self.db.get_doc('my_doc_id') + self.assertTrue(isinstance(result, TestAlternativeDocument)) + self.assertEqual(doc.doc_id, result.doc_id) + self.assertEqual(doc.rev, result.rev) + self.assertEqual(doc.get_json(), result.get_json()) + self.assertEqual(False, result.has_conflicts) + + def test_get_doc_nonexisting_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + self.assertIs(None, self.db.get_doc('non-existing')) + + def test_get_all_docs_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + self.db.create_doc(self.simple_doc) + self.assertTrue(isinstance( + list(self.db.get_all_docs()[1])[0], TestAlternativeDocument)) + + def test_get_docs_conflicted_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + doc1 = self.db.create_doc(self.simple_doc) + doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc) + self.db._put_doc_if_newer( + doc2, save_conflict=True, replica_uid='r', replica_gen=1, + replica_trans_id='foo') + self.assertTrue( + isinstance( + list(self.db.get_docs([doc1.doc_id]))[0], + TestAlternativeDocument)) + + def test_get_from_index_with_factory(self): + self.db.set_document_factory(TestAlternativeDocument) + self.db.create_doc(self.simple_doc) + self.db.create_index('test-idx', 'key') + self.assertTrue( + isinstance( + self.db.get_from_index('test-idx', 'value')[0], + TestAlternativeDocument)) + + def test_sync_exchange_updates_indexes(self): + doc = self.db.create_doc(self.simple_doc) + self.db.create_index('test-idx', 'key') + new_content = '{"key": "altval"}' + other_rev = 'test:1|z:2' + st = self.db.get_sync_target() + + def ignore(doc_id, doc_rev, doc): + pass + + doc_other = self.make_document(doc.doc_id, other_rev, new_content) + docs_by_gen = [(doc_other, 10, 'T-sid')] + st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=ignore) + self.assertGetDoc(self.db, doc.doc_id, other_rev, new_content, False) + self.assertEqual( + [doc_other], self.db.get_from_index('test-idx', 'altval')) + self.assertEqual([], self.db.get_from_index('test-idx', 'value')) + + +# Use a custom loader to apply the scenarios at load time. +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_c_backend.py b/src/leap/soledad/u1db/tests/test_c_backend.py new file mode 100644 index 00000000..bdd2aec7 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_c_backend.py @@ -0,0 +1,634 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +try: + import simplejson as json +except ImportError: + import json # noqa +from u1db import ( + Document, + errors, + tests, + ) +from u1db.tests import c_backend_wrapper, c_backend_error +from u1db.tests.test_remote_sync_target import ( + make_http_app, + make_oauth_http_app + ) + + +class TestCDatabaseExists(tests.TestCase): + + def test_c_backend_compiled(self): + if c_backend_wrapper is None: + self.fail("Could not import the c_backend_wrapper module." + " Was it compiled properly?\n%s" % (c_backend_error,)) + + +# Rather than lots of failing tests, we have the above check to test that the +# module exists, and all these tests just get skipped +class BackendTests(tests.TestCase): + + def setUp(self): + super(BackendTests, self).setUp() + if c_backend_wrapper is None: + self.skipTest("The c_backend_wrapper could not be imported") + + +class TestCDatabase(BackendTests): + + def test_exists(self): + if c_backend_wrapper is None: + self.fail("Could not import the c_backend_wrapper module." + " Was it compiled properly?") + db = c_backend_wrapper.CDatabase(':memory:') + self.assertEqual(':memory:', db._filename) + + def test__is_closed(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertTrue(db._sql_is_open()) + db.close() + self.assertFalse(db._sql_is_open()) + + def test__run_sql(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertTrue(db._sql_is_open()) + self.assertEqual([], db._run_sql('CREATE TABLE test (id INTEGER)')) + self.assertEqual([], db._run_sql('INSERT INTO test VALUES (1)')) + self.assertEqual([('1',)], db._run_sql('SELECT * FROM test')) + + def test__get_generation(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertEqual(0, db._get_generation()) + db.create_doc_from_json(tests.simple_doc) + self.assertEqual(1, db._get_generation()) + + def test__get_generation_info(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertEqual((0, ''), db._get_generation_info()) + db.create_doc_from_json(tests.simple_doc) + info = db._get_generation_info() + self.assertEqual(1, info[0]) + self.assertTrue(info[1].startswith('T-')) + + def test__set_replica_uid(self): + db = c_backend_wrapper.CDatabase(':memory:') + self.assertIsNot(None, db._replica_uid) + db._set_replica_uid('foo') + self.assertEqual([('foo',)], db._run_sql( + "SELECT value FROM u1db_config WHERE name='replica_uid'")) + + def test_default_replica_uid(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.assertIsNot(None, self.db._replica_uid) + self.assertEqual(32, len(self.db._replica_uid)) + # casting to an int from the uid *is* the check for correct behavior. + int(self.db._replica_uid, 16) + + def test_get_conflicts_with_borked_data(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + # We add an entry to conflicts, but not to documents, which is an + # invalid situation + self.db._run_sql("INSERT INTO conflicts" + " VALUES ('doc-id', 'doc-rev', '{}')") + self.assertRaises(Exception, self.db.get_doc_conflicts, 'doc-id') + + def test_create_index_list(self): + # We manually poke data into the DB, so that we test just the "get_doc" + # code, rather than also testing the index management code. + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index_list("key-idx", ["key"]) + docs = self.db.get_from_index('key-idx', 'value') + self.assertEqual([doc], docs) + + def test_create_index_list_on_non_ascii_field_name(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'})) + self.db.create_index_list('test-idx', [u'\xe5']) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_list_indexes_with_non_ascii_field_names(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index_list('test-idx', [u'\xe5']) + self.assertEqual( + [('test-idx', [u'\xe5'])], self.db.list_indexes()) + + def test_create_index_evaluates_it(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_wildcard_matches_unicode_value(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"})) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([doc], self.db.get_from_index('test-idx', '*')) + + def test_create_index_fails_if_name_taken(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index_list('test-idx', ['key']) + self.assertRaises(errors.IndexNameTakenError, + self.db.create_index_list, + 'test-idx', ['stuff']) + + def test_create_index_does_not_fail_if_name_taken_with_same_index(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index_list('test-idx', ['key']) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([('test-idx', ['key'])], self.db.list_indexes()) + + def test_create_index_after_deleting_document(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db.create_doc_from_json(tests.simple_doc) + self.db.delete_doc(doc2) + self.db.create_index_list('test-idx', ['key']) + self.assertEqual([doc], self.db.get_from_index('test-idx', 'value')) + + def test_get_from_index(self): + # We manually poke data into the DB, so that we test just the "get_doc" + # code, rather than also testing the index management code. + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index("key-idx", "key") + docs = self.db.get_from_index('key-idx', 'value') + self.assertEqual([doc], docs) + + def test_get_from_index_list(self): + # We manually poke data into the DB, so that we test just the "get_doc" + # code, rather than also testing the index management code. + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index("key-idx", "key") + docs = self.db.get_from_index_list('key-idx', ['value']) + self.assertEqual([doc], docs) + + def test_get_from_index_list_multi(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + content = '{"key": "value", "key2": "value2"}' + doc = self.db.create_doc_from_json(content) + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc], + self.db.get_from_index_list('test-idx', ['value', 'value2'])) + + def test_get_from_index_list_multi_ordered(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc1 = self.db.create_doc_from_json( + '{"key": "value3", "key2": "value4"}') + doc2 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value3"}') + doc3 = self.db.create_doc_from_json( + '{"key": "value2", "key2": "value2"}') + doc4 = self.db.create_doc_from_json( + '{"key": "value1", "key2": "value1"}') + self.db.create_index('test-idx', 'key', 'key2') + self.assertEqual( + [doc4, doc3, doc2, doc1], + self.db.get_from_index_list('test-idx', ['v*', '*'])) + + def test_get_from_index_2(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + doc = self.db.create_doc_from_json(tests.nested_doc) + self.db.create_index("multi-idx", "key", "sub.doc") + docs = self.db.get_from_index('multi-idx', 'value', 'underneath') + self.assertEqual([doc], docs) + + def test_get_index_keys(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_doc_from_json(tests.simple_doc) + self.db.create_index("key-idx", "key") + keys = self.db.get_index_keys('key-idx') + self.assertEqual([("value",)], keys) + + def test__query_init_one_field(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index("key-idx", "key") + query = self.db._query_init("key-idx") + self.assertEqual("key-idx", query.index_name) + self.assertEqual(1, query.num_fields) + self.assertEqual(["key"], query.fields) + + def test__query_init_two_fields(self): + self.db = c_backend_wrapper.CDatabase(':memory:') + self.db.create_index("two-idx", "key", "key2") + query = self.db._query_init("two-idx") + self.assertEqual("two-idx", query.index_name) + self.assertEqual(2, query.num_fields) + self.assertEqual(["key", "key2"], query.fields) + + def assertFormatQueryEquals(self, expected, wildcards, fields): + val, w = c_backend_wrapper._format_query(fields) + self.assertEqual(expected, val) + self.assertEqual(wildcards, w) + + def test__format_query(self): + self.assertFormatQueryEquals( + "SELECT d0.doc_id FROM document_fields d0" + " WHERE d0.field_name = ? AND d0.value = ? ORDER BY d0.value", + [0], ["1"]) + self.assertFormatQueryEquals( + "SELECT d0.doc_id" + " FROM document_fields d0, document_fields d1" + " WHERE d0.field_name = ? AND d0.value = ?" + " AND d0.doc_id = d1.doc_id" + " AND d1.field_name = ? AND d1.value = ?" + " ORDER BY d0.value, d1.value", + [0, 0], ["1", "2"]) + self.assertFormatQueryEquals( + "SELECT d0.doc_id" + " FROM document_fields d0, document_fields d1, document_fields d2" + " WHERE d0.field_name = ? AND d0.value = ?" + " AND d0.doc_id = d1.doc_id" + " AND d1.field_name = ? AND d1.value = ?" + " AND d0.doc_id = d2.doc_id" + " AND d2.field_name = ? AND d2.value = ?" + " ORDER BY d0.value, d1.value, d2.value", + [0, 0, 0], ["1", "2", "3"]) + + def test__format_query_wildcard(self): + self.assertFormatQueryEquals( + "SELECT d0.doc_id FROM document_fields d0" + " WHERE d0.field_name = ? AND d0.value NOT NULL ORDER BY d0.value", + [1], ["*"]) + self.assertFormatQueryEquals( + "SELECT d0.doc_id" + " FROM document_fields d0, document_fields d1" + " WHERE d0.field_name = ? AND d0.value = ?" + " AND d0.doc_id = d1.doc_id" + " AND d1.field_name = ? AND d1.value NOT NULL" + " ORDER BY d0.value, d1.value", + [0, 1], ["1", "*"]) + + def test__format_query_glob(self): + self.assertFormatQueryEquals( + "SELECT d0.doc_id FROM document_fields d0" + " WHERE d0.field_name = ? AND d0.value GLOB ? ORDER BY d0.value", + [2], ["1*"]) + + +class TestCSyncTarget(BackendTests): + + def setUp(self): + super(TestCSyncTarget, self).setUp() + self.db = c_backend_wrapper.CDatabase(':memory:') + self.st = self.db.get_sync_target() + + def test_attached_to_db(self): + self.assertEqual( + self.db._replica_uid, self.st.get_sync_info("misc")[0]) + + def test_get_sync_exchange(self): + exc = self.st._get_sync_exchange("source-uid", 10) + self.assertIsNot(None, exc) + + def test_sync_exchange_insert_doc_from_source(self): + exc = self.st._get_sync_exchange("source-uid", 5) + doc = c_backend_wrapper.make_document('doc-id', 'replica:1', + tests.simple_doc) + self.assertEqual([], exc.get_seen_ids()) + exc.insert_doc_from_source(doc, 10, 'T-sid') + self.assertGetDoc(self.db, 'doc-id', 'replica:1', tests.simple_doc, + False) + self.assertEqual( + (10, 'T-sid'), self.db._get_replica_gen_and_trans_id('source-uid')) + self.assertEqual(['doc-id'], exc.get_seen_ids()) + + def test_sync_exchange_conflicted_doc(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + exc = self.st._get_sync_exchange("source-uid", 5) + doc2 = c_backend_wrapper.make_document(doc.doc_id, 'replica:1', + tests.nested_doc) + self.assertEqual([], exc.get_seen_ids()) + # The insert should be rejected and the doc_id not considered 'seen' + exc.insert_doc_from_source(doc2, 10, 'T-sid') + self.assertGetDoc( + self.db, doc.doc_id, doc.rev, tests.simple_doc, False) + self.assertEqual([], exc.get_seen_ids()) + + def test_sync_exchange_find_doc_ids(self): + doc = self.db.create_doc_from_json(tests.simple_doc) + exc = self.st._get_sync_exchange("source-uid", 0) + self.assertEqual(0, exc.target_gen) + exc.find_doc_ids_to_return() + doc_id = exc.get_doc_ids_to_return()[0] + self.assertEqual( + (doc.doc_id, 1), doc_id[:-1]) + self.assertTrue(doc_id[-1].startswith('T-')) + self.assertEqual(1, exc.target_gen) + + def test_sync_exchange_find_doc_ids_not_including_recently_inserted(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db.create_doc_from_json(tests.nested_doc) + exc = self.st._get_sync_exchange("source-uid", 0) + doc3 = c_backend_wrapper.make_document(doc1.doc_id, + doc1.rev + "|zreplica:2", tests.simple_doc) + exc.insert_doc_from_source(doc3, 10, 'T-sid') + exc.find_doc_ids_to_return() + self.assertEqual( + (doc2.doc_id, 2), exc.get_doc_ids_to_return()[0][:-1]) + self.assertEqual(3, exc.target_gen) + + def test_sync_exchange_return_docs(self): + returned = [] + + def return_doc_cb(doc, gen, trans_id): + returned.append((doc, gen, trans_id)) + + doc1 = self.db.create_doc_from_json(tests.simple_doc) + exc = self.st._get_sync_exchange("source-uid", 0) + exc.find_doc_ids_to_return() + exc.return_docs(return_doc_cb) + self.assertEqual((doc1, 1), returned[0][:-1]) + + def test_sync_exchange_doc_ids(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc, doc_id='doc-1') + db2 = c_backend_wrapper.CDatabase(':memory:') + doc2 = db2.create_doc_from_json(tests.nested_doc, doc_id='doc-2') + returned = [] + + def return_doc_cb(doc, gen, trans_id): + returned.append((doc, gen, trans_id)) + + val = self.st.sync_exchange_doc_ids( + db2, [(doc2.doc_id, 1, 'T-sid')], 0, None, return_doc_cb) + last_trans_id = self.db._get_transaction_log()[-1][1] + self.assertEqual(2, self.db._get_generation()) + self.assertEqual((2, last_trans_id), val) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, + False) + self.assertEqual((doc1, 1), returned[0][:-1]) + + +class TestCHTTPSyncTarget(BackendTests): + + def test_format_sync_url(self): + target = c_backend_wrapper.create_http_sync_target("http://base_url") + self.assertEqual("http://base_url/sync-from/replica-uid", + c_backend_wrapper._format_sync_url(target, "replica-uid")) + + def test_format_sync_url_escapes(self): + # The base_url should not get munged (we assume it is already a + # properly formed URL), but the replica-uid should get properly escaped + target = c_backend_wrapper.create_http_sync_target( + "http://host/base%2Ctest/") + self.assertEqual("http://host/base%2Ctest/sync-from/replica%2Cuid", + c_backend_wrapper._format_sync_url(target, "replica,uid")) + + def test_format_refuses_non_http(self): + db = c_backend_wrapper.CDatabase(':memory:') + target = db.get_sync_target() + self.assertRaises(RuntimeError, + c_backend_wrapper._format_sync_url, target, 'replica,uid') + + def test_oauth_credentials(self): + target = c_backend_wrapper.create_oauth_http_sync_target( + "http://host/base%2Ctest/", + 'consumer-key', 'consumer-secret', 'token-key', 'token-secret') + auth = c_backend_wrapper._get_oauth_authorization(target, + "GET", "http://host/base%2Ctest/sync-from/abcd-efg") + self.assertIsNot(None, auth) + self.assertTrue(auth.startswith('Authorization: OAuth realm="", ')) + self.assertNotIn('http://host/base', auth) + self.assertIn('oauth_nonce="', auth) + self.assertIn('oauth_timestamp="', auth) + self.assertIn('oauth_consumer_key="consumer-key"', auth) + self.assertIn('oauth_signature_method="HMAC-SHA1"', auth) + self.assertIn('oauth_version="1.0"', auth) + self.assertIn('oauth_token="token-key"', auth) + self.assertIn('oauth_signature="', auth) + + +class TestSyncCtoHTTPViaC(tests.TestCaseWithServer): + + make_app_with_state = staticmethod(make_http_app) + + def setUp(self): + super(TestSyncCtoHTTPViaC, self).setUp() + if c_backend_wrapper is None: + self.skipTest("The c_backend_wrapper could not be imported") + self.startServer() + + def test_trivial_sync(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + url = self.getURL('test.db') + target = c_backend_wrapper.create_http_sync_target(url) + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + c_backend_wrapper.sync_db_to_target(db, target) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + def test_unavailable(self): + mem_db = self.request_state._create_database('test.db') + mem_db.create_doc_from_json(tests.nested_doc) + tries = [] + + def wrapper(instance, *args, **kwargs): + tries.append(None) + raise errors.Unavailable + + mem_db.whats_changed = wrapper + url = self.getURL('test.db') + target = c_backend_wrapper.create_http_sync_target(url) + db = c_backend_wrapper.CDatabase(':memory:') + db.create_doc_from_json(tests.simple_doc) + self.assertRaises( + errors.Unavailable, c_backend_wrapper.sync_db_to_target, db, + target) + self.assertEqual(5, len(tries)) + + def test_unavailable_then_available(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + orig_whatschanged = mem_db.whats_changed + tries = [] + + def wrapper(instance, *args, **kwargs): + if len(tries) < 1: + tries.append(None) + raise errors.Unavailable + return orig_whatschanged(instance, *args, **kwargs) + + mem_db.whats_changed = wrapper + url = self.getURL('test.db') + target = c_backend_wrapper.create_http_sync_target(url) + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + c_backend_wrapper.sync_db_to_target(db, target) + self.assertEqual(1, len(tries)) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + def test_db_sync(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + url = self.getURL('test.db') + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + local_gen_before_sync = db.sync(url) + gen, _, changes = db.whats_changed(local_gen_before_sync) + self.assertEqual(1, len(changes)) + self.assertEqual(mem_doc.doc_id, changes[0][0]) + self.assertEqual(1, gen - local_gen_before_sync) + self.assertEqual(1, local_gen_before_sync) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + +class TestSyncCtoOAuthHTTPViaC(tests.TestCaseWithServer): + + make_app_with_state = staticmethod(make_oauth_http_app) + + def setUp(self): + super(TestSyncCtoOAuthHTTPViaC, self).setUp() + if c_backend_wrapper is None: + self.skipTest("The c_backend_wrapper could not be imported") + self.startServer() + + def test_trivial_sync(self): + mem_db = self.request_state._create_database('test.db') + mem_doc = mem_db.create_doc_from_json(tests.nested_doc) + url = self.getURL('~/test.db') + target = c_backend_wrapper.create_oauth_http_sync_target(url, + tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + db = c_backend_wrapper.CDatabase(':memory:') + doc = db.create_doc_from_json(tests.simple_doc) + c_backend_wrapper.sync_db_to_target(db, target) + self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False) + self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(), + False) + + +class TestVectorClock(BackendTests): + + def create_vcr(self, rev): + return c_backend_wrapper.VectorClockRev(rev) + + def test_parse_empty(self): + self.assertEqual('VectorClockRev()', + repr(self.create_vcr(''))) + + def test_parse_invalid(self): + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('x'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('x:a'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:a'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('x:a|y:1'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2a'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1||'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2|'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2|:'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:2|m:'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|x:|m:3'))) + self.assertEqual('VectorClockRev(None)', + repr(self.create_vcr('y:1|:|m:3'))) + + def test_parse_single(self): + self.assertEqual('VectorClockRev(test:1)', + repr(self.create_vcr('test:1'))) + + def test_parse_multi(self): + self.assertEqual('VectorClockRev(test:1|z:2)', + repr(self.create_vcr('test:1|z:2'))) + self.assertEqual('VectorClockRev(ab:1|bc:2|cd:3|de:4|ef:5)', + repr(self.create_vcr('ab:1|bc:2|cd:3|de:4|ef:5'))) + self.assertEqual('VectorClockRev(a:2|b:1)', + repr(self.create_vcr('b:1|a:2'))) + + +class TestCDocument(BackendTests): + + def make_document(self, *args, **kwargs): + return c_backend_wrapper.make_document(*args, **kwargs) + + def test_create(self): + self.make_document('doc-id', 'uid:1', tests.simple_doc) + + def assertPyDocEqualCDoc(self, *args, **kwargs): + cdoc = self.make_document(*args, **kwargs) + pydoc = Document(*args, **kwargs) + self.assertEqual(pydoc, cdoc) + self.assertEqual(cdoc, pydoc) + + def test_cmp_to_pydoc_equal(self): + self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc) + self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=False) + self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=True) + + def test_cmp_to_pydoc_not_equal_conflicts(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=True) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + def test_cmp_to_pydoc_not_equal_doc_id(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc2-id', 'uid:1', tests.simple_doc) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + def test_cmp_to_pydoc_not_equal_doc_rev(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc-id', 'uid:2', tests.simple_doc) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + def test_cmp_to_pydoc_not_equal_content(self): + cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + pydoc = Document('doc-id', 'uid:1', tests.nested_doc) + self.assertNotEqual(cdoc, pydoc) + self.assertNotEqual(pydoc, cdoc) + + +class TestUUID(BackendTests): + + def test_uuid4_conformance(self): + uuids = set() + for i in range(20): + uuid = c_backend_wrapper.generate_hex_uuid() + self.assertIsInstance(uuid, str) + self.assertEqual(32, len(uuid)) + # This will raise ValueError if it isn't a valid hex string + long(uuid, 16) + # Version 4 uuids have 2 other requirements, the high 4 bits of the + # seventh byte are always '0x4', and the middle bits of byte 9 are + # always set + self.assertEqual('4', uuid[12]) + self.assertTrue(uuid[16] in '89ab') + self.assertTrue(uuid not in uuids) + uuids.add(uuid) diff --git a/src/leap/soledad/u1db/tests/test_common_backend.py b/src/leap/soledad/u1db/tests/test_common_backend.py new file mode 100644 index 00000000..8c7c7ed9 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_common_backend.py @@ -0,0 +1,33 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test common backend bits.""" + +from u1db import ( + backends, + tests, + ) + + +class TestCommonBackendImpl(tests.TestCase): + + def test__allocate_doc_id(self): + db = backends.CommonBackend() + doc_id1 = db._allocate_doc_id() + self.assertTrue(doc_id1.startswith('D-')) + self.assertEqual(34, len(doc_id1)) + int(doc_id1[len('D-'):], 16) + self.assertNotEqual(doc_id1, db._allocate_doc_id()) diff --git a/src/leap/soledad/u1db/tests/test_document.py b/src/leap/soledad/u1db/tests/test_document.py new file mode 100644 index 00000000..20f254b9 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_document.py @@ -0,0 +1,148 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + + +from u1db import errors, tests + + +class TestDocument(tests.TestCase): + + scenarios = ([( + 'py', {'make_document_for_test': tests.make_document_for_test})] + + tests.C_DATABASE_SCENARIOS) + + def test_create_doc(self): + doc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + self.assertEqual('doc-id', doc.doc_id) + self.assertEqual('uid:1', doc.rev) + self.assertEqual(tests.simple_doc, doc.get_json()) + self.assertFalse(doc.has_conflicts) + + def test__repr__(self): + doc = self.make_document('doc-id', 'uid:1', tests.simple_doc) + self.assertEqual( + '%s(doc-id, uid:1, \'{"key": "value"}\')' + % (doc.__class__.__name__,), + repr(doc)) + + def test__repr__conflicted(self): + doc = self.make_document('doc-id', 'uid:1', tests.simple_doc, + has_conflicts=True) + self.assertEqual( + '%s(doc-id, uid:1, conflicted, \'{"key": "value"}\')' + % (doc.__class__.__name__,), + repr(doc)) + + def test__lt__(self): + doc_a = self.make_document('a', 'b', '{}') + doc_b = self.make_document('b', 'b', '{}') + self.assertTrue(doc_a < doc_b) + self.assertTrue(doc_b > doc_a) + doc_aa = self.make_document('a', 'a', '{}') + self.assertTrue(doc_aa < doc_a) + + def test__eq__(self): + doc_a = self.make_document('a', 'b', '{}') + doc_b = self.make_document('a', 'b', '{}') + self.assertTrue(doc_a == doc_b) + doc_b = self.make_document('a', 'b', '{}', has_conflicts=True) + self.assertFalse(doc_a == doc_b) + + def test_non_json_dict(self): + self.assertRaises( + errors.InvalidJSON, self.make_document, 'id', 'uid:1', + '"not a json dictionary"') + + def test_non_json(self): + self.assertRaises( + errors.InvalidJSON, self.make_document, 'id', 'uid:1', + 'not a json dictionary') + + def test_get_size(self): + doc_a = self.make_document('a', 'b', '{"some": "content"}') + self.assertEqual( + len('a' + 'b' + '{"some": "content"}'), doc_a.get_size()) + + def test_get_size_empty_document(self): + doc_a = self.make_document('a', 'b', None) + self.assertEqual(len('a' + 'b'), doc_a.get_size()) + + +class TestPyDocument(tests.TestCase): + + scenarios = ([( + 'py', {'make_document_for_test': tests.make_document_for_test})]) + + def test_get_content(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertEqual({"content": ""}, doc.content) + doc.set_json('{"content": "new"}') + self.assertEqual({"content": "new"}, doc.content) + + def test_set_content(self): + doc = self.make_document('id', 'rev', '{"content":""}') + doc.content = {"content": "new"} + self.assertEqual('{"content": "new"}', doc.get_json()) + + def test_set_bad_content(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertRaises( + errors.InvalidContent, setattr, doc, 'content', + '{"content": "new"}') + + def test_is_tombstone(self): + doc_a = self.make_document('a', 'b', '{}') + self.assertFalse(doc_a.is_tombstone()) + doc_a.set_json(None) + self.assertTrue(doc_a.is_tombstone()) + + def test_make_tombstone(self): + doc_a = self.make_document('a', 'b', '{}') + self.assertFalse(doc_a.is_tombstone()) + doc_a.make_tombstone() + self.assertTrue(doc_a.is_tombstone()) + + def test_same_content_as(self): + doc_a = self.make_document('a', 'b', '{}') + doc_b = self.make_document('d', 'e', '{}') + self.assertTrue(doc_a.same_content_as(doc_b)) + doc_b = self.make_document('p', 'q', '{}', has_conflicts=True) + self.assertTrue(doc_a.same_content_as(doc_b)) + doc_b.content['key'] = 'value' + self.assertFalse(doc_a.same_content_as(doc_b)) + + def test_same_content_as_json_order(self): + doc_a = self.make_document( + 'a', 'b', '{"key1": "val1", "key2": "val2"}') + doc_b = self.make_document( + 'c', 'd', '{"key2": "val2", "key1": "val1"}') + self.assertTrue(doc_a.same_content_as(doc_b)) + + def test_set_json(self): + doc = self.make_document('id', 'rev', '{"content":""}') + doc.set_json('{"content": "new"}') + self.assertEqual('{"content": "new"}', doc.get_json()) + + def test_set_json_non_dict(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertRaises(errors.InvalidJSON, doc.set_json, '"is not a dict"') + + def test_set_json_error(self): + doc = self.make_document('id', 'rev', '{"content":""}') + self.assertRaises(errors.InvalidJSON, doc.set_json, 'is not json') + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_errors.py b/src/leap/soledad/u1db/tests/test_errors.py new file mode 100644 index 00000000..0e089ede --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_errors.py @@ -0,0 +1,61 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests error infrastructure.""" + +from u1db import ( + errors, + tests, + ) + + +class TestError(tests.TestCase): + + def test_error_base(self): + err = errors.U1DBError() + self.assertEqual("error", err.wire_description) + self.assertIs(None, err.message) + + err = errors.U1DBError("Message.") + self.assertEqual("error", err.wire_description) + self.assertEqual("Message.", err.message) + + def test_HTTPError(self): + err = errors.HTTPError(500) + self.assertEqual(500, err.status) + self.assertIs(None, err.wire_description) + self.assertIs(None, err.message) + + err = errors.HTTPError(500, "Crash.") + self.assertEqual(500, err.status) + self.assertIs(None, err.wire_description) + self.assertEqual("Crash.", err.message) + + def test_HTTPError_str(self): + err = errors.HTTPError(500) + self.assertEqual("HTTPError(500)", str(err)) + + err = errors.HTTPError(500, "ERROR") + self.assertEqual("HTTPError(500, 'ERROR')", str(err)) + + def test_Unvailable(self): + err = errors.Unavailable() + self.assertEqual(503, err.status) + self.assertEqual("Unavailable()", str(err)) + + err = errors.Unavailable("DOWN") + self.assertEqual("DOWN", err.message) + self.assertEqual("Unavailable('DOWN')", str(err)) diff --git a/src/leap/soledad/u1db/tests/test_http_app.py b/src/leap/soledad/u1db/tests/test_http_app.py new file mode 100644 index 00000000..13522693 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_http_app.py @@ -0,0 +1,1133 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test the WSGI app.""" + +import paste.fixture +import sys +try: + import simplejson as json +except ImportError: + import json # noqa +import StringIO + +from u1db import ( + __version__ as _u1db_version, + errors, + sync, + tests, + ) + +from u1db.remote import ( + http_app, + http_errors, + ) + + +class TestFencedReader(tests.TestCase): + + def test_init(self): + reader = http_app._FencedReader(StringIO.StringIO(""), 25, 100) + self.assertEqual(25, reader.remaining) + + def test_read_chunk(self): + inp = StringIO.StringIO("abcdef") + reader = http_app._FencedReader(inp, 5, 10) + data = reader.read_chunk(2) + self.assertEqual("ab", data) + self.assertEqual(2, inp.tell()) + self.assertEqual(3, reader.remaining) + + def test_read_chunk_remaining(self): + inp = StringIO.StringIO("abcdef") + reader = http_app._FencedReader(inp, 4, 10) + data = reader.read_chunk(9999) + self.assertEqual("abcd", data) + self.assertEqual(4, inp.tell()) + self.assertEqual(0, reader.remaining) + + def test_read_chunk_nothing_left(self): + inp = StringIO.StringIO("abc") + reader = http_app._FencedReader(inp, 2, 10) + reader.read_chunk(2) + self.assertEqual(2, inp.tell()) + self.assertEqual(0, reader.remaining) + data = reader.read_chunk(2) + self.assertEqual("", data) + self.assertEqual(2, inp.tell()) + self.assertEqual(0, reader.remaining) + + def test_read_chunk_kept(self): + inp = StringIO.StringIO("abcde") + reader = http_app._FencedReader(inp, 4, 10) + reader._kept = "xyz" + data = reader.read_chunk(2) # atmost ignored + self.assertEqual("xyz", data) + self.assertEqual(0, inp.tell()) + self.assertEqual(4, reader.remaining) + self.assertIsNone(reader._kept) + + def test_getline(self): + inp = StringIO.StringIO("abc\r\nde") + reader = http_app._FencedReader(inp, 6, 10) + reader.MAXCHUNK = 6 + line = reader.getline() + self.assertEqual("abc\r\n", line) + self.assertEqual("d", reader._kept) + + def test_getline_exact(self): + inp = StringIO.StringIO("abcd\r\nef") + reader = http_app._FencedReader(inp, 6, 10) + reader.MAXCHUNK = 6 + line = reader.getline() + self.assertEqual("abcd\r\n", line) + self.assertIs(None, reader._kept) + + def test_getline_no_newline(self): + inp = StringIO.StringIO("abcd") + reader = http_app._FencedReader(inp, 4, 10) + reader.MAXCHUNK = 6 + line = reader.getline() + self.assertEqual("abcd", line) + + def test_getline_many_chunks(self): + inp = StringIO.StringIO("abcde\r\nf") + reader = http_app._FencedReader(inp, 8, 10) + reader.MAXCHUNK = 4 + line = reader.getline() + self.assertEqual("abcde\r\n", line) + self.assertEqual("f", reader._kept) + line = reader.getline() + self.assertEqual("f", line) + + def test_getline_empty(self): + inp = StringIO.StringIO("") + reader = http_app._FencedReader(inp, 0, 10) + reader.MAXCHUNK = 4 + line = reader.getline() + self.assertEqual("", line) + line = reader.getline() + self.assertEqual("", line) + + def test_getline_just_newline(self): + inp = StringIO.StringIO("\r\n") + reader = http_app._FencedReader(inp, 2, 10) + reader.MAXCHUNK = 4 + line = reader.getline() + self.assertEqual("\r\n", line) + line = reader.getline() + self.assertEqual("", line) + + def test_getline_too_large(self): + inp = StringIO.StringIO("x" * 50) + reader = http_app._FencedReader(inp, 50, 25) + reader.MAXCHUNK = 4 + self.assertRaises(http_app.BadRequest, reader.getline) + + def test_getline_too_large_complete(self): + inp = StringIO.StringIO("x" * 25 + "\r\n") + reader = http_app._FencedReader(inp, 50, 25) + reader.MAXCHUNK = 4 + self.assertRaises(http_app.BadRequest, reader.getline) + + +class TestHTTPMethodDecorator(tests.TestCase): + + def test_args(self): + @http_app.http_method() + def f(self, a, b): + return self, a, b + res = f("self", {"a": "x", "b": "y"}, None) + self.assertEqual(("self", "x", "y"), res) + + def test_args_missing(self): + @http_app.http_method() + def f(self, a, b): + return a, b + self.assertRaises(http_app.BadRequest, f, "self", {"a": "x"}, None) + + def test_args_unexpected(self): + @http_app.http_method() + def f(self, a): + return a + self.assertRaises(http_app.BadRequest, f, "self", + {"a": "x", "c": "z"}, None) + + def test_args_default(self): + @http_app.http_method() + def f(self, a, b="z"): + return a, b + res = f("self", {"a": "x"}, None) + self.assertEqual(("x", "z"), res) + + def test_args_conversion(self): + @http_app.http_method(b=int) + def f(self, a, b): + return self, a, b + res = f("self", {"a": "x", "b": "2"}, None) + self.assertEqual(("self", "x", 2), res) + + self.assertRaises(http_app.BadRequest, f, "self", + {"a": "x", "b": "foo"}, None) + + def test_args_conversion_with_default(self): + @http_app.http_method(b=str) + def f(self, a, b=None): + return self, a, b + res = f("self", {"a": "x"}, None) + self.assertEqual(("self", "x", None), res) + + def test_args_content(self): + @http_app.http_method() + def f(self, a, content): + return a, content + res = f(self, {"a": "x"}, "CONTENT") + self.assertEqual(("x", "CONTENT"), res) + + def test_args_content_as_args(self): + @http_app.http_method(b=int, content_as_args=True) + def f(self, a, b): + return self, a, b + res = f("self", {"a": "x"}, '{"b": "2"}') + self.assertEqual(("self", "x", 2), res) + + self.assertRaises(http_app.BadRequest, f, "self", {}, 'not-json') + + def test_args_content_no_query(self): + @http_app.http_method(no_query=True, + content_as_args=True) + def f(self, a='a', b='b'): + return a, b + res = f("self", {}, '{"b": "y"}') + self.assertEqual(('a', 'y'), res) + + self.assertRaises(http_app.BadRequest, f, "self", {'a': 'x'}, + '{"b": "y"}') + + +class TestResource(object): + + @http_app.http_method() + def get(self, a, b): + self.args = dict(a=a, b=b) + return 'Get' + + @http_app.http_method() + def put(self, a, content): + self.args = dict(a=a) + self.content = content + return 'Put' + + @http_app.http_method(content_as_args=True) + def put_args(self, a, b): + self.args = dict(a=a, b=b) + self.order = ['a'] + self.entries = [] + + @http_app.http_method() + def put_stream_entry(self, content): + self.entries.append(content) + self.order.append('s') + + def put_end(self): + self.order.append('e') + return "Put/end" + + +class parameters: + max_request_size = 200000 + max_entry_size = 100000 + + +class TestHTTPInvocationByMethodWithBody(tests.TestCase): + + def test_get(self): + resource = TestResource() + environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'GET'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + res = invoke() + self.assertEqual('Get', res) + self.assertEqual({'a': '1', 'b': '2'}, resource.args) + + def test_put_json(self): + resource = TestResource() + body = '{"body": true}' + environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + res = invoke() + self.assertEqual('Put', res) + self.assertEqual({'a': '1'}, resource.args) + self.assertEqual('{"body": true}', resource.content) + + def test_put_sync_stream(self): + resource = TestResource() + body = ( + '[\r\n' + '{"b": 2},\r\n' # args + '{"entry": "x"},\r\n' # stream entry + '{"entry": "y"}\r\n' # stream entry + ']' + ) + environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/x-u1db-sync-stream'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + res = invoke() + self.assertEqual('Put/end', res) + self.assertEqual({'a': '1', 'b': 2}, resource.args) + self.assertEqual( + ['{"entry": "x"}', '{"entry": "y"}'], resource.entries) + self.assertEqual(['a', 's', 's', 'e'], resource.order) + + def _put_sync_stream(self, body): + resource = TestResource() + environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/x-u1db-sync-stream'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + invoke() + + def test_put_sync_stream_wrong_start(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "{}\r\n]") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "\r\n{}\r\n]") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "") + + def test_put_sync_stream_wrong_end(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{}") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{}\r\n]\r\n...") + + def test_put_sync_stream_missing_comma(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{}\r\n{}\r\n]") + + def test_put_sync_stream_extra_comma(self): + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{},\r\n]") + + self.assertRaises(http_app.BadRequest, + self._put_sync_stream, "[\r\n{},\r\n{},\r\n]") + + def test_bad_request_decode_failure(self): + resource = TestResource() + environ = {'QUERY_STRING': 'a=\xff', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '2', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_content_type(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '2', + 'CONTENT_TYPE': 'text/plain'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_content_length_too_large(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '10000', + 'CONTENT_TYPE': 'text/plain'} + + resource.max_request_size = 5000 + resource.max_entry_size = sys.maxint # we don't get to use this + + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_no_content_length(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('a'), + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_invalid_content_length(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('abc'), + 'CONTENT_LENGTH': '1unk', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_empty_body(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO(''), + 'CONTENT_LENGTH': '0', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_method_get_like(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'DELETE'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_method_put_like(self): + resource = TestResource() + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT', + 'wsgi.input': StringIO.StringIO('{}'), + 'CONTENT_LENGTH': '2', + 'CONTENT_TYPE': 'application/json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + def test_bad_request_unsupported_method_put_like_multi_json(self): + resource = TestResource() + body = '{}\r\n{}\r\n' + environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'POST', + 'wsgi.input': StringIO.StringIO(body), + 'CONTENT_LENGTH': str(len(body)), + 'CONTENT_TYPE': 'application/x-u1db-multi-json'} + invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ, + parameters) + self.assertRaises(http_app.BadRequest, invoke) + + +class TestHTTPResponder(tests.TestCase): + + def start_response(self, status, headers): + self.status = status + self.headers = dict(headers) + self.response_body = [] + + def write(data): + self.response_body.append(data) + + return write + + def test_send_response_content_w_headers(self): + responder = http_app.HTTPResponder(self.start_response) + responder.send_response_content('foo', headers={'x-a': '1'}) + self.assertEqual('200 OK', self.status) + self.assertEqual({'content-type': 'application/json', + 'cache-control': 'no-cache', + 'x-a': '1', 'content-length': '3'}, self.headers) + self.assertEqual([], self.response_body) + self.assertEqual(['foo'], responder.content) + + def test_send_response_json(self): + responder = http_app.HTTPResponder(self.start_response) + responder.send_response_json(value='success') + self.assertEqual('200 OK', self.status) + expected_body = '{"value": "success"}\r\n' + self.assertEqual({'content-type': 'application/json', + 'content-length': str(len(expected_body)), + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual([], self.response_body) + self.assertEqual([expected_body], responder.content) + + def test_send_response_json_status_fail(self): + responder = http_app.HTTPResponder(self.start_response) + responder.send_response_json(400) + self.assertEqual('400 Bad Request', self.status) + expected_body = '{}\r\n' + self.assertEqual({'content-type': 'application/json', + 'content-length': str(len(expected_body)), + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual([], self.response_body) + self.assertEqual([expected_body], responder.content) + + def test_start_finish_response_status_fail(self): + responder = http_app.HTTPResponder(self.start_response) + responder.start_response(404, {'error': 'not found'}) + responder.finish_response() + self.assertEqual('404 Not Found', self.status) + self.assertEqual({'content-type': 'application/json', + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual(['{"error": "not found"}\r\n'], self.response_body) + self.assertEqual([], responder.content) + + def test_send_stream_entry(self): + responder = http_app.HTTPResponder(self.start_response) + responder.content_type = "application/x-u1db-multi-json" + responder.start_response(200) + responder.start_stream() + responder.stream_entry({'entry': 1}) + responder.stream_entry({'entry': 2}) + responder.end_stream() + responder.finish_response() + self.assertEqual('200 OK', self.status) + self.assertEqual({'content-type': 'application/x-u1db-multi-json', + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual(['[', + '\r\n', '{"entry": 1}', + ',\r\n', '{"entry": 2}', + '\r\n]\r\n'], self.response_body) + self.assertEqual([], responder.content) + + def test_send_stream_w_error(self): + responder = http_app.HTTPResponder(self.start_response) + responder.content_type = "application/x-u1db-multi-json" + responder.start_response(200) + responder.start_stream() + responder.stream_entry({'entry': 1}) + responder.send_response_json(503, error="unavailable") + self.assertEqual('200 OK', self.status) + self.assertEqual({'content-type': 'application/x-u1db-multi-json', + 'cache-control': 'no-cache'}, self.headers) + self.assertEqual(['[', + '\r\n', '{"entry": 1}'], self.response_body) + self.assertEqual([',\r\n', '{"error": "unavailable"}\r\n'], + responder.content) + + +class TestHTTPApp(tests.TestCase): + + def setUp(self): + super(TestHTTPApp, self).setUp() + self.state = tests.ServerStateForTests() + self.http_app = http_app.HTTPApp(self.state) + self.app = paste.fixture.TestApp(self.http_app) + self.db0 = self.state._create_database('db0') + + def test_bad_request_broken(self): + resp = self.app.put('/db0/doc/doc1', params='{"x": 1}', + headers={'content-type': 'application/foo'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_bad_request_dispatch(self): + resp = self.app.put('/db0/foo/doc1', params='{"x": 1}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_version(self): + resp = self.app.get('/') + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"version": _u1db_version}, json.loads(resp.body)) + + def test_create_database(self): + resp = self.app.put('/db1', params='{}', + headers={'content-type': 'application/json'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'ok': True}, json.loads(resp.body)) + + resp = self.app.put('/db1', params='{}', + headers={'content-type': 'application/json'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'ok': True}, json.loads(resp.body)) + + def test_delete_database(self): + resp = self.app.delete('/db0') + self.assertEqual(200, resp.status) + self.assertRaises(errors.DatabaseDoesNotExist, + self.state.check_database, 'db0') + + def test_get_database(self): + resp = self.app.get('/db0') + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({}, json.loads(resp.body)) + + def test_valid_database_names(self): + resp = self.app.get('/a-database', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/db1', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/0', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/0-0', expect_errors=True) + self.assertEqual(404, resp.status) + + resp = self.app.get('/org.future', expect_errors=True) + self.assertEqual(404, resp.status) + + def test_invalid_database_names(self): + resp = self.app.get('/.a', expect_errors=True) + self.assertEqual(400, resp.status) + + resp = self.app.get('/-a', expect_errors=True) + self.assertEqual(400, resp.status) + + resp = self.app.get('/_a', expect_errors=True) + self.assertEqual(400, resp.status) + + def test_put_doc_create(self): + resp = self.app.put('/db0/doc/doc1', params='{"x": 1}', + headers={'content-type': 'application/json'}) + doc = self.db0.get_doc('doc1') + self.assertEqual(201, resp.status) # created + self.assertEqual('{"x": 1}', doc.get_json()) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) + + def test_put_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev, + params='{"x": 2}', + headers={'content-type': 'application/json'}) + doc = self.db0.get_doc('doc1') + self.assertEqual(200, resp.status) + self.assertEqual('{"x": 2}', doc.get_json()) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) + + def test_put_doc_too_large(self): + self.http_app.max_request_size = 15000 + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev, + params='{"%s": 2}' % ('z' * 16000), + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_delete_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.delete('/db0/doc/doc1?old_rev=%s' % doc.rev) + doc = self.db0.get_doc('doc1', include_deleted=True) + self.assertEqual(None, doc.content) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'rev': doc.rev}, json.loads(resp.body)) + + def test_get_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + resp = self.app.get('/db0/doc/%s' % doc.doc_id) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual('{"x": 1}', resp.body) + self.assertEqual(doc.rev, resp.header('x-u1db-rev')) + self.assertEqual('false', resp.header('x-u1db-has-conflicts')) + + def test_get_doc_non_existing(self): + resp = self.app.get('/db0/doc/not-there', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "document does not exist"}, json.loads(resp.body)) + self.assertEqual('', resp.header('x-u1db-rev')) + self.assertEqual('false', resp.header('x-u1db-has-conflicts')) + + def test_get_doc_deleted(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.db0.delete_doc(doc) + resp = self.app.get('/db0/doc/doc1', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": errors.DocumentDoesNotExist.wire_description}, + json.loads(resp.body)) + + def test_get_doc_deleted_explicit_exclude(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.db0.delete_doc(doc) + resp = self.app.get( + '/db0/doc/doc1?include_deleted=false', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": errors.DocumentDoesNotExist.wire_description}, + json.loads(resp.body)) + + def test_get_deleted_doc(self): + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.db0.delete_doc(doc) + resp = self.app.get( + '/db0/doc/doc1?include_deleted=true', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": errors.DOCUMENT_DELETED}, json.loads(resp.body)) + self.assertEqual(doc.rev, resp.header('x-u1db-rev')) + self.assertEqual('false', resp.header('x-u1db-has-conflicts')) + + def test_get_doc_non_existing_dabase(self): + resp = self.app.get('/not-there/doc/doc1', expect_errors=True) + self.assertEqual(404, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "database does not exist"}, json.loads(resp.body)) + + def test_get_docs(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", + "has_conflicts": False}, + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_docs_missing_doc_ids(self): + resp = self.app.get('/db0/docs', expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "missing document ids"}, json.loads(resp.body)) + + def test_get_docs_empty_doc_ids(self): + resp = self.app.get('/db0/docs?doc_ids=', expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual( + {"error": "missing document ids"}, json.loads(resp.body)) + + def test_get_docs_percent(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc%1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc%1", + "has_conflicts": False}, + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_docs_deleted(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + self.db0.delete_doc(doc2) + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_docs_include_deleted(self): + doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2') + self.db0.delete_doc(doc2) + ids = ','.join([doc1.doc_id, doc2.doc_id]) + resp = self.app.get('/db0/docs?doc_ids=%s&include_deleted=true' % ids) + self.assertEqual(200, resp.status) + self.assertEqual( + 'application/json', resp.header('content-type')) + expected = [ + {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1", + "has_conflicts": False}, + {"content": None, "doc_rev": "db0:2", "doc_id": "doc2", + "has_conflicts": False}] + self.assertEqual(expected, json.loads(resp.body)) + + def test_get_sync_info(self): + self.db0._set_replica_gen_and_trans_id('other-id', 1, 'T-transid') + resp = self.app.get('/db0/sync-from/other-id') + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual(dict(target_replica_uid='db0', + target_replica_generation=0, + target_replica_transaction_id='', + source_replica_uid='other-id', + source_replica_generation=1, + source_transaction_id='T-transid'), + json.loads(resp.body)) + + def test_record_sync_info(self): + resp = self.app.put('/db0/sync-from/other-id', + params='{"generation": 2, "transaction_id": "T-transid"}', + headers={'content-type': 'application/json'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({'ok': True}, json.loads(resp.body)) + self.assertEqual( + (2, 'T-transid'), + self.db0._get_replica_gen_and_trans_id('other-id')) + + def test_sync_exchange_send(self): + entries = { + 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': + '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'}, + 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content': + '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'} + } + + gens = [] + _do_set_replica_gen_and_trans_id = \ + self.db0._do_set_replica_gen_and_trans_id + + def set_sync_generation_witness(other_uid, other_gen, other_trans_id): + gens.append((other_uid, other_gen)) + _do_set_replica_gen_and_trans_id( + other_uid, other_gen, other_trans_id) + self.assertGetDoc(self.db0, entries[other_gen]['id'], + entries[other_gen]['rev'], + entries[other_gen]['content'], False) + + self.patch( + self.db0, '_do_set_replica_gen_and_trans_id', + set_sync_generation_witness) + + args = dict(last_known_generation=0) + body = ("[\r\n" + + "%s,\r\n" % json.dumps(args) + + "%s,\r\n" % json.dumps(entries[10]) + + "%s\r\n" % json.dumps(entries[11]) + + "]\r\n") + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + bits = resp.body.split('\r\n') + self.assertEqual('[', bits[0]) + last_trans_id = self.db0._get_transaction_log()[-1][1] + self.assertEqual({'new_generation': 2, + 'new_transaction_id': last_trans_id}, + json.loads(bits[1])) + self.assertEqual(']', bits[2]) + self.assertEqual('', bits[3]) + self.assertEqual([('replica', 10), ('replica', 11)], gens) + + def test_sync_exchange_send_ensure(self): + entries = { + 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': + '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'}, + 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content': + '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'} + } + + args = dict(last_known_generation=0, ensure=True) + body = ("[\r\n" + + "%s,\r\n" % json.dumps(args) + + "%s,\r\n" % json.dumps(entries[10]) + + "%s\r\n" % json.dumps(entries[11]) + + "]\r\n") + resp = self.app.post('/dbnew/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + bits = resp.body.split('\r\n') + self.assertEqual('[', bits[0]) + dbnew = self.state.open_database("dbnew") + last_trans_id = dbnew._get_transaction_log()[-1][1] + self.assertEqual({'new_generation': 2, + 'new_transaction_id': last_trans_id, + 'replica_uid': dbnew._replica_uid}, + json.loads(bits[1])) + self.assertEqual(']', bits[2]) + self.assertEqual('', bits[3]) + + def test_sync_exchange_send_entry_too_large(self): + self.patch(http_app.SyncResource, 'max_request_size', 20000) + self.patch(http_app.SyncResource, 'max_entry_size', 10000) + entries = { + 10: {'id': 'doc-here', 'rev': 'replica:1', 'content': + '{"value": "%s"}' % ('H' * 11000), 'gen': 10}, + } + args = dict(last_known_generation=0) + body = ("[\r\n" + + "%s,\r\n" % json.dumps(args) + + "%s\r\n" % json.dumps(entries[10]) + + "]\r\n") + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}, + expect_errors=True) + self.assertEqual(400, resp.status) + + def test_sync_exchange_receive(self): + doc = self.db0.create_doc_from_json('{"value": "there"}') + doc2 = self.db0.create_doc_from_json('{"value": "there2"}') + args = dict(last_known_generation=0) + body = "[\r\n%s\r\n]" % json.dumps(args) + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + parts = resp.body.splitlines() + self.assertEqual(5, len(parts)) + self.assertEqual('[', parts[0]) + last_trans_id = self.db0._get_transaction_log()[-1][1] + self.assertEqual({'new_generation': 2, + 'new_transaction_id': last_trans_id}, + json.loads(parts[1].rstrip(","))) + part2 = json.loads(parts[2].rstrip(",")) + self.assertTrue(part2['trans_id'].startswith('T-')) + self.assertEqual('{"value": "there"}', part2['content']) + self.assertEqual(doc.rev, part2['rev']) + self.assertEqual(doc.doc_id, part2['id']) + self.assertEqual(1, part2['gen']) + part3 = json.loads(parts[3].rstrip(",")) + self.assertTrue(part3['trans_id'].startswith('T-')) + self.assertEqual('{"value": "there2"}', part3['content']) + self.assertEqual(doc2.rev, part3['rev']) + self.assertEqual(doc2.doc_id, part3['id']) + self.assertEqual(2, part3['gen']) + self.assertEqual(']', parts[4]) + + def test_sync_exchange_error_in_stream(self): + args = dict(last_known_generation=0) + body = "[\r\n%s\r\n]" % json.dumps(args) + + def boom(self, return_doc_cb): + raise errors.Unavailable + + self.patch(sync.SyncExchange, 'return_docs', + boom) + resp = self.app.post('/db0/sync-from/replica', + params=body, + headers={'content-type': + 'application/x-u1db-sync-stream'}) + self.assertEqual(200, resp.status) + self.assertEqual('application/x-u1db-sync-stream', + resp.header('content-type')) + parts = resp.body.splitlines() + self.assertEqual(3, len(parts)) + self.assertEqual('[', parts[0]) + self.assertEqual({'new_generation': 0, 'new_transaction_id': ''}, + json.loads(parts[1].rstrip(","))) + self.assertEqual({'error': 'unavailable'}, json.loads(parts[2])) + + +class TestRequestHooks(tests.TestCase): + + def setUp(self): + super(TestRequestHooks, self).setUp() + self.state = tests.ServerStateForTests() + self.http_app = http_app.HTTPApp(self.state) + self.app = paste.fixture.TestApp(self.http_app) + self.db0 = self.state._create_database('db0') + + def test_begin_and_done(self): + calls = [] + + def begin(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('begin') + + def done(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('done') + + self.http_app.request_begin = begin + self.http_app.request_done = done + + doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1') + self.app.get('/db0/doc/%s' % doc.doc_id) + + self.assertEqual(['begin', 'done'], calls) + + def test_bad_request(self): + calls = [] + + def begin(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('begin') + + def bad_request(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('bad-request') + + self.http_app.request_begin = begin + self.http_app.request_bad_request = bad_request + # shouldn't be called + self.http_app.request_done = lambda env: 1 / 0 + + resp = self.app.put('/db0/foo/doc1', params='{"x": 1}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(400, resp.status) + self.assertEqual(['begin', 'bad-request'], calls) + + +class TestHTTPErrors(tests.TestCase): + + def test_wire_description_to_status(self): + self.assertNotIn("error", http_errors.wire_description_to_status) + + +class TestHTTPAppErrorHandling(tests.TestCase): + + def setUp(self): + super(TestHTTPAppErrorHandling, self).setUp() + self.exc = None + self.state = tests.ServerStateForTests() + + class ErroringResource(object): + + def post(_, args, content): + raise self.exc + + def lookup_resource(environ, responder): + return ErroringResource() + + self.http_app = http_app.HTTPApp(self.state) + self.http_app._lookup_resource = lookup_resource + self.app = paste.fixture.TestApp(self.http_app) + + def test_RevisionConflict_etc(self): + self.exc = errors.RevisionConflict() + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(409, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"error": "revision conflict"}, + json.loads(resp.body)) + + def test_Unavailable(self): + self.exc = errors.Unavailable + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(503, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"error": "unavailable"}, + json.loads(resp.body)) + + def test_generic_u1db_errors(self): + self.exc = errors.U1DBError() + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(500, resp.status) + self.assertEqual('application/json', resp.header('content-type')) + self.assertEqual({"error": "error"}, + json.loads(resp.body)) + + def test_generic_u1db_errors_hooks(self): + calls = [] + + def begin(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append('begin') + + def u1db_error(environ, exc): + self.assertTrue('PATH_INFO' in environ) + calls.append(('error', exc)) + + self.http_app.request_begin = begin + self.http_app.request_u1db_error = u1db_error + # shouldn't be called + self.http_app.request_done = lambda env: 1 / 0 + + self.exc = errors.U1DBError() + resp = self.app.post('/req', params='{}', + headers={'content-type': 'application/json'}, + expect_errors=True) + self.assertEqual(500, resp.status) + self.assertEqual(['begin', ('error', self.exc)], calls) + + def test_failure(self): + class Failure(Exception): + pass + self.exc = Failure() + self.assertRaises(Failure, self.app.post, '/req', params='{}', + headers={'content-type': 'application/json'}) + + def test_failure_hooks(self): + class Failure(Exception): + pass + calls = [] + + def begin(environ): + calls.append('begin') + + def failed(environ): + self.assertTrue('PATH_INFO' in environ) + calls.append(('failed', sys.exc_info())) + + self.http_app.request_begin = begin + self.http_app.request_failed = failed + # shouldn't be called + self.http_app.request_done = lambda env: 1 / 0 + + self.exc = Failure() + self.assertRaises(Failure, self.app.post, '/req', params='{}', + headers={'content-type': 'application/json'}) + + self.assertEqual(2, len(calls)) + self.assertEqual('begin', calls[0]) + marker, (exc_type, exc, tb) = calls[1] + self.assertEqual('failed', marker) + self.assertEqual(self.exc, exc) + + +class TestPluggableSyncExchange(tests.TestCase): + + def setUp(self): + super(TestPluggableSyncExchange, self).setUp() + self.state = tests.ServerStateForTests() + self.state.ensure_database('foo') + + def test_plugging(self): + + class MySyncExchange(object): + def __init__(self, db, source_replica_uid, last_known_generation): + pass + + class MySyncResource(http_app.SyncResource): + sync_exchange_class = MySyncExchange + + sync_res = MySyncResource('foo', 'src', self.state, None) + sync_res.post_args( + {'last_known_generation': 0, 'last_known_trans_id': None}, '{}') + self.assertIsInstance(sync_res.sync_exch, MySyncExchange) diff --git a/src/leap/soledad/u1db/tests/test_http_client.py b/src/leap/soledad/u1db/tests/test_http_client.py new file mode 100644 index 00000000..115c8aaa --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_http_client.py @@ -0,0 +1,361 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests for HTTPDatabase""" + +from oauth import oauth +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + errors, + tests, + ) +from u1db.remote import ( + http_client, + ) + + +class TestEncoder(tests.TestCase): + + def test_encode_string(self): + self.assertEqual("foo", http_client._encode_query_parameter("foo")) + + def test_encode_true(self): + self.assertEqual("true", http_client._encode_query_parameter(True)) + + def test_encode_false(self): + self.assertEqual("false", http_client._encode_query_parameter(False)) + + +class TestHTTPClientBase(tests.TestCaseWithServer): + + def setUp(self): + super(TestHTTPClientBase, self).setUp() + self.errors = 0 + + def app(self, environ, start_response): + if environ['PATH_INFO'].endswith('echo'): + start_response("200 OK", [('Content-Type', 'application/json')]) + ret = {} + for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'): + ret[name] = environ[name] + if environ['REQUEST_METHOD'] in ('PUT', 'POST'): + ret['CONTENT_TYPE'] = environ['CONTENT_TYPE'] + content_length = int(environ['CONTENT_LENGTH']) + ret['body'] = environ['wsgi.input'].read(content_length) + return [json.dumps(ret)] + elif environ['PATH_INFO'].endswith('error_then_accept'): + if self.errors >= 3: + start_response( + "200 OK", [('Content-Type', 'application/json')]) + ret = {} + for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'): + ret[name] = environ[name] + if environ['REQUEST_METHOD'] in ('PUT', 'POST'): + ret['CONTENT_TYPE'] = environ['CONTENT_TYPE'] + content_length = int(environ['CONTENT_LENGTH']) + ret['body'] = '{"oki": "doki"}' + return [json.dumps(ret)] + self.errors += 1 + content_length = int(environ['CONTENT_LENGTH']) + error = json.loads( + environ['wsgi.input'].read(content_length)) + response = error['response'] + # In debug mode, wsgiref has an assertion that the status parameter + # is a 'str' object. However error['status'] returns a unicode + # object. + status = str(error['status']) + if isinstance(response, unicode): + response = str(response) + if isinstance(response, str): + start_response(status, [('Content-Type', 'text/plain')]) + return [str(response)] + else: + start_response(status, [('Content-Type', 'application/json')]) + return [json.dumps(response)] + elif environ['PATH_INFO'].endswith('error'): + self.errors += 1 + content_length = int(environ['CONTENT_LENGTH']) + error = json.loads( + environ['wsgi.input'].read(content_length)) + response = error['response'] + # In debug mode, wsgiref has an assertion that the status parameter + # is a 'str' object. However error['status'] returns a unicode + # object. + status = str(error['status']) + if isinstance(response, unicode): + response = str(response) + if isinstance(response, str): + start_response(status, [('Content-Type', 'text/plain')]) + return [str(response)] + else: + start_response(status, [('Content-Type', 'application/json')]) + return [json.dumps(response)] + elif '/oauth' in environ['PATH_INFO']: + base_url = self.getURL('').rstrip('/') + oauth_req = oauth.OAuthRequest.from_request( + http_method=environ['REQUEST_METHOD'], + http_url=base_url + environ['PATH_INFO'], + headers={'Authorization': environ['HTTP_AUTHORIZATION']}, + query_string=environ['QUERY_STRING'] + ) + oauth_server = oauth.OAuthServer(tests.testingOAuthStore) + oauth_server.add_signature_method(tests.sign_meth_HMAC_SHA1) + try: + consumer, token, params = oauth_server.verify_request( + oauth_req) + except oauth.OAuthError, e: + start_response("401 Unauthorized", + [('Content-Type', 'application/json')]) + return [json.dumps({"error": "unauthorized", + "message": e.message})] + start_response("200 OK", [('Content-Type', 'application/json')]) + return [json.dumps([environ['PATH_INFO'], token.key, params])] + + def make_app(self): + return self.app + + def getClient(self, **kwds): + self.startServer() + return http_client.HTTPClientBase(self.getURL('dbase'), **kwds) + + def test_construct(self): + self.startServer() + url = self.getURL() + cli = http_client.HTTPClientBase(url) + self.assertEqual(url, cli._url.geturl()) + self.assertIs(None, cli._conn) + + def test_parse_url(self): + cli = http_client.HTTPClientBase( + '%s://127.0.0.1:12345/' % self.url_scheme) + self.assertEqual(self.url_scheme, cli._url.scheme) + self.assertEqual('127.0.0.1', cli._url.hostname) + self.assertEqual(12345, cli._url.port) + self.assertEqual('/', cli._url.path) + + def test__ensure_connection(self): + cli = self.getClient() + self.assertIs(None, cli._conn) + cli._ensure_connection() + self.assertIsNot(None, cli._conn) + conn = cli._conn + cli._ensure_connection() + self.assertIs(conn, cli._conn) + + def test_close(self): + cli = self.getClient() + cli._ensure_connection() + cli.close() + self.assertIs(None, cli._conn) + + def test__request(self): + cli = self.getClient() + res, headers = cli._request('PUT', ['echo'], {}, {}) + self.assertEqual({'CONTENT_TYPE': 'application/json', + 'PATH_INFO': '/dbase/echo', + 'QUERY_STRING': '', + 'body': '{}', + 'REQUEST_METHOD': 'PUT'}, json.loads(res)) + + res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1}) + self.assertEqual({'PATH_INFO': '/dbase/doc/echo', + 'QUERY_STRING': 'a=1', + 'REQUEST_METHOD': 'GET'}, json.loads(res)) + + res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1}) + self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo', + 'QUERY_STRING': 'a=1', + 'REQUEST_METHOD': 'GET'}, json.loads(res)) + + res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body', + 'application/x-test') + self.assertEqual({'CONTENT_TYPE': 'application/x-test', + 'PATH_INFO': '/dbase/echo', + 'QUERY_STRING': 'b=2', + 'body': 'Body', + 'REQUEST_METHOD': 'POST'}, json.loads(res)) + + def test__request_json(self): + cli = self.getClient() + res, headers = cli._request_json( + 'POST', ['echo'], {'b': 2}, {'a': 'x'}) + self.assertEqual('application/json', headers['content-type']) + self.assertEqual({'CONTENT_TYPE': 'application/json', + 'PATH_INFO': '/dbase/echo', + 'QUERY_STRING': 'b=2', + 'body': '{"a": "x"}', + 'REQUEST_METHOD': 'POST'}, res) + + def test_unspecified_http_error(self): + cli = self.getClient() + self.assertRaises(errors.HTTPError, + cli._request_json, 'POST', ['error'], {}, + {'status': "500 Internal Error", + 'response': "Crash."}) + try: + cli._request_json('POST', ['error'], {}, + {'status': "500 Internal Error", + 'response': "Fail."}) + except errors.HTTPError, e: + pass + + self.assertEqual(500, e.status) + self.assertEqual("Fail.", e.message) + self.assertTrue("content-type" in e.headers) + + def test_revision_conflict(self): + cli = self.getClient() + self.assertRaises(errors.RevisionConflict, + cli._request_json, 'POST', ['error'], {}, + {'status': "409 Conflict", + 'response': {"error": "revision conflict"}}) + + def test_unavailable_proper(self): + cli = self.getClient() + cli._delays = (0, 0, 0, 0, 0) + self.assertRaises(errors.Unavailable, + cli._request_json, 'POST', ['error'], {}, + {'status': "503 Service Unavailable", + 'response': {"error": "unavailable"}}) + self.assertEqual(5, self.errors) + + def test_unavailable_then_available(self): + cli = self.getClient() + cli._delays = (0, 0, 0, 0, 0) + res, headers = cli._request_json( + 'POST', ['error_then_accept'], {'b': 2}, + {'status': "503 Service Unavailable", + 'response': {"error": "unavailable"}}) + self.assertEqual('application/json', headers['content-type']) + self.assertEqual({'CONTENT_TYPE': 'application/json', + 'PATH_INFO': '/dbase/error_then_accept', + 'QUERY_STRING': 'b=2', + 'body': '{"oki": "doki"}', + 'REQUEST_METHOD': 'POST'}, res) + self.assertEqual(3, self.errors) + + def test_unavailable_random_source(self): + cli = self.getClient() + cli._delays = (0, 0, 0, 0, 0) + try: + cli._request_json('POST', ['error'], {}, + {'status': "503 Service Unavailable", + 'response': "random unavailable."}) + except errors.Unavailable, e: + pass + + self.assertEqual(503, e.status) + self.assertEqual("random unavailable.", e.message) + self.assertTrue("content-type" in e.headers) + self.assertEqual(5, self.errors) + + def test_document_too_big(self): + cli = self.getClient() + self.assertRaises(errors.DocumentTooBig, + cli._request_json, 'POST', ['error'], {}, + {'status': "403 Forbidden", + 'response': {"error": "document too big"}}) + + def test_user_quota_exceeded(self): + cli = self.getClient() + self.assertRaises(errors.UserQuotaExceeded, + cli._request_json, 'POST', ['error'], {}, + {'status': "403 Forbidden", + 'response': {"error": "user quota exceeded"}}) + + def test_user_needs_subscription(self): + cli = self.getClient() + self.assertRaises(errors.SubscriptionNeeded, + cli._request_json, 'POST', ['error'], {}, + {'status': "403 Forbidden", + 'response': {"error": "user needs subscription"}}) + + def test_generic_u1db_error(self): + cli = self.getClient() + self.assertRaises(errors.U1DBError, + cli._request_json, 'POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': {"error": "error"}}) + try: + cli._request_json('POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': {"error": "error"}}) + except errors.U1DBError, e: + pass + self.assertIs(e.__class__, errors.U1DBError) + + def test_unspecified_bad_request(self): + cli = self.getClient() + self.assertRaises(errors.HTTPError, + cli._request_json, 'POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': "<Bad Request>"}) + try: + cli._request_json('POST', ['error'], {}, + {'status': "400 Bad Request", + 'response': "<Bad Request>"}) + except errors.HTTPError, e: + pass + + self.assertEqual(400, e.status) + self.assertEqual("<Bad Request>", e.message) + self.assertTrue("content-type" in e.headers) + + def test_oauth(self): + cli = self.getClient() + cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + params = {'x': u'\xf0', 'y': "foo"} + res, headers = cli._request('GET', ['doc', 'oauth'], params) + self.assertEqual( + ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res)) + + # oauth does its own internal quoting + params = {'x': u'\xf0', 'y': "foo"} + res, headers = cli._request('GET', ['doc', 'oauth', 'foo bar'], params) + self.assertEqual( + ['/dbase/doc/oauth/foo bar', tests.token1.key, params], + json.loads(res)) + + def test_oauth_ctr_creds(self): + cli = self.getClient(creds={'oauth': { + 'consumer_key': tests.consumer1.key, + 'consumer_secret': tests.consumer1.secret, + 'token_key': tests.token1.key, + 'token_secret': tests.token1.secret, + }}) + params = {'x': u'\xf0', 'y': "foo"} + res, headers = cli._request('GET', ['doc', 'oauth'], params) + self.assertEqual( + ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res)) + + def test_unknown_creds(self): + self.assertRaises(errors.UnknownAuthMethod, + self.getClient, creds={'foo': {}}) + self.assertRaises(errors.UnknownAuthMethod, + self.getClient, creds={}) + + def test_oauth_Unauthorized(self): + cli = self.getClient() + cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, "WRONG") + params = {'y': 'foo'} + self.assertRaises(errors.Unauthorized, cli._request, 'GET', + ['doc', 'oauth'], params) diff --git a/src/leap/soledad/u1db/tests/test_http_database.py b/src/leap/soledad/u1db/tests/test_http_database.py new file mode 100644 index 00000000..c8e7eb76 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_http_database.py @@ -0,0 +1,256 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests for HTTPDatabase""" + +import inspect +try: + import simplejson as json +except ImportError: + import json # noqa + +from u1db import ( + errors, + Document, + tests, + ) +from u1db.remote import ( + http_database, + http_target, + ) +from u1db.tests.test_remote_sync_target import ( + make_http_app, +) + + +class TestHTTPDatabaseSimpleOperations(tests.TestCase): + + def setUp(self): + super(TestHTTPDatabaseSimpleOperations, self).setUp() + self.db = http_database.HTTPDatabase('dbase') + self.db._conn = object() # crash if used + self.got = None + self.response_val = None + + def _request(method, url_parts, params=None, body=None, + content_type=None): + self.got = method, url_parts, params, body, content_type + if isinstance(self.response_val, Exception): + raise self.response_val + return self.response_val + + def _request_json(method, url_parts, params=None, body=None, + content_type=None): + self.got = method, url_parts, params, body, content_type + if isinstance(self.response_val, Exception): + raise self.response_val + return self.response_val + + self.db._request = _request + self.db._request_json = _request_json + + def test__sanity_same_signature(self): + my_request_sig = inspect.getargspec(self.db._request) + my_request_sig = (['self'] + my_request_sig[0],) + my_request_sig[1:] + self.assertEqual(my_request_sig, + inspect.getargspec(http_database.HTTPDatabase._request)) + my_request_json_sig = inspect.getargspec(self.db._request_json) + my_request_json_sig = ((['self'] + my_request_json_sig[0],) + + my_request_json_sig[1:]) + self.assertEqual(my_request_json_sig, + inspect.getargspec(http_database.HTTPDatabase._request_json)) + + def test__ensure(self): + self.response_val = {'ok': True}, {} + self.db._ensure() + self.assertEqual(('PUT', [], {}, {}, None), self.got) + + def test__delete(self): + self.response_val = {'ok': True}, {} + self.db._delete() + self.assertEqual(('DELETE', [], {}, {}, None), self.got) + + def test__check(self): + self.response_val = {}, {} + res = self.db._check() + self.assertEqual({}, res) + self.assertEqual(('GET', [], None, None, None), self.got) + + def test_put_doc(self): + self.response_val = {'rev': 'doc-rev'}, {} + doc = Document('doc-id', None, '{"v": 1}') + res = self.db.put_doc(doc) + self.assertEqual('doc-rev', res) + self.assertEqual('doc-rev', doc.rev) + self.assertEqual(('PUT', ['doc', 'doc-id'], {}, + '{"v": 1}', 'application/json'), self.got) + + self.response_val = {'rev': 'doc-rev-2'}, {} + doc.content = {"v": 2} + res = self.db.put_doc(doc) + self.assertEqual('doc-rev-2', res) + self.assertEqual('doc-rev-2', doc.rev) + self.assertEqual(('PUT', ['doc', 'doc-id'], {'old_rev': 'doc-rev'}, + '{"v": 2}', 'application/json'), self.got) + + def test_get_doc(self): + self.response_val = '{"v": 2}', {'x-u1db-rev': 'doc-rev', + 'x-u1db-has-conflicts': 'false'} + self.assertGetDoc(self.db, 'doc-id', 'doc-rev', '{"v": 2}', False) + self.assertEqual( + ('GET', ['doc', 'doc-id'], {'include_deleted': False}, None, None), + self.got) + + def test_get_doc_non_existing(self): + self.response_val = errors.DocumentDoesNotExist() + self.assertIs(None, self.db.get_doc('not-there')) + self.assertEqual( + ('GET', ['doc', 'not-there'], {'include_deleted': False}, None, + None), self.got) + + def test_get_doc_deleted(self): + self.response_val = errors.DocumentDoesNotExist() + self.assertIs(None, self.db.get_doc('deleted')) + self.assertEqual( + ('GET', ['doc', 'deleted'], {'include_deleted': False}, None, + None), self.got) + + def test_get_doc_deleted_include_deleted(self): + self.response_val = errors.HTTPError(404, + json.dumps( + {"error": errors.DOCUMENT_DELETED} + ), + {'x-u1db-rev': 'doc-rev-gone', + 'x-u1db-has-conflicts': 'false'}) + doc = self.db.get_doc('deleted', include_deleted=True) + self.assertEqual('deleted', doc.doc_id) + self.assertEqual('doc-rev-gone', doc.rev) + self.assertIs(None, doc.content) + self.assertEqual( + ('GET', ['doc', 'deleted'], {'include_deleted': True}, None, None), + self.got) + + def test_get_doc_pass_through_errors(self): + self.response_val = errors.HTTPError(500, 'Crash.') + self.assertRaises(errors.HTTPError, + self.db.get_doc, 'something-something') + + def test_create_doc_with_id(self): + self.response_val = {'rev': 'doc-rev'}, {} + new_doc = self.db.create_doc_from_json('{"v": 1}', doc_id='doc-id') + self.assertEqual('doc-rev', new_doc.rev) + self.assertEqual('doc-id', new_doc.doc_id) + self.assertEqual('{"v": 1}', new_doc.get_json()) + self.assertEqual(('PUT', ['doc', 'doc-id'], {}, + '{"v": 1}', 'application/json'), self.got) + + def test_create_doc_without_id(self): + self.response_val = {'rev': 'doc-rev-2'}, {} + new_doc = self.db.create_doc_from_json('{"v": 3}') + self.assertEqual('D-', new_doc.doc_id[:2]) + self.assertEqual('doc-rev-2', new_doc.rev) + self.assertEqual('{"v": 3}', new_doc.get_json()) + self.assertEqual(('PUT', ['doc', new_doc.doc_id], {}, + '{"v": 3}', 'application/json'), self.got) + + def test_delete_doc(self): + self.response_val = {'rev': 'doc-rev-gone'}, {} + doc = Document('doc-id', 'doc-rev', None) + self.db.delete_doc(doc) + self.assertEqual('doc-rev-gone', doc.rev) + self.assertEqual(('DELETE', ['doc', 'doc-id'], {'old_rev': 'doc-rev'}, + None, None), self.got) + + def test_get_sync_target(self): + st = self.db.get_sync_target() + self.assertIsInstance(st, http_target.HTTPSyncTarget) + self.assertEqual(st._url, self.db._url) + + def test_get_sync_target_inherits_oauth_credentials(self): + self.db.set_oauth_credentials(tests.consumer1.key, + tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + st = self.db.get_sync_target() + self.assertEqual(self.db._creds, st._creds) + + +class TestHTTPDatabaseCtrWithCreds(tests.TestCase): + + def test_ctr_with_creds(self): + db1 = http_database.HTTPDatabase('http://dbs/db', creds={'oauth': { + 'consumer_key': tests.consumer1.key, + 'consumer_secret': tests.consumer1.secret, + 'token_key': tests.token1.key, + 'token_secret': tests.token1.secret + }}) + self.assertIn('oauth', db1._creds) + + +class TestHTTPDatabaseIntegration(tests.TestCaseWithServer): + + make_app_with_state = staticmethod(make_http_app) + + def setUp(self): + super(TestHTTPDatabaseIntegration, self).setUp() + self.startServer() + + def test_non_existing_db(self): + db = http_database.HTTPDatabase(self.getURL('not-there')) + self.assertRaises(errors.DatabaseDoesNotExist, db.get_doc, 'doc1') + + def test__ensure(self): + db = http_database.HTTPDatabase(self.getURL('new')) + db._ensure() + self.assertIs(None, db.get_doc('doc1')) + + def test__delete(self): + self.request_state._create_database('db0') + db = http_database.HTTPDatabase(self.getURL('db0')) + db._delete() + self.assertRaises(errors.DatabaseDoesNotExist, + self.request_state.check_database, 'db0') + + def test_open_database_existing(self): + self.request_state._create_database('db0') + db = http_database.HTTPDatabase.open_database(self.getURL('db0'), + create=False) + self.assertIs(None, db.get_doc('doc1')) + + def test_open_database_non_existing(self): + self.assertRaises(errors.DatabaseDoesNotExist, + http_database.HTTPDatabase.open_database, + self.getURL('not-there'), + create=False) + + def test_open_database_create(self): + db = http_database.HTTPDatabase.open_database(self.getURL('new'), + create=True) + self.assertIs(None, db.get_doc('doc1')) + + def test_delete_database_existing(self): + self.request_state._create_database('db0') + http_database.HTTPDatabase.delete_database(self.getURL('db0')) + self.assertRaises(errors.DatabaseDoesNotExist, + self.request_state.check_database, 'db0') + + def test_doc_ids_needing_quoting(self): + db0 = self.request_state._create_database('db0') + db = http_database.HTTPDatabase.open_database(self.getURL('db0'), + create=False) + doc = Document('%fff', None, '{}') + db.put_doc(doc) + self.assertGetDoc(db0, '%fff', doc.rev, '{}', False) + self.assertGetDoc(db, '%fff', doc.rev, '{}', False) diff --git a/src/leap/soledad/u1db/tests/test_https.py b/src/leap/soledad/u1db/tests/test_https.py new file mode 100644 index 00000000..67681c8a --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_https.py @@ -0,0 +1,117 @@ +"""Test support for client-side https support.""" + +import os +import ssl +import sys + +from paste import httpserver + +from u1db import ( + tests, + ) +from u1db.remote import ( + http_client, + http_target, + ) + +from u1db.tests.test_remote_sync_target import ( + make_oauth_http_app, + ) + + +def https_server_def(): + def make_server(host_port, application): + from OpenSSL import SSL + cert_file = os.path.join(os.path.dirname(__file__), 'testing-certs', + 'testing.cert') + key_file = os.path.join(os.path.dirname(__file__), 'testing-certs', + 'testing.key') + ssl_context = SSL.Context(SSL.SSLv23_METHOD) + ssl_context.use_privatekey_file(key_file) + ssl_context.use_certificate_chain_file(cert_file) + srv = httpserver.WSGIServerBase(application, host_port, + httpserver.WSGIHandler, + ssl_context=ssl_context + ) + + def shutdown_request(req): + req.shutdown() + srv.close_request(req) + + srv.shutdown_request = shutdown_request + application.base_url = "https://localhost:%s" % srv.server_address[1] + return srv + return make_server, "shutdown", "https" + + +def oauth_https_sync_target(test, host, path): + _, port = test.server.server_address + st = http_target.HTTPSyncTarget('https://%s:%d/~/%s' % (host, port, path)) + st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return st + + +class TestHttpSyncTargetHttpsSupport(tests.TestCaseWithServer): + + scenarios = [ + ('oauth_https', {'server_def': https_server_def, + 'make_app_with_state': make_oauth_http_app, + 'make_document_for_test': tests.make_document_for_test, + 'sync_target': oauth_https_sync_target + }), + ] + + def setUp(self): + try: + import OpenSSL # noqa + except ImportError: + self.skipTest("Requires pyOpenSSL") + self.cacert_pem = os.path.join(os.path.dirname(__file__), + 'testing-certs', 'cacert.pem') + super(TestHttpSyncTargetHttpsSupport, self).setUp() + + def getSyncTarget(self, host, path=None): + if self.server is None: + self.startServer() + return self.sync_target(self, host, path) + + def test_working(self): + self.startServer() + db = self.request_state._create_database('test') + self.patch(http_client, 'CA_CERTS', self.cacert_pem) + remote_target = self.getSyncTarget('localhost', 'test') + remote_target.record_sync_info('other-id', 2, 'T-id') + self.assertEqual( + (2, 'T-id'), db._get_replica_gen_and_trans_id('other-id')) + + def test_cannot_verify_cert(self): + if not sys.platform.startswith('linux'): + self.skipTest( + "XXX certificate verification happens on linux only for now") + self.startServer() + # don't print expected traceback server-side + self.server.handle_error = lambda req, cli_addr: None + self.request_state._create_database('test') + remote_target = self.getSyncTarget('localhost', 'test') + try: + remote_target.record_sync_info('other-id', 2, 'T-id') + except ssl.SSLError, e: + self.assertIn("certificate verify failed", str(e)) + else: + self.fail("certificate verification should have failed.") + + def test_host_mismatch(self): + if not sys.platform.startswith('linux'): + self.skipTest( + "XXX certificate verification happens on linux only for now") + self.startServer() + self.request_state._create_database('test') + self.patch(http_client, 'CA_CERTS', self.cacert_pem) + remote_target = self.getSyncTarget('127.0.0.1', 'test') + self.assertRaises( + http_client.CertificateError, remote_target.record_sync_info, + 'other-id', 2, 'T-id') + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_inmemory.py b/src/leap/soledad/u1db/tests/test_inmemory.py new file mode 100644 index 00000000..255a1e08 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_inmemory.py @@ -0,0 +1,128 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test in-memory backend internals.""" + +from u1db import ( + errors, + tests, + ) +from u1db.backends import inmemory + + +simple_doc = '{"key": "value"}' + + +class TestInMemoryDatabaseInternals(tests.TestCase): + + def setUp(self): + super(TestInMemoryDatabaseInternals, self).setUp() + self.db = inmemory.InMemoryDatabase('test') + + def test__allocate_doc_rev_from_None(self): + self.assertEqual('test:1', self.db._allocate_doc_rev(None)) + + def test__allocate_doc_rev_incremental(self): + self.assertEqual('test:2', self.db._allocate_doc_rev('test:1')) + + def test__allocate_doc_rev_other(self): + self.assertEqual('replica:1|test:1', + self.db._allocate_doc_rev('replica:1')) + + def test__get_replica_uid(self): + self.assertEqual('test', self.db._replica_uid) + + +class TestInMemoryIndex(tests.TestCase): + + def test_has_name_and_definition(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + self.assertEqual('idx-name', idx._name) + self.assertEqual(['key'], idx._definition) + + def test_evaluate_json(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + self.assertEqual(['value'], idx.evaluate_json(simple_doc)) + + def test_evaluate_json_field_None(self): + idx = inmemory.InMemoryIndex('idx-name', ['missing']) + self.assertEqual([], idx.evaluate_json(simple_doc)) + + def test_evaluate_json_subfield_None(self): + idx = inmemory.InMemoryIndex('idx-name', ['key', 'missing']) + self.assertEqual([], idx.evaluate_json(simple_doc)) + + def test_evaluate_multi_index(self): + doc = '{"key": "value", "key2": "value2"}' + idx = inmemory.InMemoryIndex('idx-name', ['key', 'key2']) + self.assertEqual(['value\x01value2'], + idx.evaluate_json(doc)) + + def test_update_ignores_None(self): + idx = inmemory.InMemoryIndex('idx-name', ['nokey']) + idx.add_json('doc-id', simple_doc) + self.assertEqual({}, idx._values) + + def test_update_adds_entry(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual({'value': ['doc-id']}, idx._values) + + def test_remove_json(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual({'value': ['doc-id']}, idx._values) + idx.remove_json('doc-id', simple_doc) + self.assertEqual({}, idx._values) + + def test_remove_json_multiple(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + idx.add_json('doc2-id', simple_doc) + self.assertEqual({'value': ['doc-id', 'doc2-id']}, idx._values) + idx.remove_json('doc-id', simple_doc) + self.assertEqual({'value': ['doc2-id']}, idx._values) + + def test_keys(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual(['value'], idx.keys()) + + def test_lookup(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + self.assertEqual(['doc-id'], idx.lookup(['value'])) + + def test_lookup_multi(self): + idx = inmemory.InMemoryIndex('idx-name', ['key']) + idx.add_json('doc-id', simple_doc) + idx.add_json('doc2-id', simple_doc) + self.assertEqual(['doc-id', 'doc2-id'], idx.lookup(['value'])) + + def test__find_non_wildcards(self): + idx = inmemory.InMemoryIndex('idx-name', ['k1', 'k2', 'k3']) + self.assertEqual(-1, idx._find_non_wildcards(('a', 'b', 'c'))) + self.assertEqual(2, idx._find_non_wildcards(('a', 'b', '*'))) + self.assertEqual(3, idx._find_non_wildcards(('a', 'b', 'c*'))) + self.assertEqual(2, idx._find_non_wildcards(('a', 'b*', '*'))) + self.assertEqual(0, idx._find_non_wildcards(('*', '*', '*'))) + self.assertEqual(1, idx._find_non_wildcards(('a*', '*', '*'))) + self.assertRaises(errors.InvalidValueForIndex, + idx._find_non_wildcards, ('a', 'b')) + self.assertRaises(errors.InvalidValueForIndex, + idx._find_non_wildcards, ('a', 'b', 'c', 'd')) + self.assertRaises(errors.InvalidGlobbing, + idx._find_non_wildcards, ('*', 'b', 'c')) diff --git a/src/leap/soledad/u1db/tests/test_open.py b/src/leap/soledad/u1db/tests/test_open.py new file mode 100644 index 00000000..fbeb0cfd --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_open.py @@ -0,0 +1,69 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test u1db.open""" + +import os + +from u1db import ( + errors, + open as u1db_open, + tests, + ) +from u1db.backends import sqlite_backend +from u1db.tests.test_backends import TestAlternativeDocument + + +class TestU1DBOpen(tests.TestCase): + + def setUp(self): + super(TestU1DBOpen, self).setUp() + tmpdir = self.createTempDir() + self.db_path = tmpdir + '/test.db' + + def test_open_no_create(self): + self.assertRaises(errors.DatabaseDoesNotExist, + u1db_open, self.db_path, create=False) + self.assertFalse(os.path.exists(self.db_path)) + + def test_open_create(self): + db = u1db_open(self.db_path, create=True) + self.addCleanup(db.close) + self.assertTrue(os.path.exists(self.db_path)) + self.assertIsInstance(db, sqlite_backend.SQLiteDatabase) + + def test_open_with_factory(self): + db = u1db_open(self.db_path, create=True, + document_factory=TestAlternativeDocument) + self.addCleanup(db.close) + self.assertEqual(TestAlternativeDocument, db._factory) + + def test_open_existing(self): + db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) + self.addCleanup(db.close) + doc = db.create_doc_from_json(tests.simple_doc) + # Even though create=True, we shouldn't wipe the db + db2 = u1db_open(self.db_path, create=True) + self.addCleanup(db2.close) + doc2 = db2.get_doc(doc.doc_id) + self.assertEqual(doc, doc2) + + def test_open_existing_no_create(self): + db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path) + self.addCleanup(db.close) + db2 = u1db_open(self.db_path, create=False) + self.addCleanup(db2.close) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) diff --git a/src/leap/soledad/u1db/tests/test_query_parser.py b/src/leap/soledad/u1db/tests/test_query_parser.py new file mode 100644 index 00000000..ee374267 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_query_parser.py @@ -0,0 +1,443 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +from u1db import ( + errors, + query_parser, + tests, + ) + + +trivial_raw_doc = {} + + +class TestFieldName(tests.TestCase): + + def test_check_fieldname_valid(self): + self.assertIsNone(query_parser.check_fieldname("foo")) + + def test_check_fieldname_invalid(self): + self.assertRaises( + errors.IndexDefinitionParseError, query_parser.check_fieldname, + "foo.") + + +class TestMakeTree(tests.TestCase): + + def setUp(self): + super(TestMakeTree, self).setUp() + self.parser = query_parser.Parser() + + def assertParseError(self, definition): + self.assertRaises( + errors.IndexDefinitionParseError, self.parser.parse, + definition) + + def test_single_field(self): + self.assertIsInstance( + self.parser.parse('f'), query_parser.ExtractField) + + def test_single_mapping(self): + self.assertIsInstance( + self.parser.parse('bool(field1)'), query_parser.Bool) + + def test_nested_mapping(self): + self.assertIsInstance( + self.parser.parse('lower(split_words(field1))'), + query_parser.Lower) + + def test_nested_branching_mapping(self): + self.assertIsInstance( + self.parser.parse( + 'combine(lower(field1), split_words(field2), ' + 'number(field3, 5))'), query_parser.Combine) + + def test_single_mapping_multiple_fields(self): + self.assertIsInstance( + self.parser.parse('number(field1, 5)'), query_parser.Number) + + def test_unknown_mapping(self): + self.assertParseError('mapping(whatever)') + + def test_parse_missing_close_paren(self): + self.assertParseError("lower(a") + + def test_parse_trailing_chars(self): + self.assertParseError("lower(ab))") + + def test_parse_empty_op(self): + self.assertParseError("(ab)") + + def test_parse_top_level_commas(self): + self.assertParseError("a, b") + + def test_invalid_field_name(self): + self.assertParseError("a.") + + def test_invalid_inner_field_name(self): + self.assertParseError("lower(a.)") + + def test_gobbledigook(self): + self.assertParseError("(@#@cc @#!*DFJSXV(()jccd") + + def test_leading_space(self): + self.assertIsInstance( + self.parser.parse(" lower(a)"), query_parser.Lower) + + def test_trailing_space(self): + self.assertIsInstance( + self.parser.parse("lower(a) "), query_parser.Lower) + + def test_spaces_before_open_paren(self): + self.assertIsInstance( + self.parser.parse("lower (a)"), query_parser.Lower) + + def test_spaces_after_open_paren(self): + self.assertIsInstance( + self.parser.parse("lower( a)"), query_parser.Lower) + + def test_spaces_before_close_paren(self): + self.assertIsInstance( + self.parser.parse("lower(a )"), query_parser.Lower) + + def test_spaces_before_comma(self): + self.assertIsInstance( + self.parser.parse("number(a , 5)"), query_parser.Number) + + def test_spaces_after_comma(self): + self.assertIsInstance( + self.parser.parse("number(a, 5)"), query_parser.Number) + + +class TestStaticGetter(tests.TestCase): + + def test_returns_string(self): + getter = query_parser.StaticGetter('foo') + self.assertEqual(['foo'], getter.get(trivial_raw_doc)) + + def test_returns_int(self): + getter = query_parser.StaticGetter(9) + self.assertEqual([9], getter.get(trivial_raw_doc)) + + def test_returns_float(self): + getter = query_parser.StaticGetter(9.2) + self.assertEqual([9.2], getter.get(trivial_raw_doc)) + + def test_returns_None(self): + getter = query_parser.StaticGetter(None) + self.assertEqual([], getter.get(trivial_raw_doc)) + + def test_returns_list(self): + getter = query_parser.StaticGetter(['a', 'b']) + self.assertEqual(['a', 'b'], getter.get(trivial_raw_doc)) + + +class TestExtractField(tests.TestCase): + + def assertExtractField(self, expected, field_name, raw_doc): + getter = query_parser.ExtractField(field_name) + self.assertEqual(expected, getter.get(raw_doc)) + + def test_get_value(self): + self.assertExtractField(['bar'], 'foo', {'foo': 'bar'}) + + def test_get_value_None(self): + self.assertExtractField([], 'foo', {'foo': None}) + + def test_get_value_missing_key(self): + self.assertExtractField([], 'foo', {}) + + def test_get_value_subfield(self): + self.assertExtractField(['bar'], 'foo.baz', {'foo': {'baz': 'bar'}}) + + def test_get_value_subfield_missing(self): + self.assertExtractField([], 'foo.baz', {'foo': 'bar'}) + + def test_get_value_dict(self): + self.assertExtractField([], 'foo', {'foo': {'baz': 'bar'}}) + + def test_get_value_list(self): + self.assertExtractField(['bar', 'zap'], 'foo', {'foo': ['bar', 'zap']}) + + def test_get_value_mixed_list(self): + self.assertExtractField(['bar', 'zap'], 'foo', + {'foo': ['bar', ['baa'], 'zap', {'bing': 9}]}) + + def test_get_value_list_of_dicts(self): + self.assertExtractField([], 'foo', {'foo': [{'zap': 'bar'}]}) + + def test_get_value_list_of_dicts2(self): + self.assertExtractField( + ['bar', 'baz'], 'foo.zap', + {'foo': [{'zap': 'bar'}, {'zap': 'baz'}]}) + + def test_get_value_int(self): + self.assertExtractField([9], 'foo', {'foo': 9}) + + def test_get_value_float(self): + self.assertExtractField([9.2], 'foo', {'foo': 9.2}) + + def test_get_value_bool(self): + self.assertExtractField([True], 'foo', {'foo': True}) + self.assertExtractField([False], 'foo', {'foo': False}) + + +class TestLower(tests.TestCase): + + def assertLowerGets(self, expected, input_val): + getter = query_parser.Lower(query_parser.StaticGetter(input_val)) + out_val = getter.get(trivial_raw_doc) + self.assertEqual(sorted(expected), sorted(out_val)) + + def test_inner_returns_None(self): + self.assertLowerGets([], None) + + def test_inner_returns_string(self): + self.assertLowerGets(['foo'], 'fOo') + + def test_inner_returns_list(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', 'bAr']) + + def test_inner_returns_int(self): + self.assertLowerGets([], 9) + + def test_inner_returns_float(self): + self.assertLowerGets([], 9.0) + + def test_inner_returns_bool(self): + self.assertLowerGets([], True) + + def test_inner_returns_list_containing_int(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', 9, 'bAr']) + + def test_inner_returns_list_containing_float(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', 9.2, 'bAr']) + + def test_inner_returns_list_containing_bool(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', True, 'bAr']) + + def test_inner_returns_list_containing_list(self): + # TODO: Should this be unfolding the inner list? + self.assertLowerGets(['foo', 'bar'], ['fOo', ['bAa'], 'bAr']) + + def test_inner_returns_list_containing_dict(self): + self.assertLowerGets(['foo', 'bar'], ['fOo', {'baa': 'xam'}, 'bAr']) + + +class TestSplitWords(tests.TestCase): + + def assertSplitWords(self, expected, value): + getter = query_parser.SplitWords(query_parser.StaticGetter(value)) + self.assertEqual(sorted(expected), sorted(getter.get(trivial_raw_doc))) + + def test_inner_returns_None(self): + self.assertSplitWords([], None) + + def test_inner_returns_string(self): + self.assertSplitWords(['foo', 'bar'], 'foo bar') + + def test_inner_returns_list(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', 'bar sux']) + + def test_deduplicates(self): + self.assertSplitWords(['bar'], ['bar', 'bar', 'bar']) + + def test_inner_returns_int(self): + self.assertSplitWords([], 9) + + def test_inner_returns_float(self): + self.assertSplitWords([], 9.2) + + def test_inner_returns_bool(self): + self.assertSplitWords([], True) + + def test_inner_returns_list_containing_int(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', 9, 'bar sux']) + + def test_inner_returns_list_containing_float(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', 9.2, 'bar sux']) + + def test_inner_returns_list_containing_bool(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', True, 'bar sux']) + + def test_inner_returns_list_containing_list(self): + # TODO: Expand sub-lists? + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', ['baa'], 'bar sux']) + + def test_inner_returns_list_containing_dict(self): + self.assertSplitWords(['foo', 'baz', 'bar', 'sux'], + ['foo baz', {'baa': 'xam'}, 'bar sux']) + + +class TestNumber(tests.TestCase): + + def assertNumber(self, expected, value, padding=5): + """Assert number transformation produced expected values.""" + getter = query_parser.Number(query_parser.StaticGetter(value), padding) + self.assertEqual(expected, getter.get(trivial_raw_doc)) + + def test_inner_returns_None(self): + """None is thrown away.""" + self.assertNumber([], None) + + def test_inner_returns_int(self): + """A single integer is converted to zero padded strings.""" + self.assertNumber(['00009'], 9) + + def test_inner_returns_list(self): + """Integers are converted to zero padded strings.""" + self.assertNumber(['00009', '00235'], [9, 235]) + + def test_inner_returns_string(self): + """A string is thrown away.""" + self.assertNumber([], 'foo bar') + + def test_inner_returns_float(self): + """A float is thrown away.""" + self.assertNumber([], 9.2) + + def test_inner_returns_bool(self): + """A boolean is thrown away.""" + self.assertNumber([], True) + + def test_inner_returns_list_containing_strings(self): + """Strings in a list are thrown away.""" + self.assertNumber(['00009'], ['foo baz', 9, 'bar sux']) + + def test_inner_returns_list_containing_float(self): + """Floats in a list are thrown away.""" + self.assertNumber( + ['00083', '00073'], [83, 9.2, 73]) + + def test_inner_returns_list_containing_bool(self): + """Booleans in a list are thrown away.""" + self.assertNumber( + ['00083', '00073'], [83, True, 73]) + + def test_inner_returns_list_containing_list(self): + """Lists in a list are thrown away.""" + # TODO: Expand sub-lists? + self.assertNumber( + ['00012', '03333'], [12, [29], 3333]) + + def test_inner_returns_list_containing_dict(self): + """Dicts in a list are thrown away.""" + self.assertNumber( + ['00012', '00001'], [12, {54: 89}, 1]) + + +class TestIsNull(tests.TestCase): + + def assertIsNull(self, value): + getter = query_parser.IsNull(query_parser.StaticGetter(value)) + self.assertEqual([True], getter.get(trivial_raw_doc)) + + def assertIsNotNull(self, value): + getter = query_parser.IsNull(query_parser.StaticGetter(value)) + self.assertEqual([False], getter.get(trivial_raw_doc)) + + def test_inner_returns_None(self): + self.assertIsNull(None) + + def test_inner_returns_string(self): + self.assertIsNotNull('foo') + + def test_inner_returns_list(self): + self.assertIsNotNull(['foo', 'bar']) + + def test_inner_returns_empty_list(self): + # TODO: is this the behavior we want? + self.assertIsNull([]) + + def test_inner_returns_int(self): + self.assertIsNotNull(9) + + def test_inner_returns_float(self): + self.assertIsNotNull(9.2) + + def test_inner_returns_bool(self): + self.assertIsNotNull(True) + + # TODO: What about a dict? Inner is likely to return None, even though the + # attribute does exist... + + +class TestParser(tests.TestCase): + + def parse(self, spec): + parser = query_parser.Parser() + return parser.parse(spec) + + def parse_all(self, specs): + parser = query_parser.Parser() + return parser.parse_all(specs) + + def assertParseError(self, definition): + self.assertRaises(errors.IndexDefinitionParseError, self.parse, + definition) + + def test_parse_empty_string(self): + self.assertRaises(errors.IndexDefinitionParseError, self.parse, "") + + def test_parse_field(self): + getter = self.parse("a") + self.assertIsInstance(getter, query_parser.ExtractField) + self.assertEqual(["a"], getter.field) + + def test_parse_dotted_field(self): + getter = self.parse("a.b") + self.assertIsInstance(getter, query_parser.ExtractField) + self.assertEqual(["a", "b"], getter.field) + + def test_parse_dotted_field_nothing_after_dot(self): + self.assertParseError("a.") + + def test_parse_missing_close_on_transformation(self): + self.assertParseError("lower(a") + + def test_parse_missing_field_in_transformation(self): + self.assertParseError("lower()") + + def test_parse_trailing_chars(self): + self.assertParseError("lower(ab))") + + def test_parse_empty_op(self): + self.assertParseError("(ab)") + + def test_parse_unknown_op(self): + self.assertParseError("no_such_operation(field)") + + def test_parse_wrong_arg_type(self): + self.assertParseError("number(field, fnord)") + + def test_parse_transformation(self): + getter = self.parse("lower(a)") + self.assertIsInstance(getter, query_parser.Lower) + self.assertIsInstance(getter.inner, query_parser.ExtractField) + self.assertEqual(["a"], getter.inner.field) + + def test_parse_all(self): + getters = self.parse_all(["a", "b"]) + self.assertEqual(2, len(getters)) + self.assertIsInstance(getters[0], query_parser.ExtractField) + self.assertEqual(["a"], getters[0].field) + self.assertIsInstance(getters[1], query_parser.ExtractField) + self.assertEqual(["b"], getters[1].field) diff --git a/src/leap/soledad/u1db/tests/test_remote_sync_target.py b/src/leap/soledad/u1db/tests/test_remote_sync_target.py new file mode 100644 index 00000000..3e0d8995 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_remote_sync_target.py @@ -0,0 +1,314 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests for the remote sync targets""" + +import cStringIO + +from u1db import ( + errors, + tests, + ) +from u1db.remote import ( + http_app, + http_target, + oauth_middleware, + ) + + +class TestHTTPSyncTargetBasics(tests.TestCase): + + def test_parse_url(self): + remote_target = http_target.HTTPSyncTarget('http://127.0.0.1:12345/') + self.assertEqual('http', remote_target._url.scheme) + self.assertEqual('127.0.0.1', remote_target._url.hostname) + self.assertEqual(12345, remote_target._url.port) + self.assertEqual('/', remote_target._url.path) + + +class TestParsingSyncStream(tests.TestCase): + + def test_wrong_start(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "{}\r\n]", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "\r\n{}\r\n]", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "", None) + + def test_wrong_end(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n{}", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n", None) + + def test_missing_comma(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, + '[\r\n{}\r\n{"id": "i", "rev": "r", ' + '"content": "c", "gen": 3}\r\n]', None) + + def test_no_entries(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n]", None) + + def test_extra_comma(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, "[\r\n{},\r\n]", None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, + '[\r\n{},\r\n{"id": "i", "rev": "r", ' + '"content": "{}", "gen": 3, "trans_id": "T-sid"}' + ',\r\n]', + lambda doc, gen, trans_id: None) + + def test_error_in_stream(self): + tgt = http_target.HTTPSyncTarget("http://foo/foo") + + self.assertRaises(errors.Unavailable, + tgt._parse_sync_stream, + '[\r\n{"new_generation": 0},' + '\r\n{"error": "unavailable"}\r\n', None) + + self.assertRaises(errors.Unavailable, + tgt._parse_sync_stream, + '[\r\n{"error": "unavailable"}\r\n', None) + + self.assertRaises(errors.BrokenSyncStream, + tgt._parse_sync_stream, + '[\r\n{"error": "?"}\r\n', None) + + +def make_http_app(state): + return http_app.HTTPApp(state) + + +def http_sync_target(test, path): + return http_target.HTTPSyncTarget(test.getURL(path)) + + +def make_oauth_http_app(state): + app = http_app.HTTPApp(state) + application = oauth_middleware.OAuthMiddleware(app, None, prefix='/~/') + application.get_oauth_data_store = lambda: tests.testingOAuthStore + return application + + +def oauth_http_sync_target(test, path): + st = http_sync_target(test, '~/' + path) + st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return st + + +class TestRemoteSyncTargets(tests.TestCaseWithServer): + + scenarios = [ + ('http', {'make_app_with_state': make_http_app, + 'make_document_for_test': tests.make_document_for_test, + 'sync_target': http_sync_target}), + ('oauth_http', {'make_app_with_state': make_oauth_http_app, + 'make_document_for_test': tests.make_document_for_test, + 'sync_target': oauth_http_sync_target}), + ] + + def getSyncTarget(self, path=None): + if self.server is None: + self.startServer() + return self.sync_target(self, path) + + def test_get_sync_info(self): + self.startServer() + db = self.request_state._create_database('test') + db._set_replica_gen_and_trans_id('other-id', 1, 'T-transid') + remote_target = self.getSyncTarget('test') + self.assertEqual(('test', 0, '', 1, 'T-transid'), + remote_target.get_sync_info('other-id')) + + def test_record_sync_info(self): + self.startServer() + db = self.request_state._create_database('test') + remote_target = self.getSyncTarget('test') + remote_target.record_sync_info('other-id', 2, 'T-transid') + self.assertEqual( + (2, 'T-transid'), db._get_replica_gen_and_trans_id('other-id')) + + def test_sync_exchange_send(self): + self.startServer() + db = self.request_state._create_database('test') + remote_target = self.getSyncTarget('test') + other_docs = [] + + def receive_doc(doc): + other_docs.append((doc.doc_id, doc.rev, doc.get_json())) + + doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}') + new_gen, trans_id = remote_target.sync_exchange( + [(doc, 10, 'T-sid')], 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=receive_doc) + self.assertEqual(1, new_gen) + self.assertGetDoc( + db, 'doc-here', 'replica:1', '{"value": "here"}', False) + + def test_sync_exchange_send_failure_and_retry_scenario(self): + self.startServer() + + def blackhole_getstderr(inst): + return cStringIO.StringIO() + + self.patch(self.server.RequestHandlerClass, 'get_stderr', + blackhole_getstderr) + db = self.request_state._create_database('test') + _put_doc_if_newer = db._put_doc_if_newer + trigger_ids = ['doc-here2'] + + def bomb_put_doc_if_newer(doc, save_conflict, + replica_uid=None, replica_gen=None, + replica_trans_id=None): + if doc.doc_id in trigger_ids: + raise Exception + return _put_doc_if_newer(doc, save_conflict=save_conflict, + replica_uid=replica_uid, replica_gen=replica_gen, + replica_trans_id=replica_trans_id) + self.patch(db, '_put_doc_if_newer', bomb_put_doc_if_newer) + remote_target = self.getSyncTarget('test') + other_changes = [] + + def receive_doc(doc, gen, trans_id): + other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + doc1 = self.make_document('doc-here', 'replica:1', '{"value": "here"}') + doc2 = self.make_document('doc-here2', 'replica:1', + '{"value": "here2"}') + self.assertRaises( + errors.HTTPError, + remote_target.sync_exchange, + [(doc1, 10, 'T-sid'), (doc2, 11, 'T-sud')], + 'replica', last_known_generation=0, last_known_trans_id=None, + return_doc_cb=receive_doc) + self.assertGetDoc(db, 'doc-here', 'replica:1', '{"value": "here"}', + False) + self.assertEqual( + (10, 'T-sid'), db._get_replica_gen_and_trans_id('replica')) + self.assertEqual([], other_changes) + # retry + trigger_ids = [] + new_gen, trans_id = remote_target.sync_exchange( + [(doc2, 11, 'T-sud')], 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=receive_doc) + self.assertGetDoc(db, 'doc-here2', 'replica:1', '{"value": "here2"}', + False) + self.assertEqual( + (11, 'T-sud'), db._get_replica_gen_and_trans_id('replica')) + self.assertEqual(2, new_gen) + # bounced back to us + self.assertEqual( + ('doc-here', 'replica:1', '{"value": "here"}', 1), + other_changes[0][:-1]) + + def test_sync_exchange_in_stream_error(self): + self.startServer() + + def blackhole_getstderr(inst): + return cStringIO.StringIO() + + self.patch(self.server.RequestHandlerClass, 'get_stderr', + blackhole_getstderr) + db = self.request_state._create_database('test') + doc = db.create_doc_from_json('{"value": "there"}') + + def bomb_get_docs(doc_ids, check_for_conflicts=None, + include_deleted=False): + yield doc + # delayed failure case + raise errors.Unavailable + + self.patch(db, 'get_docs', bomb_get_docs) + remote_target = self.getSyncTarget('test') + other_changes = [] + + def receive_doc(doc, gen, trans_id): + other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + self.assertRaises( + errors.Unavailable, remote_target.sync_exchange, [], 'replica', + last_known_generation=0, last_known_trans_id=None, + return_doc_cb=receive_doc) + self.assertEqual( + (doc.doc_id, doc.rev, '{"value": "there"}', 1), + other_changes[0][:-1]) + + def test_sync_exchange_receive(self): + self.startServer() + db = self.request_state._create_database('test') + doc = db.create_doc_from_json('{"value": "there"}') + remote_target = self.getSyncTarget('test') + other_changes = [] + + def receive_doc(doc, gen, trans_id): + other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + new_gen, trans_id = remote_target.sync_exchange( + [], 'replica', last_known_generation=0, last_known_trans_id=None, + return_doc_cb=receive_doc) + self.assertEqual(1, new_gen) + self.assertEqual( + (doc.doc_id, doc.rev, '{"value": "there"}', 1), + other_changes[0][:-1]) + + def test_sync_exchange_send_ensure_callback(self): + self.startServer() + remote_target = self.getSyncTarget('test') + other_docs = [] + replica_uid_box = [] + + def receive_doc(doc): + other_docs.append((doc.doc_id, doc.rev, doc.get_json())) + + def ensure_cb(replica_uid): + replica_uid_box.append(replica_uid) + + doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}') + new_gen, trans_id = remote_target.sync_exchange( + [(doc, 10, 'T-sid')], 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=receive_doc, + ensure_callback=ensure_cb) + self.assertEqual(1, new_gen) + db = self.request_state.open_database('test') + self.assertEqual(1, len(replica_uid_box)) + self.assertEqual(db._replica_uid, replica_uid_box[0]) + self.assertGetDoc( + db, 'doc-here', 'replica:1', '{"value": "here"}', False) + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_remote_utils.py b/src/leap/soledad/u1db/tests/test_remote_utils.py new file mode 100644 index 00000000..959cd882 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_remote_utils.py @@ -0,0 +1,36 @@ +# Copyright 2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests for protocol details utils.""" + +from u1db.tests import TestCase +from u1db.remote import utils + + +class TestUtils(TestCase): + + def test_check_and_strip_comma(self): + line, comma = utils.check_and_strip_comma("abc,") + self.assertTrue(comma) + self.assertEqual("abc", line) + + line, comma = utils.check_and_strip_comma("abc") + self.assertFalse(comma) + self.assertEqual("abc", line) + + line, comma = utils.check_and_strip_comma("") + self.assertFalse(comma) + self.assertEqual("", line) diff --git a/src/leap/soledad/u1db/tests/test_server_state.py b/src/leap/soledad/u1db/tests/test_server_state.py new file mode 100644 index 00000000..fc3f1282 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_server_state.py @@ -0,0 +1,93 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests for server state object.""" + +import os + +from u1db import ( + errors, + tests, + ) +from u1db.remote import ( + server_state, + ) +from u1db.backends import sqlite_backend + + +class TestServerState(tests.TestCase): + + def setUp(self): + super(TestServerState, self).setUp() + self.state = server_state.ServerState() + + def test_set_workingdir(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + self.assertTrue(self.state._relpath('path').startswith(tempdir)) + + def test_open_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + self.assertFalse(os.path.exists(path)) + # Create the db, but don't do anything with it + sqlite_backend.SQLitePartialExpandDatabase(path) + db = self.state.open_database('test.db') + self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase) + + def test_check_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + self.assertFalse(os.path.exists(path)) + + # doesn't exist => raises + self.assertRaises(errors.DatabaseDoesNotExist, + self.state.check_database, 'test.db') + + # Create the db, but don't do anything with it + sqlite_backend.SQLitePartialExpandDatabase(path) + # exists => returns + res = self.state.check_database('test.db') + self.assertIsNone(res) + + def test_ensure_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + self.assertFalse(os.path.exists(path)) + db, replica_uid = self.state.ensure_database('test.db') + self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase) + self.assertEqual(db._replica_uid, replica_uid) + self.assertTrue(os.path.exists(path)) + db2 = self.state.open_database('test.db') + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_delete_database(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + path = tempdir + '/test.db' + db, _ = self.state.ensure_database('test.db') + db.close() + self.state.delete_database('test.db') + self.assertFalse(os.path.exists(path)) + + def test_delete_database_DoesNotExist(self): + tempdir = self.createTempDir() + self.state.set_workingdir(tempdir) + self.assertRaises(errors.DatabaseDoesNotExist, + self.state.delete_database, 'test.db') diff --git a/src/leap/soledad/u1db/tests/test_sqlite_backend.py b/src/leap/soledad/u1db/tests/test_sqlite_backend.py new file mode 100644 index 00000000..73330789 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_sqlite_backend.py @@ -0,0 +1,493 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Test sqlite backend internals.""" + +import os +import time +import threading + +from sqlite3 import dbapi2 + +from u1db import ( + errors, + tests, + query_parser, + ) +from u1db.backends import sqlite_backend +from u1db.tests.test_backends import TestAlternativeDocument + + +simple_doc = '{"key": "value"}' +nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}' + + +class TestSQLiteDatabase(tests.TestCase): + + def test_atomic_initialize(self): + tmpdir = self.createTempDir() + dbname = os.path.join(tmpdir, 'atomic.db') + + t2 = None # will be a thread + + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + _index_storage_value = "testing" + + def __init__(self, dbname, ntry): + self._try = ntry + self._is_initialized_invocations = 0 + super(SQLiteDatabaseTesting, self).__init__(dbname) + + def _is_initialized(self, c): + res = super(SQLiteDatabaseTesting, self)._is_initialized(c) + if self._try == 1: + self._is_initialized_invocations += 1 + if self._is_initialized_invocations == 2: + t2.start() + # hard to do better and have a generic test + time.sleep(0.05) + return res + + outcome2 = [] + + def second_try(): + try: + db2 = SQLiteDatabaseTesting(dbname, 2) + except Exception, e: + outcome2.append(e) + else: + outcome2.append(db2) + + t2 = threading.Thread(target=second_try) + db1 = SQLiteDatabaseTesting(dbname, 1) + t2.join() + + self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting) + db2 = outcome2[0] + self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor())) + + +class TestSQLitePartialExpandDatabase(tests.TestCase): + + def setUp(self): + super(TestSQLitePartialExpandDatabase, self).setUp() + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.db._set_replica_uid('test') + + def test_create_database(self): + raw_db = self.db._get_sqlite_handle() + self.assertNotEqual(None, raw_db) + + def test_default_replica_uid(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.assertIsNot(None, self.db._replica_uid) + self.assertEqual(32, len(self.db._replica_uid)) + int(self.db._replica_uid, 16) + + def test__close_sqlite_handle(self): + raw_db = self.db._get_sqlite_handle() + self.db._close_sqlite_handle() + self.assertRaises(dbapi2.ProgrammingError, + raw_db.cursor) + + def test_create_database_initializes_schema(self): + raw_db = self.db._get_sqlite_handle() + c = raw_db.cursor() + c.execute("SELECT * FROM u1db_config") + config = dict([(r[0], r[1]) for r in c.fetchall()]) + self.assertEqual({'sql_schema': '0', 'replica_uid': 'test', + 'index_storage': 'expand referenced'}, config) + + # These tables must exist, though we don't care what is in them yet + c.execute("SELECT * FROM transaction_log") + c.execute("SELECT * FROM document") + c.execute("SELECT * FROM document_fields") + c.execute("SELECT * FROM sync_log") + c.execute("SELECT * FROM conflicts") + c.execute("SELECT * FROM index_definitions") + + def test__parse_index(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + g = self.db._parse_index_definition('fieldname') + self.assertIsInstance(g, query_parser.ExtractField) + self.assertEqual(['fieldname'], g.field) + + def test__update_indexes(self): + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + g = self.db._parse_index_definition('fieldname') + c = self.db._get_sqlite_handle().cursor() + self.db._update_indexes('doc-id', {'fieldname': 'val'}, + [('fieldname', g)], c) + c.execute('SELECT doc_id, field_name, value FROM document_fields') + self.assertEqual([('doc-id', 'fieldname', 'val')], + c.fetchall()) + + def test__set_replica_uid(self): + # Start from scratch, so that replica_uid isn't set. + self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:') + self.assertIsNot(None, self.db._real_replica_uid) + self.assertIsNot(None, self.db._replica_uid) + self.db._set_replica_uid('foo') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'") + self.assertEqual(('foo',), c.fetchone()) + self.assertEqual('foo', self.db._real_replica_uid) + self.assertEqual('foo', self.db._replica_uid) + self.db._close_sqlite_handle() + self.assertEqual('foo', self.db._replica_uid) + + def test__get_generation(self): + self.assertEqual(0, self.db._get_generation()) + + def test__get_generation_info(self): + self.assertEqual((0, ''), self.db._get_generation_info()) + + def test_create_index(self): + self.db.create_index('test-idx', "key") + self.assertEqual([('test-idx', ["key"])], self.db.list_indexes()) + + def test_create_index_multiple_fields(self): + self.db.create_index('test-idx', "key", "key2") + self.assertEqual([('test-idx', ["key", "key2"])], + self.db.list_indexes()) + + def test__get_index_definition(self): + self.db.create_index('test-idx', "key", "key2") + # TODO: How would you test that an index is getting used for an SQL + # request? + self.assertEqual(["key", "key2"], + self.db._get_index_definition('test-idx')) + + def test_list_index_mixed(self): + # Make sure that we properly order the output + c = self.db._get_sqlite_handle().cursor() + # We intentionally insert the data in weird ordering, to make sure the + # query still gets it back correctly. + c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)", + [('idx-1', 0, 'key10'), + ('idx-2', 2, 'key22'), + ('idx-1', 1, 'key11'), + ('idx-2', 0, 'key20'), + ('idx-2', 1, 'key21')]) + self.assertEqual([('idx-1', ['key10', 'key11']), + ('idx-2', ['key20', 'key21', 'key22'])], + self.db.list_indexes()) + + def test_no_indexes_no_document_fields(self): + self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + + def test_create_extracts_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}') + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([], c.fetchall()) + self.db.create_index('test', 'key1', 'key2') + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual(sorted( + [(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "val2"), + (doc2.doc_id, "key1", "valx"), + (doc2.doc_id, "key2", "valy"), + ]), sorted(c.fetchall())) + + def test_put_updates_fields(self): + self.db.create_index('test', 'key1', 'key2') + doc1 = self.db.create_doc_from_json( + '{"key1": "val1", "key2": "val2"}') + doc1.content = {"key1": "val1", "key2": "valy"} + self.db.put_doc(doc1) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key1", "val1"), + (doc1.doc_id, "key2", "valy"), + ], c.fetchall()) + + def test_put_updates_nested_fields(self): + self.db.create_index('test', 'key', 'sub.doc') + doc1 = self.db.create_doc_from_json(nested_doc) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, "key", "value"), + (doc1.doc_id, "sub.doc", "underneath"), + ], c.fetchall()) + + def test__ensure_schema_rollback(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/rollback.db' + + class SQLitePartialExpandDbTesting( + sqlite_backend.SQLitePartialExpandDatabase): + + def _set_replica_uid_in_transaction(self, uid): + super(SQLitePartialExpandDbTesting, + self)._set_replica_uid_in_transaction(uid) + if fail: + raise Exception() + + db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + fail = True + self.assertRaises(Exception, db._ensure_schema) + fail = False + db._initialize(db._db_handle.cursor()) + + def test__open_database(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/test.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase._open_database(path) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test__open_database_with_factory(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/test.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase._open_database( + path, document_factory=TestAlternativeDocument) + self.assertEqual(TestAlternativeDocument, db2._factory) + + def test__open_database_non_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase._open_database, path) + + def test__open_database_during_init(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/initialised.db' + db = sqlite_backend.SQLitePartialExpandDatabase.__new__( + sqlite_backend.SQLitePartialExpandDatabase) + db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed + self.addCleanup(db.close) + observed = [] + + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 + + @classmethod + def _which_index_storage(cls, c): + res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c) + db._ensure_schema() # init db + observed.append(res[0]) + return res + + db2 = SQLiteDatabaseTesting._open_database(path) + self.addCleanup(db2.close) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + self.assertEqual([None, + sqlite_backend.SQLitePartialExpandDatabase._index_storage_value], + observed) + + def test__open_database_invalid(self): + class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase): + WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1 + temp_dir = self.createTempDir(prefix='u1db-test-') + path1 = temp_dir + '/invalid1.db' + with open(path1, 'wb') as f: + f.write("") + self.assertRaises(dbapi2.OperationalError, + SQLiteDatabaseTesting._open_database, path1) + with open(path1, 'wb') as f: + f.write("invalid") + self.assertRaises(dbapi2.DatabaseError, + SQLiteDatabaseTesting._open_database, path1) + + def test_open_database_existing(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/existing.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_open_database_with_factory(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/existing.sqlite' + sqlite_backend.SQLitePartialExpandDatabase(path) + db2 = sqlite_backend.SQLiteDatabase.open_database( + path, create=False, document_factory=TestAlternativeDocument) + self.assertEqual(TestAlternativeDocument, db2._factory) + + def test_open_database_create(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/new.sqlite' + sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False) + self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase) + + def test_open_database_non_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.open_database, path, + create=False) + + def test_delete_database_existent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/new.sqlite' + db = sqlite_backend.SQLiteDatabase.open_database(path, create=True) + db.close() + sqlite_backend.SQLiteDatabase.delete_database(path) + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.open_database, path, + create=False) + + def test_delete_database_nonexistent(self): + temp_dir = self.createTempDir(prefix='u1db-test-') + path = temp_dir + '/non-existent.sqlite' + self.assertRaises(errors.DatabaseDoesNotExist, + sqlite_backend.SQLiteDatabase.delete_database, path) + + def test__get_indexed_fields(self): + self.db.create_index('idx1', 'a', 'b') + self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields()) + self.db.create_index('idx2', 'b', 'c') + self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields()) + + def test_indexed_fields_expanded(self): + self.db.create_index('idx1', 'key1') + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def test_create_index_updates_fields(self): + doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}') + self.db.create_index('idx1', 'key1') + self.assertEqual(set(['key1']), self.db._get_indexed_fields()) + c = self.db._get_sqlite_handle().cursor() + c.execute("SELECT doc_id, field_name, value FROM document_fields" + " ORDER BY doc_id, field_name, value") + self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall()) + + def assertFormatQueryEquals(self, exp_statement, exp_args, definition, + values): + statement, args = self.db._format_query(definition, values) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_query(self): + self.assertFormatQueryEquals( + "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM " + "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON " + "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name " + "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content " + "ORDER BY d0.value;", ["key1", "a"], + ["key1"], ["a"]) + + def test__format_query2(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b", "key3", "c"], + ["key1", "key2", "key3"], ["a", "b", "c"]) + + def test__format_query_wildcard(self): + self.assertFormatQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? ' + 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content ' + 'ORDER BY d0.value, d1.value, d2.value;', + ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"], + ["a", "b*", "*"]) + + def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition, + start_value, end_value): + statement, args = self.db._format_range_query( + definition, start_value, end_value) + self.assertEqual(exp_statement, statement) + self.assertEqual(exp_args, args) + + def test__format_range_query(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q', + 'key3', 'r'], + ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"]) + + def test__format_range_query_no_start(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], None, ["a", "b", "c"]) + + def test__format_range_query_no_end(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY ' + 'd0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'c'], + ["key1", "key2", "key3"], ["a", "b", "c"], None) + + def test__format_range_query_wildcard(self): + self.assertFormatRangeQueryEquals( + 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM ' + 'document d, document_fields d0, document_fields d1, ' + 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = ' + 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND ' + 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND ' + 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND ' + 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? ' + 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? ' + 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id ' + 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, ' + 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;', + ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*', + 'key3'], + ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"]) diff --git a/src/leap/soledad/u1db/tests/test_sync.py b/src/leap/soledad/u1db/tests/test_sync.py new file mode 100644 index 00000000..f2a925f0 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_sync.py @@ -0,0 +1,1285 @@ +# Copyright 2011-2012 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""The Synchronization class for U1DB.""" + +import os +from wsgiref import simple_server + +from u1db import ( + errors, + sync, + tests, + vectorclock, + SyncTarget, + ) +from u1db.backends import ( + inmemory, + ) +from u1db.remote import ( + http_target, + ) + +from u1db.tests.test_remote_sync_target import ( + make_http_app, + make_oauth_http_app, + ) + +simple_doc = tests.simple_doc +nested_doc = tests.nested_doc + + +def _make_local_db_and_target(test): + db = test.create_database('test') + st = db.get_sync_target() + return db, st + + +def _make_local_db_and_http_target(test, path='test'): + test.startServer() + db = test.request_state._create_database(os.path.basename(path)) + st = http_target.HTTPSyncTarget.connect(test.getURL(path)) + return db, st + + +def _make_c_db_and_c_http_target(test, path='test'): + test.startServer() + db = test.request_state._create_database(os.path.basename(path)) + url = test.getURL(path) + st = tests.c_backend_wrapper.create_http_sync_target(url) + return db, st + + +def _make_local_db_and_oauth_http_target(test): + db, st = _make_local_db_and_http_target(test, '~/test') + st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return db, st + + +def _make_c_db_and_oauth_http_target(test, path='~/test'): + test.startServer() + db = test.request_state._create_database(os.path.basename(path)) + url = test.getURL(path) + st = tests.c_backend_wrapper.create_oauth_http_sync_target(url, + tests.consumer1.key, tests.consumer1.secret, + tests.token1.key, tests.token1.secret) + return db, st + + +target_scenarios = [ + ('local', {'create_db_and_target': _make_local_db_and_target}), + ('http', {'create_db_and_target': _make_local_db_and_http_target, + 'make_app_with_state': make_http_app}), + ('oauth_http', {'create_db_and_target': + _make_local_db_and_oauth_http_target, + 'make_app_with_state': make_oauth_http_app}), + ] + +c_db_scenarios = [ + ('local,c', {'create_db_and_target': _make_local_db_and_target, + 'make_database_for_test': tests.make_c_database_for_test, + 'copy_database_for_test': tests.copy_c_database_for_test, + 'make_document_for_test': tests.make_c_document_for_test, + 'whitebox': False}), + ('http,c', {'create_db_and_target': _make_c_db_and_c_http_target, + 'make_database_for_test': tests.make_c_database_for_test, + 'copy_database_for_test': tests.copy_c_database_for_test, + 'make_document_for_test': tests.make_c_document_for_test, + 'make_app_with_state': make_http_app, + 'whitebox': False}), + ('oauth_http,c', {'create_db_and_target': _make_c_db_and_oauth_http_target, + 'make_database_for_test': tests.make_c_database_for_test, + 'copy_database_for_test': tests.copy_c_database_for_test, + 'make_document_for_test': tests.make_c_document_for_test, + 'make_app_with_state': make_oauth_http_app, + 'whitebox': False}), + ] + + +class DatabaseSyncTargetTests(tests.DatabaseBaseTests, + tests.TestCaseWithServer): + + scenarios = (tests.multiply_scenarios(tests.DatabaseBaseTests.scenarios, + target_scenarios) + + c_db_scenarios) + # whitebox true means self.db is the actual local db object + # against which the sync is performed + whitebox = True + + def setUp(self): + super(DatabaseSyncTargetTests, self).setUp() + self.db, self.st = self.create_db_and_target(self) + self.other_changes = [] + + def tearDown(self): + # We delete them explicitly, so that connections are cleanly closed + del self.st + self.db.close() + del self.db + super(DatabaseSyncTargetTests, self).tearDown() + + def receive_doc(self, doc, gen, trans_id): + self.other_changes.append( + (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id)) + + def set_trace_hook(self, callback, shallow=False): + setter = (self.st._set_trace_hook if not shallow else + self.st._set_trace_hook_shallow) + try: + setter(callback) + except NotImplementedError: + self.skipTest("%s does not implement _set_trace_hook" + % (self.st.__class__.__name__,)) + + def test_get_sync_target(self): + self.assertIsNot(None, self.st) + + def test_get_sync_info(self): + self.assertEqual( + ('test', 0, '', 0, ''), self.st.get_sync_info('other')) + + def test_create_doc_updates_sync_info(self): + self.assertEqual( + ('test', 0, '', 0, ''), self.st.get_sync_info('other')) + self.db.create_doc_from_json(simple_doc) + self.assertEqual(1, self.st.get_sync_info('other')[1]) + + def test_record_sync_info(self): + self.st.record_sync_info('replica', 10, 'T-transid') + self.assertEqual( + ('test', 0, '', 10, 'T-transid'), self.st.get_sync_info('replica')) + + def test_sync_exchange(self): + docs_by_gen = [ + (self.make_document('doc-id', 'replica:1', simple_doc), 10, + 'T-sid')] + new_gen, trans_id = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False) + self.assertTransactionLog(['doc-id'], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 1, last_trans_id), + (self.other_changes, new_gen, last_trans_id)) + self.assertEqual(10, self.st.get_sync_info('replica')[3]) + + def test_sync_exchange_deleted(self): + doc = self.db.create_doc_from_json('{}') + edit_rev = 'replica:1|' + doc.rev + docs_by_gen = [ + (self.make_document(doc.doc_id, edit_rev, None), 10, 'T-sid')] + new_gen, trans_id = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertGetDocIncludeDeleted( + self.db, doc.doc_id, edit_rev, None, False) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 2, last_trans_id), + (self.other_changes, new_gen, trans_id)) + self.assertEqual(10, self.st.get_sync_info('replica')[3]) + + def test_sync_exchange_push_many(self): + docs_by_gen = [ + (self.make_document('doc-id', 'replica:1', simple_doc), 10, 'T-1'), + (self.make_document('doc-id2', 'replica:1', nested_doc), 11, + 'T-2')] + new_gen, trans_id = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False) + self.assertGetDoc(self.db, 'doc-id2', 'replica:1', nested_doc, False) + self.assertTransactionLog(['doc-id', 'doc-id2'], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 2, last_trans_id), + (self.other_changes, new_gen, trans_id)) + self.assertEqual(11, self.st.get_sync_info('replica')[3]) + + def test_sync_exchange_refuses_conflicts(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'replica:1', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertEqual( + (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1]) + self.assertEqual(1, new_gen) + if self.whitebox: + self.assertEqual(self.db._last_exchange_log['return'], + {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]}) + + def test_sync_exchange_ignores_convergence(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + gen, txid = self.db._get_generation_info() + docs_by_gen = [ + (self.make_document(doc.doc_id, doc.rev, simple_doc), 10, 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'replica', last_known_generation=gen, + last_known_trans_id=txid, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertEqual(([], 1), (self.other_changes, new_gen)) + + def test_sync_exchange_returns_new_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_gen, _ = self.st.sync_exchange( + [], 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertEqual( + (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1]) + self.assertEqual(1, new_gen) + if self.whitebox: + self.assertEqual(self.db._last_exchange_log['return'], + {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]}) + + def test_sync_exchange_returns_deleted_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + self.db.delete_doc(doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + new_gen, _ = self.st.sync_exchange( + [], 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + self.assertEqual( + (doc.doc_id, doc.rev, None, 2), self.other_changes[0][:-1]) + self.assertEqual(2, new_gen) + if self.whitebox: + self.assertEqual(self.db._last_exchange_log['return'], + {'last_gen': 2, 'docs': [(doc.doc_id, doc.rev)]}) + + def test_sync_exchange_returns_many_new_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + doc2 = self.db.create_doc_from_json(nested_doc) + self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db) + new_gen, _ = self.st.sync_exchange( + [], 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db) + self.assertEqual(2, new_gen) + self.assertEqual( + [(doc.doc_id, doc.rev, simple_doc, 1), + (doc2.doc_id, doc2.rev, nested_doc, 2)], + [c[:-1] for c in self.other_changes]) + if self.whitebox: + self.assertEqual( + self.db._last_exchange_log['return'], + {'last_gen': 2, 'docs': + [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]}) + + def test_sync_exchange_getting_newer_docs(self): + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db) + self.assertEqual(([], 2), (self.other_changes, new_gen)) + + def test_sync_exchange_with_concurrent_updates_of_synced_doc(self): + expected = [] + + def before_whatschanged_cb(state): + if state != 'before whats_changed': + return + cont = '{"key": "cuncurrent"}' + conc_rev = self.db.put_doc( + self.make_document(doc.doc_id, 'test:1|z:2', cont)) + expected.append((doc.doc_id, conc_rev, cont, 3)) + + self.set_trace_hook(before_whatschanged_cb) + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertEqual(expected, [c[:-1] for c in self.other_changes]) + self.assertEqual(3, new_gen) + + def test_sync_exchange_with_concurrent_updates(self): + + def after_whatschanged_cb(state): + if state != 'after whats_changed': + return + self.db.create_doc_from_json('{"new": "doc"}') + + self.set_trace_hook(after_whatschanged_cb) + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + new_doc = '{"key": "altval"}' + docs_by_gen = [ + (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10, + 'T-sid')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertEqual(([], 2), (self.other_changes, new_gen)) + + def test_sync_exchange_converged_handling(self): + doc = self.db.create_doc_from_json(simple_doc) + docs_by_gen = [ + (self.make_document('new', 'other:1', '{}'), 4, 'T-foo'), + (self.make_document(doc.doc_id, doc.rev, doc.get_json()), 5, + 'T-bar')] + new_gen, _ = self.st.sync_exchange( + docs_by_gen, 'other-replica', last_known_generation=0, + last_known_trans_id=None, return_doc_cb=self.receive_doc) + self.assertEqual(([], 2), (self.other_changes, new_gen)) + + def test_sync_exchange_detect_incomplete_exchange(self): + def before_get_docs_explode(state): + if state != 'before get_docs': + return + raise errors.U1DBError("fail") + self.set_trace_hook(before_get_docs_explode) + # suppress traceback printing in the wsgiref server + self.patch(simple_server.ServerHandler, + 'log_exception', lambda h, exc_info: None) + doc = self.db.create_doc_from_json(simple_doc) + self.assertTransactionLog([doc.doc_id], self.db) + self.assertRaises( + (errors.U1DBError, errors.BrokenSyncStream), + self.st.sync_exchange, [], 'other-replica', + last_known_generation=0, last_known_trans_id=None, + return_doc_cb=self.receive_doc) + + def test_sync_exchange_doc_ids(self): + sync_exchange_doc_ids = getattr(self.st, 'sync_exchange_doc_ids', None) + if sync_exchange_doc_ids is None: + self.skipTest("sync_exchange_doc_ids not implemented") + db2 = self.create_database('test2') + doc = db2.create_doc_from_json(simple_doc) + new_gen, trans_id = sync_exchange_doc_ids( + db2, [(doc.doc_id, 10, 'T-sid')], 0, None, + return_doc_cb=self.receive_doc) + self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False) + self.assertTransactionLog([doc.doc_id], self.db) + last_trans_id = self.getLastTransId(self.db) + self.assertEqual(([], 1, last_trans_id), + (self.other_changes, new_gen, trans_id)) + self.assertEqual(10, self.st.get_sync_info(db2._replica_uid)[3]) + + def test__set_trace_hook(self): + called = [] + + def cb(state): + called.append(state) + + self.set_trace_hook(cb) + self.st.sync_exchange([], 'replica', 0, None, self.receive_doc) + self.st.record_sync_info('replica', 0, 'T-sid') + self.assertEqual(['before whats_changed', + 'after whats_changed', + 'before get_docs', + 'record_sync_info', + ], + called) + + def test__set_trace_hook_shallow(self): + if (self.st._set_trace_hook_shallow == self.st._set_trace_hook + or self.st._set_trace_hook_shallow.im_func == + SyncTarget._set_trace_hook_shallow.im_func): + # shallow same as full + expected = ['before whats_changed', + 'after whats_changed', + 'before get_docs', + 'record_sync_info', + ] + else: + expected = ['sync_exchange', 'record_sync_info'] + + called = [] + + def cb(state): + called.append(state) + + self.set_trace_hook(cb, shallow=True) + self.st.sync_exchange([], 'replica', 0, None, self.receive_doc) + self.st.record_sync_info('replica', 0, 'T-sid') + self.assertEqual(expected, called) + + +def sync_via_synchronizer(test, db_source, db_target, trace_hook=None, + trace_hook_shallow=None): + target = db_target.get_sync_target() + trace_hook = trace_hook or trace_hook_shallow + if trace_hook: + target._set_trace_hook(trace_hook) + return sync.Synchronizer(db_source, target).sync() + + +sync_scenarios = [] +for name, scenario in tests.LOCAL_DATABASES_SCENARIOS: + scenario = dict(scenario) + scenario['do_sync'] = sync_via_synchronizer + sync_scenarios.append((name, scenario)) + scenario = dict(scenario) + + +def make_database_for_http_test(test, replica_uid): + if test.server is None: + test.startServer() + db = test.request_state._create_database(replica_uid) + try: + http_at = test._http_at + except AttributeError: + http_at = test._http_at = {} + http_at[db] = replica_uid + return db + + +def copy_database_for_http_test(test, db): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS + # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE + # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN + # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR HOUSE. + if test.server is None: + test.startServer() + new_db = test.request_state._copy_database(db) + try: + http_at = test._http_at + except AttributeError: + http_at = test._http_at = {} + path = db._replica_uid + while path in http_at.values(): + path += 'copy' + http_at[new_db] = path + return new_db + + +def sync_via_synchronizer_and_http(test, db_source, db_target, + trace_hook=None, trace_hook_shallow=None): + if trace_hook: + test.skipTest("full trace hook unsupported over http") + path = test._http_at[db_target] + target = http_target.HTTPSyncTarget.connect(test.getURL(path)) + if trace_hook_shallow: + target._set_trace_hook_shallow(trace_hook_shallow) + return sync.Synchronizer(db_source, target).sync() + + +sync_scenarios.append(('pyhttp', { + 'make_database_for_test': make_database_for_http_test, + 'copy_database_for_test': copy_database_for_http_test, + 'make_document_for_test': tests.make_document_for_test, + 'make_app_with_state': make_http_app, + 'do_sync': sync_via_synchronizer_and_http + })) + + +if tests.c_backend_wrapper is not None: + # TODO: We should hook up sync tests with an HTTP target + def sync_via_c_sync(test, db_source, db_target, trace_hook=None, + trace_hook_shallow=None): + target = db_target.get_sync_target() + trace_hook = trace_hook or trace_hook_shallow + if trace_hook: + target._set_trace_hook(trace_hook) + return tests.c_backend_wrapper.sync_db_to_target(db_source, target) + + for name, scenario in tests.C_DATABASE_SCENARIOS: + scenario = dict(scenario) + scenario['do_sync'] = sync_via_synchronizer + sync_scenarios.append((name + ',pysync', scenario)) + scenario = dict(scenario) + scenario['do_sync'] = sync_via_c_sync + sync_scenarios.append((name + ',csync', scenario)) + + +class DatabaseSyncTests(tests.DatabaseBaseTests, + tests.TestCaseWithServer): + + scenarios = sync_scenarios + do_sync = None # set by scenarios + + def create_database(self, replica_uid, sync_role=None): + if replica_uid == 'test' and sync_role is None: + # created up the chain by base class but unused + return None + db = self.create_database_for_role(replica_uid, sync_role) + if sync_role: + self._use_tracking[db] = (replica_uid, sync_role) + return db + + def create_database_for_role(self, replica_uid, sync_role): + # hook point for reuse + return super(DatabaseSyncTests, self).create_database(replica_uid) + + def copy_database(self, db, sync_role=None): + # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES + # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST + # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS + # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND + # NINJA TO YOUR HOUSE. + db_copy = super(DatabaseSyncTests, self).copy_database(db) + name, orig_sync_role = self._use_tracking[db] + self._use_tracking[db_copy] = (name + '(copy)', sync_role + or orig_sync_role) + return db_copy + + def sync(self, db_from, db_to, trace_hook=None, + trace_hook_shallow=None): + from_name, from_sync_role = self._use_tracking[db_from] + to_name, to_sync_role = self._use_tracking[db_to] + if from_sync_role not in ('source', 'both'): + raise Exception("%s marked for %s use but used as source" % + (from_name, from_sync_role)) + if to_sync_role not in ('target', 'both'): + raise Exception("%s marked for %s use but used as target" % + (to_name, to_sync_role)) + return self.do_sync(self, db_from, db_to, trace_hook, + trace_hook_shallow) + + def setUp(self): + self._use_tracking = {} + super(DatabaseSyncTests, self).setUp() + + def assertLastExchangeLog(self, db, expected): + log = getattr(db, '_last_exchange_log', None) + if log is None: + return + self.assertEqual(expected, log) + + def test_sync_tracks_db_generation_of_other(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.assertEqual(0, self.sync(self.db1, self.db2)) + self.assertEqual( + (0, ''), self.db1._get_replica_gen_and_trans_id('test2')) + self.assertEqual( + (0, ''), self.db2._get_replica_gen_and_trans_id('test1')) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [], 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 0}}) + + def test_sync_autoresolves(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc1 = self.db1.create_doc_from_json(simple_doc, doc_id='doc') + rev1 = doc1.rev + doc2 = self.db2.create_doc_from_json(simple_doc, doc_id='doc') + rev2 = doc2.rev + self.sync(self.db1, self.db2) + doc = self.db1.get_doc('doc') + self.assertFalse(doc.has_conflicts) + self.assertEqual(doc.rev, self.db2.get_doc('doc').rev) + v = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev1))) + self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev2))) + + def test_sync_autoresolves_moar(self): + # here we test that when a database that has a conflicted document is + # the source of a sync, and the target database has a revision of the + # conflicted document that is newer than the source database's, and + # that target's database's document's content is the same as the + # source's document's conflict's, the source's document's conflict gets + # autoresolved, and the source's document's revision bumped. + # + # idea is as follows: + # A B + # a1 - + # `-------> + # a1 a1 + # v v + # a2 a1b1 + # `-------> + # a1b1+a2 a1b1 + # v + # a1b1+a2 a1b2 (a1b2 has same content as a2) + # `-------> + # a3b2 a1b2 (autoresolved) + # `-------> + # a3b2 a3b2 + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(simple_doc, doc_id='doc') + self.sync(self.db1, self.db2) + for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]: + doc = db.get_doc('doc') + doc.set_json(content) + db.put_doc(doc) + self.sync(self.db1, self.db2) + # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict + doc = self.db1.get_doc('doc') + rev1 = doc.rev + self.assertTrue(doc.has_conflicts) + # set db2 to have a doc of {} (same as db1 before the conflict) + doc = self.db2.get_doc('doc') + doc.set_json('{}') + self.db2.put_doc(doc) + rev2 = doc.rev + # sync it across + self.sync(self.db1, self.db2) + # tadaa! + doc = self.db1.get_doc('doc') + self.assertFalse(doc.has_conflicts) + vec1 = vectorclock.VectorClockRev(rev1) + vec2 = vectorclock.VectorClockRev(rev2) + vec3 = vectorclock.VectorClockRev(doc.rev) + self.assertTrue(vec3.is_newer(vec1)) + self.assertTrue(vec3.is_newer(vec2)) + # because the conflict is on the source, sync it another time + self.sync(self.db1, self.db2) + # make sure db2 now has the exact same thing + self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) + + def test_sync_autoresolves_moar_backwards(self): + # here we test that when a database that has a conflicted document is + # the target of a sync, and the source database has a revision of the + # conflicted document that is newer than the target database's, and + # that source's database's document's content is the same as the + # target's document's conflict's, the target's document's conflict gets + # autoresolved, and the document's revision bumped. + # + # idea is as follows: + # A B + # a1 - + # `-------> + # a1 a1 + # v v + # a2 a1b1 + # `-------> + # a1b1+a2 a1b1 + # v + # a1b1+a2 a1b2 (a1b2 has same content as a2) + # <-------' + # a3b2 a3b2 (autoresolved and propagated) + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'both') + self.db1.create_doc_from_json(simple_doc, doc_id='doc') + self.sync(self.db1, self.db2) + for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]: + doc = db.get_doc('doc') + doc.set_json(content) + db.put_doc(doc) + self.sync(self.db1, self.db2) + # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict + doc = self.db1.get_doc('doc') + rev1 = doc.rev + self.assertTrue(doc.has_conflicts) + revc = self.db1.get_doc_conflicts('doc')[-1].rev + # set db2 to have a doc of {} (same as db1 before the conflict) + doc = self.db2.get_doc('doc') + doc.set_json('{}') + self.db2.put_doc(doc) + rev2 = doc.rev + # sync it across + self.sync(self.db2, self.db1) + # tadaa! + doc = self.db1.get_doc('doc') + self.assertFalse(doc.has_conflicts) + vec1 = vectorclock.VectorClockRev(rev1) + vec2 = vectorclock.VectorClockRev(rev2) + vec3 = vectorclock.VectorClockRev(doc.rev) + vecc = vectorclock.VectorClockRev(revc) + self.assertTrue(vec3.is_newer(vec1)) + self.assertTrue(vec3.is_newer(vec2)) + self.assertTrue(vec3.is_newer(vecc)) + # make sure db2 now has the exact same thing + self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) + + def test_sync_autoresolves_moar_backwards_three(self): + # same as autoresolves_moar_backwards, but with three databases (note + # all the syncs go in the same direction -- this is a more natural + # scenario): + # + # A B C + # a1 - - + # `-------> + # a1 a1 - + # `-------> + # a1 a1 a1 + # v v + # a2 a1b1 a1 + # `-------------------> + # a2 a1b1 a2 + # `-------> + # a2+a1b1 a2 + # v + # a2 a2+a1b1 a2c1 (same as a1b1) + # `-------------------> + # a2c1 a2+a1b1 a2c1 + # `-------> + # a2b2c1 a2b2c1 a2c1 + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'both') + self.db3 = self.create_database('test3', 'target') + self.db1.create_doc_from_json(simple_doc, doc_id='doc') + self.sync(self.db1, self.db2) + self.sync(self.db2, self.db3) + for db, content in [(self.db2, '{"hi": 42}'), + (self.db1, '{}'), + ]: + doc = db.get_doc('doc') + doc.set_json(content) + db.put_doc(doc) + self.sync(self.db1, self.db3) + self.sync(self.db2, self.db3) + # db2 and db3 now both have a doc of {}, but db2 has a + # conflict + doc = self.db2.get_doc('doc') + self.assertTrue(doc.has_conflicts) + revc = self.db2.get_doc_conflicts('doc')[-1].rev + self.assertEqual('{}', doc.get_json()) + self.assertEqual(self.db3.get_doc('doc').get_json(), doc.get_json()) + self.assertEqual(self.db3.get_doc('doc').rev, doc.rev) + # set db3 to have a doc of {hi:42} (same as db2 before the conflict) + doc = self.db3.get_doc('doc') + doc.set_json('{"hi": 42}') + self.db3.put_doc(doc) + rev3 = doc.rev + # sync it across to db1 + self.sync(self.db1, self.db3) + # db1 now has hi:42, with a rev that is newer than db2's doc + doc = self.db1.get_doc('doc') + rev1 = doc.rev + self.assertFalse(doc.has_conflicts) + self.assertEqual('{"hi": 42}', doc.get_json()) + VCR = vectorclock.VectorClockRev + self.assertTrue(VCR(rev1).is_newer(VCR(self.db2.get_doc('doc').rev))) + # so sync it to db2 + self.sync(self.db1, self.db2) + # tadaa! + doc = self.db2.get_doc('doc') + self.assertFalse(doc.has_conflicts) + # db2's revision of the document is strictly newer than db1's before + # the sync, and db3's before that sync way back when + self.assertTrue(VCR(doc.rev).is_newer(VCR(rev1))) + self.assertTrue(VCR(doc.rev).is_newer(VCR(rev3))) + self.assertTrue(VCR(doc.rev).is_newer(VCR(revc))) + # make sure both dbs now have the exact same thing + self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc')) + + def test_sync_puts_changes(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db1.create_doc_from_json(simple_doc) + self.assertEqual(1, self.sync(self.db1, self.db2)) + self.assertGetDoc(self.db2, doc.doc_id, doc.rev, simple_doc, False) + self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) + self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0]) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc.doc_id, doc.rev)], + 'source_uid': 'test1', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 1}}) + + def test_sync_pulls_changes(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db2.create_doc_from_json(simple_doc) + self.db1.create_index('test-idx', 'key') + self.assertEqual(0, self.sync(self.db1, self.db2)) + self.assertGetDoc(self.db1, doc.doc_id, doc.rev, simple_doc, False) + self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) + self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0]) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [], 'last_known_gen': 0}, + 'return': {'docs': [(doc.doc_id, doc.rev)], + 'last_gen': 1}}) + self.assertEqual([doc], self.db1.get_from_index('test-idx', 'value')) + + def test_sync_pulling_doesnt_update_other_if_changed(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db2.create_doc_from_json(simple_doc) + # After the local side has sent its list of docs, before we start + # receiving the "targets" response, we update the local database with a + # new record. + # When we finish synchronizing, we can notice that something locally + # was updated, and we cannot tell c2 our new updated generation + + def before_get_docs(state): + if state != 'before get_docs': + return + self.db1.create_doc_from_json(simple_doc) + + self.assertEqual(0, self.sync(self.db1, self.db2, + trace_hook=before_get_docs)) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [], 'last_known_gen': 0}, + 'return': {'docs': [(doc.doc_id, doc.rev)], + 'last_gen': 1}}) + self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0]) + # c2 should not have gotten a '_record_sync_info' call, because the + # local database had been updated more than just by the messages + # returned from c2. + self.assertEqual( + (0, ''), self.db2._get_replica_gen_and_trans_id('test1')) + + def test_sync_doesnt_update_other_if_nothing_pulled(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(simple_doc) + + def no_record_sync_info(state): + if state != 'record_sync_info': + return + self.fail('SyncTarget.record_sync_info was called') + self.assertEqual(1, self.sync(self.db1, self.db2, + trace_hook_shallow=no_record_sync_info)) + self.assertEqual( + 1, + self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)[0]) + + def test_sync_ignores_convergence(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'both') + doc = self.db1.create_doc_from_json(simple_doc) + self.db3 = self.create_database('test3', 'target') + self.assertEqual(1, self.sync(self.db1, self.db3)) + self.assertEqual(0, self.sync(self.db2, self.db3)) + self.assertEqual(1, self.sync(self.db1, self.db2)) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc.doc_id, doc.rev)], + 'source_uid': 'test1', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 1}}) + + def test_sync_ignores_superseded(self): + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'both') + doc = self.db1.create_doc_from_json(simple_doc) + doc_rev1 = doc.rev + self.db3 = self.create_database('test3', 'target') + self.sync(self.db1, self.db3) + self.sync(self.db2, self.db3) + new_content = '{"key": "altval"}' + doc.set_json(new_content) + self.db1.put_doc(doc) + doc_rev2 = doc.rev + self.sync(self.db2, self.db1) + self.assertLastExchangeLog(self.db1, + {'receive': {'docs': [(doc.doc_id, doc_rev1)], + 'source_uid': 'test2', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [(doc.doc_id, doc_rev2)], + 'last_gen': 2}}) + self.assertGetDoc(self.db1, doc.doc_id, doc_rev2, new_content, False) + + def test_sync_sees_remote_conflicted(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc1 = self.db1.create_doc_from_json(simple_doc) + doc_id = doc1.doc_id + doc1_rev = doc1.rev + self.db1.create_index('test-idx', 'key') + new_doc = '{"key": "altval"}' + doc2 = self.db2.create_doc_from_json(new_doc, doc_id=doc_id) + doc2_rev = doc2.rev + self.assertTransactionLog([doc1.doc_id], self.db1) + self.sync(self.db1, self.db2) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc_id, doc1_rev)], + 'source_uid': 'test1', + 'source_gen': 1, 'last_known_gen': 0}, + 'return': {'docs': [(doc_id, doc2_rev)], + 'last_gen': 1}}) + self.assertTransactionLog([doc_id, doc_id], self.db1) + self.assertGetDoc(self.db1, doc_id, doc2_rev, new_doc, True) + self.assertGetDoc(self.db2, doc_id, doc2_rev, new_doc, False) + from_idx = self.db1.get_from_index('test-idx', 'altval')[0] + self.assertEqual(doc2.doc_id, from_idx.doc_id) + self.assertEqual(doc2.rev, from_idx.rev) + self.assertTrue(from_idx.has_conflicts) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + + def test_sync_sees_remote_delete_conflicted(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc1 = self.db1.create_doc_from_json(simple_doc) + doc_id = doc1.doc_id + self.db1.create_index('test-idx', 'key') + self.sync(self.db1, self.db2) + doc2 = self.make_document(doc1.doc_id, doc1.rev, doc1.get_json()) + new_doc = '{"key": "altval"}' + doc1.set_json(new_doc) + self.db1.put_doc(doc1) + self.db2.delete_doc(doc2) + self.assertTransactionLog([doc_id, doc_id], self.db1) + self.sync(self.db1, self.db2) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc_id, doc1.rev)], + 'source_uid': 'test1', + 'source_gen': 2, 'last_known_gen': 1}, + 'return': {'docs': [(doc_id, doc2.rev)], + 'last_gen': 2}}) + self.assertTransactionLog([doc_id, doc_id, doc_id], self.db1) + self.assertGetDocIncludeDeleted(self.db1, doc_id, doc2.rev, None, True) + self.assertGetDocIncludeDeleted( + self.db2, doc_id, doc2.rev, None, False) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + + def test_sync_local_race_conflicted(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + doc = self.db1.create_doc_from_json(simple_doc) + doc_id = doc.doc_id + doc1_rev = doc.rev + self.db1.create_index('test-idx', 'key') + self.sync(self.db1, self.db2) + content1 = '{"key": "localval"}' + content2 = '{"key": "altval"}' + doc.set_json(content2) + self.db2.put_doc(doc) + doc2_rev2 = doc.rev + triggered = [] + + def after_whatschanged(state): + if state != 'after whats_changed': + return + triggered.append(True) + doc = self.make_document(doc_id, doc1_rev, content1) + self.db1.put_doc(doc) + + self.sync(self.db1, self.db2, trace_hook=after_whatschanged) + self.assertEqual([True], triggered) + self.assertGetDoc(self.db1, doc_id, doc2_rev2, content2, True) + from_idx = self.db1.get_from_index('test-idx', 'altval')[0] + self.assertEqual(doc.doc_id, from_idx.doc_id) + self.assertEqual(doc.rev, from_idx.rev) + self.assertTrue(from_idx.has_conflicts) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + self.assertEqual([], self.db1.get_from_index('test-idx', 'localval')) + + def test_sync_propagates_deletes(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'both') + doc1 = self.db1.create_doc_from_json(simple_doc) + doc_id = doc1.doc_id + self.db1.create_index('test-idx', 'key') + self.sync(self.db1, self.db2) + self.db2.create_index('test-idx', 'key') + self.db3 = self.create_database('test3', 'target') + self.sync(self.db1, self.db3) + self.db1.delete_doc(doc1) + deleted_rev = doc1.rev + self.sync(self.db1, self.db2) + self.assertLastExchangeLog(self.db2, + {'receive': {'docs': [(doc_id, deleted_rev)], + 'source_uid': 'test1', + 'source_gen': 2, 'last_known_gen': 1}, + 'return': {'docs': [], 'last_gen': 2}}) + self.assertGetDocIncludeDeleted( + self.db1, doc_id, deleted_rev, None, False) + self.assertGetDocIncludeDeleted( + self.db2, doc_id, deleted_rev, None, False) + self.assertEqual([], self.db1.get_from_index('test-idx', 'value')) + self.assertEqual([], self.db2.get_from_index('test-idx', 'value')) + self.sync(self.db2, self.db3) + self.assertLastExchangeLog(self.db3, + {'receive': {'docs': [(doc_id, deleted_rev)], + 'source_uid': 'test2', + 'source_gen': 2, 'last_known_gen': 0}, + 'return': {'docs': [], 'last_gen': 2}}) + self.assertGetDocIncludeDeleted( + self.db3, doc_id, deleted_rev, None, False) + + def test_sync_propagates_resolution(self): + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'both') + doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc') + db3 = self.create_database('test3', 'both') + self.sync(self.db2, self.db1) + self.assertEqual( + self.db1._get_generation_info(), + self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)) + self.assertEqual( + self.db2._get_generation_info(), + self.db1._get_replica_gen_and_trans_id(self.db2._replica_uid)) + self.sync(db3, self.db1) + # update on 2 + doc2 = self.make_document('the-doc', doc1.rev, '{"a": 2}') + self.db2.put_doc(doc2) + self.sync(self.db2, db3) + self.assertEqual(db3.get_doc('the-doc').rev, doc2.rev) + # update on 1 + doc1.set_json('{"a": 3}') + self.db1.put_doc(doc1) + # conflicts + self.sync(self.db2, self.db1) + self.sync(db3, self.db1) + self.assertTrue(self.db2.get_doc('the-doc').has_conflicts) + self.assertTrue(db3.get_doc('the-doc').has_conflicts) + # resolve + conflicts = self.db2.get_doc_conflicts('the-doc') + doc4 = self.make_document('the-doc', None, '{"a": 4}') + revs = [doc.rev for doc in conflicts] + self.db2.resolve_doc(doc4, revs) + doc2 = self.db2.get_doc('the-doc') + self.assertEqual(doc4.get_json(), doc2.get_json()) + self.assertFalse(doc2.has_conflicts) + self.sync(self.db2, db3) + doc3 = db3.get_doc('the-doc') + self.assertEqual(doc4.get_json(), doc3.get_json()) + self.assertFalse(doc3.has_conflicts) + + def test_sync_supersedes_conflicts(self): + self.db1 = self.create_database('test1', 'both') + self.db2 = self.create_database('test2', 'target') + db3 = self.create_database('test3', 'both') + doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc') + self.db2.create_doc_from_json('{"b": 1}', doc_id='the-doc') + db3.create_doc_from_json('{"c": 1}', doc_id='the-doc') + self.sync(db3, self.db1) + self.assertEqual( + self.db1._get_generation_info(), + db3._get_replica_gen_and_trans_id(self.db1._replica_uid)) + self.assertEqual( + db3._get_generation_info(), + self.db1._get_replica_gen_and_trans_id(db3._replica_uid)) + self.sync(db3, self.db2) + self.assertEqual( + self.db2._get_generation_info(), + db3._get_replica_gen_and_trans_id(self.db2._replica_uid)) + self.assertEqual( + db3._get_generation_info(), + self.db2._get_replica_gen_and_trans_id(db3._replica_uid)) + self.assertEqual(3, len(db3.get_doc_conflicts('the-doc'))) + doc1.set_json('{"a": 2}') + self.db1.put_doc(doc1) + self.sync(db3, self.db1) + # original doc1 should have been removed from conflicts + self.assertEqual(3, len(db3.get_doc_conflicts('the-doc'))) + + def test_sync_stops_after_get_sync_info(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc) + self.sync(self.db1, self.db2) + + def put_hook(state): + self.fail("Tracehook triggered for %s" % (state,)) + + self.sync(self.db1, self.db2, trace_hook_shallow=put_hook) + + def test_sync_detects_rollback_in_source(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1') + self.sync(self.db1, self.db2) + db1_copy = self.copy_database(self.db1) + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidGeneration, self.sync, db1_copy, self.db2) + + def test_sync_detects_rollback_in_target(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + db2_copy = self.copy_database(self.db2) + self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidGeneration, self.sync, self.db1, db2_copy) + + def test_sync_detects_diverged_source(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + db3 = self.copy_database(self.db1) + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + db3.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidTransactionId, self.sync, db3, self.db2) + + def test_sync_detects_diverged_target(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + db3 = self.copy_database(self.db2) + db3.create_doc_from_json(tests.nested_doc, doc_id="divergent") + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + self.assertRaises( + errors.InvalidTransactionId, self.sync, self.db1, db3) + + def test_sync_detects_rollback_and_divergence_in_source(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1') + self.sync(self.db1, self.db2) + db1_copy = self.copy_database(self.db1) + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.sync(self.db1, self.db2) + db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2') + db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.assertRaises( + errors.InvalidTransactionId, self.sync, db1_copy, self.db2) + + def test_sync_detects_rollback_and_divergence_in_target(self): + self.db1 = self.create_database('test1', 'source') + self.db2 = self.create_database('test2', 'target') + self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent") + self.sync(self.db1, self.db2) + db2_copy = self.copy_database(self.db2) + self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2') + self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.sync(self.db1, self.db2) + db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2') + db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3') + self.assertRaises( + errors.InvalidTransactionId, self.sync, self.db1, db2_copy) + + +class TestDbSync(tests.TestCaseWithServer): + """Test db.sync remote sync shortcut""" + + scenarios = [ + ('py-http', { + 'make_app_with_state': make_http_app, + 'make_database_for_test': tests.make_memory_database_for_test, + }), + ('c-http', { + 'make_app_with_state': make_http_app, + 'make_database_for_test': tests.make_c_database_for_test + }), + ('py-oauth-http', { + 'make_app_with_state': make_oauth_http_app, + 'make_database_for_test': tests.make_memory_database_for_test, + 'oauth': True + }), + ('c-oauth-http', { + 'make_app_with_state': make_oauth_http_app, + 'make_database_for_test': tests.make_c_database_for_test, + 'oauth': True + }), + ] + + oauth = False + + def do_sync(self, target_name): + if self.oauth: + path = '~/' + target_name + extra = dict(creds={'oauth': { + 'consumer_key': tests.consumer1.key, + 'consumer_secret': tests.consumer1.secret, + 'token_key': tests.token1.key, + 'token_secret': tests.token1.secret + }}) + else: + path = target_name + extra = {} + target_url = self.getURL(path) + return self.db.sync(target_url, **extra) + + def setUp(self): + super(TestDbSync, self).setUp() + self.startServer() + self.db = self.make_database_for_test(self, 'test1') + self.db2 = self.request_state._create_database('test2.db') + + def test_db_sync(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + doc2 = self.db2.create_doc_from_json(tests.nested_doc) + local_gen_before_sync = self.do_sync('test2.db') + gen, _, changes = self.db.whats_changed(local_gen_before_sync) + self.assertEqual(1, len(changes)) + self.assertEqual(doc2.doc_id, changes[0][0]) + self.assertEqual(1, gen - local_gen_before_sync) + self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc, + False) + self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc, + False) + + def test_db_sync_autocreate(self): + doc1 = self.db.create_doc_from_json(tests.simple_doc) + local_gen_before_sync = self.do_sync('test3.db') + gen, _, changes = self.db.whats_changed(local_gen_before_sync) + self.assertEqual(0, gen - local_gen_before_sync) + db3 = self.request_state.open_database('test3.db') + gen, _, changes = db3.whats_changed() + self.assertEqual(1, len(changes)) + self.assertEqual(doc1.doc_id, changes[0][0]) + self.assertGetDoc(db3, doc1.doc_id, doc1.rev, tests.simple_doc, + False) + t_gen, _ = self.db._get_replica_gen_and_trans_id('test3.db') + s_gen, _ = db3._get_replica_gen_and_trans_id('test1') + self.assertEqual(1, t_gen) + self.assertEqual(1, s_gen) + + +class TestRemoteSyncIntegration(tests.TestCaseWithServer): + """Integration tests for the most common sync scenario local -> remote""" + + make_app_with_state = staticmethod(make_http_app) + + def setUp(self): + super(TestRemoteSyncIntegration, self).setUp() + self.startServer() + self.db1 = inmemory.InMemoryDatabase('test1') + self.db2 = self.request_state._create_database('test2') + + def test_sync_tracks_generations_incrementally(self): + doc11 = self.db1.create_doc_from_json('{"a": 1}') + doc12 = self.db1.create_doc_from_json('{"a": 2}') + doc21 = self.db2.create_doc_from_json('{"b": 1}') + doc22 = self.db2.create_doc_from_json('{"b": 2}') + #sanity + self.assertEqual(2, len(self.db1._get_transaction_log())) + self.assertEqual(2, len(self.db2._get_transaction_log())) + progress1 = [] + progress2 = [] + _do_set_replica_gen_and_trans_id = \ + self.db1._do_set_replica_gen_and_trans_id + + def set_sync_generation_witness1(other_uid, other_gen, trans_id): + progress1.append((other_uid, other_gen, + [d for d, t in self.db1._get_transaction_log()[2:]])) + _do_set_replica_gen_and_trans_id(other_uid, other_gen, trans_id) + self.patch(self.db1, '_do_set_replica_gen_and_trans_id', + set_sync_generation_witness1) + _do_set_replica_gen_and_trans_id2 = \ + self.db2._do_set_replica_gen_and_trans_id + + def set_sync_generation_witness2(other_uid, other_gen, trans_id): + progress2.append((other_uid, other_gen, + [d for d, t in self.db2._get_transaction_log()[2:]])) + _do_set_replica_gen_and_trans_id2(other_uid, other_gen, trans_id) + self.patch(self.db2, '_do_set_replica_gen_and_trans_id', + set_sync_generation_witness2) + + db2_url = self.getURL('test2') + self.db1.sync(db2_url) + + self.assertEqual([('test2', 1, [doc21.doc_id]), + ('test2', 2, [doc21.doc_id, doc22.doc_id]), + ('test2', 4, [doc21.doc_id, doc22.doc_id])], + progress1) + self.assertEqual([('test1', 1, [doc11.doc_id]), + ('test1', 2, [doc11.doc_id, doc12.doc_id]), + ('test1', 4, [doc11.doc_id, doc12.doc_id])], + progress2) + + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/test_test_infrastructure.py b/src/leap/soledad/u1db/tests/test_test_infrastructure.py new file mode 100644 index 00000000..b79e0516 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_test_infrastructure.py @@ -0,0 +1,41 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""Tests for test infrastructure bits""" + +from wsgiref import simple_server + +from u1db import ( + tests, + ) + + +class TestTestCaseWithServer(tests.TestCaseWithServer): + + def make_app(self): + return "app" + + @staticmethod + def server_def(): + def make_server(host_port, application): + assert application == "app" + return simple_server.WSGIServer(host_port, None) + return (make_server, "shutdown", "http") + + def test_getURL(self): + self.startServer() + url = self.getURL() + self.assertTrue(url.startswith('http://127.0.0.1:')) diff --git a/src/leap/soledad/u1db/tests/test_vectorclock.py b/src/leap/soledad/u1db/tests/test_vectorclock.py new file mode 100644 index 00000000..72baf246 --- /dev/null +++ b/src/leap/soledad/u1db/tests/test_vectorclock.py @@ -0,0 +1,121 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""VectorClockRev helper class tests.""" + +from u1db import tests, vectorclock + +try: + from u1db.tests import c_backend_wrapper +except ImportError: + c_backend_wrapper = None + + +c_vectorclock_scenarios = [] +if c_backend_wrapper is not None: + c_vectorclock_scenarios.append( + ('c', {'create_vcr': c_backend_wrapper.VectorClockRev})) + + +class TestVectorClockRev(tests.TestCase): + + scenarios = [('py', {'create_vcr': vectorclock.VectorClockRev}) + ] + c_vectorclock_scenarios + + def assertIsNewer(self, newer_rev, older_rev): + new_vcr = self.create_vcr(newer_rev) + old_vcr = self.create_vcr(older_rev) + self.assertTrue(new_vcr.is_newer(old_vcr)) + self.assertFalse(old_vcr.is_newer(new_vcr)) + + def assertIsConflicted(self, rev_a, rev_b): + vcr_a = self.create_vcr(rev_a) + vcr_b = self.create_vcr(rev_b) + self.assertFalse(vcr_a.is_newer(vcr_b)) + self.assertFalse(vcr_b.is_newer(vcr_a)) + + def assertRoundTrips(self, rev): + self.assertEqual(rev, self.create_vcr(rev).as_str()) + + def test__is_newer_doc_rev(self): + self.assertIsNewer('test:1', None) + self.assertIsNewer('test:2', 'test:1') + self.assertIsNewer('other:2|test:1', 'other:1|test:1') + self.assertIsNewer('other:1|test:1', 'other:1') + self.assertIsNewer('a:2|b:1', 'b:1') + self.assertIsNewer('a:1|b:2', 'a:1') + self.assertIsConflicted('other:2|test:1', 'other:1|test:2') + self.assertIsConflicted('other:1|test:1', 'other:2') + self.assertIsConflicted('test:1', 'test:1') + + def test_None(self): + vcr = self.create_vcr(None) + self.assertEqual('', vcr.as_str()) + + def test_round_trips(self): + self.assertRoundTrips('test:1') + self.assertRoundTrips('a:1|b:2') + self.assertRoundTrips('alternate:2|test:1') + + def test_handles_sort_order(self): + self.assertEqual('a:1|b:2', self.create_vcr('b:2|a:1').as_str()) + # Last one out of place + self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6', + self.create_vcr('f:6|a:1|b:2|c:3|d:4|e:5').as_str()) + # Fully reversed + self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6', + self.create_vcr('f:6|e:5|d:4|c:3|b:2|a:1').as_str()) + + def assertIncrement(self, original, replica_uid, after_increment): + vcr = self.create_vcr(original) + vcr.increment(replica_uid) + self.assertEqual(after_increment, vcr.as_str()) + + def test_increment(self): + self.assertIncrement(None, 'test', 'test:1') + self.assertIncrement('test:1', 'test', 'test:2') + + def test_increment_adds_uid(self): + self.assertIncrement('other:1', 'test', 'other:1|test:1') + self.assertIncrement('a:1|ab:2', 'aa', 'a:1|aa:1|ab:2') + + def test_increment_update_partial(self): + self.assertIncrement('a:1|ab:2', 'a', 'a:2|ab:2') + self.assertIncrement('a:2|ab:2', 'ab', 'a:2|ab:3') + + def test_increment_appends_uid(self): + self.assertIncrement('b:2', 'c', 'b:2|c:1') + + def assertMaximize(self, rev1, rev2, maximized): + vcr1 = self.create_vcr(rev1) + vcr2 = self.create_vcr(rev2) + vcr1.maximize(vcr2) + self.assertEqual(maximized, vcr1.as_str()) + # reset vcr1 to maximize the other way + vcr1 = self.create_vcr(rev1) + vcr2.maximize(vcr1) + self.assertEqual(maximized, vcr2.as_str()) + + def test_maximize(self): + self.assertMaximize(None, None, '') + self.assertMaximize(None, 'x:1', 'x:1') + self.assertMaximize('x:1', 'y:1', 'x:1|y:1') + self.assertMaximize('x:2', 'x:1', 'x:2') + self.assertMaximize('x:2', 'x:1|y:2', 'x:2|y:2') + self.assertMaximize('a:1|c:2|e:3', 'b:3|d:4|f:5', + 'a:1|b:3|c:2|d:4|e:3|f:5') + +load_tests = tests.load_with_scenarios diff --git a/src/leap/soledad/u1db/tests/testing-certs/Makefile b/src/leap/soledad/u1db/tests/testing-certs/Makefile new file mode 100644 index 00000000..2385e75b --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/Makefile @@ -0,0 +1,35 @@ +CATOP=./demoCA +ORIG_CONF=/usr/lib/ssl/openssl.cnf +ELEVEN_YEARS=-days 4015 + +init: + cp $(ORIG_CONF) ca.conf + install -d $(CATOP) + install -d $(CATOP)/certs + install -d $(CATOP)/crl + install -d $(CATOP)/newcerts + install -d $(CATOP)/private + touch $(CATOP)/index.txt + echo 01>$(CATOP)/crlnumber + @echo '**** Making CA certificate ...' + openssl req -nodes -new \ + -newkey rsa -keyout $(CATOP)/private/cakey.pem \ + -out $(CATOP)/careq.pem \ + -multivalue-rdn \ + -subj "/C=UK/ST=-/O=u1db LOCAL TESTING ONLY, DO NO TRUST/CN=u1db testing CA" + openssl ca -config ./ca.conf -create_serial \ + -out $(CATOP)/cacert.pem $(ELEVEN_YEARS) -batch \ + -keyfile $(CATOP)/private/cakey.pem -selfsign \ + -extensions v3_ca -infiles $(CATOP)/careq.pem + +pems: + cp ./demoCA/cacert.pem . + openssl req -new -config ca.conf \ + -multivalue-rdn \ + -subj "/O=u1db LOCAL TESTING ONLY, DO NOT TRUST/CN=localhost" \ + -nodes -keyout testing.key -out newreq.pem $(ELEVEN_YEARS) + openssl ca -batch -config ./ca.conf $(ELEVEN_YEARS) \ + -policy policy_anything \ + -out testing.cert -infiles newreq.pem + +.PHONY: init pems diff --git a/src/leap/soledad/u1db/tests/testing-certs/cacert.pem b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem new file mode 100644 index 00000000..c019a730 --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem @@ -0,0 +1,58 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + e4:de:01:76:c4:78:78:7e + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA + Validity + Not Before: May 3 11:11:11 2012 GMT + Not After : May 1 11:11:11 2023 GMT + Subject: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:bc:91:a5:7f:7d:37:f7:06:c7:db:5b:83:6a:6b: + 63:c3:8b:5c:f7:84:4d:97:6d:d4:be:bf:e7:79:a8: + c1:03:57:ec:90:d4:20:e7:02:95:d9:a6:49:e3:f9: + 9a:ea:37:b9:b2:02:62:ab:40:d3:42:bb:4a:4e:a2: + 47:71:0f:1d:a2:c5:94:a1:cf:35:d3:23:32:42:c0: + 1e:8d:cb:08:58:fb:8a:5c:3e:ea:eb:d5:2c:ed:d6: + aa:09:b4:b5:7d:e3:45:c9:ae:c2:82:b2:ae:c0:81: + bc:24:06:65:a9:e7:e0:61:ac:25:ee:53:d3:d7:be: + 22:f7:00:a2:ad:c6:0e:3a:39 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D + X509v3 Authority Key Identifier: + keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 72:9b:c1:f7:07:65:83:36:25:4e:01:2f:b7:4a:f2:a4:00:28: + 80:c7:56:2c:32:39:90:13:61:4b:bb:12:c5:44:9d:42:57:85: + 28:19:70:69:e1:43:c8:bd:11:f6:94:df:91:2d:c3:ea:82:8d: + b4:8f:5d:47:a3:00:99:53:29:93:27:6c:c5:da:c1:20:6f:ab: + ec:4a:be:34:f3:8f:02:e5:0c:c0:03:ac:2b:33:41:71:4f:0a: + 72:5a:b4:26:1a:7f:81:bc:c0:95:8a:06:87:a8:11:9f:5c:73: + 38:df:5a:69:40:21:29:ad:46:23:56:75:e1:e9:8b:10:18:4c: + 7b:54 +-----BEGIN CERTIFICATE----- +MIICkjCCAfugAwIBAgIJAOTeAXbEeHh+MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV +BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg +T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x +MjA1MDMxMTExMTFaFw0yMzA1MDExMTExMTFaMGIxCzAJBgNVBAYTAlVLMQowCAYD +VQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcgT05MWSwgRE8gTk8g +VFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTCBnzANBgkqhkiG9w0BAQEF +AAOBjQAwgYkCgYEAvJGlf3039wbH21uDamtjw4tc94RNl23Uvr/neajBA1fskNQg +5wKV2aZJ4/ma6je5sgJiq0DTQrtKTqJHcQ8dosWUoc810yMyQsAejcsIWPuKXD7q +69Us7daqCbS1feNFya7CgrKuwIG8JAZlqefgYawl7lPT174i9wCircYOOjkCAwEA +AaNQME4wHQYDVR0OBBYEFNs9k1FsMhVUjxBQ/ElPNhUou5VtMB8GA1UdIwQYMBaA +FNs9k1FsMhVUjxBQ/ElPNhUou5VtMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF +BQADgYEAcpvB9wdlgzYlTgEvt0rypAAogMdWLDI5kBNhS7sSxUSdQleFKBlwaeFD +yL0R9pTfkS3D6oKNtI9dR6MAmVMpkydsxdrBIG+r7Eq+NPOPAuUMwAOsKzNBcU8K +clq0Jhp/gbzAlYoGh6gRn1xzON9aaUAhKa1GI1Z14emLEBhMe1Q= +-----END CERTIFICATE----- diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.cert b/src/leap/soledad/u1db/tests/testing-certs/testing.cert new file mode 100644 index 00000000..985684fb --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/testing.cert @@ -0,0 +1,61 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + e4:de:01:76:c4:78:78:7f + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA + Validity + Not Before: May 3 11:11:14 2012 GMT + Not After : May 1 11:11:14 2023 GMT + Subject: O=u1db LOCAL TESTING ONLY, DO NOT TRUST, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c6:1d:72:d3:c5:e4:fc:d1:4c:d9:e4:08:3e:90: + 10:ce:3f:1f:87:4a:1d:4f:7f:2a:5a:52:c9:65:4f: + d9:2c:bf:69:75:18:1a:b5:c9:09:32:00:47:f5:60: + aa:c6:dd:3a:87:37:5f:16:be:de:29:b5:ea:fc:41: + 7e:eb:77:bb:df:63:c3:06:1e:ed:e9:a0:67:1a:f1: + ec:e1:9d:f7:9c:8f:1c:fa:c3:66:7b:39:dc:70:ae: + 09:1b:9c:c0:9a:c4:90:77:45:8e:39:95:a9:2f:92: + 43:bd:27:07:5a:99:51:6e:76:a0:af:dd:b1:2c:8f: + ca:8b:8c:47:0d:f6:6e:fc:69 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + 1C:63:85:E1:1D:F3:89:2E:6C:4E:3F:FB:D0:10:64:5A:C1:22:6A:2A + X509v3 Authority Key Identifier: + keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D + + Signature Algorithm: sha1WithRSAEncryption + 1d:6d:3e:bd:93:fd:bd:3e:17:b8:9f:f0:99:7f:db:50:5c:b2: + 01:42:03:b5:d5:94:05:d3:f6:8e:80:82:55:47:1f:58:f2:18: + 6c:ab:ef:43:2c:2f:10:e1:7c:c4:5c:cc:ac:50:50:22:42:aa: + 35:33:f5:b9:f3:a6:66:55:d9:36:f4:f2:e4:d4:d9:b5:2c:52: + 66:d4:21:17:97:22:b8:9b:d7:0e:7c:3d:ce:85:19:ca:c4:d2: + 58:62:31:c6:18:3e:44:fc:f4:30:b6:95:87:ee:21:4a:08:f0: + af:3c:8f:c4:ba:5e:a1:5c:37:1a:7d:7b:fe:66:ae:62:50:17: + 31:ca +-----BEGIN CERTIFICATE----- +MIICnzCCAgigAwIBAgIJAOTeAXbEeHh/MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV +BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg +T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x +MjA1MDMxMTExMTRaFw0yMzA1MDExMTExMTRaMEQxLjAsBgNVBAoMJXUxZGIgTE9D +QUwgVEVTVElORyBPTkxZLCBETyBOT1QgVFJVU1QxEjAQBgNVBAMMCWxvY2FsaG9z +dDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAxh1y08Xk/NFM2eQIPpAQzj8f +h0odT38qWlLJZU/ZLL9pdRgatckJMgBH9WCqxt06hzdfFr7eKbXq/EF+63e732PD +Bh7t6aBnGvHs4Z33nI8c+sNmeznccK4JG5zAmsSQd0WOOZWpL5JDvScHWplRbnag +r92xLI/Ki4xHDfZu/GkCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0E +HxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFBxjheEd +84kubE4/+9AQZFrBImoqMB8GA1UdIwQYMBaAFNs9k1FsMhVUjxBQ/ElPNhUou5Vt +MA0GCSqGSIb3DQEBBQUAA4GBAB1tPr2T/b0+F7if8Jl/21BcsgFCA7XVlAXT9o6A +glVHH1jyGGyr70MsLxDhfMRczKxQUCJCqjUz9bnzpmZV2Tb08uTU2bUsUmbUIReX +Irib1w58Pc6FGcrE0lhiMcYYPkT89DC2lYfuIUoI8K88j8S6XqFcNxp9e/5mrmJQ +FzHK +-----END CERTIFICATE----- diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.key b/src/leap/soledad/u1db/tests/testing-certs/testing.key new file mode 100644 index 00000000..d83d4920 --- /dev/null +++ b/src/leap/soledad/u1db/tests/testing-certs/testing.key @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMYdctPF5PzRTNnk +CD6QEM4/H4dKHU9/KlpSyWVP2Sy/aXUYGrXJCTIAR/VgqsbdOoc3Xxa+3im16vxB +fut3u99jwwYe7emgZxrx7OGd95yPHPrDZns53HCuCRucwJrEkHdFjjmVqS+SQ70n +B1qZUW52oK/dsSyPyouMRw32bvxpAgMBAAECgYBs3lXxhjg1rhabTjIxnx19GTcM +M3Az9V+izweZQu3HJ1CeZiaXauhAr+LbNsniCkRVddotN6oCJdQB10QVxXBZc9Jz +HPJ4zxtZfRZlNMTMmG7eLWrfxpgWnb/BUjDb40yy1nhr9yhDUnI/8RoHDRHnAEHZ +/CnHGUrqcVcrY5zJAQJBAPLhBJg9W88JVmcOKdWxRgs7dLHnZb999Kv1V5mczmAi +jvGvbUmucqOqke6pTUHNYyNHqU6pySzGUi2cH+BAkFECQQDQ0VoAOysg6FVoT15v +tGh57t5sTiCZZ7PS8jwvtThsgA+vcf6c16XWzXgjGXSap4r2QDOY2rI5lsWLaQ8T ++fyZAkAfyFJRmbXp4c7srW3MCOahkaYzoZQu+syJtBFCiMJ40gzik5I5khpuUGPI +V19EvRu8AiSlppIsycb3MPb64XgBAkEAy7DrUf5le5wmc7G4NM6OeyJ+5LbxJbL6 +vnJ8My1a9LuWkVVpQCU7J+UVo2dZTuLPspW9vwTVhUeFOxAoHRxlQQJAFem93f7m +el2BkB2EFqU3onPejkZ5UrDmfmeOQR1axMQNSXqSxcJxqa16Ru1BWV2gcWRbwajQ +oc+kuJThu/r/Ug== +-----END PRIVATE KEY----- diff --git a/src/leap/soledad/u1db/vectorclock.py b/src/leap/soledad/u1db/vectorclock.py new file mode 100644 index 00000000..42bceaa8 --- /dev/null +++ b/src/leap/soledad/u1db/vectorclock.py @@ -0,0 +1,89 @@ +# Copyright 2011 Canonical Ltd. +# +# This file is part of u1db. +# +# u1db is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# u1db is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with u1db. If not, see <http://www.gnu.org/licenses/>. + +"""VectorClockRev helper class.""" + + +class VectorClockRev(object): + """Track vector clocks for multiple replica ids. + + This allows simple comparison to determine if one VectorClockRev is + newer/older/in-conflict-with another VectorClockRev without having to + examine history. Every replica has a strictly increasing revision. When + creating a new revision, they include all revisions for all other replicas + which the new revision dominates, and increment their own revision to + something greater than the current value. + """ + + def __init__(self, value): + self._values = self._expand(value) + + def __repr__(self): + s = self.as_str() + return '%s(%s)' % (self.__class__.__name__, s) + + def as_str(self): + s = '|'.join(['%s:%d' % (m, r) for m, r + in sorted(self._values.items())]) + return s + + def _expand(self, value): + result = {} + if value is None: + return result + for replica_info in value.split('|'): + replica_uid, counter = replica_info.split(':') + counter = int(counter) + result[replica_uid] = counter + return result + + def is_newer(self, other): + """Is this VectorClockRev strictly newer than other. + """ + if not self._values: + return False + if not other._values: + return True + this_is_newer = False + other_expand = dict(other._values) + for key, value in self._values.iteritems(): + if key in other_expand: + other_value = other_expand.pop(key) + if other_value > value: + return False + elif other_value < value: + this_is_newer = True + else: + this_is_newer = True + if other_expand: + return False + return this_is_newer + + def increment(self, replica_uid): + """Increase the 'replica_uid' section of this vector clock. + + :return: A string representing the new vector clock value + """ + self._values[replica_uid] = self._values.get(replica_uid, 0) + 1 + + def maximize(self, other_vcr): + for replica_uid, counter in other_vcr._values.iteritems(): + if replica_uid not in self._values: + self._values[replica_uid] = counter + else: + this_counter = self._values[replica_uid] + if this_counter < counter: + self._values[replica_uid] = counter diff --git a/src/leap/testing/__init__.py b/src/leap/testing/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/testing/__init__.py diff --git a/src/leap/testing/basetest.py b/src/leap/testing/basetest.py new file mode 100644 index 00000000..3186e1eb --- /dev/null +++ b/src/leap/testing/basetest.py @@ -0,0 +1,85 @@ +import os +import platform +import shutil +import tempfile + +try: + import unittest2 as unittest +except ImportError: + import unittest + +from leap.base.config import get_username, get_groupname +from leap.util.fileutil import mkdir_p, check_and_fix_urw_only + +_system = platform.system() + + +class BaseLeapTest(unittest.TestCase): + + __name__ = "leap_test" + + @classmethod + def setUpClass(cls): + cls.old_path = os.environ['PATH'] + cls.old_home = os.environ['HOME'] + cls.tempdir = tempfile.mkdtemp(prefix="leap_tests-") + cls.home = cls.tempdir + bin_tdir = os.path.join( + cls.tempdir, + 'bin') + os.environ["PATH"] = bin_tdir + os.environ["HOME"] = cls.tempdir + + @classmethod + def tearDownClass(cls): + os.environ["PATH"] = cls.old_path + os.environ["HOME"] = cls.old_home + # safety check + assert cls.tempdir.startswith('/tmp/leap_tests-') + shutil.rmtree(cls.tempdir) + + # you have to override these methods + # this way we ensure we did not put anything + # here that you can forget to call. + + def setUp(self): + raise NotImplementedError("abstract base class") + + def tearDown(self): + raise NotImplementedError("abstract base class") + + # + # helper methods + # + + def get_tempfile(self, filename): + return os.path.join(self.tempdir, filename) + + def get_username(self): + return get_username() + + def get_groupname(self): + return get_groupname() + + def _missing_test_for_plat(self, do_raise=False): + if do_raise: + raise NotImplementedError( + "This test is not implemented " + "for the running platform: %s" % + _system) + + def touch(self, filepath): + folder, filename = os.path.split(filepath) + if not os.path.isdir(folder): + mkdir_p(folder) + # XXX should move to test_basetest + self.assertTrue(os.path.isdir(folder)) + + with open(filepath, 'w') as fp: + fp.write(' ') + + # XXX should move to test_basetest + self.assertTrue(os.path.isfile(filepath)) + + def chmod600(self, filepath): + check_and_fix_urw_only(filepath) diff --git a/src/leap/testing/cacert.pem b/src/leap/testing/cacert.pem new file mode 100644 index 00000000..6989c480 --- /dev/null +++ b/src/leap/testing/cacert.pem @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID1TCCAr2gAwIBAgIJAOv0BS09D8byMA0GCSqGSIb3DQEBBQUAMIGAMQswCQYD +VQQGEwJVUzETMBEGA1UECAwKY3liZXJzcGFjZTEnMCUGA1UECgweTEVBUCBFbmNy +eXB0aW9uIEFjY2VzcyBQcm9qZWN0MRYwFAYDVQQDDA10ZXN0cy1sZWFwLnNlMRsw +GQYJKoZIhvcNAQkBFgxpbmZvQGxlYXAuc2UwHhcNMTIwODMxMTYyNjMwWhcNMTUw +ODMxMTYyNjMwWjCBgDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCmN5YmVyc3BhY2Ux +JzAlBgNVBAoMHkxFQVAgRW5jcnlwdGlvbiBBY2Nlc3MgUHJvamVjdDEWMBQGA1UE +AwwNdGVzdHMtbGVhcC5zZTEbMBkGCSqGSIb3DQEJARYMaW5mb0BsZWFwLnNlMIIB +IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1pU7OU+abrUXFZwp6X0LlF0f +xQvC1Nmr5sFH7N9RTu3bdwY2t57ECP2TPkH6+x7oOvCTgAMxIE1scWEEkfgKViqW +FH/Om1UW1PMaiDYGtFuqEuxM95FvaYxp2K6rzA37WNsedA28sCYzhRD+/5HqbCNT +3rRS2cPaVO8kXI/5bgd8bUk3009pWTg4SvTtOW/9MWJbBH5f5JWmMn7Ayt6hIdT/ +E6npofEK/UCqAlEscARYFXSB/F8nK1whjo9mGFjMUd7d/25UbFHqOk4K7ishD4DH +F7LaS84rS+Sjwn3YtDdDQblGghJfz8X1AfPSGivGnvLVdkmMF9Y2hJlSQ7+C5wID +AQABo1AwTjAdBgNVHQ4EFgQUnpJEv4FnlqKbfm7mprudKdrnOAowHwYDVR0jBBgw +FoAUnpJEv4FnlqKbfm7mprudKdrnOAowDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0B +AQUFAAOCAQEAGW66qwdK/ATRVZkTpI2sgi+2dWD5tY4VyZuJIrRwfXsGPeVvmdsa +zDmwW5dMkth1Of5yO6o7ijvUvfnw/UCLNLNICKZhH5G0DHstfBeFc0jnP2MqOZCp +puRGPBlO2nxUCvoGcPRUKGQK9XSYmxcmaSFyzKVDMLnmH+Lakj5vaY9a8ZAcZTz7 +T5qePxKAxg+RIlH8Ftc485QP3fhqPYPrRsL3g6peiqCvIRshoP1MSoh19boI+1uX +wHQ/NyDkL5ErKC5JCSpaeF8VG1ek570kKWQLuQAbnlXZw+Sqfu35CIdizHaYGEcx +xA8oXH4L2JaT2x9GKDSpCmB2xXy/NVamUg== +-----END CERTIFICATE----- diff --git a/src/leap/testing/https_server.py b/src/leap/testing/https_server.py new file mode 100644 index 00000000..21191c32 --- /dev/null +++ b/src/leap/testing/https_server.py @@ -0,0 +1,68 @@ +from BaseHTTPServer import HTTPServer +import os +import ssl +import SocketServer +import threading +import unittest + +_where = os.path.split(__file__)[0] + + +def where(filename): + return os.path.join(_where, filename) + + +class HTTPSServer(HTTPServer): + def server_bind(self): + SocketServer.TCPServer.server_bind(self) + self.socket = ssl.wrap_socket( + self.socket, server_side=True, + certfile=where("leaptestscert.pem"), + keyfile=where("leaptestskey.pem"), + ca_certs=where("cacert.pem"), + ssl_version=ssl.PROTOCOL_SSLv23) + + +class TestServerThread(threading.Thread): + def __init__(self, test_object, request_handler): + threading.Thread.__init__(self) + self.request_handler = request_handler + self.test_object = test_object + + def run(self): + self.server = HTTPSServer(('localhost', 0), self.request_handler) + host, port = self.server.socket.getsockname() + self.test_object.HOST, self.test_object.PORT = host, port + self.test_object.server_started.set() + self.test_object = None + try: + self.server.serve_forever(0.05) + finally: + self.server.server_close() + + def stop(self): + self.server.shutdown() + + +class BaseHTTPSServerTestCase(unittest.TestCase): + """ + derived classes need to implement a request_handler + """ + def setUp(self): + self.server_started = threading.Event() + self.thread = TestServerThread(self, self.request_handler) + self.thread.start() + self.server_started.wait() + + def tearDown(self): + self.thread.stop() + + def get_server(self): + host, port = self.HOST, self.PORT + if host == "127.0.0.1": + host = "localhost" + return "%s:%s" % (host, port) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/testing/leaptestscert.pem b/src/leap/testing/leaptestscert.pem new file mode 100644 index 00000000..65596b1a --- /dev/null +++ b/src/leap/testing/leaptestscert.pem @@ -0,0 +1,84 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + eb:f4:05:2d:3d:0f:c6:f3 + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=US, ST=cyberspace, O=LEAP Encryption Access Project, CN=tests-leap.se/emailAddress=info@leap.se + Validity + Not Before: Aug 31 16:30:17 2012 GMT + Not After : Aug 31 16:30:17 2013 GMT + Subject: C=US, ST=cyberspace, L=net, O=LEAP Encryption Access Project, CN=localhost/emailAddress=info@leap.se + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:bc:f1:c4:05:ce:4b:d5:9b:9a:fa:c1:a5:0c:89: + 15:7e:05:69:b6:a4:62:38:3a:d6:14:4a:36:aa:3c: + 31:70:54:2e:bf:7d:05:19:ad:7b:0c:a9:a6:7d:46: + be:83:62:cb:ea:b9:48:6c:7d:78:a0:10:0b:ad:8a: + 74:7a:b8:ff:32:85:64:36:90:dc:38:dd:90:6e:07: + 82:70:ae:5f:4e:1f:f4:46:98:f3:98:b4:fa:08:65: + bf:d6:ec:a9:ba:7e:a8:f0:40:a2:d0:1a:cb:e6:fc: + 95:c5:54:63:92:5b:b8:0a:36:cc:26:d3:2b:ad:16: + ff:49:53:f4:65:7c:64:27:9a:f5:12:75:11:a5:0c: + 5a:ea:1e:e4:31:f3:a6:2b:db:0e:4a:5d:aa:47:3a: + f0:5e:2a:d5:6f:74:b6:f8:bc:9a:73:d0:fa:8a:be: + a8:69:47:9b:07:45:d9:b5:cd:1c:9b:c5:41:9a:65: + cc:99:a0:bd:bf:b5:e8:9f:66:5f:69:c9:6d:c8:68: + 50:68:74:ae:8e:12:7e:9c:24:4f:dc:05:61:b7:8a: + 6d:2a:95:43:d9:3f:fe:d8:c9:a7:ae:63:cd:30:d5: + 95:84:18:2d:12:b5:2d:a6:fe:37:dd:74:b8:f8:a5: + 59:18:8f:ca:f7:ae:63:0d:9d:66:51:7d:9c:40:48: + 9b:a1 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + B2:50:B4:C6:38:8F:BA:C4:3B:69:4C:6B:45:7C:CF:08:48:36:02:E0 + X509v3 Authority Key Identifier: + keyid:9E:92:44:BF:81:67:96:A2:9B:7E:6E:E6:A6:BB:9D:29:DA:E7:38:0A + + Signature Algorithm: sha1WithRSAEncryption + aa:ab:d4:27:e3:cb:42:05:55:fd:24:b3:e5:55:7d:fb:ce:6c: + ff:c7:96:f0:7d:30:a1:53:4a:04:eb:a4:24:5e:96:ee:65:ef: + e5:aa:08:47:9d:aa:95:2a:bb:6a:28:9f:51:62:63:d9:7d:1a: + 81:a0:72:f7:9f:33:6b:3b:f4:dc:85:cd:2a:ee:83:a9:93:3d: + 75:53:91:fa:0b:1b:10:83:11:2c:03:4e:ac:bf:c3:e6:25:74: + 9f:14:13:4a:43:66:c2:d7:1c:6c:94:3e:a6:f3:a5:bd:01:2c: + 9f:20:29:2e:62:82:12:d8:8b:70:1b:88:2b:18:68:5a:45:80: + 46:2a:6a:d5:df:1f:d3:e8:57:39:0a:be:1a:d8:b0:3e:e5:b6: + c3:69:b7:5e:c0:7b:b3:a8:a6:78:ee:0a:3d:a0:74:40:fb:42: + 9f:f4:98:7f:47:cc:15:28:eb:b1:95:77:82:a8:65:9b:46:c3: + 4f:f9:f4:72:be:bd:24:28:5c:0d:b3:89:e4:13:71:c8:a7:54: + 1b:26:15:f3:c1:b2:a9:13:77:54:c2:b9:b0:c7:24:39:00:4c: + 1a:a7:9b:e7:ad:4a:3a:32:c2:81:0d:13:2d:27:ea:98:00:a9: + 0e:9e:38:3b:8f:80:34:17:17:3d:49:7e:f4:a5:19:05:28:08: + 7d:de:d3:1f +-----BEGIN CERTIFICATE----- +MIIECjCCAvKgAwIBAgIJAOv0BS09D8bzMA0GCSqGSIb3DQEBBQUAMIGAMQswCQYD +VQQGEwJVUzETMBEGA1UECAwKY3liZXJzcGFjZTEnMCUGA1UECgweTEVBUCBFbmNy +eXB0aW9uIEFjY2VzcyBQcm9qZWN0MRYwFAYDVQQDDA10ZXN0cy1sZWFwLnNlMRsw +GQYJKoZIhvcNAQkBFgxpbmZvQGxlYXAuc2UwHhcNMTIwODMxMTYzMDE3WhcNMTMw +ODMxMTYzMDE3WjCBijELMAkGA1UEBhMCVVMxEzARBgNVBAgMCmN5YmVyc3BhY2Ux +DDAKBgNVBAcMA25ldDEnMCUGA1UECgweTEVBUCBFbmNyeXB0aW9uIEFjY2VzcyBQ +cm9qZWN0MRIwEAYDVQQDDAlsb2NhbGhvc3QxGzAZBgkqhkiG9w0BCQEWDGluZm9A +bGVhcC5zZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALzxxAXOS9Wb +mvrBpQyJFX4FabakYjg61hRKNqo8MXBULr99BRmtewyppn1GvoNiy+q5SGx9eKAQ +C62KdHq4/zKFZDaQ3DjdkG4HgnCuX04f9EaY85i0+ghlv9bsqbp+qPBAotAay+b8 +lcVUY5JbuAo2zCbTK60W/0lT9GV8ZCea9RJ1EaUMWuoe5DHzpivbDkpdqkc68F4q +1W90tvi8mnPQ+oq+qGlHmwdF2bXNHJvFQZplzJmgvb+16J9mX2nJbchoUGh0ro4S +fpwkT9wFYbeKbSqVQ9k//tjJp65jzTDVlYQYLRK1Lab+N910uPilWRiPyveuYw2d +ZlF9nEBIm6ECAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3Bl +blNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFLJQtMY4j7rEO2lM +a0V8zwhINgLgMB8GA1UdIwQYMBaAFJ6SRL+BZ5aim35u5qa7nSna5zgKMA0GCSqG +SIb3DQEBBQUAA4IBAQCqq9Qn48tCBVX9JLPlVX37zmz/x5bwfTChU0oE66QkXpbu +Ze/lqghHnaqVKrtqKJ9RYmPZfRqBoHL3nzNrO/Tchc0q7oOpkz11U5H6CxsQgxEs +A06sv8PmJXSfFBNKQ2bC1xxslD6m86W9ASyfICkuYoIS2ItwG4grGGhaRYBGKmrV +3x/T6Fc5Cr4a2LA+5bbDabdewHuzqKZ47go9oHRA+0Kf9Jh/R8wVKOuxlXeCqGWb +RsNP+fRyvr0kKFwNs4nkE3HIp1QbJhXzwbKpE3dUwrmwxyQ5AEwap5vnrUo6MsKB +DRMtJ+qYAKkOnjg7j4A0Fxc9SX70pRkFKAh93tMf +-----END CERTIFICATE----- diff --git a/src/leap/testing/leaptestskey.pem b/src/leap/testing/leaptestskey.pem new file mode 100644 index 00000000..fe6291a1 --- /dev/null +++ b/src/leap/testing/leaptestskey.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAvPHEBc5L1Zua+sGlDIkVfgVptqRiODrWFEo2qjwxcFQuv30F +Ga17DKmmfUa+g2LL6rlIbH14oBALrYp0erj/MoVkNpDcON2QbgeCcK5fTh/0Rpjz +mLT6CGW/1uypun6o8ECi0BrL5vyVxVRjklu4CjbMJtMrrRb/SVP0ZXxkJ5r1EnUR +pQxa6h7kMfOmK9sOSl2qRzrwXirVb3S2+Lyac9D6ir6oaUebB0XZtc0cm8VBmmXM +maC9v7Xon2ZfacltyGhQaHSujhJ+nCRP3AVht4ptKpVD2T/+2MmnrmPNMNWVhBgt +ErUtpv433XS4+KVZGI/K965jDZ1mUX2cQEiboQIDAQABAoIBAQCh/+yhSbrtoCgm +PegEsnix/3QfPBxWt+Obq/HozglZlWQrnMbFuF+bgM4V9ZUdU5UhYNF+66mEG53X +orGyE3IDYCmHO3cGbroKDPhDIs7mTjGEYlniIbGLh6oPXgU8uKKis9ik84TGPOUx +NuTUtT07zLYHx+FX3DLwLUKLzTaWWSRgA7nxNwCY8aPqDxCkXEyZHvSlm9KYZnhe +nVevycoHR+chxL6X/ebbBt2FKR7tl4328mlDXvMXr0vahPH94CuXEvfTj+f6ZxZF +OctdikyRfd8O3ebrUw0XjafPYyTsDMH0/rQovEBVlecEHqh6Z9dBFlogRq5DSun9 +jem4bBXRAoGBAPGPi4g21pTQPqTFxpqea8TsPqIfo3csfMDPdzT246MxzALHqCfG +yZi4g2JYJrReSWHulZDORO5skSKNEb5VTA/3xFhKLt8CULZOakKBDLkzRXlnDFXg +Jsu9vtjDWjQcJsdsRx1tc5V6s+hmel70aaUu/maUlEYZnyIXaTe+1SB1AoGBAMg9 +EMEO5YN52pOI5qPH8j7uyVKtZWKRiR6jb5KA5TxWqZalSdPV6YwDqV/e+HjWrZNw +kSEFONY0seKpIHwXchx91aym7rDHUgOoBQfCWufRMYvRXLhfOTBu4X+U52++i8wt +FvKgh6eSmc7VayAaDfHp7yfrIfS03IiN0T35mGj9AoGAPCoXg7a83VW8tId5/trE +VsjMlM6yhSU0cUV7GFsBuYzWlj6qODX/0iTqvFzeTwBI4LZu1CE78/Jgd62RJMnT +5wo8Ag1//RVziuSe/K9tvtbxT9qFrQHmR8qbtRt65Q257uOeFstDBZEJLDIR+oJ/ +qZ+5x0zsXUVWaERSdYr3RF0CgYEApKDgN3oB5Ti4Jnh1984aMver+heptYKmU9RX +lQH4dsVhpQO8UTgcTgtso+/0JZWLHB9+ksFyW1rzrcETfjLglOA4XzzYHeuiWHM5 +v4lhqBpsO+Ij80oHAPUI3RYVud/VnEauCUlGftWfM1hwPPJu6KhHAnDleAWDE5pV +oDinwBkCgYEAnn/OceaqA2fNYp1IRegbFzpewjUlHLq3bXiCIVhO7W/HqsdfUxjE +VVdjEno/pAG7ZCO5j8u+rLkG2ZIVY3qsUENUiXz52Q08qEltgM8nfirK7vIQkfd9 +YISRE3QHYJd+ArY4v+7rNeF1O5eIEyzPAbvG5raeZFcZ6POxy66uWKo= +-----END RSA PRIVATE KEY----- diff --git a/src/leap/testing/test_basetest.py b/src/leap/testing/test_basetest.py new file mode 100644 index 00000000..14d8f8a3 --- /dev/null +++ b/src/leap/testing/test_basetest.py @@ -0,0 +1,91 @@ +"""becase it's oh so meta""" +try: + import unittest2 as unittest +except ImportError: + import unittest + +import os +import StringIO + +from leap.testing.basetest import BaseLeapTest + +# global for tempdir checking +_tempdir = None + + +class _TestCaseRunner(object): + def run_testcase(self, testcase=None): + if not testcase: + return None + loader = unittest.TestLoader() + suite = loader.loadTestsFromTestCase(testcase) + + # Create runner, and run testcase + io = StringIO.StringIO() + runner = unittest.TextTestRunner(stream=io) + results = runner.run(suite) + return results + + +class TestAbstractBaseLeapTest(unittest.TestCase, _TestCaseRunner): + + def test_abstract_base_class(self): + class _BaseTest(BaseLeapTest): + def test_dummy_method(self): + pass + + def test_tautology(self): + assert True + + results = self.run_testcase(_BaseTest) + + # should be 2 errors: NotImplemented + # raised for setUp/tearDown + self.assertEquals(results.testsRun, 2) + self.assertEquals(len(results.failures), 0) + self.assertEquals(len(results.errors), 2) + + +class TestInitBaseLeapTest(BaseLeapTest): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_path_is_changed(self): + os_path = os.environ['PATH'] + self.assertTrue(os_path.startswith(self.tempdir)) + + def test_old_path_is_saved(self): + self.assertTrue(len(self.old_path) > 1) + + +class TestCleanedBaseLeapTest(unittest.TestCase, _TestCaseRunner): + + def test_tempdir_is_cleaned_after_tests(self): + class _BaseTest(BaseLeapTest): + def setUp(self): + global _tempdir + _tempdir = self.tempdir + + def tearDown(self): + pass + + def test_tempdir_created(self): + self.assertTrue(os.path.isdir(self.tempdir)) + + def test_tempdir_created_on_setupclass(self): + self.assertEqual(_tempdir, self.tempdir) + + results = self.run_testcase(_BaseTest) + self.assertEquals(results.testsRun, 2) + self.assertEquals(len(results.failures), 0) + self.assertEquals(len(results.errors), 0) + + # did we cleaned the tempdir? + self.assertFalse(os.path.isdir(_tempdir)) + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/tests/fakeclient.py b/src/leap/tests/fakeclient.py deleted file mode 100644 index 45de2cd6..00000000 --- a/src/leap/tests/fakeclient.py +++ /dev/null @@ -1,63 +0,0 @@ -fakeoutput = """ -mullvad Sun Jun 17 14:34:57 2012 OpenVPN 2.2.1 i486-linux-gnu [SSL] [LZO2] [EPOLL] [PKCS11] [eurephia] [MH] [PF_INET6] [IPv6 payload 20110424-2 (2.2RC2)] built - on Mar 23 2012 -Sun Jun 17 14:34:57 2012 MANAGEMENT: TCP Socket listening on [AF_INET]127.0.0.1:7505 -Sun Jun 17 14:34:57 2012 NOTE: the current --script-security setting may allow this configuration to call user-defined scripts -Sun Jun 17 14:34:57 2012 WARNING: file 'ssl/1021380964266.key' is group or others accessible -Sun Jun 17 14:34:57 2012 LZO compression initialized -Sun Jun 17 14:34:57 2012 Control Channel MTU parms [ L:1542 D:138 EF:38 EB:0 ET:0 EL:0 ] -Sun Jun 17 14:34:57 2012 Socket Buffers: R=[163840->131072] S=[163840->131072] -Sun Jun 17 14:34:57 2012 Data Channel MTU parms [ L:1542 D:1450 EF:42 EB:135 ET:0 EL:0 AF:3/1 ] -Sun Jun 17 14:34:57 2012 Local Options hash (VER=V4): '41690919' -Sun Jun 17 14:34:57 2012 Expected Remote Options hash (VER=V4): '530fdded' -Sun Jun 17 14:34:57 2012 UDPv4 link local: [undef] -Sun Jun 17 14:34:57 2012 UDPv4 link remote: [AF_INET]46.21.99.25:1197 -Sun Jun 17 14:34:57 2012 TLS: Initial packet from [AF_INET]46.21.99.25:1197, sid=63c29ace 1d3060d0 -Sun Jun 17 14:34:58 2012 VERIFY OK: depth=2, /C=NA/ST=None/L=None/O=Mullvad/CN=Mullvad_CA/emailAddress=info@mullvad.net -Sun Jun 17 14:34:58 2012 VERIFY OK: depth=1, /C=NA/ST=None/L=None/O=Mullvad/CN=master.mullvad.net/emailAddress=info@mullvad.net -Sun Jun 17 14:34:58 2012 Validating certificate key usage -Sun Jun 17 14:34:58 2012 ++ Certificate has key usage 00a0, expects 00a0 -Sun Jun 17 14:34:58 2012 VERIFY KU OK -Sun Jun 17 14:34:58 2012 Validating certificate extended key usage -Sun Jun 17 14:34:58 2012 ++ Certificate has EKU (str) TLS Web Server Authentication, expects TLS Web Server Authentication -Sun Jun 17 14:34:58 2012 VERIFY EKU OK -Sun Jun 17 14:34:58 2012 VERIFY OK: depth=0, /C=NA/ST=None/L=None/O=Mullvad/CN=se2.mullvad.net/emailAddress=info@mullvad.net -Sun Jun 17 14:34:59 2012 Data Channel Encrypt: Cipher 'BF-CBC' initialized with 128 bit key -Sun Jun 17 14:34:59 2012 Data Channel Encrypt: Using 160 bit message hash 'SHA1' for HMAC authentication -Sun Jun 17 14:34:59 2012 Data Channel Decrypt: Cipher 'BF-CBC' initialized with 128 bit key -Sun Jun 17 14:34:59 2012 Data Channel Decrypt: Using 160 bit message hash 'SHA1' for HMAC authentication -Sun Jun 17 14:34:59 2012 Control Channel: TLSv1, cipher TLSv1/SSLv3 DHE-RSA-AES256-SHA, 2048 bit RSA -Sun Jun 17 14:34:59 2012 [se2.mullvad.net] Peer Connection Initiated with [AF_INET]46.21.99.25:1197 -Sun Jun 17 14:35:01 2012 SENT CONTROL [se2.mullvad.net]: 'PUSH_REQUEST' (status=1) -Sun Jun 17 14:35:02 2012 PUSH: Received control message: 'PUSH_REPLY,redirect-gateway def1 bypass-dhcp,dhcp-option DNS 10.11.0.1,route 10.11.0.1,topology net30,ifconfig 10.11.0.202 10.11.0.201' -Sun Jun 17 14:35:02 2012 OPTIONS IMPORT: --ifconfig/up options modified -Sun Jun 17 14:35:02 2012 OPTIONS IMPORT: route options modified -Sun Jun 17 14:35:02 2012 OPTIONS IMPORT: --ip-win32 and/or --dhcp-option options modified -Sun Jun 17 14:35:02 2012 ROUTE default_gateway=192.168.0.1 -Sun Jun 17 14:35:02 2012 TUN/TAP device tun0 opened -Sun Jun 17 14:35:02 2012 TUN/TAP TX queue length set to 100 -Sun Jun 17 14:35:02 2012 do_ifconfig, tt->ipv6=0, tt->did_ifconfig_ipv6_setup=0 -Sun Jun 17 14:35:02 2012 /sbin/ifconfig tun0 10.11.0.202 pointopoint 10.11.0.201 mtu 1500 -Sun Jun 17 14:35:02 2012 /etc/openvpn/update-resolv-conf tun0 1500 1542 10.11.0.202 10.11.0.201 init -dhcp-option DNS 10.11.0.1 -Sun Jun 17 14:35:05 2012 /sbin/route add -net 46.21.99.25 netmask 255.255.255.255 gw 192.168.0.1 -Sun Jun 17 14:35:05 2012 /sbin/route add -net 0.0.0.0 netmask 128.0.0.0 gw 10.11.0.201 -Sun Jun 17 14:35:05 2012 /sbin/route add -net 128.0.0.0 netmask 128.0.0.0 gw 10.11.0.201 -Sun Jun 17 14:35:05 2012 /sbin/route add -net 10.11.0.1 netmask 255.255.255.255 gw 10.11.0.201 -Sun Jun 17 14:35:05 2012 Initialization Sequence Completed -Sun Jun 17 14:34:57 2012 MANAGEMENT: TCP Socket listening on [AF_INET]127.0.0.1:7505 -""" - -import time -import sys - - -def write_output(): - for line in fakeoutput.split('\n'): - sys.stdout.write(line + '\n') - sys.stdout.flush() - #print(line) - time.sleep(0.1) - -if __name__ == "__main__": - write_output() diff --git a/src/leap/tests/mocks/__init__.py b/src/leap/tests/mocks/__init__.py deleted file mode 100644 index 06f96870..00000000 --- a/src/leap/tests/mocks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -import manager diff --git a/src/leap/tests/mocks/manager.py b/src/leap/tests/mocks/manager.py deleted file mode 100644 index 564631cd..00000000 --- a/src/leap/tests/mocks/manager.py +++ /dev/null @@ -1,20 +0,0 @@ -from mock import Mock - -from eip_client.vpnmanager import OpenVPNManager - -vpn_commands = { - 'status': [ - 'OpenVPN STATISTICS', 'Updated,Mon Jun 25 11:51:21 2012', - 'TUN/TAP read bytes,306170', 'TUN/TAP write bytes,872102', - 'TCP/UDP read bytes,986177', 'TCP/UDP write bytes,439329', - 'Auth read bytes,872102'], - 'state': ['1340616463,CONNECTED,SUCCESS,172.28.0.2,198.252.153.38'], - # XXX add more tests - } - - -def get_openvpn_manager_mocks(): - manager = OpenVPNManager() - manager.status = Mock(return_value='\n'.join(vpn_commands['status'])) - manager.state = Mock(return_value=vpn_commands['state'][0]) - return manager diff --git a/src/leap/util/__init__.py b/src/leap/util/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/util/__init__.py diff --git a/src/leap/utils/coroutines.py b/src/leap/util/coroutines.py index 5e25eb63..0657fc04 100644 --- a/src/leap/utils/coroutines.py +++ b/src/leap/util/coroutines.py @@ -4,10 +4,13 @@ from __future__ import division, print_function +import logging from subprocess import PIPE, Popen import sys from threading import Thread +logger = logging.getLogger(__name__) + ON_POSIX = 'posix' in sys.builtin_module_names @@ -38,8 +41,7 @@ for each event if callable(callback): callback(m) else: - #XXX log instead - print('not a callable passed') + logger.debug('not a callable passed') except GeneratorExit: return @@ -72,10 +74,10 @@ def watch_output(out, observers): :type out: fd :param observers: tuple of coroutines to send data\ for each event - :type ovservers: tuple + :type observers: tuple """ - observer_dict = {observer: process_events(observer) - for observer in observers} + observer_dict = dict(((observer, process_events(observer)) + for observer in observers)) for line in iter(out.readline, b''): for obs in observer_dict: observer_dict[obs].send(line) diff --git a/src/leap/util/dicts.py b/src/leap/util/dicts.py new file mode 100644 index 00000000..001ca96b --- /dev/null +++ b/src/leap/util/dicts.py @@ -0,0 +1,268 @@ +# Backport of OrderedDict() class that runs +# on Python 2.4, 2.5, 2.6, 2.7 and pypy. +# Passes Python2.7's test suite and incorporates all the latest updates. + +try: + from thread import get_ident as _get_ident +except ImportError: + from dummy_thread import get_ident as _get_ident + +try: + from _abcoll import KeysView, ValuesView, ItemsView +except ImportError: + pass + + +class OrderedDict(dict): + 'Dictionary that remembers insertion order' + # An inherited dict maps keys to values. + # The inherited dict provides __getitem__, __len__, __contains__, and get. + # The remaining methods are order-aware. + # Big-O running times for all methods are the same as for regular + # dictionaries. + + # The internal self.__map dictionary maps keys to links in a doubly + # linked list. + # The circular doubly linked list starts and ends with a sentinel element. + # The sentinel element never gets deleted (this simplifies the algorithm). + # Each link is stored as a list of length three: [PREV, NEXT, KEY]. + + def __init__(self, *args, **kwds): + '''Initialize an ordered dictionary. Signature is the same as for + regular dictionaries, but keyword arguments are not recommended + because their insertion order is arbitrary. + + ''' + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: + self.__root + except AttributeError: + self.__root = root = [] # sentinel node + root[:] = [root, root, None] + self.__map = {} + self.__update(*args, **kwds) + + def __setitem__(self, key, value, dict_setitem=dict.__setitem__): + 'od.__setitem__(i, y) <==> od[i]=y' + # Setting a new item creates a new link which goes at the end + # of the linked list, and the inherited dictionary is updated + # with the new key/value pair. + if key not in self: + root = self.__root + last = root[0] + last[1] = root[0] = self.__map[key] = [last, root, key] + dict_setitem(self, key, value) + + def __delitem__(self, key, dict_delitem=dict.__delitem__): + 'od.__delitem__(y) <==> del od[y]' + # Deleting an existing item uses self.__map to find the link which is + # then removed by updating the links in the predecessor and successor + # nodes. + dict_delitem(self, key) + link_prev, link_next, key = self.__map.pop(key) + link_prev[1] = link_next + link_next[0] = link_prev + + def __iter__(self): + 'od.__iter__() <==> iter(od)' + root = self.__root + curr = root[1] + while curr is not root: + yield curr[2] + curr = curr[1] + + def __reversed__(self): + 'od.__reversed__() <==> reversed(od)' + root = self.__root + curr = root[0] + while curr is not root: + yield curr[2] + curr = curr[0] + + def clear(self): + 'od.clear() -> None. Remove all items from od.' + try: + for node in self.__map.itervalues(): + del node[:] + root = self.__root + root[:] = [root, root, None] + self.__map.clear() + except AttributeError: + pass + dict.clear(self) + + def popitem(self, last=True): + '''od.popitem() -> (k, v), return and remove a (key, value) pair. + Pairs are returned in LIFO order if last is true or FIFO order if + false. + ''' + if not self: + raise KeyError('dictionary is empty') + root = self.__root + if last: + link = root[0] + link_prev = link[0] + link_prev[1] = root + root[0] = link_prev + else: + link = root[1] + link_next = link[1] + root[1] = link_next + link_next[0] = root + key = link[2] + del self.__map[key] + value = dict.pop(self, key) + return key, value + + # -- the following methods do not depend on the internal structure -- + + def keys(self): + 'od.keys() -> list of keys in od' + return list(self) + + def values(self): + 'od.values() -> list of values in od' + return [self[key] for key in self] + + def items(self): + 'od.items() -> list of (key, value) pairs in od' + return [(key, self[key]) for key in self] + + def iterkeys(self): + 'od.iterkeys() -> an iterator over the keys in od' + return iter(self) + + def itervalues(self): + 'od.itervalues -> an iterator over the values in od' + for k in self: + yield self[k] + + def iteritems(self): + 'od.iteritems -> an iterator over the (key, value) items in od' + for k in self: + yield (k, self[k]) + + def update(*args, **kwds): + '''od.update(E, **F) -> None. Update od from dict/iterable E and F. + + If E is a dict instance, does: for k in E: od[k] = E[k] + If E has a .keys() method, does: for k in E.keys(): + od[k] = E[k] + Or if E is an iterable of items, does: for k, v in E: od[k] = v + In either case, this is followed by: for k, v in F.items(): + od[k] = v + ''' + + if len(args) > 2: + raise TypeError('update() takes at most 2 positional ' + 'arguments (%d given)' % (len(args),)) + elif not args: + raise TypeError('update() takes at least 1 argument (0 given)') + self = args[0] + # Make progressively weaker assumptions about "other" + other = () + if len(args) == 2: + other = args[1] + if isinstance(other, dict): + for key in other: + self[key] = other[key] + elif hasattr(other, 'keys'): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + __update = update # let subclasses override update + # without breaking __init__ + + __marker = object() + + def pop(self, key, default=__marker): + '''od.pop(k[,d]) -> v + remove specified key and return the corresponding value. + If key is not found, d is returned if given, + otherwise KeyError is raised. + + ''' + if key in self: + result = self[key] + del self[key] + return result + if default is self.__marker: + raise KeyError(key) + return default + + def setdefault(self, key, default=None): + 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od' + if key in self: + return self[key] + self[key] = default + return default + + def __repr__(self, _repr_running={}): + 'od.__repr__() <==> repr(od)' + call_key = id(self), _get_ident() + if call_key in _repr_running: + return '...' + _repr_running[call_key] = 1 + try: + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + finally: + del _repr_running[call_key] + + def __reduce__(self): + 'Return state information for pickling' + items = [[k, self[k]] for k in self] + inst_dict = vars(self).copy() + for k in vars(OrderedDict()): + inst_dict.pop(k, None) + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def copy(self): + 'od.copy() -> a shallow copy of od' + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S + and values equal to v (which defaults to None). + + ''' + d = cls() + for key in iterable: + d[key] = value + return d + + def __eq__(self, other): + '''od.__eq__(y) <==> od==y. + Comparison to another OD is order-sensitive + while comparison to a regular mapping is order-insensitive. + ''' + if isinstance(other, OrderedDict): + return len(self) == len(other) and self.items() == other.items() + return dict.__eq__(self, other) + + def __ne__(self, other): + return not self == other + + # -- the following methods are only used in Python 2.7 -- + + def viewkeys(self): + "od.viewkeys() -> a set-like object providing a view on od's keys" + return KeysView(self) + + def viewvalues(self): + "od.viewvalues() -> an object providing a view on od's values" + return ValuesView(self) + + def viewitems(self): + "od.viewitems() -> a set-like object providing a view on od's items" + return ItemsView(self) diff --git a/src/leap/util/fileutil.py b/src/leap/util/fileutil.py new file mode 100644 index 00000000..aef4cfe0 --- /dev/null +++ b/src/leap/util/fileutil.py @@ -0,0 +1,115 @@ +import errno +from itertools import chain +import logging +import os +import platform +import stat + + +logger = logging.getLogger() + + +def is_user_executable(fpath): + st = os.stat(fpath) + return bool(st.st_mode & stat.S_IXUSR) + + +def extend_path(): + ourplatform = platform.system() + if ourplatform == "Linux": + return "/usr/local/sbin:/usr/sbin" + # XXX add mac / win extended search paths? + + +def which(program, path=None): + """ + an implementation of which + that extends the path with + other locations, like sbin + (f.i., openvpn binary is likely to be there) + @param program: a string representing the binary we're looking for. + """ + def is_exe(fpath): + """ + check that path exists, + it's a file, + and is executable by the owner + """ + # we would check for access, + # but it's likely that we're + # using uid 0 + polkitd + + return os.path.isfile(fpath)\ + and is_user_executable(fpath) + + def ext_candidates(fpath): + yield fpath + for ext in os.environ.get("PATHEXT", "").split(os.pathsep): + yield fpath + ext + + def iter_path(pathset): + """ + returns iterator with + full path for a given path list + and the current target bin. + """ + for path in pathset.split(os.pathsep): + exe_file = os.path.join(path, program) + #print 'file=%s' % exe_file + for candidate in ext_candidates(exe_file): + if is_exe(candidate): + yield candidate + + fpath, fname = os.path.split(program) + if fpath: + if is_exe(program): + return program + else: + # extended iterator + # with extra path + if path is None: + path = os.environ['PATH'] + extended_path = chain( + iter_path(path), + iter_path(extend_path())) + for candidate in extended_path: + if candidate is not None: + return candidate + + # sorry bro. + return None + + +def mkdir_p(path): + """ + implements mkdir -p functionality + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST: + pass + else: + raise + + +def check_and_fix_urw_only(_file): + """ + test for 600 mode and try + to set it if anything different found + """ + mode = stat.S_IMODE( + os.stat(_file).st_mode) + + if mode != int('600', 8): + try: + logger.warning( + 'bad permission on %s ' + 'attempting to set 600', + _file) + os.chmod(_file, stat.S_IRUSR | stat.S_IWUSR) + except OSError: + logger.error( + 'error while trying to chmod 600 %s', + _file) + raise diff --git a/src/leap/util/leap_argparse.py b/src/leap/util/leap_argparse.py new file mode 100644 index 00000000..2f996a31 --- /dev/null +++ b/src/leap/util/leap_argparse.py @@ -0,0 +1,41 @@ +import argparse + + +def build_parser(): + """ + all the options for the leap arg parser + Some of these could be switched on only if debug flag is present! + """ + epilog = "Copyright 2012 The Leap Project" + parser = argparse.ArgumentParser(description=""" +Launches main LEAP Client""", epilog=epilog) + parser.add_argument('-d', '--debug', action="store_true", + help='launches in debug mode') + parser.add_argument('-c', '--config', metavar="CONFIG FILE", nargs='?', + action="store", dest="config_file", + type=argparse.FileType('r'), + help='optional config file') + parser.add_argument('--logfile', metavar="LOG FILE", nargs='?', + action="store", dest="log_file", + #type=argparse.FileType('w'), + help='optional log file') + parser.add_argument('--openvpn-verbosity', nargs='?', + type=int, + action="store", dest="openvpn_verb", + help='verbosity level for openvpn logs [1-6]') + parser.add_argument('-l', '--no-provider-checks', + action="store_true", default=False, + help="skips download of provider config files. gets " + "config from local files only. Will fail if cannot " + "find any") + parser.add_argument('-k', '--no-ca-verify', + action="store_true", default=False, + help="(insecure). Skips verification of the server " + "certificate used in TLS handshake.") + return parser + + +def init_leapc_args(): + parser = build_parser() + opts = parser.parse_args() + return parser, opts diff --git a/src/leap/util/tests/__init__.py b/src/leap/util/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/util/tests/__init__.py diff --git a/src/leap/util/tests/test_fileutil.py b/src/leap/util/tests/test_fileutil.py new file mode 100644 index 00000000..f5131b3d --- /dev/null +++ b/src/leap/util/tests/test_fileutil.py @@ -0,0 +1,100 @@ +import os +import platform +import shutil +import stat +import tempfile +import unittest + +from leap.util import fileutil + + +class FileUtilTest(unittest.TestCase): + """ + test our file utils + """ + + def setUp(self): + self.system = platform.system() + self.create_temp_dir() + + def tearDown(self): + self.remove_temp_dir() + + # + # helpers + # + + def create_temp_dir(self): + self.tmpdir = tempfile.mkdtemp() + + def remove_temp_dir(self): + shutil.rmtree(self.tmpdir) + + def get_file_path(self, filename): + return os.path.join( + self.tmpdir, + filename) + + def touch_exec_file(self): + fp = self.get_file_path('testexec') + open(fp, 'w').close() + os.chmod( + fp, + stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) + return fp + + def get_mode(self, fp): + return stat.S_IMODE(os.stat(fp).st_mode) + + # + # tests + # + + def test_is_user_executable(self): + """ + touch_exec_file creates in mode 700? + """ + # XXX could check access X_OK + + fp = self.touch_exec_file() + mode = self.get_mode(fp) + self.assertEqual(mode, int('700', 8)) + + def test_which(self): + """ + which implementation ok? + not a very reliable test, + but I cannot think of anything smarter now + I guess it's highly improbable that copy + """ + # XXX yep, we can change the syspath + # for the test... ! + + if self.system == "Linux": + self.assertEqual( + fileutil.which('cp'), + '/bin/cp') + + def test_mkdir_p(self): + """ + our own mkdir -p implementation ok? + """ + testdir = self.get_file_path( + os.path.join('test', 'foo', 'bar')) + self.assertEqual(os.path.isdir(testdir), False) + fileutil.mkdir_p(testdir) + self.assertEqual(os.path.isdir(testdir), True) + + def test_check_and_fix_urw_only(self): + """ + ensure check_and_fix_urx_only ok? + """ + fp = self.touch_exec_file() + mode = self.get_mode(fp) + self.assertEqual(mode, int('700', 8)) + fileutil.check_and_fix_urw_only(fp) + mode = self.get_mode(fp) + self.assertEqual(mode, int('600', 8)) + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/util/tests/test_leap_argparse.py b/src/leap/util/tests/test_leap_argparse.py new file mode 100644 index 00000000..082919b7 --- /dev/null +++ b/src/leap/util/tests/test_leap_argparse.py @@ -0,0 +1,35 @@ +from argparse import Namespace +import unittest + +from leap.util import leap_argparse + + +class LeapArgParseTest(unittest.TestCase): + """ + Test argparse options for eip client + """ + + def setUp(self): + """ + get the parser + """ + self.parser = leap_argparse.build_parser() + + def test_debug_mode(self): + """ + test debug mode option + """ + opts = self.parser.parse_args( + ['--debug']) + self.assertEqual( + opts, + Namespace( + config_file=None, + debug=True, + log_file=None, + no_provider_checks=False, + no_ca_verify=False, + openvpn_verb=None)) + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/util/web.py b/src/leap/util/web.py new file mode 100644 index 00000000..b2aef058 --- /dev/null +++ b/src/leap/util/web.py @@ -0,0 +1,39 @@ +""" +web related utilities +""" + + +class UsageError(Exception): + """ """ + + +def get_https_domain_and_port(full_domain): + """ + returns a tuple with domain and port + from a full_domain string that can + contain a colon + """ + if full_domain is None: + return None, None + + https_sch = "https://" + http_sch = "http://" + + if full_domain.startswith(https_sch): + full_domain = full_domain.lstrip(https_sch) + elif full_domain.startswith(http_sch): + raise UsageError( + "cannot be called with a domain " + "that begins with 'http://'") + + domain_split = full_domain.split(':') + _len = len(domain_split) + if _len == 1: + domain, port = full_domain, 443 + elif _len == 2: + domain, port = domain_split + else: + raise UsageError( + "must be called with one only parameter" + "in the form domain[:port]") + return domain, port diff --git a/src/leap/utils/leap_argparse.py b/src/leap/utils/leap_argparse.py deleted file mode 100644 index 9c355134..00000000 --- a/src/leap/utils/leap_argparse.py +++ /dev/null @@ -1,20 +0,0 @@ -import argparse - - -def build_parser(): - epilog = "Copyright 2012 The Leap Project" - parser = argparse.ArgumentParser(description=""" -Launches main LEAP Client""", epilog=epilog) - parser.add_argument('--debug', action="store_true", - help='launches in debug mode') - parser.add_argument('--config', metavar="CONFIG FILE", nargs='?', - action="store", dest="config_file", - type=argparse.FileType('r'), - help='optional config file') - return parser - - -def init_leapc_args(): - parser = build_parser() - opts = parser.parse_args() - return parser, opts |