summaryrefslogtreecommitdiff
path: root/src/leap
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap')
-rw-r--r--src/leap/__init__.py35
-rw-r--r--src/leap/_version.py197
-rw-r--r--src/leap/app.py77
-rw-r--r--src/leap/base/__init__.py (renamed from src/leap/utils/__init__.py)0
-rw-r--r--src/leap/base/auth.py376
-rw-r--r--src/leap/base/authentication.py11
-rw-r--r--src/leap/base/checks.py127
-rw-r--r--src/leap/base/config.py279
-rw-r--r--src/leap/base/connection.py115
-rw-r--r--src/leap/base/constants.py32
-rw-r--r--src/leap/base/exceptions.py77
-rw-r--r--src/leap/base/network.py84
-rw-r--r--src/leap/base/pluggableconfig.py421
-rw-r--r--src/leap/base/providers.py29
-rw-r--r--src/leap/base/specs.py59
-rw-r--r--src/leap/base/tests/__init__.py0
-rw-r--r--src/leap/base/tests/test_checks.py124
-rw-r--r--src/leap/base/tests/test_config.py247
-rw-r--r--src/leap/base/tests/test_providers.py143
-rw-r--r--src/leap/base/tests/test_validation.py92
-rw-r--r--src/leap/baseapp/config.py40
-rw-r--r--src/leap/baseapp/constants.py6
-rw-r--r--src/leap/baseapp/dialogs.py58
-rw-r--r--src/leap/baseapp/eip.py217
-rw-r--r--src/leap/baseapp/leap_app.py153
-rw-r--r--src/leap/baseapp/log.py65
-rw-r--r--src/leap/baseapp/mainwindow.py520
-rw-r--r--src/leap/baseapp/network.py40
-rw-r--r--src/leap/baseapp/permcheck.py17
-rw-r--r--src/leap/baseapp/systray.py245
-rw-r--r--src/leap/certs/__init__.py7
-rw-r--r--src/leap/crypto/__init__.py0
-rw-r--r--src/leap/crypto/certs.py71
-rw-r--r--src/leap/crypto/leapkeyring.py69
-rw-r--r--src/leap/eip/checks.py518
-rw-r--r--src/leap/eip/conductor.py272
-rw-r--r--src/leap/eip/config.py303
-rw-r--r--src/leap/eip/constants.py3
-rw-r--r--src/leap/eip/eipconnection.py350
-rw-r--r--src/leap/eip/exceptions.py156
-rw-r--r--src/leap/eip/openvpnconnection.py460
-rw-r--r--src/leap/eip/specs.py124
-rw-r--r--src/leap/eip/tests/__init__.py0
-rw-r--r--src/leap/eip/tests/data.py48
-rw-r--r--src/leap/eip/tests/test_checks.py367
-rw-r--r--src/leap/eip/tests/test_config.py153
-rw-r--r--src/leap/eip/tests/test_eipconnection.py191
-rw-r--r--src/leap/eip/tests/test_openvpnconnection.py147
-rw-r--r--src/leap/eip/udstelnet.py38
-rw-r--r--src/leap/eip/vpnmanager.py262
-rw-r--r--src/leap/eip/vpnwatcher.py169
-rw-r--r--src/leap/gui/__init__.py10
-rw-r--r--src/leap/gui/constants.py13
-rw-r--r--src/leap/gui/firstrun/__init__.py29
-rw-r--r--src/leap/gui/firstrun/connect.py231
-rw-r--r--src/leap/gui/firstrun/constants.py0
-rw-r--r--src/leap/gui/firstrun/intro.py68
-rw-r--r--src/leap/gui/firstrun/last.py92
-rw-r--r--src/leap/gui/firstrun/login.py330
-rw-r--r--src/leap/gui/firstrun/mixins.py18
-rw-r--r--src/leap/gui/firstrun/providerinfo.py98
-rw-r--r--src/leap/gui/firstrun/providerselect.py472
-rw-r--r--src/leap/gui/firstrun/providersetup.py174
-rw-r--r--src/leap/gui/firstrun/register.py368
-rw-r--r--src/leap/gui/firstrun/regvalidation.py204
-rwxr-xr-xsrc/leap/gui/firstrun/tests/integration/fake_provider.py295
-rwxr-xr-xsrc/leap/gui/firstrun/wizard.py286
-rw-r--r--src/leap/gui/mainwindow_rc.py1035
-rw-r--r--src/leap/gui/progress.py448
-rw-r--r--src/leap/gui/styles.py16
-rw-r--r--src/leap/gui/test_mainwindow_rc.py29
-rw-r--r--src/leap/gui/tests/integration/fake_user_signup.py84
-rw-r--r--src/leap/gui/threads.py21
-rw-r--r--src/leap/gui/utils.py34
-rw-r--r--src/leap/soledad/README13
-rw-r--r--src/leap/soledad/__init__.py164
-rw-r--r--src/leap/soledad/swiftclient/__init__.py5
-rw-r--r--src/leap/soledad/swiftclient/client.py1056
-rw-r--r--src/leap/soledad/swiftclient/openstack/__init__.py0
-rw-r--r--src/leap/soledad/swiftclient/openstack/common/__init__.py0
-rw-r--r--src/leap/soledad/swiftclient/openstack/common/setup.py342
-rw-r--r--src/leap/soledad/swiftclient/versioninfo1
-rw-r--r--src/leap/soledad/u1db/__init__.py697
-rw-r--r--src/leap/soledad/u1db/backends/__init__.py211
-rw-r--r--src/leap/soledad/u1db/backends/dbschema.sql42
-rw-r--r--src/leap/soledad/u1db/backends/inmemory.py469
-rw-r--r--src/leap/soledad/u1db/backends/sqlite_backend.py926
-rw-r--r--src/leap/soledad/u1db/commandline/__init__.py15
-rw-r--r--src/leap/soledad/u1db/commandline/client.py497
-rw-r--r--src/leap/soledad/u1db/commandline/command.py80
-rw-r--r--src/leap/soledad/u1db/commandline/serve.py34
-rw-r--r--src/leap/soledad/u1db/errors.py189
-rw-r--r--src/leap/soledad/u1db/query_parser.py370
-rw-r--r--src/leap/soledad/u1db/remote/__init__.py15
-rw-r--r--src/leap/soledad/u1db/remote/basic_auth_middleware.py68
-rw-r--r--src/leap/soledad/u1db/remote/http_app.py629
-rw-r--r--src/leap/soledad/u1db/remote/http_client.py218
-rw-r--r--src/leap/soledad/u1db/remote/http_database.py143
-rw-r--r--src/leap/soledad/u1db/remote/http_errors.py46
-rw-r--r--src/leap/soledad/u1db/remote/http_target.py135
-rw-r--r--src/leap/soledad/u1db/remote/oauth_middleware.py89
-rw-r--r--src/leap/soledad/u1db/remote/server_state.py67
-rw-r--r--src/leap/soledad/u1db/remote/ssl_match_hostname.py64
-rw-r--r--src/leap/soledad/u1db/remote/utils.py23
-rw-r--r--src/leap/soledad/u1db/sync.py304
-rw-r--r--src/leap/soledad/u1db/tests/__init__.py463
-rw-r--r--src/leap/soledad/u1db/tests/c_backend_wrapper.pyx1541
-rw-r--r--src/leap/soledad/u1db/tests/commandline/__init__.py47
-rw-r--r--src/leap/soledad/u1db/tests/commandline/test_client.py916
-rw-r--r--src/leap/soledad/u1db/tests/commandline/test_command.py105
-rw-r--r--src/leap/soledad/u1db/tests/commandline/test_serve.py101
-rw-r--r--src/leap/soledad/u1db/tests/test_auth_middleware.py309
-rw-r--r--src/leap/soledad/u1db/tests/test_backends.py1895
-rw-r--r--src/leap/soledad/u1db/tests/test_c_backend.py634
-rw-r--r--src/leap/soledad/u1db/tests/test_common_backend.py33
-rw-r--r--src/leap/soledad/u1db/tests/test_document.py148
-rw-r--r--src/leap/soledad/u1db/tests/test_errors.py61
-rw-r--r--src/leap/soledad/u1db/tests/test_http_app.py1133
-rw-r--r--src/leap/soledad/u1db/tests/test_http_client.py361
-rw-r--r--src/leap/soledad/u1db/tests/test_http_database.py256
-rw-r--r--src/leap/soledad/u1db/tests/test_https.py117
-rw-r--r--src/leap/soledad/u1db/tests/test_inmemory.py128
-rw-r--r--src/leap/soledad/u1db/tests/test_open.py69
-rw-r--r--src/leap/soledad/u1db/tests/test_query_parser.py443
-rw-r--r--src/leap/soledad/u1db/tests/test_remote_sync_target.py314
-rw-r--r--src/leap/soledad/u1db/tests/test_remote_utils.py36
-rw-r--r--src/leap/soledad/u1db/tests/test_server_state.py93
-rw-r--r--src/leap/soledad/u1db/tests/test_sqlite_backend.py493
-rw-r--r--src/leap/soledad/u1db/tests/test_sync.py1285
-rw-r--r--src/leap/soledad/u1db/tests/test_test_infrastructure.py41
-rw-r--r--src/leap/soledad/u1db/tests/test_vectorclock.py121
-rw-r--r--src/leap/soledad/u1db/tests/testing-certs/Makefile35
-rw-r--r--src/leap/soledad/u1db/tests/testing-certs/cacert.pem58
-rw-r--r--src/leap/soledad/u1db/tests/testing-certs/testing.cert61
-rw-r--r--src/leap/soledad/u1db/tests/testing-certs/testing.key16
-rw-r--r--src/leap/soledad/u1db/vectorclock.py89
-rw-r--r--src/leap/testing/__init__.py0
-rw-r--r--src/leap/testing/basetest.py85
-rw-r--r--src/leap/testing/cacert.pem23
-rw-r--r--src/leap/testing/https_server.py68
-rw-r--r--src/leap/testing/leaptestscert.pem84
-rw-r--r--src/leap/testing/leaptestskey.pem27
-rw-r--r--src/leap/testing/test_basetest.py91
-rw-r--r--src/leap/tests/fakeclient.py63
-rw-r--r--src/leap/tests/mocks/__init__.py1
-rw-r--r--src/leap/tests/mocks/manager.py20
-rw-r--r--src/leap/util/__init__.py0
-rw-r--r--src/leap/util/coroutines.py (renamed from src/leap/utils/coroutines.py)12
-rw-r--r--src/leap/util/dicts.py268
-rw-r--r--src/leap/util/fileutil.py115
-rw-r--r--src/leap/util/leap_argparse.py41
-rw-r--r--src/leap/util/tests/__init__.py0
-rw-r--r--src/leap/util/tests/test_fileutil.py100
-rw-r--r--src/leap/util/tests/test_leap_argparse.py35
-rw-r--r--src/leap/util/web.py39
-rw-r--r--src/leap/utils/leap_argparse.py20
156 files changed, 30117 insertions, 1238 deletions
diff --git a/src/leap/__init__.py b/src/leap/__init__.py
index e69de29b..5e003931 100644
--- a/src/leap/__init__.py
+++ b/src/leap/__init__.py
@@ -0,0 +1,35 @@
+"""
+LEAP Encryption Access Project
+website: U{https://leap.se/}
+"""
+
+from leap import eip
+from leap import baseapp
+from leap import util
+
+__all__ = [eip, baseapp, util]
+
+__version__ = "unknown"
+try:
+ from ._version import get_versions
+ __version__ = get_versions()['version']
+ del get_versions
+except ImportError:
+ #running on a tree that has not run
+ #the setup.py setver
+ pass
+
+__appname__ = "unknown"
+try:
+ from leap._appname import __appname__
+except ImportError:
+ #running on a tree that has not run
+ #the setup.py setver
+ pass
+
+__full_version__ = __appname__ + '/' + str(__version__)
+
+try:
+ from leap._branding import BRANDING as __branding
+except ImportError:
+ __branding = {}
diff --git a/src/leap/_version.py b/src/leap/_version.py
new file mode 100644
index 00000000..c33430ea
--- /dev/null
+++ b/src/leap/_version.py
@@ -0,0 +1,197 @@
+
+IN_LONG_VERSION_PY = True
+# This file helps to compute a version number in source trees obtained from
+# git-archive tarball (such as those provided by githubs download-from-tag
+# feature). Distribution tarballs (build by setup.py sdist) and build
+# directories (produced by setup.py build) will contain a much shorter file
+# that just contains the computed version number.
+
+# This file is released into the public domain. Generated by
+# versioneer-0.7+ (https://github.com/warner/python-versioneer)
+
+# these strings will be replaced by git during git-archive
+git_refnames = "$Format:%d$"
+git_full = "$Format:%H$"
+
+
+import subprocess
+import sys
+
+def run_command(args, cwd=None, verbose=False):
+ try:
+ # remember shell=False, so use git.cmd on windows, not just git
+ p = subprocess.Popen(args, stdout=subprocess.PIPE, cwd=cwd)
+ except EnvironmentError:
+ e = sys.exc_info()[1]
+ if verbose:
+ print("unable to run %s" % args[0])
+ print(e)
+ return None
+ stdout = p.communicate()[0].strip()
+ if sys.version >= '3':
+ stdout = stdout.decode()
+ if p.returncode != 0:
+ if verbose:
+ print("unable to run %s (error)" % args[0])
+ return None
+ return stdout
+
+
+import sys
+import re
+import os.path
+
+def get_expanded_variables(versionfile_source):
+ # the code embedded in _version.py can just fetch the value of these
+ # variables. When used from setup.py, we don't want to import
+ # _version.py, so we do it with a regexp instead. This function is not
+ # used from _version.py.
+ variables = {}
+ try:
+ for line in open(versionfile_source,"r").readlines():
+ if line.strip().startswith("git_refnames ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ variables["refnames"] = mo.group(1)
+ if line.strip().startswith("git_full ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ variables["full"] = mo.group(1)
+ except EnvironmentError:
+ pass
+ return variables
+
+def versions_from_expanded_variables(variables, tag_prefix, verbose=False):
+ refnames = variables["refnames"].strip()
+ if refnames.startswith("$Format"):
+ if verbose:
+ print("variables are unexpanded, not using")
+ return {} # unexpanded, so not in an unpacked git-archive tarball
+ refs = set([r.strip() for r in refnames.strip("()").split(",")])
+ for ref in list(refs):
+ if not re.search(r'\d', ref):
+ if verbose:
+ print("discarding '%s', no digits" % ref)
+ refs.discard(ref)
+ # Assume all version tags have a digit. git's %d expansion
+ # behaves like git log --decorate=short and strips out the
+ # refs/heads/ and refs/tags/ prefixes that would let us
+ # distinguish between branches and tags. By ignoring refnames
+ # without digits, we filter out many common branch names like
+ # "release" and "stabilization", as well as "HEAD" and "master".
+ if verbose:
+ print("remaining refs: %s" % ",".join(sorted(refs)))
+ for ref in sorted(refs):
+ # sorting will prefer e.g. "2.0" over "2.0rc1"
+ if ref.startswith(tag_prefix):
+ r = ref[len(tag_prefix):]
+ if verbose:
+ print("picking %s" % r)
+ return { "version": r,
+ "full": variables["full"].strip() }
+ # no suitable tags, so we use the full revision id
+ if verbose:
+ print("no suitable tags, using full revision id")
+ return { "version": variables["full"].strip(),
+ "full": variables["full"].strip() }
+
+def versions_from_vcs(tag_prefix, versionfile_source, verbose=False):
+ # this runs 'git' from the root of the source tree. That either means
+ # someone ran a setup.py command (and this code is in versioneer.py, so
+ # IN_LONG_VERSION_PY=False, thus the containing directory is the root of
+ # the source tree), or someone ran a project-specific entry point (and
+ # this code is in _version.py, so IN_LONG_VERSION_PY=True, thus the
+ # containing directory is somewhere deeper in the source tree). This only
+ # gets called if the git-archive 'subst' variables were *not* expanded,
+ # and _version.py hasn't already been rewritten with a short version
+ # string, meaning we're inside a checked out source tree.
+
+ try:
+ here = os.path.abspath(__file__)
+ except NameError:
+ # some py2exe/bbfreeze/non-CPython implementations don't do __file__
+ return {} # not always correct
+
+ # versionfile_source is the relative path from the top of the source tree
+ # (where the .git directory might live) to this file. Invert this to find
+ # the root from __file__.
+ root = here
+ if IN_LONG_VERSION_PY:
+ for i in range(len(versionfile_source.split("/"))):
+ root = os.path.dirname(root)
+ else:
+ root = os.path.dirname(here)
+ if not os.path.exists(os.path.join(root, ".git")):
+ if verbose:
+ print("no .git in %s" % root)
+ return {}
+
+ GIT = "git"
+ if sys.platform == "win32":
+ GIT = "git.cmd"
+ stdout = run_command([GIT, "describe", "--tags", "--dirty", "--always"],
+ cwd=root)
+ if stdout is None:
+ return {}
+ if not stdout.startswith(tag_prefix):
+ if verbose:
+ print("tag '%s' doesn't start with prefix '%s'" % (stdout, tag_prefix))
+ return {}
+ tag = stdout[len(tag_prefix):]
+ stdout = run_command([GIT, "rev-parse", "HEAD"], cwd=root)
+ if stdout is None:
+ return {}
+ full = stdout.strip()
+ if tag.endswith("-dirty"):
+ full += "-dirty"
+ return {"version": tag, "full": full}
+
+
+def versions_from_parentdir(parentdir_prefix, versionfile_source, verbose=False):
+ if IN_LONG_VERSION_PY:
+ # We're running from _version.py. If it's from a source tree
+ # (execute-in-place), we can work upwards to find the root of the
+ # tree, and then check the parent directory for a version string. If
+ # it's in an installed application, there's no hope.
+ try:
+ here = os.path.abspath(__file__)
+ except NameError:
+ # py2exe/bbfreeze/non-CPython don't have __file__
+ return {} # without __file__, we have no hope
+ # versionfile_source is the relative path from the top of the source
+ # tree to _version.py. Invert this to find the root from __file__.
+ root = here
+ for i in range(len(versionfile_source.split("/"))):
+ root = os.path.dirname(root)
+ else:
+ # we're running from versioneer.py, which means we're running from
+ # the setup.py in a source tree. sys.argv[0] is setup.py in the root.
+ here = os.path.abspath(sys.argv[0])
+ root = os.path.dirname(here)
+
+ # Source tarballs conventionally unpack into a directory that includes
+ # both the project name and a version string.
+ dirname = os.path.basename(root)
+ if not dirname.startswith(parentdir_prefix):
+ if verbose:
+ print("guessing rootdir is '%s', but '%s' doesn't start with prefix '%s'" %
+ (root, dirname, parentdir_prefix))
+ return None
+ return {"version": dirname[len(parentdir_prefix):], "full": ""}
+
+tag_prefix = ""
+parentdir_prefix = "leap_client-"
+versionfile_source = "src/leap/_version.py"
+
+def get_versions(default={"version": "unknown", "full": ""}, verbose=False):
+ variables = { "refnames": git_refnames, "full": git_full }
+ ver = versions_from_expanded_variables(variables, tag_prefix, verbose)
+ if not ver:
+ ver = versions_from_vcs(tag_prefix, versionfile_source, verbose)
+ if not ver:
+ ver = versions_from_parentdir(parentdir_prefix, versionfile_source,
+ verbose)
+ if not ver:
+ ver = default
+ return ver
+
diff --git a/src/leap/app.py b/src/leap/app.py
index 0a61fd4f..d594c7cd 100644
--- a/src/leap/app.py
+++ b/src/leap/app.py
@@ -1,12 +1,24 @@
+# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4
+from functools import partial
import logging
+import signal
+
# This is only needed for Python v2 but is harmless for Python v3.
import sip
sip.setapi('QVariant', 2)
+sip.setapi('QString', 2)
from PyQt4.QtGui import (QApplication, QSystemTrayIcon, QMessageBox)
+from PyQt4.QtCore import QTimer
+from leap import __version__ as VERSION
from leap.baseapp.mainwindow import LeapWindow
-logger = logging.getLogger(name=__name__)
+
+def sigint_handler(*args, **kwargs):
+ logger = kwargs.get('logger', None)
+ logger.debug('SIGINT catched. shutting down...')
+ mainwindow = args[0]
+ mainwindow.shutdownSignal.emit()
def main():
@@ -15,26 +27,77 @@ def main():
long live to the (hidden) leap window!
"""
import sys
- from leap.utils import leap_argparse
+ from leap.util import leap_argparse
parser, opts = leap_argparse.init_leapc_args()
debug = getattr(opts, 'debug', False)
- #XXX get debug level and set logger accordingly
+ # XXX get severity from command line args
if debug:
- logger.debug('args: ', opts)
+ level = logging.DEBUG
+ else:
+ level = logging.WARNING
+
+ logger = logging.getLogger(name='leap')
+ logger.setLevel(level)
+ console = logging.StreamHandler()
+ console.setLevel(level)
+ formatter = logging.Formatter(
+ '%(asctime)s '
+ '- %(name)s - %(levelname)s - %(message)s')
+ console.setFormatter(formatter)
+ logger.addHandler(console)
+ #logger.debug(opts)
+ logger.info('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
+ logger.info('LEAP client version %s', VERSION)
+ logger.info('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
+ logfile = getattr(opts, 'log_file', False)
+ if logfile:
+ logger.debug('setting logfile to %s ', logfile)
+ fileh = logging.FileHandler(logfile)
+ fileh.setLevel(logging.DEBUG)
+ fileh.setFormatter(formatter)
+ logger.addHandler(fileh)
+
+ logger.info('Starting app')
app = QApplication(sys.argv)
+ # needed for initializing qsettings
+ # it will write .config/leap/leap.conf
+ # top level app settings
+ # in a platform independent way
+ app.setOrganizationName("leap")
+ app.setApplicationName("leap")
+ app.setOrganizationDomain("leap.se")
+
if not QSystemTrayIcon.isSystemTrayAvailable():
QMessageBox.critical(None, "Systray",
- "I couldn't detect any \
-system tray on this system.")
+ "I couldn't detect"
+ "any system tray on this system.")
sys.exit(1)
if not debug:
QApplication.setQuitOnLastWindowClosed(False)
window = LeapWindow(opts)
- window.show()
+
+ # this dummy timer ensures that
+ # control is given to the outside loop, so we
+ # can hook our sigint handler.
+ timer = QTimer()
+ timer.start(500)
+ timer.timeout.connect(lambda: None)
+
+ sigint_window = partial(sigint_handler, window, logger=logger)
+ signal.signal(signal.SIGINT, sigint_window)
+
+ if debug:
+ # we only show the main window
+ # if debug mode active.
+ # if not, it will be set visible
+ # from the systray menu.
+ window.show()
+
+ # run main loop
sys.exit(app.exec_())
if __name__ == "__main__":
diff --git a/src/leap/utils/__init__.py b/src/leap/base/__init__.py
index e69de29b..e69de29b 100644
--- a/src/leap/utils/__init__.py
+++ b/src/leap/base/__init__.py
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py
new file mode 100644
index 00000000..50533278
--- /dev/null
+++ b/src/leap/base/auth.py
@@ -0,0 +1,376 @@
+import binascii
+import json
+import logging
+#import urlparse
+
+import requests
+import srp
+
+from PyQt4 import QtCore
+
+from leap.base import constants as baseconstants
+from leap.crypto import leapkeyring
+from leap.util.web import get_https_domain_and_port
+
+logger = logging.getLogger(__name__)
+
+SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5)
+
+"""
+Registration and authentication classes for the
+SRP auth mechanism used in the leap platform.
+
+We're using the srp library which uses a c-based implementation
+of the protocol if the c extension is available, and a python-based
+one if not.
+"""
+
+
+class ImproperlyConfigured(Exception):
+ """
+ """
+
+
+class SRPAuthenticationError(Exception):
+ """
+ exception raised
+ for authentication errors
+ """
+
+
+def null_check(value, value_name):
+ try:
+ assert value is not None
+ except AssertionError:
+ raise ImproperlyConfigured(
+ "%s parameter cannot be None" % value_name)
+
+
+safe_unhexlify = lambda x: binascii.unhexlify(x) \
+ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
+
+
+class LeapSRPRegister(object):
+
+ def __init__(self,
+ schema="https",
+ provider=None,
+ port=None,
+ verify=True,
+ register_path="1/users.json",
+ method="POST",
+ fetcher=requests,
+ srp=srp,
+ hashfun=srp.SHA256,
+ ng_constant=srp.NG_1024):
+
+ null_check(provider, provider)
+
+ self.schema = schema
+
+ # XXX FIXME
+ self.provider = provider
+ self.port = port
+ # XXX splitting server,port
+ # deprecate port call.
+ domain, port = get_https_domain_and_port(provider)
+ self.provider = domain
+ self.port = port
+
+ self.verify = verify
+ self.register_path = register_path
+ self.method = method
+ self.fetcher = fetcher
+ self.srp = srp
+ self.HASHFUN = hashfun
+ self.NG = ng_constant
+
+ self.init_session()
+
+ def init_session(self):
+ self.session = self.fetcher.session()
+
+ def get_registration_uri(self):
+ # XXX assert is https!
+ # use urlparse
+ if self.port:
+ uri = "%s://%s:%s/%s" % (
+ self.schema,
+ self.provider,
+ self.port,
+ self.register_path)
+ else:
+ uri = "%s://%s/%s" % (
+ self.schema,
+ self.provider,
+ self.register_path)
+
+ return uri
+
+ def register_user(self, username, password, keep=False):
+ """
+ @rtype: tuple
+ @rparam: (ok, request)
+ """
+ salt, vkey = self.srp.create_salted_verification_key(
+ username,
+ password,
+ self.HASHFUN,
+ self.NG)
+
+ user_data = {
+ 'user[login]': username,
+ 'user[password_verifier]': binascii.hexlify(vkey),
+ 'user[password_salt]': binascii.hexlify(salt)}
+
+ uri = self.get_registration_uri()
+ logger.debug('post to uri: %s' % uri)
+
+ # XXX get self.method
+ req = self.session.post(
+ uri, data=user_data,
+ timeout=SIGNUP_TIMEOUT,
+ verify=self.verify)
+ logger.debug(req)
+ logger.debug('user_data: %s', user_data)
+ #logger.debug('response: %s', req.text)
+ # we catch it in the form
+ #req.raise_for_status()
+ return (req.ok, req)
+
+
+class SRPAuth(requests.auth.AuthBase):
+
+ def __init__(self, username, password, server=None, verify=None):
+ # sanity check
+ null_check(server, 'server')
+ self.username = username
+ self.password = password
+ self.server = server
+ self.verify = verify
+
+ self.init_data = None
+ self.session = requests.session()
+
+ self.init_srp()
+
+ def get_json_data(self, response):
+ return json.loads(response.content)
+
+ def init_srp(self):
+ usr = srp.User(
+ self.username,
+ self.password,
+ srp.SHA256,
+ srp.NG_1024)
+ uname, A = usr.start_authentication()
+
+ self.srp_usr = usr
+ self.A = A
+
+ def get_auth_data(self):
+ return {
+ 'login': self.username,
+ 'A': binascii.hexlify(self.A)
+ }
+
+ def get_init_data(self):
+ try:
+ init_session = self.session.post(
+ self.server + '/1/sessions.json/',
+ data=self.get_auth_data(),
+ verify=self.verify)
+ except requests.exceptions.ConnectionError:
+ raise SRPAuthenticationError(
+ "No connection made (salt).")
+ if init_session.status_code not in (200, ):
+ raise SRPAuthenticationError(
+ "No valid response (salt).")
+
+ # XXX should get auth_result.json instead
+ self.init_data = self.get_json_data(init_session)
+ return self.init_data
+
+ def get_server_proof_data(self):
+ try:
+ auth_result = self.session.put(
+ #self.server + '/1/sessions.json/' + self.username,
+ self.server + '/1/sessions/' + self.username,
+ data={'client_auth': binascii.hexlify(self.M)},
+ verify=self.verify)
+ except requests.exceptions.ConnectionError:
+ raise SRPAuthenticationError(
+ "No connection made (HAMK).")
+
+ if auth_result.status_code not in (200, ):
+ raise SRPAuthenticationError(
+ "No valid response (HAMK).")
+
+ # XXX should get auth_result.json instead
+ try:
+ self.auth_data = self.get_json_data(auth_result)
+ except ValueError:
+ raise SRPAuthenticationError(
+ "No valid data sent (HAMK)")
+
+ return self.auth_data
+
+ def authenticate(self):
+ logger.debug('start authentication...')
+
+ init_data = self.get_init_data()
+ salt = init_data.get('salt', None)
+ B = init_data.get('B', None)
+
+ # XXX refactor this function
+ # move checks and un-hex
+ # to routines
+
+ if not salt or not B:
+ raise SRPAuthenticationError(
+ "Server did not send initial data.")
+
+ try:
+ unhex_salt = safe_unhexlify(salt)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (salt)")
+ try:
+ unhex_B = safe_unhexlify(B)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (B)")
+
+ self.M = self.srp_usr.process_challenge(
+ unhex_salt,
+ unhex_B
+ )
+
+ proof_data = self.get_server_proof_data()
+
+ HAMK = proof_data.get("M2", None)
+ if not HAMK:
+ errors = proof_data.get('errors', None)
+ if errors:
+ logger.error(errors)
+ raise SRPAuthenticationError("Server did not send HAMK.")
+
+ try:
+ unhex_HAMK = safe_unhexlify(HAMK)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (HAMK)")
+
+ self.srp_usr.verify_session(
+ unhex_HAMK)
+
+ try:
+ assert self.srp_usr.authenticated()
+ logger.debug('user is authenticated!')
+ except (AssertionError):
+ raise SRPAuthenticationError(
+ "Auth verification failed.")
+
+ def __call__(self, req):
+ self.authenticate()
+ req.session = self.session
+ return req
+
+
+def srpauth_protected(user=None, passwd=None, server=None, verify=True):
+ """
+ decorator factory that accepts
+ user and password keyword arguments
+ and add those to the decorated request
+ """
+ def srpauth(fn):
+ def wrapper(*args, **kwargs):
+ if user and passwd:
+ auth = SRPAuth(user, passwd, server, verify)
+ kwargs['auth'] = auth
+ kwargs['verify'] = verify
+ return fn(*args, **kwargs)
+ return wrapper
+ return srpauth
+
+
+def get_leap_credentials():
+ settings = QtCore.QSettings()
+ full_username = settings.value('eip_username')
+ username, domain = full_username.split('@')
+ seed = settings.value('%s_seed' % domain, None)
+ password = leapkeyring.leap_get_password(full_username, seed=seed)
+ return (username, password)
+
+
+# XXX TODO
+# Pass verify as single argument,
+# in srpauth_protected style
+
+def magick_srpauth(fn):
+ """
+ decorator that gets user and password
+ from the config file and adds those to
+ the decorated request
+ """
+ logger.debug('magick srp auth decorator called')
+
+ def wrapper(*args, **kwargs):
+ #uri = args[0]
+ # XXX Ugh!
+ # Problem with this approach.
+ # This won't work when we're using
+ # api.foo.bar
+ # Unless we keep a table with the
+ # equivalencies...
+ user, passwd = get_leap_credentials()
+
+ # XXX pass verify and server too
+ # (pop)
+ auth = SRPAuth(user, passwd)
+ kwargs['auth'] = auth
+ return fn(*args, **kwargs)
+ return wrapper
+
+
+if __name__ == "__main__":
+ """
+ To test against test_provider (twisted version)
+ Register an user: (will be valid during the session)
+ >>> python auth.py add test password
+
+ Test login with that user:
+ >>> python auth.py login test password
+ """
+
+ import sys
+
+ if len(sys.argv) not in (4, 5):
+ print 'Usage: auth <add|login> <user> <pass> [server]'
+ sys.exit(0)
+
+ action = sys.argv[1]
+ user = sys.argv[2]
+ passwd = sys.argv[3]
+
+ if len(sys.argv) == 5:
+ SERVER = sys.argv[4]
+ else:
+ SERVER = "https://localhost:8443"
+
+ if action == "login":
+
+ @srpauth_protected(
+ user=user, passwd=passwd, server=SERVER, verify=False)
+ def test_srp_protected_get(*args, **kwargs):
+ req = requests.get(*args, **kwargs)
+ req.raise_for_status
+ return req
+
+ req = test_srp_protected_get('https://localhost:8443/1/cert')
+ print 'cert :', req.content[:200] + "..."
+ sys.exit(0)
+
+ if action == "add":
+ auth = LeapSRPRegister(provider=SERVER, verify=False)
+ auth.register_user(user, passwd)
diff --git a/src/leap/base/authentication.py b/src/leap/base/authentication.py
new file mode 100644
index 00000000..09ff1d07
--- /dev/null
+++ b/src/leap/base/authentication.py
@@ -0,0 +1,11 @@
+"""
+Authentication Base Class
+"""
+
+
+class Authentication(object):
+ """
+ I have no idea how Authentication (certs,?)
+ will be done, but stub it here.
+ """
+ pass
diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py
new file mode 100644
index 00000000..23446f4a
--- /dev/null
+++ b/src/leap/base/checks.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+import logging
+import platform
+import socket
+
+import netifaces
+import ping
+import requests
+
+from leap.base import constants
+from leap.base import exceptions
+
+logger = logging.getLogger(name=__name__)
+
+
+class LeapNetworkChecker(object):
+ """
+ all network related checks
+ """
+ def __init__(self, *args, **kwargs):
+ provider_gw = kwargs.pop('provider_gw', None)
+ self.provider_gateway = provider_gw
+
+ def run_all(self, checker=None):
+ if not checker:
+ checker = self
+ #self.error = None # ?
+
+ # for MVS
+ checker.check_tunnel_default_interface()
+ checker.check_internet_connection()
+ checker.is_internet_up()
+
+ if self.provider_gateway:
+ checker.ping_gateway(self.provider_gateway)
+
+ def check_internet_connection(self):
+ try:
+ # XXX remove this hardcoded random ip
+ # ping leap.se or eip provider instead...?
+ requests.get('http://216.172.161.165')
+
+ except (requests.HTTPError, requests.RequestException) as e:
+ raise exceptions.NoInternetConnection(e.message)
+ except requests.ConnectionError as e:
+ error = "Unidentified Connection Error"
+ if e.message == "[Errno 113] No route to host":
+ if not self.is_internet_up():
+ error = "No valid internet connection found."
+ else:
+ error = "Provider server appears to be down."
+ logger.error(error)
+ raise exceptions.NoInternetConnection(error)
+ logger.debug('Network appears to be up.')
+
+ def is_internet_up(self):
+ iface, gateway = self.get_default_interface_gateway()
+ self.ping_gateway(self.provider_gateway)
+
+ def check_tunnel_default_interface(self):
+ """
+ Raises an TunnelNotDefaultRouteError
+ (including when no routes are present)
+ """
+ if not platform.system() == "Linux":
+ raise NotImplementedError
+
+ f = open("/proc/net/route")
+ route_table = f.readlines()
+ f.close()
+ #toss out header
+ route_table.pop(0)
+
+ if not route_table:
+ raise exceptions.TunnelNotDefaultRouteError()
+
+ line = route_table.pop(0)
+ iface, destination = line.split('\t')[0:2]
+ if not destination == '00000000' or not iface == 'tun0':
+ raise exceptions.TunnelNotDefaultRouteError()
+
+ def get_default_interface_gateway(self):
+ """only impletemented for linux so far."""
+ if not platform.system() == "Linux":
+ raise NotImplementedError
+
+ # XXX use psutil
+ f = open("/proc/net/route")
+ route_table = f.readlines()
+ f.close()
+ #toss out header
+ route_table.pop(0)
+
+ default_iface = None
+ gateway = None
+ while route_table:
+ line = route_table.pop(0)
+ iface, destination, gateway = line.split('\t')[0:3]
+ if destination == '00000000':
+ default_iface = iface
+ break
+
+ if not default_iface:
+ raise exceptions.NoDefaultInterfaceFoundError
+
+ if default_iface not in netifaces.interfaces():
+ raise exceptions.InterfaceNotFoundError
+
+ return default_iface, gateway
+
+ def ping_gateway(self, gateway):
+ # TODO: Discuss how much packet loss (%) is acceptable.
+
+ # XXX -- validate gateway
+ # -- is it a valid ip? (there's something in util)
+ # -- is it a domain?
+ # -- can we resolve? -- raise NoDNSError if not.
+ packet_loss = ping.quiet_ping(gateway)[0]
+ if packet_loss > constants.MAX_ICMP_PACKET_LOSS:
+ raise exceptions.NoConnectionToGateway
+
+ def check_name_resolution(self, domain_name):
+ try:
+ socket.gethostbyname(domain_name)
+ return True
+ except socket.gaierror:
+ raise exceptions.CannotResolveDomainError
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
new file mode 100644
index 00000000..0255fbab
--- /dev/null
+++ b/src/leap/base/config.py
@@ -0,0 +1,279 @@
+"""
+Configuration Base Class
+"""
+import grp
+import json
+import logging
+import socket
+import tempfile
+import os
+
+logger = logging.getLogger(name=__name__)
+
+import requests
+
+from leap.base import exceptions
+from leap.base import constants
+from leap.base.pluggableconfig import PluggableConfig
+from leap.util.fileutil import (mkdir_p)
+
+# move to base!
+from leap.eip import exceptions as eipexceptions
+
+
+class BaseLeapConfig(object):
+ slug = None
+
+ # XXX we have to enforce that every derived class
+ # has a slug (via interface)
+ # get property getter that raises NI..
+
+ def save(self):
+ raise NotImplementedError("abstract base class")
+
+ def load(self):
+ raise NotImplementedError("abstract base class")
+
+ def get_config(self, *kwargs):
+ raise NotImplementedError("abstract base class")
+
+ @property
+ def config(self):
+ return self.get_config()
+
+ def get_value(self, *kwargs):
+ raise NotImplementedError("abstract base class")
+
+
+class MetaConfigWithSpec(type):
+ """
+ metaclass for JSONLeapConfig classes.
+ It creates a configuration spec out of
+ the `spec` dictionary. The `properties` attribute
+ of the spec dict is turn into the `schema` attribute
+ of the new class (which will be used to validate against).
+ """
+ # XXX in the near future, this is the
+ # place where we want to enforce
+ # singletons, read-only and similar stuff.
+
+ def __new__(meta, classname, bases, classDict):
+ schema_obj = classDict.get('spec', None)
+
+ # not quite happy with this workaround.
+ # I want to raise if missing spec dict, but only
+ # for grand-children of this metaclass.
+ # maybe should use abc module for this.
+ abcderived = ("JSONLeapConfig",)
+ if schema_obj is None and classname not in abcderived:
+ raise exceptions.ImproperlyConfigured(
+ "missing spec dict on your derived class (%s)" % classname)
+
+ # we create a configuration spec attribute
+ # from the spec dict
+ config_class = type(
+ classname + "Spec",
+ (PluggableConfig, object),
+ {'options': schema_obj})
+ classDict['spec'] = config_class
+
+ return type.__new__(meta, classname, bases, classDict)
+
+##########################################################
+# some hacking still in progress:
+
+# Configs have:
+
+# - a slug (from where a filename/folder is derived)
+# - a spec (for validation and defaults).
+# this spec is conformant to the json-schema.
+# basically a dict that will be used
+# for type casting and validation, and defaults settings.
+
+# all config objects, since they are derived from BaseConfig, implement basic
+# useful methods:
+# - save
+# - load
+
+##########################################################
+
+
+class JSONLeapConfig(BaseLeapConfig):
+
+ __metaclass__ = MetaConfigWithSpec
+
+ def __init__(self, *args, **kwargs):
+ # sanity check
+ try:
+ assert self.slug is not None
+ except AssertionError:
+ raise exceptions.ImproperlyConfigured(
+ "missing slug on JSONLeapConfig"
+ " derived class")
+ try:
+ assert self.spec is not None
+ except AssertionError:
+ raise exceptions.ImproperlyConfigured(
+ "missing spec on JSONLeapConfig"
+ " derived class")
+ assert issubclass(self.spec, PluggableConfig)
+
+ self.domain = kwargs.pop('domain', None)
+ self._config = self.spec(format="json")
+ self._config.load()
+ self.fetcher = kwargs.pop('fetcher', requests)
+
+ # mandatory baseconfig interface
+
+ def save(self, to=None):
+ if to is None:
+ to = self.filename
+ folder, filename = os.path.split(to)
+ if folder and not os.path.isdir(folder):
+ mkdir_p(folder)
+ self._config.serialize(to)
+
+ def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False):
+ if from_uri is not None:
+ fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify)
+ if fetched:
+ return
+ if fromfile is None:
+ fromfile = self.filename
+ if os.path.isfile(fromfile):
+ self._config.load(fromfile=fromfile)
+ else:
+ logger.error('tried to load config from non-existent path')
+ logger.error('Not Found: %s', fromfile)
+
+ def fetch(self, uri, fetcher=None, verify=True):
+ if not fetcher:
+ fetcher = self.fetcher
+ logger.debug('verify: %s', verify)
+ logger.debug('uri: %s', uri)
+ request = fetcher.get(uri, verify=verify)
+ # XXX should send a if-modified-since header
+
+ # XXX get 404, ...
+ # and raise a UnableToFetch...
+ request.raise_for_status()
+ fd, fname = tempfile.mkstemp(suffix=".json")
+
+ if request.json:
+ self._config.load(json.dumps(request.json))
+
+ else:
+ # not request.json
+ # might be server did not announce content properly,
+ # let's try deserializing all the same.
+ try:
+ self._config.load(request.content)
+ except ValueError:
+ raise eipexceptions.LeapBadConfigFetchedError
+
+ return True
+
+ def get_config(self):
+ return self._config.config
+
+ # public methods
+
+ def get_filename(self):
+ return self._slug_to_filename()
+
+ @property
+ def filename(self):
+ return self.get_filename()
+
+ def validate(self, data):
+ logger.debug('validating schema')
+ self._config.validate(data)
+ return True
+
+ # private
+
+ def _slug_to_filename(self):
+ # is this going to work in winland if slug is "foo/bar" ?
+ folder, filename = os.path.split(self.slug)
+ config_file = get_config_file(filename, folder)
+ return config_file
+
+ def exists(self):
+ return os.path.isfile(self.filename)
+
+
+#
+# utility functions
+#
+# (might be moved to some class as we see fit, but
+# let's remain functional for a while)
+# maybe base.config.util ??
+#
+
+
+def get_config_dir():
+ """
+ get the base dir for all leap config
+ @rparam: config path
+ @rtype: string
+ """
+ # TODO
+ # check for $XDG_CONFIG_HOME var?
+ # get a more sensible path for win/mac
+ # kclair: opinion? ^^
+
+ return os.path.expanduser(
+ os.path.join('~',
+ '.config',
+ 'leap'))
+
+
+def get_config_file(filename, folder=None):
+ """
+ concatenates the given filename
+ with leap config dir.
+ @param filename: name of the file
+ @type filename: string
+ @rparam: full path to config file
+ """
+ path = []
+ path.append(get_config_dir())
+ if folder is not None:
+ path.append(folder)
+ path.append(filename)
+ return os.path.join(*path)
+
+
+def get_default_provider_path():
+ default_subpath = os.path.join("providers",
+ constants.DEFAULT_PROVIDER)
+ default_provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return default_provider_path
+
+
+def get_provider_path(domain):
+ # XXX if not domain, return get_default_provider_path
+ default_subpath = os.path.join("providers", domain)
+ provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return provider_path
+
+
+def validate_ip(ip_str):
+ """
+ raises exception if the ip_str is
+ not a valid representation of an ip
+ """
+ socket.inet_aton(ip_str)
+
+
+def get_username():
+ return os.getlogin()
+
+
+def get_groupname():
+ gid = os.getgroups()[-1]
+ return grp.getgrgid(gid).gr_name
diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py
new file mode 100644
index 00000000..41d13935
--- /dev/null
+++ b/src/leap/base/connection.py
@@ -0,0 +1,115 @@
+"""
+Base Connection Classs
+"""
+from __future__ import (division, unicode_literals, print_function)
+
+import logging
+
+from leap.base.authentication import Authentication
+
+logger = logging.getLogger(name=__name__)
+
+
+class Connection(Authentication):
+ # JSONLeapConfig
+ #spec = {}
+
+ def __init__(self, *args, **kwargs):
+ self.connection_state = None
+ self.desired_connection_state = None
+ #XXX FIXME diamond inheritance gotcha..
+ #If you inherit from >1 class,
+ #super is only initializing one
+ #of the bases..!!
+ # I think we better pass config as a constructor
+ # parameter -- kali 2012-08-30 04:33
+ super(Connection, self).__init__(*args, **kwargs)
+
+ def connect(self):
+ """
+ entry point for connection process
+ """
+ pass
+
+ def disconnect(self):
+ """
+ disconnects client
+ """
+ pass
+
+ #def shutdown(self):
+ #"""
+ #shutdown and quit
+ #"""
+ #self.desired_con_state = self.status.DISCONNECTED
+
+ def connection_state(self):
+ """
+ returns the current connection state
+ """
+ return self.status.current
+
+ def desired_connection_state(self):
+ """
+ returns the desired_connection state
+ """
+ return self.desired_connection_state
+
+ def get_icon_name(self):
+ """
+ get icon name from status object
+ """
+ return self.status.get_state_icon()
+
+ #
+ # private methods
+ #
+
+ def _disconnect(self):
+ """
+ private method for disconnecting
+ """
+ if self.subp is not None:
+ self.subp.terminate()
+ self.subp = None
+ # XXX signal state changes! :)
+
+ def _is_alive(self):
+ """
+ don't know yet
+ """
+ pass
+
+ def _connect(self):
+ """
+ entry point for connection cascade methods.
+ """
+ #conn_result = ConState.DISCONNECTED
+ try:
+ conn_result = self._try_connection()
+ except UnrecoverableError as except_msg:
+ logger.error("FATAL: %s" % unicode(except_msg))
+ conn_result = self.status.UNRECOVERABLE
+ except Exception as except_msg:
+ self.error_queue.append(except_msg)
+ logger.error("Failed Connection: %s" %
+ unicode(except_msg))
+ return conn_result
+
+
+class ConnectionError(Exception):
+ """
+ generic connection error
+ """
+ def __str__(self):
+ if len(self.args) >= 1:
+ return repr(self.args[0])
+ else:
+ raise self()
+
+
+class UnrecoverableError(ConnectionError):
+ """
+ we cannot do anything about it, sorry
+ """
+ pass
diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py
new file mode 100644
index 00000000..f7be8d98
--- /dev/null
+++ b/src/leap/base/constants.py
@@ -0,0 +1,32 @@
+"""constants to be used in base module"""
+from leap import __branding
+APP_NAME = __branding.get("short_name", "leap")
+
+# default provider placeholder
+# using `example.org` we make sure that this
+# is not going to be resolved during the tests phases
+# (we expect testers to add it to their /etc/hosts
+
+DEFAULT_PROVIDER = __branding.get(
+ "provider_domain",
+ "testprovider.example.org")
+
+DEFINITION_EXPECTED_PATH = "provider.json"
+
+DEFAULT_PROVIDER_DEFINITION = {
+ u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER,
+ u'api_version': u'0.1.0',
+ u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e',
+ u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER,
+ u'description': {u'en': u'This is a test provider'},
+ u'display_name': {u'en': u'Test Provider'},
+ u'domain': u'%s' % DEFAULT_PROVIDER,
+ u'enrollment_policy': u'open',
+ u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21',
+ u'serial': 1,
+ u'services': [u'eip'],
+ u'version': u'0.1.0'}
+
+MAX_ICMP_PACKET_LOSS = 10
+
+ROUTE_CHECK_INTERVAL = 10
diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py
new file mode 100644
index 00000000..227da953
--- /dev/null
+++ b/src/leap/base/exceptions.py
@@ -0,0 +1,77 @@
+"""
+Exception attributes and their meaning/uses
+-------------------------------------------
+
+* critical: if True, will abort execution prematurely,
+ after attempting any cleaning
+ action.
+
+* failfirst: breaks any error_check loop that is examining
+ the error queue.
+
+* message: the message that will be used in the __repr__ of the exception.
+
+* usermessage: the message that will be passed to user in ErrorDialogs
+ in Qt-land.
+"""
+
+
+class LeapException(Exception):
+ """
+ base LeapClient exception
+ sets some parameters that we will check
+ during error checking routines
+ """
+ critical = False
+ failfirst = False
+ warning = False
+
+
+class CriticalError(LeapException):
+ """
+ we cannot do anything about it
+ """
+ critical = True
+ failfirst = True
+
+
+# In use ???
+# don't thing so. purge if not...
+
+class MissingConfigFileError(Exception):
+ pass
+
+
+class ImproperlyConfigured(Exception):
+ pass
+
+
+class NoDefaultInterfaceFoundError(LeapException):
+ message = "no default interface found"
+ usermessage = "Looks like your computer is not connected to the internet"
+
+
+class InterfaceNotFoundError(LeapException):
+ # XXX should take iface arg on init maybe?
+ message = "interface not found"
+
+
+class NoConnectionToGateway(CriticalError):
+ message = "no connection to gateway"
+ usermessage = "Looks like there are problems with your internet connection"
+
+
+class NoInternetConnection(CriticalError):
+ message = "No Internet connection found"
+ usermessage = "It looks like there is no internet connection."
+ # and now we try to connect to our web to troubleshoot LOL :P
+
+
+class CannotResolveDomainError(LeapException):
+ message = "Cannot resolve domain"
+ usermessage = "Domain cannot be found"
+
+
+class TunnelNotDefaultRouteError(CriticalError):
+ message = "Tunnel connection dissapeared. VPN down?"
+ usermessage = "The Encrypted Connection was lost. Shutting down..."
diff --git a/src/leap/base/network.py b/src/leap/base/network.py
new file mode 100644
index 00000000..3aba3f61
--- /dev/null
+++ b/src/leap/base/network.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+from __future__ import (print_function)
+import logging
+import threading
+
+from leap.eip.config import get_eip_gateway
+from leap.base.checks import LeapNetworkChecker
+from leap.base.constants import ROUTE_CHECK_INTERVAL
+from leap.base.exceptions import TunnelNotDefaultRouteError
+from leap.util.coroutines import (launch_thread, process_events)
+
+from time import sleep
+
+logger = logging.getLogger(name=__name__)
+
+
+class NetworkCheckerThread(object):
+ """
+ Manages network checking thread that makes sure we have a working network
+ connection.
+ """
+ def __init__(self, *args, **kwargs):
+ self.status_signals = kwargs.pop('status_signals', None)
+ #self.watcher_cb = kwargs.pop('status_signals', None)
+ self.error_cb = kwargs.pop(
+ 'error_cb',
+ lambda exc: logger.error("%s", exc.message))
+ self.shutdown = threading.Event()
+
+ # XXX get provider_gateway and pass it to checker
+ # see in eip.config for function
+ # #718
+ self.checker = LeapNetworkChecker(
+ provider_gw=get_eip_gateway())
+
+ def start(self):
+ self.process_handle = self._launch_recurrent_network_checks(
+ (self.error_cb,))
+
+ def stop(self):
+ self.shutdown.set()
+ logger.debug("network checked stopped.")
+
+ def run_checks(self):
+ pass
+
+ #private methods
+
+ #here all the observers in fail_callbacks expect one positional argument,
+ #which is exception so we can try by passing a lambda with logger to
+ #check it works.
+ def _network_checks_thread(self, fail_callbacks):
+ #TODO: replace this with waiting for a signal from openvpn
+ while True:
+ try:
+ self.checker.check_tunnel_default_interface()
+ break
+ except TunnelNotDefaultRouteError:
+ # XXX ??? why do we sleep here???
+ # aa: If the openvpn isn't up and running yet,
+ # let's give it a moment to breath.
+ sleep(1)
+
+ fail_observer_dict = dict(((
+ observer,
+ process_events(observer)) for observer in fail_callbacks))
+ while not self.shutdown.is_set():
+ try:
+ self.checker.check_tunnel_default_interface()
+ self.checker.check_internet_connection()
+ sleep(ROUTE_CHECK_INTERVAL)
+ except Exception as exc:
+ for obs in fail_observer_dict:
+ fail_observer_dict[obs].send(exc)
+ sleep(ROUTE_CHECK_INTERVAL)
+ #reset event
+ self.shutdown.clear()
+
+ def _launch_recurrent_network_checks(self, fail_callbacks):
+ #we need to wrap the fail callback in a tuple
+ watcher = launch_thread(
+ self._network_checks_thread,
+ (fail_callbacks,))
+ return watcher
diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py
new file mode 100644
index 00000000..b8615ad8
--- /dev/null
+++ b/src/leap/base/pluggableconfig.py
@@ -0,0 +1,421 @@
+"""
+generic configuration handlers
+"""
+import copy
+import json
+import logging
+import os
+import time
+import urlparse
+
+import jsonschema
+
+logger = logging.getLogger(__name__)
+
+
+__all__ = ['PluggableConfig',
+ 'adaptors',
+ 'types',
+ 'UnknownOptionException',
+ 'MissingValueException',
+ 'ConfigurationProviderException',
+ 'TypeCastException']
+
+# exceptions
+
+
+class UnknownOptionException(Exception):
+ """exception raised when a non-configuration
+ value is present in the configuration"""
+
+
+class MissingValueException(Exception):
+ """exception raised when a required value is missing"""
+
+
+class ConfigurationProviderException(Exception):
+ """exception raised when a configuration provider is missing, etc"""
+
+
+class TypeCastException(Exception):
+ """exception raised when a
+ configuration item cannot be coerced to a type"""
+
+
+class ConfigAdaptor(object):
+ """
+ abstract base class for config adaotors for
+ serialization/deserialization and custom validation
+ and type casting.
+ """
+ def read(self, filename):
+ raise NotImplementedError("abstract base class")
+
+ def write(self, config, filename):
+ with open(filename, 'w') as f:
+ self._write(f, config)
+
+ def _write(self, fp, config):
+ raise NotImplementedError("abstract base class")
+
+ def validate(self, config, schema):
+ raise NotImplementedError("abstract base class")
+
+
+adaptors = {}
+
+
+class JSONSchemaEncoder(json.JSONEncoder):
+ """
+ custom default encoder that
+ casts python objects to json objects for
+ the schema validation
+ """
+ def default(self, obj):
+ if obj is str:
+ return 'string'
+ if obj is unicode:
+ return 'string'
+ if obj is int:
+ return 'integer'
+ if obj is list:
+ return 'array'
+ if obj is dict:
+ return 'object'
+ if obj is bool:
+ return 'boolean'
+
+
+class JSONAdaptor(ConfigAdaptor):
+ indent = 2
+ extensions = ['json']
+
+ def read(self, _from):
+ if isinstance(_from, file):
+ _from_string = _from.read()
+ if isinstance(_from, str):
+ _from_string = _from
+ return json.loads(_from_string)
+
+ def _write(self, fp, config):
+ fp.write(json.dumps(config,
+ indent=self.indent,
+ sort_keys=True))
+
+ def validate(self, config, schema_obj):
+ schema_json = JSONSchemaEncoder().encode(schema_obj)
+ schema = json.loads(schema_json)
+ jsonschema.validate(config, schema)
+
+
+adaptors['json'] = JSONAdaptor()
+
+#
+# Adaptors
+#
+# Allow to apply a predefined set of types to the
+# specs, so it checks the validity of formats and cast it
+# to proper python types.
+
+# TODO:
+# - multilingual object.
+# - HTTPS uri
+
+
+class DateType(object):
+ fmt = '%Y-%m-%d'
+
+ def to_python(self, data):
+ return time.strptime(data, self.fmt)
+
+ def get_prep_value(self, data):
+ return time.strftime(self.fmt, data)
+
+
+class URIType(object):
+
+ def to_python(self, data):
+ parsed = urlparse.urlparse(data)
+ if not parsed.scheme:
+ raise TypeCastException("uri %s has no schema" % data)
+ return parsed
+
+ def get_prep_value(self, data):
+ return data.geturl()
+
+
+class HTTPSURIType(object):
+
+ def to_python(self, data):
+ parsed = urlparse.urlparse(data)
+ if not parsed.scheme:
+ raise TypeCastException("uri %s has no schema" % data)
+ if parsed.scheme != "https":
+ raise TypeCastException(
+ "uri %s does not has "
+ "https schema" % data)
+ return parsed
+
+ def get_prep_value(self, data):
+ return data.geturl()
+
+
+types = {
+ 'date': DateType(),
+ 'uri': URIType(),
+ 'https-uri': HTTPSURIType(),
+}
+
+
+class PluggableConfig(object):
+
+ options = {}
+
+ def __init__(self,
+ adaptors=adaptors,
+ types=types,
+ format=None):
+
+ self.config = {}
+ self.adaptors = adaptors
+ self.types = types
+ self._format = format
+
+ @property
+ def option_dict(self):
+ if hasattr(self, 'options') and isinstance(self.options, dict):
+ return self.options.get('properties', None)
+
+ def items(self):
+ """
+ act like an iterator
+ """
+ if isinstance(self.option_dict, dict):
+ return self.option_dict.items()
+ return self.options
+
+ def validate(self, config, format=None):
+ """
+ validate config
+ """
+ schema = self.options
+ if format is None:
+ format = self._format
+
+ if format:
+ adaptor = self.get_adaptor(self._format)
+ adaptor.validate(config, schema)
+ else:
+ # we really should make format mandatory...
+ logger.error('no format passed to validate')
+
+ # first round of validation is ok.
+ # now we proceed to cast types if any specified.
+ self.to_python(config)
+
+ def to_python(self, config):
+ """
+ cast types following first type and then format indications.
+ """
+ unseen_options = [i for i in config if i not in self.option_dict]
+ if unseen_options:
+ raise UnknownOptionException(
+ "Unknown options: %s" % ', '.join(unseen_options))
+
+ for key, value in config.items():
+ _type = self.option_dict[key].get('type')
+ if _type is None and 'default' in self.option_dict[key]:
+ _type = type(self.option_dict[key]['default'])
+ if _type is not None:
+ tocast = True
+ if not callable(_type) and isinstance(value, _type):
+ tocast = False
+ if tocast:
+ try:
+ config[key] = _type(value)
+ except BaseException, e:
+ raise TypeCastException(
+ "Could not coerce %s, %s, "
+ "to type %s: %s" % (key, value, _type.__name__, e))
+ _format = self.option_dict[key].get('format', None)
+ _ftype = self.types.get(_format, None)
+ if _ftype:
+ try:
+ config[key] = _ftype.to_python(value)
+ except BaseException, e:
+ raise TypeCastException(
+ "Could not coerce %s, %s, "
+ "to format %s: %s" % (key, value,
+ _ftype.__class__.__name__,
+ e))
+
+ return config
+
+ def prep_value(self, config):
+ """
+ the inverse of to_python method,
+ called just before serialization
+ """
+ for key, value in config.items():
+ _format = self.option_dict[key].get('format', None)
+ _ftype = self.types.get(_format, None)
+ if _ftype and hasattr(_ftype, 'get_prep_value'):
+ try:
+ config[key] = _ftype.get_prep_value(value)
+ except BaseException, e:
+ raise TypeCastException(
+ "Could not serialize %s, %s, "
+ "by format %s: %s" % (key, value,
+ _ftype.__class__.__name__,
+ e))
+ else:
+ config[key] = value
+ return config
+
+ # methods for adding configuration
+
+ def get_default_values(self):
+ """
+ return a config options from configuration defaults
+ """
+ defaults = {}
+ for key, value in self.items():
+ if 'default' in value:
+ defaults[key] = value['default']
+ return copy.deepcopy(defaults)
+
+ def get_adaptor(self, format):
+ """
+ get specified format adaptor or
+ guess for a given filename
+ """
+ adaptor = self.adaptors.get(format, None)
+ if adaptor:
+ return adaptor
+
+ # not registered in adaptors dict, let's try all
+ for adaptor in self.adaptors.values():
+ if format in adaptor.extensions:
+ return adaptor
+
+ def filename2format(self, filename):
+ extension = os.path.splitext(filename)[-1]
+ return extension.lstrip('.') or None
+
+ def serialize(self, filename, format=None, full=False):
+ if not format:
+ format = self._format
+ if not format:
+ format = self.filename2format(filename)
+ if not format:
+ raise Exception('Please specify a format')
+ # TODO: more specific exception type
+
+ adaptor = self.get_adaptor(format)
+ if not adaptor:
+ raise Exception("Adaptor not found for format: %s" % format)
+
+ config = copy.deepcopy(self.config)
+ serializable = self.prep_value(config)
+ adaptor.write(serializable, filename)
+
+ def deserialize(self, string=None, fromfile=None, format=None):
+ """
+ load configuration from a file or string
+ """
+
+ def _try_deserialize():
+ if fromfile:
+ with open(fromfile, 'r') as f:
+ content = adaptor.read(f)
+ elif string:
+ content = adaptor.read(string)
+ return content
+
+ # XXX cleanup this!
+
+ if fromfile:
+ assert os.path.exists(fromfile)
+ if not format:
+ format = self.filename2format(fromfile)
+
+ if not format:
+ format = self._format
+ if format:
+ adaptor = self.get_adaptor(format)
+ else:
+ adaptor = None
+
+ if adaptor:
+ content = _try_deserialize()
+ return content
+
+ # no adaptor, let's try rest of adaptors
+
+ adaptors = self.adaptors[:]
+
+ if format:
+ adaptors.sort(
+ key=lambda x: int(
+ format in x.extensions),
+ reverse=True)
+
+ for adaptor in adaptors:
+ content = _try_deserialize()
+ return content
+
+ def load(self, *args, **kwargs):
+ """
+ load from string or file
+ if no string of fromfile option is given,
+ it will attempt to load from defaults
+ defined in the schema.
+ """
+ string = args[0] if args else None
+ fromfile = kwargs.get("fromfile", None)
+ content = None
+
+ # start with defaults, so we can
+ # have partial values applied.
+ content = self.get_default_values()
+ if string and isinstance(string, str):
+ content = self.deserialize(string)
+
+ if not string and fromfile is not None:
+ #import ipdb;ipdb.set_trace()
+ content = self.deserialize(fromfile=fromfile)
+
+ if not content:
+ logger.error('no content could be loaded')
+ # XXX raise!
+ return
+
+ # lazy evaluation until first level of nesting
+ # to allow lambdas with context-dependant info
+ # like os.path.expanduser
+ for k, v in content.iteritems():
+ if callable(v):
+ content[k] = v()
+
+ self.validate(content)
+ self.config = content
+ return True
+
+
+def testmain():
+ from tests import test_validation as t
+ import pprint
+
+ config = PluggableConfig(_format="json")
+ properties = copy.deepcopy(t.sample_spec)
+
+ config.options = properties
+ config.load(fromfile='data.json')
+
+ print 'config'
+ pprint.pprint(config.config)
+
+ config.serialize('/tmp/testserial.json')
+
+if __name__ == "__main__":
+ testmain()
diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py
new file mode 100644
index 00000000..d41f3695
--- /dev/null
+++ b/src/leap/base/providers.py
@@ -0,0 +1,29 @@
+"""all dealing with leap-providers: definition files, updating"""
+from leap.base import config as baseconfig
+from leap.base import specs
+
+
+class LeapProviderDefinition(baseconfig.JSONLeapConfig):
+ spec = specs.leap_provider_spec
+
+ def _get_slug(self):
+ domain = getattr(self, 'domain', None)
+ if domain:
+ path = baseconfig.get_provider_path(domain)
+ else:
+ path = baseconfig.get_default_provider_path()
+
+ return baseconfig.get_config_file(
+ 'provider.json', folder=path)
+
+ def _set_slug(self, *args, **kwargs):
+ raise AttributeError("you cannot set slug")
+
+ slug = property(_get_slug, _set_slug)
+
+
+class LeapProviderSet(object):
+ # we gather them from the filesystem
+ # TODO: (MVS+)
+ def __init__(self):
+ self.count = 0
diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py
new file mode 100644
index 00000000..b4bb8dcf
--- /dev/null
+++ b/src/leap/base/specs.py
@@ -0,0 +1,59 @@
+leap_provider_spec = {
+ 'description': 'provider definition',
+ 'type': 'object',
+ 'properties': {
+ 'serial': {
+ 'type': int,
+ 'default': 1,
+ 'required': True,
+ },
+ 'version': {
+ 'type': unicode,
+ 'default': '0.1.0'
+ #'required': True
+ },
+ 'domain': {
+ 'type': unicode, # XXX define uri type
+ 'default': 'testprovider.example.org'
+ #'required': True,
+ },
+ 'display_name': {
+ 'type': dict, # XXX multilingual object?
+ 'default': {u'en': u'Test Provider'}
+ #'required': True
+ },
+ 'description': {
+ 'type': dict,
+ 'default': {u'en': u'Test provider'}
+ },
+ 'enrollment_policy': {
+ 'type': unicode, # oneof ??
+ 'default': 'open'
+ },
+ 'services': {
+ 'type': list, # oneof ??
+ 'default': ['eip']
+ },
+ 'api_version': {
+ 'type': unicode,
+ 'default': '0.1.0' # version regexp
+ },
+ 'api_uri': {
+ 'type': unicode # uri
+ },
+ 'public_key': {
+ 'type': unicode # fingerprint
+ },
+ 'ca_cert_fingerprint': {
+ 'type': unicode,
+ },
+ 'ca_cert_uri': {
+ 'type': unicode,
+ 'format': 'https-uri'
+ },
+ 'languages': {
+ 'type': list,
+ 'default': ['en']
+ }
+ }
+}
diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/base/tests/__init__.py
diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py
new file mode 100644
index 00000000..8d573b1e
--- /dev/null
+++ b/src/leap/base/tests/test_checks.py
@@ -0,0 +1,124 @@
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+
+from mock import (patch, Mock)
+from StringIO import StringIO
+
+import ping
+import requests
+
+from leap.base import checks
+from leap.base import exceptions
+from leap.testing.basetest import BaseLeapTest
+
+_uid = os.getuid()
+
+
+class LeapNetworkCheckTest(BaseLeapTest):
+ __name__ = "leap_network_check_tests"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_checker_should_implement_check_methods(self):
+ checker = checks.LeapNetworkChecker()
+
+ self.assertTrue(hasattr(checker, "check_internet_connection"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "check_tunnel_default_interface"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "is_internet_up"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "ping_gateway"),
+ "missing meth")
+
+ def test_checker_should_actually_call_all_tests(self):
+ checker = checks.LeapNetworkChecker()
+ mc = Mock()
+ checker.run_all(checker=mc)
+ self.assertTrue(mc.check_internet_connection.called, "not called")
+ self.assertTrue(mc.check_tunnel_default_interface.called, "not called")
+ self.assertTrue(mc.is_internet_up.called, "not called")
+
+ # ping gateway only called if we pass provider_gw
+ checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0")
+ mc = Mock()
+ checker.run_all(checker=mc)
+ self.assertTrue(mc.check_internet_connection.called, "not called")
+ self.assertTrue(mc.check_tunnel_default_interface.called, "not called")
+ self.assertTrue(mc.ping_gateway.called, "not called")
+ self.assertTrue(mc.is_internet_up.called, "not called")
+
+ def test_get_default_interface_no_interface(self):
+ checker = checks.LeapNetworkChecker()
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ with self.assertRaises(exceptions.NoDefaultInterfaceFoundError):
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT")
+ checker.get_default_interface_gateway()
+
+ def test_check_tunnel_default_interface(self):
+ checker = checks.LeapNetworkChecker()
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ with self.assertRaises(exceptions.TunnelNotDefaultRouteError):
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT")
+ checker.check_tunnel_default_interface()
+
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ with self.assertRaises(exceptions.TunnelNotDefaultRouteError):
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT\n"
+ "wlan0\t00000000\t0102A8C0\t"
+ "0003\t0\t0\t0\t00000000\t0\t0\t0")
+ checker.check_tunnel_default_interface()
+
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT\n"
+ "tun0\t00000000\t01002A0A\t0003\t0\t0\t0\t00000080\t0\t0\t0")
+ checker.check_tunnel_default_interface()
+
+ def test_ping_gateway_fail(self):
+ checker = checks.LeapNetworkChecker()
+ with patch.object(ping, "quiet_ping") as mocked_ping:
+ with self.assertRaises(exceptions.NoConnectionToGateway):
+ mocked_ping.return_value = [11, "", ""]
+ checker.ping_gateway("4.2.2.2")
+
+ def test_check_internet_connection_failures(self):
+ checker = checks.LeapNetworkChecker()
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.side_effect = requests.HTTPError
+ with self.assertRaises(exceptions.NoInternetConnection):
+ checker.check_internet_connection()
+
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.side_effect = requests.RequestException
+ with self.assertRaises(exceptions.NoInternetConnection):
+ checker.check_internet_connection()
+
+ #TODO: Mock possible errors that can be raised by is_internet_up
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.side_effect = requests.ConnectionError
+ with self.assertRaises(exceptions.NoInternetConnection):
+ checker.check_internet_connection()
+
+ @unittest.skipUnless(_uid == 0, "root only")
+ def test_ping_gateway(self):
+ checker = checks.LeapNetworkChecker()
+ checker.ping_gateway("4.2.2.2")
diff --git a/src/leap/base/tests/test_config.py b/src/leap/base/tests/test_config.py
new file mode 100644
index 00000000..d03149b2
--- /dev/null
+++ b/src/leap/base/tests/test_config.py
@@ -0,0 +1,247 @@
+import json
+import os
+import platform
+import socket
+#import tempfile
+
+import mock
+import requests
+
+from leap.base import config
+from leap.base import constants
+from leap.base import exceptions
+from leap.eip import constants as eipconstants
+from leap.util.fileutil import mkdir_p
+from leap.testing.basetest import BaseLeapTest
+
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+_system = platform.system()
+
+
+class JSONLeapConfigTest(BaseLeapTest):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_metaclass(self):
+ with self.assertRaises(exceptions.ImproperlyConfigured) as exc:
+ class DummyTestConfig(config.JSONLeapConfig):
+ __metaclass__ = config.MetaConfigWithSpec
+ exc.startswith("missing spec dict")
+
+ class DummyTestConfig(config.JSONLeapConfig):
+ __metaclass__ = config.MetaConfigWithSpec
+ spec = {'properties': {}}
+ with self.assertRaises(exceptions.ImproperlyConfigured) as exc:
+ DummyTestConfig()
+ exc.startswith("missing slug")
+
+ class DummyTestConfig(config.JSONLeapConfig):
+ __metaclass__ = config.MetaConfigWithSpec
+ spec = {'properties': {}}
+ slug = "foo"
+ DummyTestConfig()
+
+######################################3
+#
+# provider fetch tests block
+#
+
+
+class ProviderTest(BaseLeapTest):
+ # override per test fixtures
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+
+# XXX depreacated. similar test in eip.checks
+
+#class BareHomeTestCase(ProviderTest):
+#
+ #__name__ = "provider_config_tests_bare_home"
+#
+ #def test_should_raise_if_missing_eip_json(self):
+ #with self.assertRaises(exceptions.MissingConfigFileError):
+ #config.get_config_json(os.path.join(self.home, 'eip.json'))
+
+
+class ProviderDefinitionTestCase(ProviderTest):
+ # XXX MOVE TO eip.test_checks
+ # -- kali 2012-08-24 00:38
+
+ __name__ = "provider_config_tests"
+
+ def setUp(self):
+ # dump a sample eip file
+ # XXX Move to Use EIP Spec Instead!!!
+ # XXX tests to be moved to eip.checks and eip.providers
+ # XXX can use eipconfig.dump_default_eipconfig
+
+ path = os.path.join(self.home, '.config', 'leap')
+ mkdir_p(path)
+ with open(os.path.join(path, 'eip.json'), 'w') as fp:
+ json.dump(eipconstants.EIP_SAMPLE_JSON, fp)
+
+
+# these tests below should move to
+# eip.checks
+# config.Configuration has been deprecated
+
+# TODO:
+# - We're instantiating a ProviderTest because we're doing the home wipeoff
+# on setUpClass instead of the setUp (for speedup of the general cases).
+
+# We really should be testing all of them in the same testCase, and
+# doing an extra wipe of the tempdir... but be careful!!!! do not mess with
+# os.environ home more than needed... that could potentially bite!
+
+# XXX actually, another thing to fix here is separating tests:
+# - test that requests has been called.
+# - check deeper for error types/msgs
+
+# we SHOULD inject requests dep in the constructor
+# (so we can pass mock easily).
+
+
+#class ProviderFetchConError(ProviderTest):
+ #def test_connection_error(self):
+ #with mock.patch.object(requests, "get") as mock_method:
+ #mock_method.side_effect = requests.ConnectionError
+ #cf = config.Configuration()
+ #self.assertIsInstance(cf.error, str)
+#
+#
+#class ProviderFetchHttpError(ProviderTest):
+ #def test_file_not_found(self):
+ #with mock.patch.object(requests, "get") as mock_method:
+ #mock_method.side_effect = requests.HTTPError
+ #cf = config.Configuration()
+ #self.assertIsInstance(cf.error, str)
+#
+#
+#class ProviderFetchInvalidUrl(ProviderTest):
+ #def test_invalid_url(self):
+ #cf = config.Configuration("ht")
+ #self.assertTrue(cf.error)
+
+
+# end provider fetch tests
+###########################################
+
+
+class ConfigHelperFunctions(BaseLeapTest):
+
+ __name__ = "config_helper_tests"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ # tests
+
+ @unittest.skipUnless(_system == "Linux", "linux only")
+ def test_lin_get_config_file(self):
+ """
+ config file path where expected? (linux)
+ """
+ self.assertEqual(
+ config.get_config_file(
+ 'test', folder="foo/bar"),
+ os.path.expanduser(
+ '~/.config/leap/foo/bar/test')
+ )
+
+ @unittest.skipUnless(_system == "Darwin", "mac only")
+ def test_mac_get_config_file(self):
+ """
+ config file path where expected? (mac)
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ @unittest.skipUnless(_system == "Windows", "win only")
+ def test_win_get_config_file(self):
+ """
+ config file path where expected?
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ #
+ # XXX hey, I'm raising exceptions here
+ # on purpose. just wanted to make sure
+ # that the skip stuff is doing it right.
+ # If you're working on win/macos tests,
+ # feel free to remove tests that you see
+ # are too redundant.
+
+ @unittest.skipUnless(_system == "Linux", "linux only")
+ def test_lin_get_config_dir(self):
+ """
+ nice config dir? (linux)
+ """
+ self.assertEqual(
+ config.get_config_dir(),
+ os.path.expanduser('~/.config/leap'))
+
+ @unittest.skipUnless(_system == "Darwin", "mac only")
+ def test_mac_get_config_dir(self):
+ """
+ nice config dir? (mac)
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ @unittest.skipUnless(_system == "Windows", "win only")
+ def test_win_get_config_dir(self):
+ """
+ nice config dir? (win)
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ # provider paths
+
+ @unittest.skipUnless(_system == "Linux", "linux only")
+ def test_get_default_provider_path(self):
+ """
+ is default provider path ok?
+ """
+ self.assertEqual(
+ config.get_default_provider_path(),
+ os.path.expanduser(
+ '~/.config/leap/providers/%s/' %
+ constants.DEFAULT_PROVIDER)
+ )
+
+ # validate ip
+
+ def test_validate_ip(self):
+ """
+ check our ip validation
+ """
+ config.validate_ip('3.3.3.3')
+ with self.assertRaises(socket.error):
+ config.validate_ip('255.255.255.256')
+ with self.assertRaises(socket.error):
+ config.validate_ip('foobar')
+
+ @unittest.skip
+ def test_validate_domain(self):
+ """
+ code to be written yet
+ """
+ raise NotImplementedError
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py
new file mode 100644
index 00000000..15c4ed58
--- /dev/null
+++ b/src/leap/base/tests/test_providers.py
@@ -0,0 +1,143 @@
+import copy
+import json
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+
+import jsonschema
+
+from leap import __branding as BRANDING
+from leap.testing.basetest import BaseLeapTest
+from leap.base import providers
+
+
+EXPECTED_DEFAULT_CONFIG = {
+ u"api_version": u"0.1.0",
+ u"description": {u'en': u"Test provider"},
+ u"display_name": {u'en': u"Test Provider"},
+ u"domain": u"testprovider.example.org",
+ u"enrollment_policy": u"open",
+ u"serial": 1,
+ u"services": [
+ u"eip"
+ ],
+ u"languages": [u"en"],
+ u"version": u"0.1.0"
+}
+
+
+class TestLeapProviderDefinition(BaseLeapTest):
+ def setUp(self):
+ self.domain = "testprovider.example.org"
+ self.definition = providers.LeapProviderDefinition(
+ domain=self.domain)
+ self.definition.save()
+ self.definition.load()
+ self.config = self.definition.config
+
+ def tearDown(self):
+ if hasattr(self, 'testfile') and os.path.isfile(self.testfile):
+ os.remove(self.testfile)
+
+ # tests
+
+ # XXX most of these tests can be made more abstract
+ # and moved to test_baseconfig *triangulate!*
+
+ def test_provider_slug_property(self):
+ slug = self.definition.slug
+ self.assertEquals(
+ slug,
+ os.path.join(
+ self.home,
+ '.config', 'leap', 'providers',
+ '%s' % self.domain,
+ 'provider.json'))
+ with self.assertRaises(AttributeError):
+ self.definition.slug = 23
+
+ def test_provider_dump(self):
+ # check a good provider definition is dumped to disk
+ self.testfile = self.get_tempfile('test.json')
+ self.definition.save(to=self.testfile)
+ deserialized = json.load(open(self.testfile, 'rb'))
+ self.maxDiff = None
+ self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)
+
+ def test_provider_dump_to_slug(self):
+ # same as above, but we test the ability to save to a
+ # file generated from the slug.
+ # XXX THIS TEST SHOULD MOVE TO test_baseconfig
+ self.definition.save()
+ filename = self.definition.filename
+ self.assertTrue(os.path.isfile(filename))
+ deserialized = json.load(open(filename, 'rb'))
+ self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)
+
+ def test_provider_load(self):
+ # check loading provider from disk file
+ self.testfile = self.get_tempfile('test_load.json')
+ with open(self.testfile, 'w') as wf:
+ wf.write(json.dumps(EXPECTED_DEFAULT_CONFIG))
+ self.definition.load(fromfile=self.testfile)
+ self.assertDictEqual(self.config,
+ EXPECTED_DEFAULT_CONFIG)
+
+ def test_provider_validation(self):
+ self.definition.validate(self.config)
+ _config = copy.deepcopy(self.config)
+ _config['serial'] = 'aaa'
+ with self.assertRaises(jsonschema.ValidationError):
+ self.definition.validate(_config)
+
+ @unittest.skip
+ def test_load_malformed_json_definition(self):
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_type_validation(self):
+ # check various type validation
+ # type cast
+ raise NotImplementedError
+
+
+class TestLeapProviderSet(BaseLeapTest):
+
+ def setUp(self):
+ self.providers = providers.LeapProviderSet()
+
+ def tearDown(self):
+ pass
+ ###
+
+ def test_get_zero_count(self):
+ self.assertEqual(self.providers.count, 0)
+
+ @unittest.skip
+ def test_count_defined_providers(self):
+ # check the method used for making
+ # the list of providers
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_get_default_provider(self):
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_should_be_at_least_one_provider_after_init(self):
+ # when we init an empty environment,
+ # there should be at least one provider,
+ # that will be a dump of the default provider definition
+ # somehow a high level test
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_get_eip_remote_from_default_provider(self):
+ # from: default provider
+ # expect: remote eip domain
+ raise NotImplementedError
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/base/tests/test_validation.py b/src/leap/base/tests/test_validation.py
new file mode 100644
index 00000000..87e99648
--- /dev/null
+++ b/src/leap/base/tests/test_validation.py
@@ -0,0 +1,92 @@
+import copy
+import datetime
+#import json
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+
+import jsonschema
+
+from leap.base.config import JSONLeapConfig
+from leap.base import pluggableconfig
+from leap.testing.basetest import BaseLeapTest
+
+SAMPLE_CONFIG_DICT = {
+ 'prop_one': 1,
+ 'prop_uri': "http://example.org",
+ 'prop_date': '2012-12-12',
+}
+
+EXPECTED_CONFIG = {
+ 'prop_one': 1,
+ 'prop_uri': "http://example.org",
+ 'prop_date': datetime.datetime(2012, 12, 12)
+}
+
+sample_spec = {
+ 'description': 'sample schema definition',
+ 'type': 'object',
+ 'properties': {
+ 'prop_one': {
+ 'type': int,
+ 'default': 1,
+ 'required': True
+ },
+ 'prop_uri': {
+ 'type': str,
+ 'default': 'http://example.org',
+ 'required': True,
+ 'format': 'uri'
+ },
+ 'prop_date': {
+ 'type': str,
+ 'default': '2012-12-12',
+ 'format': 'date'
+ }
+ }
+}
+
+
+class SampleConfig(JSONLeapConfig):
+ spec = sample_spec
+
+ @property
+ def slug(self):
+ return os.path.expanduser('~/sampleconfig.json')
+
+
+class TestJSONLeapConfigValidation(BaseLeapTest):
+ def setUp(self):
+ self.sampleconfig = SampleConfig()
+ self.sampleconfig.save()
+ self.sampleconfig.load()
+ self.config = self.sampleconfig.config
+
+ def tearDown(self):
+ if hasattr(self, 'testfile') and os.path.isfile(self.testfile):
+ os.remove(self.testfile)
+
+ # tests
+
+ def test_good_validation(self):
+ self.sampleconfig.validate(SAMPLE_CONFIG_DICT)
+
+ def test_broken_int(self):
+ _config = copy.deepcopy(SAMPLE_CONFIG_DICT)
+ _config['prop_one'] = '1'
+ with self.assertRaises(jsonschema.ValidationError):
+ self.sampleconfig.validate(_config)
+
+ def test_format_property(self):
+ # JsonSchema Validator does not check the format property.
+ # We should have to extend the Configuration class
+ blah = copy.deepcopy(SAMPLE_CONFIG_DICT)
+ blah['prop_uri'] = 'xxx'
+ with self.assertRaises(pluggableconfig.TypeCastException):
+ self.sampleconfig.validate(blah)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/baseapp/config.py b/src/leap/baseapp/config.py
deleted file mode 100644
index efdb4726..00000000
--- a/src/leap/baseapp/config.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import ConfigParser
-import os
-
-
-def get_config(config_file=None):
- """
- temporary method for getting configs,
- mainly for early stage development process.
- in the future we will get preferences
- from the storage api
- """
- config = ConfigParser.ConfigParser()
- #config.readfp(open('defaults.cfg'))
- #XXX does this work on win / mac also???
- conf_path_list = ['eip.cfg', # XXX build a
- # proper path with platform-specific places
- # XXX make .config/foo
- os.path.expanduser('~/.eip.cfg')]
- if config_file:
- config.readfp(config_file)
- else:
- config.read(conf_path_list)
- return config
-
-
-# XXX wrapper around config? to get default values
-
-def get_with_defaults(config, section, option):
- if config.has_option(section, option):
- return config.get(section, option)
- else:
- # XXX lookup in defaults dict???
- pass
-
-
-def get_vpn_stdout_mockup():
- command = "python"
- args = ["-u", "-c", "from eip_client import fakeclient;\
-fakeclient.write_output()"]
- return command, args
diff --git a/src/leap/baseapp/constants.py b/src/leap/baseapp/constants.py
new file mode 100644
index 00000000..e312be21
--- /dev/null
+++ b/src/leap/baseapp/constants.py
@@ -0,0 +1,6 @@
+# This timer used for polling vpn manager state.
+
+# XXX what is an optimum polling interval?
+# too little will be overkill, too much will
+# miss transition states.
+TIMER_MILLISECONDS = 250.0
diff --git a/src/leap/baseapp/dialogs.py b/src/leap/baseapp/dialogs.py
new file mode 100644
index 00000000..3cb539cf
--- /dev/null
+++ b/src/leap/baseapp/dialogs.py
@@ -0,0 +1,58 @@
+# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4
+import logging
+
+from PyQt4.QtGui import (QDialog, QFrame, QPushButton, QLabel, QMessageBox)
+
+logger = logging.getLogger(name=__name__)
+
+
+class ErrorDialog(QDialog):
+ def __init__(self, parent=None, errtype=None, msg=None, label=None):
+ super(ErrorDialog, self).__init__(parent)
+ frameStyle = QFrame.Sunken | QFrame.Panel
+ self.warningLabel = QLabel()
+ self.warningLabel.setFrameStyle(frameStyle)
+ self.warningButton = QPushButton("QMessageBox.&warning()")
+
+ if msg is not None:
+ self.msg = msg
+ if label is not None:
+ self.label = label
+ if errtype == "critical":
+ self.criticalMessage(self.msg, self.label)
+
+ def warningMessage(self, msg, label):
+ msgBox = QMessageBox(QMessageBox.Warning,
+ "QMessageBox.warning()", msg,
+ QMessageBox.NoButton, self)
+ msgBox.addButton("&Ok", QMessageBox.AcceptRole)
+ if msgBox.exec_() == QMessageBox.AcceptRole:
+ pass
+ # do whatever we want to do after
+ # closing the dialog. we can pass that
+ # in the constructor
+
+ def criticalMessage(self, msg, label):
+ msgBox = QMessageBox(QMessageBox.Critical,
+ "QMessageBox.critical()", msg,
+ QMessageBox.NoButton, self)
+ msgBox.addButton("&Ok", QMessageBox.AcceptRole)
+ msgBox.exec_()
+
+ # It's critical, so we exit.
+ # We should better emit a signal and connect it
+ # with the proper shutdownAndQuit method, but
+ # this suffices for now.
+ logger.info('Quitting')
+ import sys
+ sys.exit()
+
+ def confirmMessage(self, msg, label, action):
+ msgBox = QMessageBox(QMessageBox.Critical,
+ "QMessageBox.critical()", msg,
+ QMessageBox.NoButton, self)
+ msgBox.addButton("&Ok", QMessageBox.AcceptRole)
+ msgBox.addButton("&Cancel", QMessageBox.RejectRole)
+
+ if msgBox.exec_() == QMessageBox.AcceptRole:
+ action()
diff --git a/src/leap/baseapp/eip.py b/src/leap/baseapp/eip.py
new file mode 100644
index 00000000..54acbc0e
--- /dev/null
+++ b/src/leap/baseapp/eip.py
@@ -0,0 +1,217 @@
+from __future__ import print_function
+import logging
+import time
+#import sys
+
+from PyQt4 import QtCore
+
+from leap.baseapp.dialogs import ErrorDialog
+from leap.baseapp import constants
+from leap.eip import exceptions as eip_exceptions
+from leap.eip.eipconnection import EIPConnection
+
+logger = logging.getLogger(name=__name__)
+
+
+class EIPConductorAppMixin(object):
+ """
+ initializes an instance of EIPConnection,
+ gathers errors, and passes status-change signals
+ from Qt land along to the conductor.
+ Connects the eip connect/disconnect logic
+ to the switches in the app (buttons/menu items).
+ """
+
+ def __init__(self, *args, **kwargs):
+ opts = kwargs.pop('opts')
+ config_file = getattr(opts, 'config_file', None)
+ provider = kwargs.pop('provider')
+
+ self.eip_service_started = False
+
+ # conductor (eip connection) is in charge of all
+ # vpn-related configuration / monitoring.
+ # we pass a tuple of signals that will be
+ # triggered when status changes.
+
+ self.conductor = EIPConnection(
+ watcher_cb=self.newLogLine.emit,
+ config_file=config_file,
+ checker_signals=(self.eipStatusChange.emit, ),
+ status_signals=(self.openvpnStatusChange.emit, ),
+ debug=self.debugmode,
+ ovpn_verbosity=opts.openvpn_verb,
+ provider=provider)
+
+ self.skip_download = opts.no_provider_checks
+ self.skip_verify = opts.no_ca_verify
+
+ def run_eip_checks(self):
+ """
+ runs eip checks and
+ the error checking loop
+ """
+ logger.debug('running EIP CHECKS')
+ self.conductor.run_checks(
+ skip_download=self.skip_download,
+ skip_verify=self.skip_verify)
+ self.error_check()
+
+ self.start_eipconnection.emit()
+
+ def error_check(self):
+ """
+ consumes the conductor error queue.
+ pops errors, and acts accordingly (launching user dialogs).
+ """
+ logger.debug('error check')
+
+ errq = self.conductor.error_queue
+ while errq.qsize() != 0:
+ logger.debug('%s errors left in conductor queue', errq.qsize())
+ # we get exception and original traceback from queue
+ error, tb = errq.get()
+
+ # redundant log, debugging the loop.
+ logger.error('%s: %s', error.__class__.__name__, error.message)
+
+ if issubclass(error.__class__, eip_exceptions.EIPClientError):
+ self.triggerEIPError.emit(error)
+
+ else:
+ # deprecated form of raising exception.
+ raise error, None, tb
+
+ if error.failfirst is True:
+ break
+
+ @QtCore.pyqtSlot(object)
+ def onEIPError(self, error):
+ """
+ check severity and launches
+ dialogs informing user about the errors.
+ in the future we plan to derive errors to
+ our log viewer.
+ """
+
+ if getattr(error, 'usermessage', None):
+ message = error.usermessage
+ else:
+ message = error.message
+
+ # XXX
+ # check headless = False before
+ # launching dialog.
+ # (so Qt tests can assert stuff)
+
+ if error.critical:
+ logger.critical(error.message)
+ #critical error (non recoverable),
+ #we give user some info and quit.
+ #(critical error dialog will exit app)
+ ErrorDialog(errtype="critical",
+ msg=message,
+ label="critical error")
+ elif error.warning:
+ logger.warning(error.message)
+
+ else:
+ dialog = ErrorDialog()
+ dialog.warningMessage(message, 'error')
+
+ @QtCore.pyqtSlot()
+ def statusUpdate(self):
+ """
+ polls status and updates ui with real time
+ info about transferred bytes / connection state.
+ right now is triggered by a timer tick
+ (timer controlled by StatusAwareTrayIcon class)
+ """
+ # TODO I guess it's too expensive to poll
+ # continously. move to signal events instead.
+ # (i.e., subscribe to connection status changes
+ # from openvpn manager)
+
+ if not self.eip_service_started:
+ # there is a race condition
+ # going on here. Depending on how long we take
+ # to init the qt app, the management socket
+ # is not ready yet.
+ return
+
+ #if self.conductor.with_errors:
+ #XXX how to wait on pkexec???
+ #something better that this workaround, plz!!
+ #I removed the pkexec pass authentication at all.
+ #time.sleep(5)
+ #logger.debug('timeout')
+ #logger.error('errors. disconnect')
+ #self.start_or_stopVPN() # is stop
+
+ state = self.conductor.poll_connection_state()
+ if not state:
+ return
+
+ ts, con_status, ok, ip, remote = state
+ self.set_statusbarMessage(con_status)
+ self.setIconToolTip()
+
+ ts = time.strftime("%a %b %d %X", ts)
+ if self.debugmode:
+ self.updateTS.setText(ts)
+ self.status_label.setText(con_status)
+ self.ip_label.setText(ip)
+ self.remote_label.setText(remote)
+
+ # status i/o
+
+ status = self.conductor.get_status_io()
+ if status and self.debugmode:
+ #XXX move this to systray menu indicators
+ ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) = status
+ ts = time.strftime("%a %b %d %X", ts)
+ self.updateTS.setText(ts)
+ self.tun_read_bytes.setText(tun_read)
+ self.tun_write_bytes.setText(tun_write)
+
+ @QtCore.pyqtSlot()
+ def start_or_stopVPN(self):
+ """
+ stub for running child process with vpn
+ """
+ if self.conductor.has_errors():
+ logger.debug('not starting vpn; conductor has errors')
+
+ if self.eip_service_started is False:
+ try:
+ self.conductor.connect()
+
+ except eip_exceptions.EIPNoCommandError as exc:
+ self.triggerEIPError.emit(exc)
+
+ except Exception as err:
+ # raise generic exception (Bad Thing Happened?)
+ logger.exception(err)
+ else:
+ # no errors, so go on.
+ if self.debugmode:
+ self.startStopButton.setText('&Disconnect')
+ self.eip_service_started = True
+ self.toggleEIPAct()
+
+ # XXX decouple! (timer is init by icons class).
+ # we could bring Timer Init to this Mixin
+ # or to its own Mixin.
+ self.timer.start(constants.TIMER_MILLISECONDS)
+ self.network_checker.start()
+ return
+
+ if self.eip_service_started is True:
+ self.network_checker.stop()
+ self.conductor.disconnect()
+ if self.debugmode:
+ self.startStopButton.setText('&Connect')
+ self.eip_service_started = False
+ self.toggleEIPAct()
+ self.timer.stop()
+ return
diff --git a/src/leap/baseapp/leap_app.py b/src/leap/baseapp/leap_app.py
new file mode 100644
index 00000000..4b63dd2f
--- /dev/null
+++ b/src/leap/baseapp/leap_app.py
@@ -0,0 +1,153 @@
+import logging
+
+import sip
+sip.setapi('QVariant', 2)
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.gui import mainwindow_rc
+
+logger = logging.getLogger(name=__name__)
+
+
+APP_LOGO = ':/images/leap-color-small.png'
+
+
+class MainWindowMixin(object):
+ """
+ create the main window
+ for leap app
+ """
+
+ def __init__(self, *args, **kwargs):
+ # XXX set initial visibility
+ # debug = no visible
+
+ widget = QtGui.QWidget()
+ self.setCentralWidget(widget)
+
+ mainLayout = QtGui.QVBoxLayout()
+ # add widgets to layout
+ #self.createWindowHeader()
+ #mainLayout.addWidget(self.headerBox)
+
+ # created in systray
+ mainLayout.addWidget(self.statusIconBox)
+ if self.debugmode:
+ mainLayout.addWidget(self.statusBox)
+ mainLayout.addWidget(self.loggerBox)
+ widget.setLayout(mainLayout)
+
+ self.createMainActions()
+ self.createMainMenus()
+
+ self.setWindowTitle("LEAP Client")
+ self.set_app_icon()
+ self.set_statusbarMessage('ready')
+
+ def createMainActions(self):
+ #self.openAct = QtGui.QAction("&Open...", self, shortcut="Ctrl+O",
+ #triggered=self.open)
+
+ self.firstRunWizardAct = QtGui.QAction(
+ "&First run wizard...", self,
+ triggered=self.stop_connection_and_launch_first_run_wizard)
+ self.aboutAct = QtGui.QAction("&About", self, triggered=self.about)
+
+ #self.aboutQtAct = QtGui.QAction("About &Qt", self,
+ #triggered=QtGui.qApp.aboutQt)
+
+ def createMainMenus(self):
+ self.connMenu = QtGui.QMenu("&Connections", self)
+ #self.viewMenu.addSeparator()
+ self.connMenu.addAction(self.quitAction)
+
+ self.settingsMenu = QtGui.QMenu("&Settings", self)
+ self.settingsMenu.addAction(self.firstRunWizardAct)
+
+ self.helpMenu = QtGui.QMenu("&Help", self)
+ self.helpMenu.addAction(self.aboutAct)
+ #self.helpMenu.addAction(self.aboutQtAct)
+
+ self.menuBar().addMenu(self.connMenu)
+ self.menuBar().addMenu(self.settingsMenu)
+ self.menuBar().addMenu(self.helpMenu)
+
+ def stop_connection_and_launch_first_run_wizard(self):
+ settings = QtCore.QSettings()
+ settings.setValue('FirstRunWizardDone', False)
+ logger.debug('should run first run wizard again...')
+
+ status = self.conductor.get_icon_name()
+ if status != "disconnected":
+ self.start_or_stopVPN()
+
+ self.launch_first_run_wizard()
+ #from leap.gui.firstrunwizard import FirstRunWizard
+ #wizard = FirstRunWizard(
+ #parent=self,
+ #success_cb=self.initReady.emit)
+ #wizard.show()
+
+ def set_app_icon(self):
+ icon = QtGui.QIcon(APP_LOGO)
+ self.setWindowIcon(icon)
+
+ #def createWindowHeader(self):
+ #"""
+ #description lines for main window
+ #"""
+ #self.headerBox = QtGui.QGroupBox()
+ #self.headerLabel = QtGui.QLabel(
+ #"<font size=40>LEAP Encryption Access Project</font>")
+ #self.headerLabelSub = QtGui.QLabel(
+ #"<br><i>your internet encryption toolkit</i>")
+#
+ #pixmap = QtGui.QPixmap(APP_LOGO)
+ #leap_lbl = QtGui.QLabel()
+ #leap_lbl.setPixmap(pixmap)
+#
+ #headerLayout = QtGui.QHBoxLayout()
+ #headerLayout.addWidget(leap_lbl)
+ #headerLayout.addWidget(self.headerLabel)
+ #headerLayout.addWidget(self.headerLabelSub)
+ #headerLayout.addStretch()
+ #self.headerBox.setLayout(headerLayout)
+
+ def set_statusbarMessage(self, msg):
+ self.statusBar().showMessage(msg)
+
+ def closeEvent(self, event):
+ """
+ redefines close event (persistent window behaviour)
+ """
+ if self.trayIcon.isVisible() and not self.debugmode:
+ QtGui.QMessageBox.information(
+ self, "Systray",
+ "The program will keep running "
+ "in the system tray. To "
+ "terminate the program, choose "
+ "<b>Quit</b> in the "
+ "context menu of the system tray entry.")
+ self.hide()
+ event.ignore()
+ return
+ self.cleanupAndQuit()
+
+ def cleanupAndQuit(self):
+ """
+ cleans state before shutting down app.
+ """
+ # save geometry for restoring
+ settings = QtCore.QSettings()
+ geom_key = "DebugGeometry" if self.debugmode else "Geometry"
+ settings.setValue(geom_key, self.saveGeometry())
+
+ # TODO:make sure to shutdown all child process / threads
+ # in conductor
+ # XXX send signal instead?
+ logger.info('Shutting down')
+ self.conductor.cleanup()
+ logger.info('Exiting. Bye.')
+ QtGui.qApp.quit()
diff --git a/src/leap/baseapp/log.py b/src/leap/baseapp/log.py
new file mode 100644
index 00000000..8a7f81c3
--- /dev/null
+++ b/src/leap/baseapp/log.py
@@ -0,0 +1,65 @@
+import logging
+
+from PyQt4 import QtGui
+from PyQt4 import QtCore
+
+vpnlogger = logging.getLogger('leap.openvpn')
+
+
+class LogPaneMixin(object):
+ """
+ a simple log pane
+ that writes new lines as they come
+ """
+
+ def createLogBrowser(self):
+ """
+ creates Browser widget for displaying logs
+ (in debug mode only).
+ """
+ self.loggerBox = QtGui.QGroupBox()
+ logging_layout = QtGui.QVBoxLayout()
+ self.logbrowser = QtGui.QTextBrowser()
+
+ startStopButton = QtGui.QPushButton("&Connect")
+ self.startStopButton = startStopButton
+
+ logging_layout.addWidget(self.logbrowser)
+ logging_layout.addWidget(self.startStopButton)
+ self.loggerBox.setLayout(logging_layout)
+
+ # status box
+
+ self.statusBox = QtGui.QGroupBox()
+ grid = QtGui.QGridLayout()
+
+ self.updateTS = QtGui.QLabel('')
+ self.status_label = QtGui.QLabel('Disconnected')
+ self.ip_label = QtGui.QLabel('')
+ self.remote_label = QtGui.QLabel('')
+
+ tun_read_label = QtGui.QLabel("tun read")
+ self.tun_read_bytes = QtGui.QLabel("0")
+ tun_write_label = QtGui.QLabel("tun write")
+ self.tun_write_bytes = QtGui.QLabel("0")
+
+ grid.addWidget(self.updateTS, 0, 0)
+ grid.addWidget(self.status_label, 0, 1)
+ grid.addWidget(self.ip_label, 1, 0)
+ grid.addWidget(self.remote_label, 1, 1)
+ grid.addWidget(tun_read_label, 2, 0)
+ grid.addWidget(self.tun_read_bytes, 2, 1)
+ grid.addWidget(tun_write_label, 3, 0)
+ grid.addWidget(self.tun_write_bytes, 3, 1)
+
+ self.statusBox.setLayout(grid)
+
+ @QtCore.pyqtSlot(str)
+ def onLoggerNewLine(self, line):
+ """
+ simple slot: writes new line to logger Pane.
+ """
+ msg = line[:-1]
+ if self.debugmode:
+ self.logbrowser.append(msg)
+ vpnlogger.info(msg)
diff --git a/src/leap/baseapp/mainwindow.py b/src/leap/baseapp/mainwindow.py
index 68b6de8f..8d61bf5c 100644
--- a/src/leap/baseapp/mainwindow.py
+++ b/src/leap/baseapp/mainwindow.py
@@ -1,398 +1,170 @@
# vim: set fileencoding=utf-8 :
#!/usr/bin/env python
import logging
-import time
-logger = logging.getLogger(name=__name__)
-from PyQt4.QtGui import (QMainWindow, QWidget, QVBoxLayout, QMessageBox,
- QSystemTrayIcon, QGroupBox, QLabel, QPixmap,
- QHBoxLayout, QIcon,
- QPushButton, QGridLayout, QAction, QMenu,
- QTextBrowser, qApp)
-from PyQt4.QtCore import (pyqtSlot, pyqtSignal, QTimer)
+import sip
+sip.setapi('QString', 2)
+sip.setapi('QVariant', 2)
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
-from leap.gui import mainwindow_rc
-from leap.eip.conductor import EIPConductor
+from leap.baseapp.eip import EIPConductorAppMixin
+from leap.baseapp.log import LogPaneMixin
+from leap.baseapp.systray import StatusAwareTrayIconMixin
+from leap.baseapp.network import NetworkCheckerAppMixin
+from leap.baseapp.leap_app import MainWindowMixin
+from leap.eip.checks import ProviderCertChecker
+from leap.gui.threads import FunThread
+logger = logging.getLogger(name=__name__)
-class LeapWindow(QMainWindow):
- #XXX tbd: refactor into model / view / controller
- #and put in its own modules...
- newLogLine = pyqtSignal([str])
- statusChange = pyqtSignal([object])
+class LeapWindow(QtGui.QMainWindow,
+ MainWindowMixin, EIPConductorAppMixin,
+ StatusAwareTrayIconMixin,
+ NetworkCheckerAppMixin,
+ LogPaneMixin):
+ """
+ main window for the leap app.
+ Initializes all of its base classes
+ We keep here some signal initialization
+ that gets tricky otherwise.
+ """
+
+ # signals
+
+ newLogLine = QtCore.pyqtSignal([str])
+ mainappReady = QtCore.pyqtSignal([])
+ initReady = QtCore.pyqtSignal([])
+ networkError = QtCore.pyqtSignal([object])
+ triggerEIPError = QtCore.pyqtSignal([object])
+ start_eipconnection = QtCore.pyqtSignal([])
+ shutdownSignal = QtCore.pyqtSignal([])
+
+ # this is status change got from openvpn management
+ openvpnStatusChange = QtCore.pyqtSignal([object])
+ # this is global eip status
+ eipStatusChange = QtCore.pyqtSignal([str])
def __init__(self, opts):
- super(LeapWindow, self).__init__()
+ logger.debug('init leap window')
self.debugmode = getattr(opts, 'debug', False)
-
- self.vpn_service_started = False
-
- self.createWindowHeader()
- self.createIconGroupBox()
-
- self.createActions()
- self.createTrayIcon()
+ super(LeapWindow, self).__init__()
if self.debugmode:
self.createLogBrowser()
- # create timer
- self.timer = QTimer()
-
- # bind signals
-
- self.trayIcon.activated.connect(self.iconActivated)
- self.newLogLine.connect(self.onLoggerNewLine)
- self.statusChange.connect(self.onStatusChange)
- self.timer.timeout.connect(self.onTimerTick)
-
- widget = QWidget()
- self.setCentralWidget(widget)
+ settings = QtCore.QSettings()
+ self.provider_domain = settings.value("provider_domain", None)
+ self.eip_username = settings.value("eip_username", None)
- # add widgets to layout
- mainLayout = QVBoxLayout()
- mainLayout.addWidget(self.headerBox)
- mainLayout.addWidget(self.statusIconBox)
- if self.debugmode:
- mainLayout.addWidget(self.statusBox)
- mainLayout.addWidget(self.loggerBox)
- widget.setLayout(mainLayout)
+ logger.debug('provider: %s', self.provider_domain)
+ logger.debug('eip_username: %s', self.eip_username)
- #
- # conductor is in charge of all
- # vpn-related configuration / monitoring.
- # we pass a tuple of signals that will be
- # triggered when status changes.
- #
- config_file = getattr(opts, 'config_file', None)
- self.conductor = EIPConductor(
- watcher_cb=self.newLogLine.emit,
- config_file=config_file,
- status_signals=(self.statusChange.emit, ))
+ EIPConductorAppMixin.__init__(
+ self, opts=opts, provider=self.provider_domain)
+ StatusAwareTrayIconMixin.__init__(self)
+ NetworkCheckerAppMixin.__init__(self)
+ MainWindowMixin.__init__(self)
- self.trayIcon.show()
+ geom_key = "DebugGeometry" if self.debugmode else "Geometry"
+ geom = settings.value(geom_key)
+ if geom:
+ self.restoreGeometry(geom)
- self.setWindowTitle("Leap")
- self.resize(400, 300)
+ # XXX check for wizard
+ self.wizard_done = settings.value("FirstRunWizardDone")
- self.set_statusbarMessage('ready')
+ self.initchecks = FunThread(self.run_eip_checks)
- if self.conductor.autostart:
- self.start_or_stopVPN()
+ # bind signals
+ self.initchecks.finished.connect(
+ lambda: logger.debug('Initial checks thread finished'))
+ self.trayIcon.activated.connect(self.iconActivated)
+ self.newLogLine.connect(
+ lambda line: self.onLoggerNewLine(line))
+ self.timer.timeout.connect(
+ lambda: self.onTimerTick())
+ self.networkError.connect(
+ lambda exc: self.onNetworkError(exc))
+ self.triggerEIPError.connect(
+ lambda exc: self.onEIPError(exc))
- def closeEvent(self, event):
- """
- redefines close event (persistent window behaviour)
- """
- if self.trayIcon.isVisible() and not self.debugmode:
- QMessageBox.information(self, "Systray",
- "The program will keep running "
- "in the system tray. To "
- "terminate the program, choose "
- "<b>Quit</b> in the "
- "context menu of the system tray entry.")
- self.hide()
- event.ignore()
if self.debugmode:
+ self.startStopButton.clicked.connect(
+ lambda: self.start_or_stopVPN())
+ self.start_eipconnection.connect(
+ lambda: self.start_or_stopVPN())
+ self.shutdownSignal.connect(
+ self.cleanupAndQuit)
+
+ # status change.
+ # TODO unify
+ self.openvpnStatusChange.connect(
+ lambda status: self.onOpenVPNStatusChange(status))
+ self.eipStatusChange.connect(
+ lambda newstatus: self.onEIPConnStatusChange(newstatus))
+ self.eipStatusChange.connect(
+ lambda newstatus: self.toggleEIPAct())
+
+ # do first run wizard and init signals
+ self.mainappReady.connect(self.do_first_run_wizard_check)
+ self.initReady.connect(self.runchecks_and_eipconnect)
+
+ # ... all ready. go!
+ # connected to do_first_run_wizard_check
+ self.mainappReady.emit()
+
+ def do_first_run_wizard_check(self):
+ """
+ checks whether first run wizard needs to be run
+ launches it if needed
+ and emits initReady signal if not.
+ """
+
+ logger.debug('first run wizard check...')
+ need_wizard = False
+
+ # do checks (can overlap if wizard was interrupted)
+ if not self.wizard_done:
+ need_wizard = True
+
+ if not self.provider_domain:
+ need_wizard = True
+ else:
+ pcertchecker = ProviderCertChecker(domain=self.provider_domain)
+ if not pcertchecker.is_cert_valid(do_raise=False):
+ logger.warning('missing valid client cert. need wizard')
+ need_wizard = True
+
+ # launch wizard if needed
+ if need_wizard:
+ self.launch_first_run_wizard()
+ else: # no wizard needed
+ logger.debug('running first run wizard')
+ self.initReady.emit()
+
+ def launch_first_run_wizard(self):
+ """
+ launches wizard and blocks
+ """
+ from leap.gui.firstrun.wizard import FirstRunWizard
+ wizard = FirstRunWizard(
+ self.conductor,
+ parent=self,
+ eip_username=self.eip_username,
+ start_eipconnection_signal=self.start_eipconnection,
+ eip_statuschange_signal=self.eipStatusChange,
+ quitcallback=self.onWizardCancel)
+ wizard.show()
+
+ def onWizardCancel(self):
+ if not self.wizard_done:
+ logger.debug(
+ 'clicked on Cancel during first '
+ 'run wizard. shutting down')
self.cleanupAndQuit()
- def setIcon(self, name):
- icon = self.Icons.get(name)
- self.trayIcon.setIcon(icon)
- self.setWindowIcon(icon)
-
- def setToolTip(self):
- """
- get readable status and place it on systray tooltip
- """
- status = self.conductor.status.get_readable_status()
- self.trayIcon.setToolTip(status)
-
- def iconActivated(self, reason):
- """
- handles left click, left double click
- showing the trayicon menu
- """
- #XXX there's a bug here!
- #menu shows on (0,0) corner first time,
- #until double clicked at least once.
- if reason in (QSystemTrayIcon.Trigger,
- QSystemTrayIcon.DoubleClick):
- self.trayIconMenu.show()
-
- def createWindowHeader(self):
- """
- description lines for main window
- """
- #XXX good candidate to refactor out! :)
- self.headerBox = QGroupBox()
- self.headerLabel = QLabel("<font size=40><b>E</b>ncryption \
-<b>I</b>nternet <b>P</b>roxy</font>")
- self.headerLabelSub = QLabel("<i>trust your \
-technolust</i>")
-
- pixmap = QPixmap(':/images/leapfrog.jpg')
- frog_lbl = QLabel()
- frog_lbl.setPixmap(pixmap)
-
- headerLayout = QHBoxLayout()
- headerLayout.addWidget(frog_lbl)
- headerLayout.addWidget(self.headerLabel)
- headerLayout.addWidget(self.headerLabelSub)
- headerLayout.addStretch()
- self.headerBox.setLayout(headerLayout)
-
- def getIcon(self, icon_name):
- # XXX get from connection dict
- icons = {'disconnected': 0,
- 'connecting': 1,
- 'connected': 2}
- return icons.get(icon_name, None)
-
- def createIconGroupBox(self):
- """
- dummy icongroupbox
- (to be removed from here -- reference only)
- """
- icons = {
- 'disconnected': ':/images/conn_error.png',
- 'connecting': ':/images/conn_connecting.png',
- 'connected': ':/images/conn_connected.png'
- }
- con_widgets = {
- 'disconnected': QLabel(),
- 'connecting': QLabel(),
- 'connected': QLabel(),
- }
- con_widgets['disconnected'].setPixmap(
- QPixmap(icons['disconnected']))
- con_widgets['connecting'].setPixmap(
- QPixmap(icons['connecting']))
- con_widgets['connected'].setPixmap(
- QPixmap(icons['connected'])),
- self.ConnectionWidgets = con_widgets
-
- con_icons = {
- 'disconnected': QIcon(icons['disconnected']),
- 'connecting': QIcon(icons['connecting']),
- 'connected': QIcon(icons['connected'])
- }
- self.Icons = con_icons
-
- self.statusIconBox = QGroupBox("Connection Status")
- statusIconLayout = QHBoxLayout()
- statusIconLayout.addWidget(self.ConnectionWidgets['disconnected'])
- statusIconLayout.addWidget(self.ConnectionWidgets['connecting'])
- statusIconLayout.addWidget(self.ConnectionWidgets['connected'])
- statusIconLayout.itemAt(1).widget().hide()
- statusIconLayout.itemAt(2).widget().hide()
- self.statusIconBox.setLayout(statusIconLayout)
-
- def createActions(self):
- """
- creates actions to be binded to tray icon
- """
- self.connectVPNAction = QAction("Connect to &VPN", self,
- triggered=self.hide)
- # XXX change action name on (dis)connect
- self.dis_connectAction = QAction("&(Dis)connect", self,
- triggered=self.start_or_stopVPN)
- self.minimizeAction = QAction("Mi&nimize", self,
- triggered=self.hide)
- self.maximizeAction = QAction("Ma&ximize", self,
- triggered=self.showMaximized)
- self.restoreAction = QAction("&Restore", self,
- triggered=self.showNormal)
- self.quitAction = QAction("&Quit", self,
- triggered=self.cleanupAndQuit)
-
- def createTrayIcon(self):
- """
- creates the tray icon
- """
- self.trayIconMenu = QMenu(self)
-
- self.trayIconMenu.addAction(self.connectVPNAction)
- self.trayIconMenu.addAction(self.dis_connectAction)
- self.trayIconMenu.addSeparator()
- self.trayIconMenu.addAction(self.minimizeAction)
- self.trayIconMenu.addAction(self.maximizeAction)
- self.trayIconMenu.addAction(self.restoreAction)
- self.trayIconMenu.addSeparator()
- self.trayIconMenu.addAction(self.quitAction)
-
- self.trayIcon = QSystemTrayIcon(self)
- self.trayIcon.setContextMenu(self.trayIconMenu)
-
- def createLogBrowser(self):
- """
- creates Browser widget for displaying logs
- (in debug mode only).
- """
- self.loggerBox = QGroupBox()
- logging_layout = QVBoxLayout()
- self.logbrowser = QTextBrowser()
-
- startStopButton = QPushButton("&Connect")
- startStopButton.clicked.connect(self.start_or_stopVPN)
- self.startStopButton = startStopButton
-
- logging_layout.addWidget(self.logbrowser)
- logging_layout.addWidget(self.startStopButton)
- self.loggerBox.setLayout(logging_layout)
-
- # status box
-
- self.statusBox = QGroupBox()
- grid = QGridLayout()
-
- self.updateTS = QLabel('')
- self.status_label = QLabel('Disconnected')
- self.ip_label = QLabel('')
- self.remote_label = QLabel('')
-
- tun_read_label = QLabel("tun read")
- self.tun_read_bytes = QLabel("0")
- tun_write_label = QLabel("tun write")
- self.tun_write_bytes = QLabel("0")
-
- grid.addWidget(self.updateTS, 0, 0)
- grid.addWidget(self.status_label, 0, 1)
- grid.addWidget(self.ip_label, 1, 0)
- grid.addWidget(self.remote_label, 1, 1)
- grid.addWidget(tun_read_label, 2, 0)
- grid.addWidget(self.tun_read_bytes, 2, 1)
- grid.addWidget(tun_write_label, 3, 0)
- grid.addWidget(self.tun_write_bytes, 3, 1)
-
- self.statusBox.setLayout(grid)
-
- @pyqtSlot(str)
- def onLoggerNewLine(self, line):
- """
- simple slot: writes new line to logger Pane.
- """
- if self.debugmode:
- self.logbrowser.append(line[:-1])
-
- def set_statusbarMessage(self, msg):
- self.statusBar().showMessage(msg)
-
- @pyqtSlot(object)
- def onStatusChange(self, status):
- """
- slot for status changes. triggers new signals for
- updating icon, status bar, etc.
- """
-
- print('STATUS CHANGED! (on Qt-land)')
- print('%s -> %s' % (status.previous, status.current))
- icon_name = self.conductor.get_icon_name()
- self.setIcon(icon_name)
- print 'icon = ', icon_name
-
- # change connection pixmap widget
- self.setConnWidget(icon_name)
-
- def setConnWidget(self, icon_name):
- #print 'changing icon to %s' % icon_name
- oldlayout = self.statusIconBox.layout()
-
- # XXX reuse with icons
- # XXX move states to StateWidget
- states = {"disconnected": 0,
- "connecting": 1,
- "connected": 2}
-
- for i in range(3):
- oldlayout.itemAt(i).widget().hide()
- new = states[icon_name]
- oldlayout.itemAt(new).widget().show()
-
- @pyqtSlot()
- def start_or_stopVPN(self):
- """
- stub for running child process with vpn
- """
- if self.vpn_service_started is False:
- self.conductor.connect()
- if self.debugmode:
- self.startStopButton.setText('&Disconnect')
- self.vpn_service_started = True
-
- # XXX what is optimum polling interval?
- # too little is overkill, too much
- # will miss transition states..
-
- self.timer.start(250.0)
- return
- if self.vpn_service_started is True:
- self.conductor.disconnect()
- # FIXME this should trigger also
- # statuschange event. why isn't working??
- if self.debugmode:
- self.startStopButton.setText('&Connect')
- self.vpn_service_started = False
- self.timer.stop()
- return
-
- @pyqtSlot()
- def onTimerTick(self):
- self.statusUpdate()
-
- @pyqtSlot()
- def statusUpdate(self):
- """
- called on timer tick
- polls status and updates ui with real time
- info about transferred bytes / connection state.
- """
- # XXX it's too expensive to poll
- # continously. move to signal events instead.
-
- if not self.vpn_service_started:
- return
-
- # XXX remove all access to manager layer
- # from here.
- if self.conductor.manager.with_errors:
- #XXX how to wait on pkexec???
- #something better that this workaround, plz!!
- time.sleep(10)
- print('errors. disconnect.')
- self.start_or_stopVPN() # is stop
-
- state = self.conductor.poll_connection_state()
- if not state:
- return
-
- ts, con_status, ok, ip, remote = state
- self.set_statusbarMessage(con_status)
- self.setToolTip()
-
- ts = time.strftime("%a %b %d %X", ts)
- if self.debugmode:
- self.updateTS.setText(ts)
- self.status_label.setText(con_status)
- self.ip_label.setText(ip)
- self.remote_label.setText(remote)
-
- # status i/o
-
- status = self.conductor.manager.get_status_io()
- if status and self.debugmode:
- #XXX move this to systray menu indicators
- ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) = status
- ts = time.strftime("%a %b %d %X", ts)
- self.updateTS.setText(ts)
- self.tun_read_bytes.setText(tun_read)
- self.tun_write_bytes.setText(tun_write)
-
- def cleanupAndQuit(self):
- """
- cleans state before shutting down app.
- """
- # TODO:make sure to shutdown all child process / threads
- # in conductor
- self.conductor.cleanup()
- qApp.quit()
+ def runchecks_and_eipconnect(self):
+ self.show_systray_icon()
+ self.initchecks.begin()
diff --git a/src/leap/baseapp/network.py b/src/leap/baseapp/network.py
new file mode 100644
index 00000000..077d5164
--- /dev/null
+++ b/src/leap/baseapp/network.py
@@ -0,0 +1,40 @@
+from __future__ import print_function
+
+import logging
+
+logger = logging.getLogger(name=__name__)
+
+from PyQt4 import QtCore
+
+from leap.baseapp.dialogs import ErrorDialog
+from leap.base.network import NetworkCheckerThread
+
+
+class NetworkCheckerAppMixin(object):
+ """
+ initialize an instance of the Network Checker,
+ which gathers error and passes them on.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.network_checker = NetworkCheckerThread(
+ error_cb=self.networkError.emit,
+ debug=self.debugmode)
+
+ # XXX move run_checks to slot
+ self.network_checker.run_checks()
+
+ @QtCore.pyqtSlot(object)
+ def onNetworkError(self, exc):
+ """
+ slot that receives a network exceptions
+ and raises a user error message
+ """
+ logger.debug('handling network exception')
+ logger.error(exc.message)
+ dialog = ErrorDialog(parent=self)
+
+ if exc.critical:
+ dialog.criticalMessage(exc.usermessage, "network error")
+ else:
+ dialog.warningMessage(exc.usermessage, "network error")
diff --git a/src/leap/baseapp/permcheck.py b/src/leap/baseapp/permcheck.py
new file mode 100644
index 00000000..6b74cb6e
--- /dev/null
+++ b/src/leap/baseapp/permcheck.py
@@ -0,0 +1,17 @@
+import commands
+import os
+
+from leap.util.fileutil import which
+
+
+def is_pkexec_in_system():
+ pkexec_path = which('pkexec')
+ if not pkexec_path:
+ return False
+ return os.access(pkexec_path, os.X_OK)
+
+
+def is_auth_agent_running():
+ return bool(
+ commands.getoutput(
+ 'ps aux | grep polkit-[g]nome-authentication-agent-1'))
diff --git a/src/leap/baseapp/systray.py b/src/leap/baseapp/systray.py
new file mode 100644
index 00000000..49f044aa
--- /dev/null
+++ b/src/leap/baseapp/systray.py
@@ -0,0 +1,245 @@
+import logging
+import sip
+sip.setapi('QString', 2)
+sip.setapi('QVariant', 2)
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap import __branding as BRANDING
+from leap import __version__ as VERSION
+
+from leap.gui import mainwindow_rc
+
+logger = logging.getLogger(__name__)
+
+
+class StatusAwareTrayIconMixin(object):
+ """
+ a mix of several functions needed
+ to create a systray and make it
+ get updated from conductor status
+ polling.
+ """
+ states = {
+ "disconnected": 0,
+ "connecting": 1,
+ "connected": 2}
+
+ iconpath = {
+ "disconnected": ':/images/conn_error.png',
+ "connecting": ':/images/conn_connecting.png',
+ "connected": ':/images/conn_connected.png'}
+
+ Icons = {
+ 'disconnected': lambda self: QtGui.QIcon(
+ self.iconpath['disconnected']),
+ 'connecting': lambda self: QtGui.QIcon(
+ self.iconpath['connecting']),
+ 'connected': lambda self: QtGui.QIcon(
+ self.iconpath['connected'])
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.createIconGroupBox()
+ self.createActions()
+ self.createTrayIcon()
+
+ # not sure if this really belongs here, but...
+ self.timer = QtCore.QTimer()
+
+ def show_systray_icon(self):
+ #logger.debug('showing tray icon................')
+ self.trayIcon.show()
+
+ def createIconGroupBox(self):
+ """
+ dummy icongroupbox
+ (to be removed from here -- reference only)
+ """
+ con_widgets = {
+ 'disconnected': QtGui.QLabel(),
+ 'connecting': QtGui.QLabel(),
+ 'connected': QtGui.QLabel(),
+ }
+ con_widgets['disconnected'].setPixmap(
+ QtGui.QPixmap(
+ self.iconpath['disconnected']))
+ con_widgets['connecting'].setPixmap(
+ QtGui.QPixmap(
+ self.iconpath['connecting']))
+ con_widgets['connected'].setPixmap(
+ QtGui.QPixmap(
+ self.iconpath['connected'])),
+ self.ConnectionWidgets = con_widgets
+
+ self.statusIconBox = QtGui.QGroupBox("EIP Connection Status")
+ statusIconLayout = QtGui.QHBoxLayout()
+ statusIconLayout.addWidget(self.ConnectionWidgets['disconnected'])
+ statusIconLayout.addWidget(self.ConnectionWidgets['connecting'])
+ statusIconLayout.addWidget(self.ConnectionWidgets['connected'])
+ statusIconLayout.itemAt(1).widget().hide()
+ statusIconLayout.itemAt(2).widget().hide()
+
+ self.leapConnStatus = QtGui.QLabel("<b>disconnected</b>")
+ statusIconLayout.addWidget(self.leapConnStatus)
+
+ self.statusIconBox.setLayout(statusIconLayout)
+
+ def createTrayIcon(self):
+ """
+ creates the tray icon
+ """
+ self.trayIconMenu = QtGui.QMenu(self)
+
+ self.trayIconMenu.addAction(self.connAct)
+ self.trayIconMenu.addSeparator()
+ self.trayIconMenu.addAction(self.detailsAct)
+ self.trayIconMenu.addSeparator()
+ self.trayIconMenu.addAction(self.aboutAct)
+ # we should get this hidden inside the "about" dialog
+ # (as a little button maybe)
+ #self.trayIconMenu.addAction(self.aboutQtAct)
+ self.trayIconMenu.addSeparator()
+ self.trayIconMenu.addAction(self.quitAction)
+
+ self.trayIcon = QtGui.QSystemTrayIcon(self)
+ self.setIcon('disconnected')
+ self.trayIcon.setContextMenu(self.trayIconMenu)
+
+ #self.trayIconMenu.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
+ #self.trayIconMenu.customContextMenuRequested.connect(
+ #self.on_context_menu)
+
+ def bad(self):
+ logger.error('this should not be called')
+
+ def createActions(self):
+ """
+ creates actions to be binded to tray icon
+ """
+ # XXX change action name on (dis)connect
+ self.connAct = QtGui.QAction("Encryption ON turn &off", self,
+ triggered=lambda: self.start_or_stopVPN())
+
+ self.detailsAct = QtGui.QAction("&Details...",
+ self,
+ triggered=self.detailsWin)
+ self.aboutAct = QtGui.QAction("&About", self,
+ triggered=self.about)
+ self.aboutQtAct = QtGui.QAction("About Q&t", self,
+ triggered=QtGui.qApp.aboutQt)
+ self.quitAction = QtGui.QAction("&Quit", self,
+ triggered=self.cleanupAndQuit)
+
+ def toggleEIPAct(self):
+ # this is too simple by now.
+ # XXX get STATUS CONSTANTS INSTEAD
+
+ icon_status = self.conductor.get_icon_name()
+ if icon_status == "connected":
+ self.connAct.setEnabled(True)
+ self.connAct.setText('Encryption ON turn o&ff')
+ return
+ if icon_status == "disconnected":
+ self.connAct.setEnabled(True)
+ self.connAct.setText('Encryption OFF turn &on')
+ return
+ if icon_status == "connecting":
+ self.connAct.setDisabled(True)
+ self.connAct.setText('connecting...')
+ return
+
+ def detailsWin(self):
+ visible = self.isVisible()
+ if visible:
+ self.hide()
+ else:
+ self.show()
+
+ def about(self):
+ # move to widget
+ flavor = BRANDING.get('short_name', None)
+ content = ("LEAP client<br>"
+ "(version <b>%s</b>)<br>" % VERSION)
+ if flavor:
+ content = content + ('<br>Flavor: <i>%s</i><br>' % flavor)
+ content = content + (
+ "<br><a href='https://leap.se/'>"
+ "https://leap.se</a>")
+ QtGui.QMessageBox.about(self, "About", content)
+
+ def setConnWidget(self, icon_name):
+ oldlayout = self.statusIconBox.layout()
+
+ for i in range(3):
+ oldlayout.itemAt(i).widget().hide()
+ new = self.states[icon_name]
+ oldlayout.itemAt(new).widget().show()
+
+ def setIcon(self, name):
+ icon_fun = self.Icons.get(name)
+ if icon_fun and callable(icon_fun):
+ icon = icon_fun(self)
+ self.trayIcon.setIcon(icon)
+
+ def getIcon(self, icon_name):
+ return self.states.get(icon_name, None)
+
+ def setIconToolTip(self):
+ """
+ get readable status and place it on systray tooltip
+ """
+ status = self.conductor.status.get_readable_status()
+ self.trayIcon.setToolTip(status)
+
+ def iconActivated(self, reason):
+ """
+ handles left click, left double click
+ showing the trayicon menu
+ """
+ if reason in (QtGui.QSystemTrayIcon.Trigger,
+ QtGui.QSystemTrayIcon.DoubleClick):
+ context_menu = self.trayIcon.contextMenu()
+ # for some reason, context_menu.show()
+ # is failing in a way beyond my understanding.
+ # (not working the first time it's clicked).
+ # this works however.
+ context_menu.exec_(self.trayIcon.geometry().center())
+
+ @QtCore.pyqtSlot()
+ def onTimerTick(self):
+ self.statusUpdate()
+
+ @QtCore.pyqtSlot(object)
+ def onOpenVPNStatusChange(self, status):
+ """
+ updates icon, according to the openvpn status change.
+ """
+ icon_name = self.conductor.get_icon_name()
+
+ # XXX refactor. Use QStateMachine
+
+ if icon_name in ("disconnected", "connected"):
+ self.eipStatusChange.emit(icon_name)
+
+ if icon_name in ("connecting"):
+ # let's see how it matches
+ leap_status_name = self.conductor.get_leap_status()
+ self.eipStatusChange.emit(leap_status_name)
+
+ self.setIcon(icon_name)
+ # change connection pixmap widget
+ self.setConnWidget(icon_name)
+
+ @QtCore.pyqtSlot(str)
+ def onEIPConnStatusChange(self, newstatus):
+ """
+ slot for EIP status changes
+ not to be confused with onOpenVPNStatusChange.
+ this only updates the non-debug LEAP Status line
+ next to the connection icon.
+ """
+ # XXX move bold to style sheet
+ self.leapConnStatus.setText(
+ "<b>%s</b>" % newstatus)
diff --git a/src/leap/certs/__init__.py b/src/leap/certs/__init__.py
new file mode 100644
index 00000000..c4d009b1
--- /dev/null
+++ b/src/leap/certs/__init__.py
@@ -0,0 +1,7 @@
+import os
+
+_where = os.path.split(__file__)[0]
+
+
+def where(filename):
+ return os.path.join(_where, filename)
diff --git a/src/leap/crypto/__init__.py b/src/leap/crypto/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/crypto/__init__.py
diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py
new file mode 100644
index 00000000..8908865d
--- /dev/null
+++ b/src/leap/crypto/certs.py
@@ -0,0 +1,71 @@
+import ctypes
+import socket
+
+import gnutls.connection
+import gnutls.crypto
+import gnutls.library
+
+
+def get_https_cert_from_domain(domain):
+ """
+ @param domain: a domain name to get a certificate from.
+ """
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ cred = gnutls.connection.X509Credentials()
+
+ session = gnutls.connection.ClientSession(sock, cred)
+ session.connect((domain, 443))
+ session.handshake()
+ cert = session.peer_certificate
+ return cert
+
+
+def get_cert_from_file(filepath):
+ with open(filepath) as f:
+ cert = gnutls.crypto.X509Certificate(f.read())
+ return cert
+
+
+def get_cert_fingerprint(domain=None, filepath=None,
+ hash_type="SHA256", sep=":"):
+ """
+ @param domain: a domain name to get a fingerprint from
+ @type domain: str
+ @param filepath: path to a file containing a PEM file
+ @type filepath: str
+ @param hash_type: the hash function to be used in the fingerprint.
+ must be one of SHA1, SHA224, SHA256, SHA384, SHA512
+ @type hash_type: str
+ @rparam: hex_fpr, a hexadecimal representation of a bytestring
+ containing the fingerprint.
+ @rtype: string
+ """
+ if domain:
+ cert = get_https_cert_from_domain(domain)
+ if filepath:
+ cert = get_cert_from_file(filepath)
+
+ _buffer = ctypes.create_string_buffer(64)
+ buffer_length = ctypes.c_size_t(64)
+
+ SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512")
+ if hash_type in SUPPORTED_DIGEST_FUN:
+ digestfunction = getattr(
+ gnutls.library.constants,
+ "GNUTLS_DIG_%s" % hash_type)
+ else:
+ # XXX improperlyconfigured or something
+ raise Exception("digest function not supported")
+
+ gnutls.library.functions.gnutls_x509_crt_get_fingerprint(
+ cert._c_object, digestfunction,
+ ctypes.byref(_buffer), ctypes.byref(buffer_length))
+
+ # deinit
+ #server_cert._X509Certificate__deinit(server_cert._c_object)
+ # needed? is segfaulting
+
+ fpr = ctypes.string_at(_buffer, buffer_length.value)
+ hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr)
+
+ return hex_fpr
diff --git a/src/leap/crypto/leapkeyring.py b/src/leap/crypto/leapkeyring.py
new file mode 100644
index 00000000..d4be7bf9
--- /dev/null
+++ b/src/leap/crypto/leapkeyring.py
@@ -0,0 +1,69 @@
+import keyring
+
+from leap.base.config import get_config_file
+
+#############
+# Disclaimer
+#############
+# This currently is not a keyring, it's more like a joke.
+# No, seriously.
+# We're affected by this **bug**
+
+# https://bitbucket.org/kang/python-keyring-lib/
+# issue/65/dbusexception-method-opensession-with
+
+# so using the gnome keyring does not seem feasible right now.
+# I thought this was the next best option to store secrets in plain sight.
+
+# in the future we should move to use the gnome/kde/macosx/win keyrings.
+
+
+class LeapCryptedFileKeyring(keyring.backend.CryptedFileKeyring):
+
+ filename = ".secrets"
+
+ @property
+ def file_path(self):
+ return get_config_file(self.filename)
+
+ def __init__(self, seed=None):
+ self.seed = seed
+
+ def _get_new_password(self):
+ # XXX every time this method is called,
+ # $deity kills a kitten.
+ return "secret%s" % self.seed
+
+ def _init_file(self):
+ self.keyring_key = self._get_new_password()
+ self.set_password('keyring_setting', 'pass_ref', 'pass_ref_value')
+
+ def _unlock(self):
+ self.keyring_key = self._get_new_password()
+ print 'keyring key ', self.keyring_key
+ try:
+ ref_pw = self.get_password(
+ 'keyring_setting',
+ 'pass_ref')
+ print 'ref pw ', ref_pw
+ assert ref_pw == "pass_ref_value"
+ except AssertionError:
+ self._lock()
+ raise ValueError('Incorrect password')
+
+
+def leap_set_password(key, value, seed="xxx"):
+ keyring.set_keyring(LeapCryptedFileKeyring(seed=seed))
+ keyring.set_password('leap', key, value)
+
+
+def leap_get_password(key, seed="xxx"):
+ keyring.set_keyring(LeapCryptedFileKeyring(seed=seed))
+ #import ipdb;ipdb.set_trace()
+ return keyring.get_password('leap', key)
+
+
+if __name__ == "__main__":
+ leap_set_password('test', 'bar')
+ passwd = leap_get_password('test')
+ assert passwd == 'bar'
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
new file mode 100644
index 00000000..116c535e
--- /dev/null
+++ b/src/leap/eip/checks.py
@@ -0,0 +1,518 @@
+import logging
+import ssl
+#import platform
+import time
+import os
+
+import gnutls.crypto
+#import netifaces
+#import ping
+import requests
+
+from leap import __branding as BRANDING
+from leap import certs as leapcerts
+from leap.base.auth import srpauth_protected, magick_srpauth
+from leap.base import config as baseconfig
+from leap.base import constants as baseconstants
+from leap.base import providers
+from leap.crypto import certs
+from leap.eip import config as eipconfig
+from leap.eip import constants as eipconstants
+from leap.eip import exceptions as eipexceptions
+from leap.eip import specs as eipspecs
+from leap.util.fileutil import mkdir_p
+
+logger = logging.getLogger(name=__name__)
+
+"""
+ProviderCertChecker
+-------------------
+Checks on certificates. To be moved to base.
+docs TBD
+
+EIPConfigChecker
+----------
+It is used from the eip conductor (a instance of EIPConnection that is
+managed from the QtApp), running `run_all` method before trying to call
+`connect` or any other of the state-changing methods.
+
+It checks that the needed files are provided or can be discovered over the
+net. Much of these tests are not specific to EIP module, and can be splitted
+into base.tests to be invoked by the base leap init routines.
+However, I'm testing them alltogether for the sake of having the whole unit
+reachable and testable as a whole.
+
+"""
+
+
+def get_branding_ca_cert(domain):
+ # XXX deprecated
+ ca_file = BRANDING.get('provider_ca_file')
+ if ca_file:
+ return leapcerts.where(ca_file)
+
+
+class ProviderCertChecker(object):
+ """
+ Several checks needed for getting
+ client certs and checking tls connection
+ with provider.
+ """
+ def __init__(self, fetcher=requests,
+ domain=None):
+
+ self.fetcher = fetcher
+ self.domain = domain
+ self.cacert = eipspecs.provider_ca_path(domain)
+
+ def run_all(
+ self, checker=None,
+ skip_download=False, skip_verify=False):
+
+ if not checker:
+ checker = self
+
+ do_verify = not skip_verify
+ logger.debug('do_verify: %s', do_verify)
+ # checker.download_ca_cert()
+
+ # For MVS+
+ # checker.download_ca_signature()
+ # checker.get_ca_signatures()
+ # checker.is_there_trust_path()
+
+ # For MVS
+ checker.is_there_provider_ca()
+
+ # XXX FAKE IT!!!
+ checker.is_https_working(verify=do_verify, autocacert=True)
+ checker.check_new_cert_needed(verify=do_verify)
+
+ def download_ca_cert(self, uri=None, verify=True):
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+
+ # should check domain exists
+ capath = self._get_ca_cert_path(self.domain)
+ with open(capath, 'w') as f:
+ f.write(req.content)
+
+ def check_ca_cert_fingerprint(
+ self, hash_type="SHA256",
+ fingerprint=None):
+ """
+ compares the fingerprint in
+ the ca cert with a string
+ we are passed
+ returns True if they are equal, False if not.
+ @param hash_type: digest function
+ @type hash_type: str
+ @param fingerprint: the fingerprint to compare with.
+ @type fingerprint: str (with : separator)
+ @rtype bool
+ """
+ ca_cert_path = self.ca_cert_path
+ ca_cert_fpr = certs.get_cert_fingerprint(
+ filepath=ca_cert_path)
+ return ca_cert_fpr == fingerprint
+
+ def verify_api_https(self, uri):
+ assert uri.startswith('https://')
+ cacert = self.ca_cert_path
+ verify = cacert and cacert or True
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+ return True
+
+ def download_ca_signature(self):
+ # MVS+
+ raise NotImplementedError
+
+ def get_ca_signatures(self):
+ # MVS+
+ raise NotImplementedError
+
+ def is_there_trust_path(self):
+ # MVS+
+ raise NotImplementedError
+
+ def is_there_provider_ca(self):
+ if not self.cacert:
+ return False
+ cacert_exists = os.path.isfile(self.cacert)
+ if cacert_exists:
+ logger.debug('True')
+ return True
+ logger.debug('False!')
+ return False
+
+ def is_https_working(
+ self, uri=None, verify=True,
+ autocacert=False):
+ if uri is None:
+ uri = self._get_root_uri()
+ # XXX raise InsecureURI or something better
+ try:
+ assert uri.startswith('https')
+ except AssertionError:
+ raise AssertionError(
+ "uri passed should start with https")
+ if autocacert and verify is True and self.cacert is not None:
+ logger.debug('verify cert: %s', self.cacert)
+ verify = self.cacert
+ #import pdb4qt; pdb4qt.set_trace()
+ logger.debug('is https working?')
+ logger.debug('uri: %s (verify:%s)', uri, verify)
+ try:
+ self.fetcher.get(uri, verify=verify)
+
+ except requests.exceptions.SSLError as exc:
+ logger.error("SSLError")
+ # XXX RAISE! See #638
+ #raise eipexceptions.HttpsBadCertError
+ logger.warning('BUG #638 CERT VERIFICATION FAILED! '
+ '(this should be CRITICAL)')
+ logger.warning('SSLError: %s', exc.message)
+
+ except requests.exceptions.ConnectionError:
+ logger.error('ConnectionError')
+ raise eipexceptions.HttpsNotSupported
+
+ else:
+ logger.debug('True')
+ return True
+
+ def check_new_cert_needed(self, skip_download=False, verify=True):
+ logger.debug('is new cert needed?')
+ if not self.is_cert_valid(do_raise=False):
+ logger.debug('True')
+ self.download_new_client_cert(
+ skip_download=skip_download,
+ verify=verify)
+ return True
+ logger.debug('False')
+ return False
+
+ def download_new_client_cert(self, uri=None, verify=True,
+ skip_download=False,
+ credentials=None):
+ logger.debug('download new client cert')
+ if skip_download:
+ return True
+ if uri is None:
+ uri = self._get_client_cert_uri()
+ # XXX raise InsecureURI or something better
+ assert uri.startswith('https')
+
+ if verify is True and self.cacert is not None:
+ verify = self.cacert
+
+ fgetfn = self.fetcher.get
+
+ if credentials:
+ user, passwd = credentials
+
+ logger.debug('domain = %s', self.domain)
+
+ @srpauth_protected(user, passwd,
+ server="https://%s" % self.domain,
+ verify=verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+
+ else:
+ # XXX FIXME fix decorated args
+ @magick_srpauth(verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+ try:
+
+ # XXX FIXME!!!!
+ # verify=verify
+ # Workaround for #638. return to verification
+ # when That's done!!!
+ #req = self.fetcher.get(uri, verify=False)
+ req = getfn(uri, verify=False)
+ req.raise_for_status()
+
+ except requests.exceptions.SSLError:
+ logger.warning('SSLError while fetching cert. '
+ 'Look below for stack trace.')
+ # XXX raise better exception
+ raise
+ try:
+ pemfile_content = req.content
+ self.is_valid_pemfile(pemfile_content)
+ cert_path = self._get_client_cert_path()
+ self.write_cert(pemfile_content, to=cert_path)
+ except:
+ logger.warning('Error while validating cert')
+ raise
+ return True
+
+ def is_cert_valid(self, cert_path=None, do_raise=True):
+ exists = lambda: self.is_certificate_exists()
+ valid_pemfile = lambda: self.is_valid_pemfile()
+ not_expired = lambda: self.is_cert_not_expired()
+
+ valid = exists() and valid_pemfile() and not_expired()
+ if not valid:
+ if do_raise:
+ raise Exception('missing valid cert')
+ else:
+ return False
+ return True
+
+ def is_certificate_exists(self, certfile=None):
+ if certfile is None:
+ certfile = self._get_client_cert_path()
+ return os.path.isfile(certfile)
+
+ def is_cert_not_expired(self, certfile=None, now=time.gmtime):
+ if certfile is None:
+ certfile = self._get_client_cert_path()
+ with open(certfile) as cf:
+ cert_s = cf.read()
+ cert = gnutls.crypto.X509Certificate(cert_s)
+ from_ = time.gmtime(cert.activation_time)
+ to_ = time.gmtime(cert.expiration_time)
+ return from_ < now() < to_
+
+ def is_valid_pemfile(self, cert_s=None):
+ """
+ checks that the passed string
+ is a valid pem certificate
+ @param cert_s: string containing pem content
+ @type cert_s: string
+ @rtype: bool
+ """
+ if cert_s is None:
+ certfile = self._get_client_cert_path()
+ with open(certfile) as cf:
+ cert_s = cf.read()
+ try:
+ # XXX get a real cert validation
+ # so far this is only checking begin/end
+ # delimiters :)
+ # XXX use gnutls for get proper
+ # validation.
+ # crypto.X509Certificate(cert_s)
+ sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5
+ # we might have private key and cert in the same file
+ certparts = cert_s.split(sep)
+ if len(certparts) > 1:
+ cert_s = sep + certparts[1]
+ ssl.PEM_cert_to_DER_cert(cert_s)
+ except:
+ # XXX raise proper exception
+ raise
+ return True
+
+ @property
+ def ca_cert_path(self):
+ return self._get_ca_cert_path(self.domain)
+
+ def _get_root_uri(self):
+ return u"https://%s/" % self.domain
+
+ def _get_client_cert_uri(self):
+ # XXX get the whole thing from constants
+ return "https://%s/1/cert" % self.domain
+
+ def _get_client_cert_path(self):
+ return eipspecs.client_cert_path(domain=self.domain)
+
+ def _get_ca_cert_path(self, domain):
+ # XXX this folder path will be broken for win
+ # and this should be moved to eipspecs.ca_path
+
+ # XXX use baseconfig.get_provider_path(folder=Foo)
+ # !!!
+
+ capath = baseconfig.get_config_file(
+ 'cacert.pem',
+ folder='providers/%s/keys/ca' % domain)
+ folder, fname = os.path.split(capath)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ return capath
+
+ def write_cert(self, pemfile_content, to=None):
+ folder, filename = os.path.split(to)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ with open(to, 'w') as cert_f:
+ cert_f.write(pemfile_content)
+
+
+class EIPConfigChecker(object):
+ """
+ Several checks needed
+ to ensure a EIPConnection
+ can be sucessfully established.
+ use run_all to run all checks.
+ """
+
+ def __init__(self, fetcher=requests, domain=None):
+ # we do not want to accept too many
+ # argument on init.
+ # we want tests
+ # to be explicitely run.
+
+ self.fetcher = fetcher
+
+ # if not domain, get from config
+ self.domain = domain
+
+ self.eipconfig = eipconfig.EIPConfig(domain=domain)
+ self.defaultprovider = providers.LeapProviderDefinition(domain=domain)
+ self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)
+
+ def run_all(self, checker=None, skip_download=False):
+ """
+ runs all checks in a row.
+ will raise if some error encountered.
+ catching those exceptions is not
+ our responsibility at this moment
+ """
+ if not checker:
+ checker = self
+
+ # let's call all tests
+ # needed for a sane eip session.
+
+ # TODO: get rid of check_default.
+ # check_complete should
+ # be enough. but here to make early tests easier.
+ checker.check_default_eipconfig()
+
+ checker.check_is_there_default_provider()
+ checker.fetch_definition(skip_download=skip_download)
+ checker.fetch_eip_service_config(skip_download=skip_download)
+ checker.check_complete_eip_config()
+ #checker.ping_gateway()
+
+ # public checks
+
+ def check_default_eipconfig(self):
+ """
+ checks if default eipconfig exists,
+ and dumps a default file if not
+ """
+ # XXX ONLY a transient check
+ # because some old function still checks
+ # for eip config at the beginning.
+
+ # it *really* does not make sense to
+ # dump it right now, we can get an in-memory
+ # config object and dump it to disk in a
+ # later moment
+ logger.debug('checking default eip config')
+ if not self._is_there_default_eipconfig():
+ self._dump_default_eipconfig()
+
+ def check_is_there_default_provider(self, config=None):
+ """
+ raises EIPMissingDefaultProvider if no
+ default provider found on eip config.
+ This is catched by ui and runs FirstRunWizard (MVS+)
+ """
+ if config is None:
+ config = self.eipconfig.config
+ logger.debug('checking default provider')
+ provider = config.get('provider', None)
+ if provider is None:
+ raise eipexceptions.EIPMissingDefaultProvider
+ # XXX raise also if malformed ProviderDefinition?
+ return True
+
+ def fetch_definition(self, skip_download=False,
+ config=None, uri=None,
+ domain=None):
+ """
+ fetches a definition file from server
+ """
+ # TODO:
+ # - Implement diff
+ # - overwrite only if different.
+ # (attend to serial field different, for instance)
+
+ logger.debug('fetching definition')
+
+ if skip_download:
+ logger.debug('(fetching def skipped)')
+ return True
+ if config is None:
+ config = self.defaultprovider.config
+ if uri is None:
+ if not domain:
+ domain = config.get('provider', None)
+ uri = self._get_provider_definition_uri(domain=domain)
+
+ # FIXME! Pass ca path verify!!!
+ # BUG #638
+ # FIXME FIXME FIXME
+ self.defaultprovider.load(
+ from_uri=uri,
+ fetcher=self.fetcher,
+ verify=False)
+ self.defaultprovider.save()
+
+ def fetch_eip_service_config(self, skip_download=False,
+ config=None, uri=None, domain=None):
+ if skip_download:
+ return True
+ if config is None:
+ config = self.eipserviceconfig.config
+ if uri is None:
+ if not domain:
+ domain = self.domain or config.get('provider', None)
+ uri = self._get_eip_service_uri(domain=domain)
+
+ self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher)
+ self.eipserviceconfig.save()
+
+ def check_complete_eip_config(self, config=None):
+ # TODO check for gateway
+ if config is None:
+ config = self.eipconfig.config
+ try:
+ 'trying assertions'
+ assert 'provider' in config
+ assert config['provider'] is not None
+ # XXX assert there is gateway !!
+ except AssertionError:
+ raise eipexceptions.EIPConfigurationError
+
+ # XXX TODO:
+ # We should WRITE eip config if missing or
+ # incomplete at this point
+ #self.eipconfig.save()
+
+ #
+ # private helpers
+ #
+
+ def _is_there_default_eipconfig(self):
+ return self.eipconfig.exists()
+
+ def _dump_default_eipconfig(self):
+ self.eipconfig.save()
+
+ def _get_provider_definition_uri(self, domain=None, path=None):
+ if domain is None:
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
+ if path is None:
+ path = baseconstants.DEFINITION_EXPECTED_PATH
+ uri = u"https://%s/%s" % (domain, path)
+ logger.debug('getting provider definition from %s' % uri)
+ return uri
+
+ def _get_eip_service_uri(self, domain=None, path=None):
+ if domain is None:
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
+ if path is None:
+ path = eipconstants.EIP_SERVICE_EXPECTED_PATH
+ uri = "https://%s/%s" % (domain, path)
+ logger.debug('getting eip service file from %s', uri)
+ return uri
diff --git a/src/leap/eip/conductor.py b/src/leap/eip/conductor.py
deleted file mode 100644
index e3adadc4..00000000
--- a/src/leap/eip/conductor.py
+++ /dev/null
@@ -1,272 +0,0 @@
-"""
-stablishes a vpn connection and monitors its state
-"""
-from __future__ import (division, unicode_literals, print_function)
-#import threading
-from functools import partial
-import logging
-
-from leap.utils.coroutines import spawn_and_watch_process
-from leap.baseapp.config import get_config, get_vpn_stdout_mockup
-from leap.eip.vpnwatcher import EIPConnectionStatus, status_watcher
-from leap.eip.vpnmanager import OpenVPNManager, ConnectionRefusedError
-
-logger = logging.getLogger(name=__name__)
-
-
-# TODO Move exceptions to their own module
-
-
-class ConnectionError(Exception):
- """
- generic connection error
- """
- pass
-
-
-class EIPClientError(Exception):
- """
- base EIPClient exception
- """
- def __str__(self):
- if len(self.args) >= 1:
- return repr(self.args[0])
- else:
- return ConnectionError
-
-
-class UnrecoverableError(EIPClientError):
- """
- we cannot do anything about it, sorry
- """
- pass
-
-
-class OpenVPNConnection(object):
- """
- All related to invocation
- of the openvpn binary
- """
- # Connection Methods
-
- def __init__(self, config_file=None, watcher_cb=None):
- #XXX FIXME
- #change watcher_cb to line_observer
- """
- :param config_file: configuration file to read from
- :param watcher_cb: callback to be \
-called for each line in watched stdout
- :param signal_map: dictionary of signal names and callables \
-to be triggered for each one of them.
- :type config_file: str
- :type watcher_cb: function
- :type signal_map: dict
- """
- # XXX get host/port from config
- self.manager = OpenVPNManager()
-
- self.config_file = config_file
- self.watcher_cb = watcher_cb
- #self.signal_maps = signal_maps
-
- self.subp = None
- self.watcher = None
-
- self.server = None
- self.port = None
- self.proto = None
-
- self.autostart = True
-
- self._get_config()
-
- def _set_command_mockup(self):
- """
- sets command and args for a command mockup
- that just mimics the output from the real thing
- """
- command, args = get_vpn_stdout_mockup()
- self.command, self.args = command, args
-
- def _get_config(self):
- """
- retrieves the config options from defaults or
- home file, or config file passed in command line.
- """
- config = get_config(config_file=self.config_file)
- self.config = config
-
- if config.has_option('openvpn', 'command'):
- commandline = config.get('openvpn', 'command')
- if commandline == "mockup":
- self._set_command_mockup()
- return
- command_split = commandline.split(' ')
- command = command_split[0]
- if len(command_split) > 1:
- args = command_split[1:]
- else:
- args = []
- self.command = command
- #print("debug: command = %s" % command)
- self.args = args
- else:
- self._set_command_mockup()
-
- if config.has_option('openvpn', 'autostart'):
- autostart = config.get('openvpn', 'autostart')
- self.autostart = autostart
-
- def _launch_openvpn(self):
- """
- invocation of openvpn binaries in a subprocess.
- """
- #XXX TODO:
- #deprecate watcher_cb,
- #use _only_ signal_maps instead
-
- if self.watcher_cb is not None:
- linewrite_callback = self.watcher_cb
- else:
- #XXX get logger instead
- linewrite_callback = lambda line: print('watcher: %s' % line)
-
- observers = (linewrite_callback,
- partial(status_watcher, self.status))
- subp, watcher = spawn_and_watch_process(
- self.command,
- self.args,
- observers=observers)
- self.subp = subp
- self.watcher = watcher
-
- conn_result = self.status.CONNECTED
- return conn_result
-
- def _try_connection(self):
- """
- attempts to connect
- """
- if self.subp is not None:
- print('cowardly refusing to launch subprocess again')
- return
- self._launch_openvpn()
-
- def cleanup(self):
- """
- terminates child subprocess
- """
- if self.subp:
- self.subp.terminate()
-
-
-class EIPConductor(OpenVPNConnection):
- """
- Manages the execution of the OpenVPN process, auto starts, monitors the
- network connection, handles configuration, fixes leaky hosts, handles
- errors, etc.
- Preferences will be stored via the Storage API. (TBD)
- Status updates (connected, bandwidth, etc) are signaled to the GUI.
- """
-
- def __init__(self, *args, **kwargs):
- self.settingsfile = kwargs.get('settingsfile', None)
- self.logfile = kwargs.get('logfile', None)
- self.error_queue = []
- self.desired_con_state = None # ???
-
- status_signals = kwargs.pop('status_signals', None)
- self.status = EIPConnectionStatus(callbacks=status_signals)
-
- super(EIPConductor, self).__init__(*args, **kwargs)
-
- def connect(self):
- """
- entry point for connection process
- """
- self.manager.forget_errors()
- self._try_connection()
- # XXX should capture errors?
-
- def disconnect(self):
- """
- disconnects client
- """
- self._disconnect()
- self.status.change_to(self.status.DISCONNECTED)
- pass
-
- def shutdown(self):
- """
- shutdown and quit
- """
- self.desired_con_state = self.status.DISCONNECTED
-
- def connection_state(self):
- """
- returns the current connection state
- """
- return self.status.current
-
- def desired_connection_state(self):
- """
- returns the desired_connection state
- """
- return self.desired_con_state
-
- def poll_connection_state(self):
- """
- """
- try:
- state = self.manager.get_connection_state()
- except ConnectionRefusedError:
- # connection refused. might be not ready yet.
- return
- if not state:
- return
- (ts, status_step,
- ok, ip, remote) = state
- self.status.set_vpn_state(status_step)
- status_step = self.status.get_readable_status()
- return (ts, status_step, ok, ip, remote)
-
- def get_icon_name(self):
- """
- get icon name from status object
- """
- return self.status.get_state_icon()
-
- #
- # private methods
- #
-
- def _disconnect(self):
- """
- private method for disconnecting
- """
- if self.subp is not None:
- self.subp.terminate()
- self.subp = None
- # XXX signal state changes! :)
-
- def _is_alive(self):
- """
- don't know yet
- """
- pass
-
- def _connect(self):
- """
- entry point for connection cascade methods.
- """
- #conn_result = ConState.DISCONNECTED
- try:
- conn_result = self._try_connection()
- except UnrecoverableError as except_msg:
- logger.error("FATAL: %s" % unicode(except_msg))
- conn_result = self.status.UNRECOVERABLE
- except Exception as except_msg:
- self.error_queue.append(except_msg)
- logger.error("Failed Connection: %s" %
- unicode(except_msg))
- return conn_result
diff --git a/src/leap/eip/config.py b/src/leap/eip/config.py
new file mode 100644
index 00000000..42c00380
--- /dev/null
+++ b/src/leap/eip/config.py
@@ -0,0 +1,303 @@
+import logging
+import os
+import platform
+import tempfile
+
+from leap import __branding as BRANDING
+from leap import certs
+from leap.util.fileutil import (which, mkdir_p, check_and_fix_urw_only)
+
+from leap.base import config as baseconfig
+from leap.baseapp.permcheck import (is_pkexec_in_system,
+ is_auth_agent_running)
+from leap.eip import exceptions as eip_exceptions
+from leap.eip import specs as eipspecs
+
+logger = logging.getLogger(name=__name__)
+provider_ca_file = BRANDING.get('provider_ca_file', None)
+
+
+class EIPConfig(baseconfig.JSONLeapConfig):
+ spec = eipspecs.eipconfig_spec
+
+ def _get_slug(self):
+ eipjsonpath = baseconfig.get_config_file(
+ 'eip.json')
+ return eipjsonpath
+
+ def _set_slug(self, *args, **kwargs):
+ raise AttributeError("you cannot set slug")
+
+ slug = property(_get_slug, _set_slug)
+
+
+class EIPServiceConfig(baseconfig.JSONLeapConfig):
+ spec = eipspecs.eipservice_config_spec
+
+ def _get_slug(self):
+ domain = getattr(self, 'domain', None)
+ if domain:
+ path = baseconfig.get_provider_path(domain)
+ else:
+ path = baseconfig.get_default_provider_path()
+ return baseconfig.get_config_file(
+ 'eip-service.json', folder=path)
+
+ def _set_slug(self):
+ raise AttributeError("you cannot set slug")
+
+ slug = property(_get_slug, _set_slug)
+
+
+def get_socket_path():
+ socket_path = os.path.join(
+ tempfile.mkdtemp(prefix="leap-tmp"),
+ 'openvpn.socket')
+ logger.debug('socket path: %s', socket_path)
+ return socket_path
+
+
+def get_eip_gateway(provider=None):
+ """
+ return the first host in eip service config
+ that matches the name defined in the eip.json config
+ file.
+ """
+ placeholder = "testprovider.example.org"
+ # XXX check for null on provider??
+
+ eipconfig = EIPConfig(domain=provider)
+ eipconfig.load()
+ conf = eipconfig.config
+
+ primary_gateway = conf.get('primary_gateway', None)
+ if not primary_gateway:
+ return placeholder
+
+ eipserviceconfig = EIPServiceConfig(domain=provider)
+ eipserviceconfig.load()
+ eipsconf = eipserviceconfig.get_config()
+ gateways = eipsconf.get('gateways', None)
+ if not gateways:
+ logger.error('missing gateways in eip service config')
+ return placeholder
+ if len(gateways) > 0:
+ for gw in gateways:
+ name = gw.get('name', None)
+ if not name:
+ return
+
+ if name == primary_gateway:
+ hosts = gw.get('hosts', None)
+ if not hosts:
+ logger.error('no hosts')
+ return
+ if len(hosts) > 0:
+ return hosts[0]
+ else:
+ logger.error('no hosts')
+ logger.error('could not find primary gateway in provider'
+ 'gateway list')
+
+
+def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
+ """
+ build a list of options
+ to be passed in the
+ openvpn invocation
+ @rtype: list
+ @rparam: options
+ """
+ # XXX review which of the
+ # options we don't need.
+
+ # TODO pass also the config file,
+ # since we will need to take some
+ # things from there if present.
+
+ provider = kwargs.pop('provider', None)
+
+ # get user/group name
+ # also from config.
+ user = baseconfig.get_username()
+ group = baseconfig.get_groupname()
+
+ opts = []
+
+ opts.append('--client')
+
+ opts.append('--dev')
+ # XXX same in win?
+ opts.append('tun')
+ opts.append('--persist-tun')
+ opts.append('--persist-key')
+
+ verbosity = kwargs.get('ovpn_verbosity', None)
+ if verbosity and 1 <= verbosity <= 6:
+ opts.append('--verb')
+ opts.append("%s" % verbosity)
+
+ # remote
+ opts.append('--remote')
+ gw = get_eip_gateway(provider=provider)
+ logger.debug('setting eip gateway to %s', gw)
+ opts.append(str(gw))
+ opts.append('1194')
+ #opts.append('80')
+ opts.append('udp')
+
+ opts.append('--tls-client')
+ opts.append('--remote-cert-tls')
+ opts.append('server')
+
+ # set user and group
+ opts.append('--user')
+ opts.append('%s' % user)
+ opts.append('--group')
+ opts.append('%s' % group)
+
+ opts.append('--management-client-user')
+ opts.append('%s' % user)
+ opts.append('--management-signal')
+
+ # set default options for management
+ # interface. unix sockets or telnet interface for win.
+ # XXX take them from the config object.
+
+ ourplatform = platform.system()
+ if ourplatform in ("Linux", "Mac"):
+ opts.append('--management')
+
+ if socket_path is None:
+ socket_path = get_socket_path()
+ opts.append(socket_path)
+ opts.append('unix')
+
+ if ourplatform == "Windows":
+ opts.append('--management')
+ opts.append('localhost')
+ # XXX which is a good choice?
+ opts.append('7777')
+
+ # certs
+ client_cert_path = eipspecs.client_cert_path(provider)
+ ca_cert_path = eipspecs.provider_ca_path(provider)
+
+ opts.append('--cert')
+ opts.append(client_cert_path)
+ opts.append('--key')
+ opts.append(client_cert_path)
+ opts.append('--ca')
+ opts.append(ca_cert_path)
+
+ # we cannot run in daemon mode
+ # with the current subp setting.
+ # see: https://leap.se/code/issues/383
+ #if daemon is True:
+ #opts.append('--daemon')
+
+ logger.debug('vpn options: %s', opts)
+ return opts
+
+
+def build_ovpn_command(debug=False, do_pkexec_check=True, vpnbin=None,
+ socket_path=None, **kwargs):
+ """
+ build a string with the
+ complete openvpn invocation
+
+ @rtype [string, [list of strings]]
+ @rparam: a list containing the command string
+ and a list of options.
+ """
+ command = []
+ use_pkexec = True
+ ovpn = None
+
+ # XXX get use_pkexec from config instead.
+
+ if platform.system() == "Linux" and use_pkexec and do_pkexec_check:
+
+ # check for both pkexec
+ # AND a suitable authentication
+ # agent running.
+ logger.info('use_pkexec set to True')
+
+ if not is_pkexec_in_system():
+ logger.error('no pkexec in system')
+ raise eip_exceptions.EIPNoPkexecAvailable
+
+ if not is_auth_agent_running():
+ logger.warning(
+ "no polkit auth agent found. "
+ "pkexec will use its own text "
+ "based authentication agent. "
+ "that's probably a bad idea")
+ raise eip_exceptions.EIPNoPolkitAuthAgentAvailable
+
+ command.append('pkexec')
+ if vpnbin is None:
+ ovpn = which('openvpn')
+ else:
+ ovpn = vpnbin
+ if ovpn:
+ vpn_command = ovpn
+ else:
+ vpn_command = "openvpn"
+ command.append(vpn_command)
+ daemon_mode = not debug
+
+ for opt in build_ovpn_options(daemon=daemon_mode, socket_path=socket_path,
+ **kwargs):
+ command.append(opt)
+
+ # XXX check len and raise proper error
+
+ return [command[0], command[1:]]
+
+
+def check_vpn_keys(provider=None):
+ """
+ performs an existance and permission check
+ over the openvpn keys file.
+ Currently we're expecting a single file
+ per provider, containing the CA cert,
+ the provider key, and our client certificate
+ """
+ assert provider is not None
+ provider_ca = eipspecs.provider_ca_path(provider)
+ client_cert = eipspecs.client_cert_path(provider)
+
+ logger.debug('provider ca = %s', provider_ca)
+ logger.debug('client cert = %s', client_cert)
+
+ # if no keys, raise error.
+ # it's catched by the ui and signal user.
+
+ if not os.path.isfile(provider_ca):
+ # not there. let's try to copy.
+ folder, filename = os.path.split(provider_ca)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ if provider_ca_file:
+ cacert = certs.where(provider_ca_file)
+ with open(provider_ca, 'w') as pca:
+ with open(cacert, 'r') as cac:
+ pca.write(cac.read())
+
+ if not os.path.isfile(provider_ca):
+ logger.error('key file %s not found. aborting.',
+ provider_ca)
+ raise eip_exceptions.EIPInitNoKeyFileError
+
+ if not os.path.isfile(client_cert):
+ logger.error('key file %s not found. aborting.',
+ client_cert)
+ raise eip_exceptions.EIPInitNoKeyFileError
+
+ for keyfile in (provider_ca, client_cert):
+ # bad perms? try to fix them
+ try:
+ check_and_fix_urw_only(keyfile)
+ except OSError:
+ raise eip_exceptions.EIPInitBadKeyFilePermError
diff --git a/src/leap/eip/constants.py b/src/leap/eip/constants.py
new file mode 100644
index 00000000..9af5a947
--- /dev/null
+++ b/src/leap/eip/constants.py
@@ -0,0 +1,3 @@
+# not used anymore with the new JSONConfig.slug
+EIP_CONFIG = "eip.json"
+EIP_SERVICE_EXPECTED_PATH = "1/config/eip-service.json"
diff --git a/src/leap/eip/eipconnection.py b/src/leap/eip/eipconnection.py
new file mode 100644
index 00000000..7828c864
--- /dev/null
+++ b/src/leap/eip/eipconnection.py
@@ -0,0 +1,350 @@
+"""
+EIP Connection Class
+"""
+from __future__ import (absolute_import,)
+import logging
+import Queue
+import sys
+
+from leap.eip.checks import ProviderCertChecker
+from leap.eip.checks import EIPConfigChecker
+from leap.eip import config as eipconfig
+from leap.eip import exceptions as eip_exceptions
+from leap.eip.openvpnconnection import OpenVPNConnection
+
+logger = logging.getLogger(name=__name__)
+
+
+class EIPConnection(OpenVPNConnection):
+ """
+ Manages the execution of the OpenVPN process, auto starts, monitors the
+ network connection, handles configuration, fixes leaky hosts, handles
+ errors, etc.
+ Status updates (connected, bandwidth, etc) are signaled to the GUI.
+ """
+
+ def __init__(self,
+ provider_cert_checker=ProviderCertChecker,
+ config_checker=EIPConfigChecker,
+ *args, **kwargs):
+ self.settingsfile = kwargs.get('settingsfile', None)
+ self.logfile = kwargs.get('logfile', None)
+ self.provider = kwargs.pop('provider', None)
+ self._providercertchecker = provider_cert_checker
+ self._configchecker = config_checker
+
+ self.error_queue = Queue.Queue()
+
+ status_signals = kwargs.pop('status_signals', None)
+ self.status = EIPConnectionStatus(callbacks=status_signals)
+
+ checker_signals = kwargs.pop('checker_signals', None)
+ self.checker_signals = checker_signals
+
+ self.init_checkers()
+
+ host = eipconfig.get_socket_path()
+ kwargs['host'] = host
+
+ super(EIPConnection, self).__init__(*args, **kwargs)
+
+ def has_errors(self):
+ return True if self.error_queue.qsize() != 0 else False
+
+ def init_checkers(self):
+ # initialize checkers
+ self.provider_cert_checker = self._providercertchecker(
+ domain=self.provider)
+ self.config_checker = self._configchecker(domain=self.provider)
+
+ def set_provider_domain(self, domain):
+ """
+ sets the provider domain.
+ used from the first run wizard when we launch the run_checks
+ and connect process after having initialized the conductor.
+ """
+ # This looks convoluted, right.
+ # We have to reinstantiate checkers cause we're passing
+ # the domain param that we did not know at the beginning
+ # (only for the firstrunwizard case)
+ self.provider = domain
+ self.init_checkers()
+
+ def run_checks(self, skip_download=False, skip_verify=False):
+ """
+ run all eip checks previous to attempting a connection
+ """
+ logger.debug('running conductor checks')
+
+ def push_err(exc):
+ # keep the original traceback!
+ exc_traceback = sys.exc_info()[2]
+ self.error_queue.put((exc, exc_traceback))
+
+ try:
+ # network (1)
+ if self.checker_signals:
+ for signal in self.checker_signals:
+ signal('checking encryption keys')
+ self.provider_cert_checker.run_all(skip_verify=skip_verify)
+ except Exception as exc:
+ push_err(exc)
+ try:
+ if self.checker_signals:
+ for signal in self.checker_signals:
+ signal('checking provider config')
+ self.config_checker.run_all(skip_download=skip_download)
+ except Exception as exc:
+ push_err(exc)
+ try:
+ self.run_openvpn_checks()
+ except Exception as exc:
+ push_err(exc)
+
+ def connect(self):
+ """
+ entry point for connection process
+ """
+ #self.forget_errors()
+ self._try_connection()
+
+ def disconnect(self):
+ """
+ disconnects client
+ """
+ self.cleanup()
+ logger.debug("disconnect: clicked.")
+ self.status.change_to(self.status.DISCONNECTED)
+
+ #def shutdown(self):
+ #"""
+ #shutdown and quit
+ #"""
+ #self.desired_con_state = self.status.DISCONNECTED
+
+ def connection_state(self):
+ """
+ returns the current connection state
+ """
+ return self.status.current
+
+ def poll_connection_state(self):
+ """
+ """
+ try:
+ state = self.get_connection_state()
+ except eip_exceptions.ConnectionRefusedError:
+ # connection refused. might be not ready yet.
+ logger.warning('connection refused')
+ return
+ if not state:
+ logger.debug('no state')
+ return
+ (ts, status_step,
+ ok, ip, remote) = state
+ self.status.set_vpn_state(status_step)
+ status_step = self.status.get_readable_status()
+ return (ts, status_step, ok, ip, remote)
+
+ def get_icon_name(self):
+ """
+ get icon name from status object
+ """
+ return self.status.get_state_icon()
+
+ def get_leap_status(self):
+ return self.status.get_leap_status()
+
+ #
+ # private methods
+ #
+
+ #def _disconnect(self):
+ # """
+ # private method for disconnecting
+ # """
+ # if self.subp is not None:
+ # logger.debug('disconnecting...')
+ # self.subp.terminate()
+ # self.subp = None
+
+ #def _is_alive(self):
+ #"""
+ #don't know yet
+ #"""
+ #pass
+
+ def _connect(self):
+ """
+ entry point for connection cascade methods.
+ """
+ try:
+ conn_result = self._try_connection()
+ except eip_exceptions.UnrecoverableError as except_msg:
+ logger.error("FATAL: %s" % unicode(except_msg))
+ conn_result = self.status.UNRECOVERABLE
+
+ # XXX enqueue exceptions themselves instead?
+ except Exception as except_msg:
+ self.error_queue.append(except_msg)
+ logger.error("Failed Connection: %s" %
+ unicode(except_msg))
+ return conn_result
+
+
+class EIPConnectionStatus(object):
+ """
+ Keep track of client (gui) and openvpn
+ states.
+
+ These are the OpenVPN states:
+ CONNECTING -- OpenVPN's initial state.
+ WAIT -- (Client only) Waiting for initial response
+ from server.
+ AUTH -- (Client only) Authenticating with server.
+ GET_CONFIG -- (Client only) Downloading configuration options
+ from server.
+ ASSIGN_IP -- Assigning IP address to virtual network
+ interface.
+ ADD_ROUTES -- Adding routes to system.
+ CONNECTED -- Initialization Sequence Completed.
+ RECONNECTING -- A restart has occurred.
+ EXITING -- A graceful exit is in progress.
+
+ We add some extra states:
+
+ DISCONNECTED -- GUI initial state.
+ UNRECOVERABLE -- An unrecoverable error has been raised
+ while invoking openvpn service.
+ """
+ CONNECTING = 1
+ WAIT = 2
+ AUTH = 3
+ GET_CONFIG = 4
+ ASSIGN_IP = 5
+ ADD_ROUTES = 6
+ CONNECTED = 7
+ RECONNECTING = 8
+ EXITING = 9
+
+ # gui specific states:
+ UNRECOVERABLE = 11
+ DISCONNECTED = 0
+
+ def __init__(self, callbacks=None):
+ """
+ EIPConnectionStatus is initialized with a tuple
+ of signals to be triggered.
+ :param callbacks: a tuple of (callable) observers
+ :type callbacks: tuple
+ """
+ self.current = self.DISCONNECTED
+ self.previous = None
+ # (callbacks to connect to signals in Qt-land)
+ self.callbacks = callbacks
+
+ def get_readable_status(self):
+ # XXX DRY status / labels a little bit.
+ # think we'll want to i18n this.
+ human_status = {
+ 0: 'disconnected',
+ 1: 'connecting',
+ 2: 'waiting',
+ 3: 'authenticating',
+ 4: 'getting config',
+ 5: 'assigning ip',
+ 6: 'adding routes',
+ 7: 'connected',
+ 8: 'reconnecting',
+ 9: 'exiting',
+ 11: 'unrecoverable error',
+ }
+ return human_status[self.current]
+
+ def get_leap_status(self):
+ # XXX improve nomenclature
+ leap_status = {
+ 0: 'disconnected',
+ 1: 'connecting to gateway',
+ 2: 'connecting to gateway',
+ 3: 'authenticating',
+ 4: 'establishing network encryption',
+ 5: 'establishing network encryption',
+ 6: 'establishing network encryption',
+ 7: 'connected',
+ 8: 'reconnecting',
+ 9: 'exiting',
+ 11: 'unrecoverable error',
+ }
+ return leap_status[self.current]
+
+ def get_state_icon(self):
+ """
+ returns the high level icon
+ for each fine-grain openvpn state
+ """
+ connecting = (self.CONNECTING,
+ self.WAIT,
+ self.AUTH,
+ self.GET_CONFIG,
+ self.ASSIGN_IP,
+ self.ADD_ROUTES)
+ connected = (self.CONNECTED,)
+ disconnected = (self.DISCONNECTED,
+ self.UNRECOVERABLE)
+
+ # this can be made smarter,
+ # but it's like it'll change,
+ # so +readability.
+
+ if self.current in connecting:
+ return "connecting"
+ if self.current in connected:
+ return "connected"
+ if self.current in disconnected:
+ return "disconnected"
+
+ def set_vpn_state(self, status):
+ """
+ accepts a state string from the management
+ interface, and sets the internal state.
+ :param status: openvpn STATE (uppercase).
+ :type status: str
+ """
+ if hasattr(self, status):
+ self.change_to(getattr(self, status))
+
+ def set_current(self, to):
+ """
+ setter for the 'current' property
+ :param to: destination state
+ :type to: int
+ """
+ self.current = to
+
+ def change_to(self, to):
+ """
+ :param to: destination state
+ :type to: int
+ """
+ if to == self.current:
+ return
+ changed = False
+ from_ = self.current
+ self.current = to
+
+ # We can add transition restrictions
+ # here to ensure no transitions are
+ # allowed outside the fsm.
+
+ self.set_current(to)
+ changed = True
+
+ #trigger signals (as callbacks)
+ #print('current state: %s' % self.current)
+ if changed:
+ self.previous = from_
+ if self.callbacks:
+ for cb in self.callbacks:
+ if callable(cb):
+ cb(self)
diff --git a/src/leap/eip/exceptions.py b/src/leap/eip/exceptions.py
new file mode 100644
index 00000000..41eed77a
--- /dev/null
+++ b/src/leap/eip/exceptions.py
@@ -0,0 +1,156 @@
+"""
+Generic error hierarchy
+Leap/EIP exceptions used for exception handling,
+logging, and notifying user of errors
+during leap operation.
+
+Exception hierarchy
+-------------------
+All EIP Errors must inherit from EIPClientError (note: move that to
+a more generic LEAPClientBaseError).
+
+Exception attributes and their meaning/uses
+-------------------------------------------
+
+* critical: if True, will abort execution prematurely,
+ after attempting any cleaning
+ action.
+
+* failfirst: breaks any error_check loop that is examining
+ the error queue.
+
+* message: the message that will be used in the __repr__ of the exception.
+
+* usermessage: the message that will be passed to user in ErrorDialogs
+ in Qt-land.
+
+TODO:
+
+* EIPClientError:
+ Should inherit from LeapException
+
+* gettext / i18n for user messages.
+
+"""
+from leap.base.exceptions import LeapException
+
+
+# This should inherit from LeapException
+class EIPClientError(Exception):
+ """
+ base EIPClient exception
+ """
+ critical = False
+ failfirst = False
+ warning = False
+
+
+class CriticalError(EIPClientError):
+ """
+ we cannot do anything about it, sorry
+ """
+ critical = True
+ failfirst = True
+
+
+class Warning(EIPClientError):
+ """
+ just that, warnings
+ """
+ warning = True
+
+
+class EIPNoPolkitAuthAgentAvailable(CriticalError):
+ message = "No polkit authentication agent could be found"
+ usermessage = ("We could not find any authentication "
+ "agent in your system.<br/>"
+ "Make sure you have "
+ "<b>polkit-gnome-authentication-agent-1</b> "
+ "running and try again.")
+
+
+class EIPNoPkexecAvailable(Warning):
+ message = "No pkexec binary found"
+ usermessage = ("We could not find <b>pkexec</b> in your "
+ "system.<br/> Do you want to try "
+ "<b>setuid workaround</b>? "
+ "(<i>DOES NOTHING YET</i>)")
+ failfirst = True
+
+
+class EIPNoCommandError(EIPClientError):
+ message = "no suitable openvpn command found"
+ usermessage = ("No suitable openvpn command found. "
+ "<br/>(Might be a permissions problem)")
+
+
+class EIPBadCertError(Warning):
+ # XXX this should be critical and fail close
+ message = "cert verification failed"
+ usermessage = "there is a problem with provider certificate"
+
+
+class LeapBadConfigFetchedError(Warning):
+ message = "provider sent a malformed json file"
+ usermessage = "an error occurred during configuratio of leap services"
+
+
+class OpenVPNAlreadyRunning(EIPClientError):
+ message = "Another OpenVPN Process is already running."
+ usermessage = ("Another OpenVPN Process has been detected."
+ "Please close it before starting leap-client")
+
+
+class HttpsNotSupported(LeapException):
+ message = "connection refused while accessing via https"
+ usermessage = "Server does not allow secure connections."
+
+
+class HttpsBadCertError(LeapException):
+ message = "verification error on cert"
+ usermessage = "Server certificate could not be verified."
+
+#
+# errors still needing some love
+#
+
+
+class EIPInitNoKeyFileError(CriticalError):
+ message = "No vpn keys found in the expected path"
+ usermessage = "We could not find your eip certs in the expected path"
+
+
+class EIPInitBadKeyFilePermError(Warning):
+ # I don't know if we should be telling user or not,
+ # we try to fix permissions and should only re-raise
+ # if permission check failed.
+ pass
+
+
+class EIPInitNoProviderError(EIPClientError):
+ pass
+
+
+class EIPInitBadProviderError(EIPClientError):
+ pass
+
+
+class EIPConfigurationError(EIPClientError):
+ pass
+
+#
+# Errors that probably we don't need anymore
+# chase down for them and check.
+#
+
+
+class MissingSocketError(Exception):
+ pass
+
+
+class ConnectionRefusedError(Exception):
+ pass
+
+
+class EIPMissingDefaultProvider(Exception):
+ pass
diff --git a/src/leap/eip/openvpnconnection.py b/src/leap/eip/openvpnconnection.py
new file mode 100644
index 00000000..859378c0
--- /dev/null
+++ b/src/leap/eip/openvpnconnection.py
@@ -0,0 +1,460 @@
+"""
+OpenVPN Connection
+"""
+from __future__ import (print_function)
+import logging
+import os
+import psutil
+import shutil
+import socket
+import time
+from functools import partial
+
+logger = logging.getLogger(name=__name__)
+
+from leap.base.connection import Connection
+from leap.util.coroutines import spawn_and_watch_process
+
+from leap.eip.udstelnet import UDSTelnet
+from leap.eip import config as eip_config
+from leap.eip import exceptions as eip_exceptions
+
+
+class OpenVPNConnection(Connection):
+ """
+ All related to invocation
+ of the openvpn binary
+ """
+
+ def __init__(self,
+ watcher_cb=None,
+ debug=False,
+ host=None,
+ port="unix",
+ password=None,
+ *args, **kwargs):
+ """
+ :param config_file: configuration file to read from
+ :param watcher_cb: callback to be \
+called for each line in watched stdout
+ :param signal_map: dictionary of signal names and callables \
+to be triggered for each one of them.
+ :type config_file: str
+ :type watcher_cb: function
+ :type signal_map: dict
+ """
+ #XXX FIXME
+ #change watcher_cb to line_observer
+
+ logger.debug('init openvpn connection')
+ self.debug = debug
+ # XXX if not host: raise ImproperlyConfigured
+ self.ovpn_verbosity = kwargs.get('ovpn_verbosity', None)
+
+ #self.config_file = config_file
+ self.watcher_cb = watcher_cb
+ #self.signal_maps = signal_maps
+
+ self.subp = None
+ self.watcher = None
+
+ self.server = None
+ self.port = None
+ self.proto = None
+
+ #XXX workaround for signaling
+ #the ui that we don't know how to
+ #manage a connection error
+ #self.with_errors = False
+
+ self.command = None
+ self.args = None
+
+ # XXX get autostart from config
+ self.autostart = True
+
+ #
+ # management init methods
+ #
+
+ self.host = host
+ if isinstance(port, str) and port.isdigit():
+ port = int(port)
+ elif port == "unix":
+ port = "unix"
+ else:
+ port = None
+ self.port = port
+ self.password = password
+
+ def run_openvpn_checks(self):
+ logger.debug('running openvpn checks')
+ self._check_if_running_instance()
+ self._set_ovpn_command()
+ self._check_vpn_keys()
+
+ def _set_ovpn_command(self):
+ # XXX check also for command-line --command flag
+ try:
+ command, args = eip_config.build_ovpn_command(
+ provider=self.provider,
+ debug=self.debug,
+ socket_path=self.host,
+ ovpn_verbosity=self.ovpn_verbosity)
+ except eip_exceptions.EIPNoPolkitAuthAgentAvailable:
+ command = args = None
+ raise
+ except eip_exceptions.EIPNoPkexecAvailable:
+ command = args = None
+ raise
+
+ # XXX if not command, signal error.
+ self.command = command
+ self.args = args
+
+ def _check_vpn_keys(self):
+ """
+ checks for correct permissions on vpn keys
+ """
+ try:
+ eip_config.check_vpn_keys(provider=self.provider)
+ except eip_exceptions.EIPInitBadKeyFilePermError:
+ logger.error('Bad VPN Keys permission!')
+ # do nothing now
+ # and raise the rest ...
+
+ def _launch_openvpn(self):
+ """
+ invocation of openvpn binaries in a subprocess.
+ """
+ #XXX TODO:
+ #deprecate watcher_cb,
+ #use _only_ signal_maps instead
+
+ logger.debug('_launch_openvpn called')
+ if self.watcher_cb is not None:
+ linewrite_callback = self.watcher_cb
+ else:
+ #XXX get logger instead
+ linewrite_callback = lambda line: print('watcher: %s' % line)
+
+ # the partial is not
+ # being applied now because we're not observing the process
+ # stdout like we did in the early stages. but I leave it
+ # here since it will be handy for observing patterns in the
+ # thru-the-manager updates (with regex)
+ observers = (linewrite_callback,
+ partial(lambda con_status, line: None, self.status))
+ subp, watcher = spawn_and_watch_process(
+ self.command,
+ self.args,
+ observers=observers)
+ self.subp = subp
+ self.watcher = watcher
+
+ def _try_connection(self):
+ """
+ attempts to connect
+ """
+ if self.command is None:
+ raise eip_exceptions.EIPNoCommandError
+ if self.subp is not None:
+ logger.debug('cowardly refusing to launch subprocess again')
+
+ self._launch_openvpn()
+
+ def _check_if_running_instance(self):
+ """
+ check if openvpn is already running
+ """
+ for process in psutil.get_process_list():
+ if process.name == "openvpn":
+ logger.debug('an openvpn instance is already running.')
+ logger.debug('attempting to stop openvpn instance.')
+ if not self._stop():
+ raise eip_exceptions.OpenVPNAlreadyRunning
+
+ logger.debug('no openvpn instance found.')
+
+ def cleanup(self):
+ """
+ terminates openvpn child subprocess
+ """
+ if self.subp:
+ try:
+ self._stop()
+ except eip_exceptions.ConnectionRefusedError:
+ logger.warning(
+ 'unable to send sigterm signal to openvpn: '
+ 'connection refused.')
+
+ # XXX kali --
+ # XXX review-me
+ # I think this will block if child process
+ # does not return.
+ # Maybe we can .poll() for a given
+ # interval and exit in any case.
+
+ RETCODE = self.subp.wait()
+ if RETCODE:
+ logger.error(
+ 'cannot terminate subprocess! Retcode %s'
+ '(We might have left openvpn running)' % RETCODE)
+
+ self.cleanup_tempfiles()
+
+ def cleanup_tempfiles(self):
+ """
+ remove all temporal files
+ we might have left behind
+ """
+ # if self.port is 'unix', we have
+ # created a temporal socket path that, under
+ # normal circumstances, we should be able to
+ # delete
+
+ if self.port == "unix":
+ logger.debug('cleaning socket file temp folder')
+
+ tempfolder = os.path.split(self.host)[0]
+ if os.path.isdir(tempfolder):
+ try:
+ shutil.rmtree(tempfolder)
+ except OSError:
+ logger.error('could not delete tmpfolder %s' % tempfolder)
+
+ def _get_openvpn_process(self):
+ # plist = [p for p in psutil.get_process_list() if p.name == "openvpn"]
+ # return plist[0] if plist else None
+ for process in psutil.get_process_list():
+ if process.name == "openvpn":
+ return process
+ return None
+
+ # management methods
+ #
+ # XXX REVIEW-ME
+ # REFACTOR INFO: (former "manager".
+ # Can we move to another
+ # base class to test independently?)
+ #
+
+ #def forget_errors(self):
+ #logger.debug('forgetting errors')
+ #self.with_errors = False
+
+ def connect_to_management(self):
+ """Connect to openvpn management interface"""
+ #logger.debug('connecting socket')
+ if hasattr(self, 'tn'):
+ self.close()
+ self.tn = UDSTelnet(self.host, self.port)
+
+ # XXX make password optional
+ # specially for win. we should generate
+ # the pass on the fly when invoking manager
+ # from conductor
+
+ #self.tn.read_until('ENTER PASSWORD:', 2)
+ #self.tn.write(self.password + '\n')
+ #self.tn.read_until('SUCCESS:', 2)
+ if self.tn:
+ self._seek_to_eof()
+ return True
+
+ def _seek_to_eof(self):
+ """
+ Read as much as available. Position seek pointer to end of stream
+ """
+ try:
+ b = self.tn.read_eager()
+ except EOFError:
+ logger.debug("Could not read from socket. Assuming it died.")
+ return
+ while b:
+ try:
+ b = self.tn.read_eager()
+ except EOFError:
+ logger.debug("Could not read from socket. Assuming it died.")
+
+ def connected(self):
+ """
+ Returns True if connected
+ rtype: bool
+ """
+ return hasattr(self, 'tn')
+
+ def close(self, announce=True):
+ """
+ Close connection to openvpn management interface
+ """
+ logger.debug('closing socket')
+ if announce:
+ self.tn.write("quit\n")
+ self.tn.read_all()
+ self.tn.get_socket().close()
+ del self.tn
+
+ def _send_command(self, cmd):
+ """
+ Send a command to openvpn and return response as list
+ """
+ if not self.connected():
+ try:
+ self.connect_to_management()
+ except eip_exceptions.MissingSocketError:
+ logger.warning('missing management socket')
+ return []
+ try:
+ if hasattr(self, 'tn'):
+ self.tn.write(cmd + "\n")
+ except socket.error:
+ logger.error('socket error')
+ self.close(announce=False)
+ return []
+ buf = self.tn.read_until(b"END", 2)
+ self._seek_to_eof()
+ blist = buf.split('\r\n')
+ if blist[-1].startswith('END'):
+ del blist[-1]
+ return blist
+ else:
+ return []
+
+ def _send_short_command(self, cmd):
+ """
+ parse output from commands that are
+ delimited by "success" instead
+ """
+ if not self.connected():
+ self.connect()
+ self.tn.write(cmd + "\n")
+ # XXX not working?
+ buf = self.tn.read_until(b"SUCCESS", 2)
+ self._seek_to_eof()
+ blist = buf.split('\r\n')
+ return blist
+
+ #
+ # useful vpn commands
+ #
+
+ def pid(self):
+ #XXX broken
+ return self._send_short_command("pid")
+
+ def make_error(self):
+ """
+ capture error and wrap it in an
+ understandable format
+ """
+ #XXX get helpful error codes
+ self.with_errors = True
+ now = int(time.time())
+ return '%s,LAUNCHER ERROR,ERROR,-,-' % now
+
+ def state(self):
+ """
+ OpenVPN command: state
+ """
+ state = self._send_command("state")
+ if not state:
+ return None
+ if isinstance(state, str):
+ return state
+ if isinstance(state, list):
+ if len(state) == 1:
+ return state[0]
+ else:
+ return state[-1]
+
+ def vpn_status(self):
+ """
+ OpenVPN command: status
+ """
+ #logger.debug('status called')
+ status = self._send_command("status")
+ return status
+
+ def vpn_status2(self):
+ """
+ OpenVPN command: last 2 statuses
+ """
+ return self._send_command("status 2")
+
+ def _stop(self):
+ """
+ stop openvpn process
+ by sending SIGTERM to the management
+ interface
+ """
+ logger.debug("disconnecting...")
+ if self.connected():
+ try:
+ self._send_command("signal SIGTERM\n")
+ except socket.error:
+ logger.warning('management socket died')
+ return
+
+ if self.subp:
+ # ???
+ return True
+
+ #shutting openvpn failured
+ #try patching in old openvpn host and trying again
+ process = self._get_openvpn_process()
+ if process:
+ logger.debug('process :%s' % process)
+ cmdline = process.cmdline
+
+ if isinstance(cmdline, list):
+ _index = cmdline.index("--management")
+ self.host = cmdline[_index + 1]
+ self._send_command("signal SIGTERM\n")
+
+ #make sure the process was terminated
+ process = self._get_openvpn_process()
+ if not process:
+ logger.debug("Existing OpenVPN Process Terminated")
+ return True
+ else:
+ logger.error("Unable to terminate existing OpenVPN Process.")
+ return False
+
+ return True
+
+ #
+ # parse info
+ #
+
+ def get_status_io(self):
+ status = self.vpn_status()
+ if isinstance(status, str):
+ lines = status.split('\n')
+ if isinstance(status, list):
+ lines = status
+ try:
+ (header, when, tun_read, tun_write,
+ tcp_read, tcp_write, auth_read) = tuple(lines)
+ except ValueError:
+ return None
+
+ when_ts = time.strptime(when.split(',')[1], "%a %b %d %H:%M:%S %Y")
+ sep = ','
+ # XXX cleanup!
+ tun_read = tun_read.split(sep)[1]
+ tun_write = tun_write.split(sep)[1]
+ tcp_read = tcp_read.split(sep)[1]
+ tcp_write = tcp_write.split(sep)[1]
+ auth_read = auth_read.split(sep)[1]
+
+ # XXX this could be a named tuple. prettier.
+ return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read)
+
+ def get_connection_state(self):
+ state = self.state()
+ if state is not None:
+ ts, status_step, ok, ip, remote = state.split(',')
+ ts = time.gmtime(float(ts))
+ # XXX this could be a named tuple. prettier.
+ return ts, status_step, ok, ip, remote
diff --git a/src/leap/eip/specs.py b/src/leap/eip/specs.py
new file mode 100644
index 00000000..57e7537b
--- /dev/null
+++ b/src/leap/eip/specs.py
@@ -0,0 +1,124 @@
+from __future__ import (unicode_literals)
+import os
+
+from leap import __branding
+from leap.base import config as baseconfig
+
+# XXX move provider stuff to base config
+
+PROVIDER_CA_CERT = __branding.get(
+ 'provider_ca_file',
+ 'cacert.pem')
+
+provider_ca_path = lambda domain: str(os.path.join(
+ #baseconfig.get_default_provider_path(),
+ baseconfig.get_provider_path(domain),
+ 'keys', 'ca',
+ 'cacert.pem'
+)) if domain else None
+
+default_provider_ca_path = lambda: str(os.path.join(
+ baseconfig.get_default_provider_path(),
+ 'keys', 'ca',
+ PROVIDER_CA_CERT
+))
+
+PROVIDER_DOMAIN = __branding.get('provider_domain', 'testprovider.example.org')
+
+
+client_cert_path = lambda domain: unicode(os.path.join(
+ baseconfig.get_provider_path(domain),
+ 'keys', 'client',
+ 'openvpn.pem'
+)) if domain else None
+
+default_client_cert_path = lambda: unicode(os.path.join(
+ baseconfig.get_default_provider_path(),
+ 'keys', 'client',
+ 'openvpn.pem'
+))
+
+eipconfig_spec = {
+ 'description': 'sample eipconfig',
+ 'type': 'object',
+ 'properties': {
+ 'provider': {
+ 'type': unicode,
+ 'default': u"%s" % PROVIDER_DOMAIN,
+ 'required': True,
+ },
+ 'transport': {
+ 'type': unicode,
+ 'default': u"openvpn",
+ },
+ 'openvpn_protocol': {
+ 'type': unicode,
+ 'default': u"tcp"
+ },
+ 'openvpn_port': {
+ 'type': int,
+ 'default': 80
+ },
+ 'openvpn_ca_certificate': {
+ 'type': unicode, # path
+ 'default': default_provider_ca_path
+ },
+ 'openvpn_client_certificate': {
+ 'type': unicode, # path
+ 'default': default_client_cert_path
+ },
+ 'connect_on_login': {
+ 'type': bool,
+ 'default': True
+ },
+ 'block_cleartext_traffic': {
+ 'type': bool,
+ 'default': True
+ },
+ 'primary_gateway': {
+ 'type': unicode,
+ 'default': u"turkey",
+ #'required': True
+ },
+ 'secondary_gateway': {
+ 'type': unicode,
+ 'default': u"france"
+ },
+ 'management_password': {
+ 'type': unicode
+ }
+ }
+}
+
+eipservice_config_spec = {
+ 'description': 'sample eip service config',
+ 'type': 'object',
+ 'properties': {
+ 'serial': {
+ 'type': int,
+ 'required': True,
+ 'default': 1
+ },
+ 'version': {
+ 'type': unicode,
+ 'required': True,
+ 'default': "0.1.0"
+ },
+ 'capabilities': {
+ 'type': dict,
+ 'default': {
+ "transport": ["openvpn"],
+ "ports": ["80", "53"],
+ "protocols": ["udp", "tcp"],
+ "static_ips": True,
+ "adblock": True}
+ },
+ 'gateways': {
+ 'type': list,
+ 'default': [{"country_code": "us",
+ "label": {"en":"west"},
+ "capabilities": {},
+ "hosts": ["1.2.3.4", "1.2.3.5"]}]
+ }
+ }
+}
diff --git a/src/leap/eip/tests/__init__.py b/src/leap/eip/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/eip/tests/__init__.py
diff --git a/src/leap/eip/tests/data.py b/src/leap/eip/tests/data.py
new file mode 100644
index 00000000..cadf720e
--- /dev/null
+++ b/src/leap/eip/tests/data.py
@@ -0,0 +1,48 @@
+from __future__ import unicode_literals
+import os
+
+#from leap import __branding
+
+# sample data used in tests
+
+#PROVIDER = __branding.get('provider_domain')
+PROVIDER = "testprovider.example.org"
+
+EIP_SAMPLE_CONFIG = {
+ "provider": "%s" % PROVIDER,
+ "transport": "openvpn",
+ "openvpn_protocol": "tcp",
+ "openvpn_port": 80,
+ "openvpn_ca_certificate": os.path.expanduser(
+ "~/.config/leap/providers/"
+ "%s/"
+ "keys/ca/cacert.pem" % PROVIDER),
+ "openvpn_client_certificate": os.path.expanduser(
+ "~/.config/leap/providers/"
+ "%s/"
+ "keys/client/openvpn.pem" % PROVIDER),
+ "connect_on_login": True,
+ "block_cleartext_traffic": True,
+ "primary_gateway": "turkey",
+ "secondary_gateway": "france",
+ #"management_password": "oph7Que1othahwiech6J"
+}
+
+EIP_SAMPLE_SERVICE = {
+ "serial": 1,
+ "version": "0.1.0",
+ "capabilities": {
+ "transport": ["openvpn"],
+ "ports": ["80", "53"],
+ "protocols": ["udp", "tcp"],
+ "static_ips": True,
+ "adblock": True
+ },
+ "gateways": [
+ {"country_code": "tr",
+ "name": "turkey",
+ "label": {"en":"Ankara, Turkey"},
+ "capabilities": {},
+ "hosts": ["192.0.43.10"]}
+ ]
+}
diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py
new file mode 100644
index 00000000..1d7bfc17
--- /dev/null
+++ b/src/leap/eip/tests/test_checks.py
@@ -0,0 +1,367 @@
+from BaseHTTPServer import BaseHTTPRequestHandler
+import copy
+import json
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+import time
+import urlparse
+
+from mock import (patch, Mock)
+
+import jsonschema
+#import ping
+import requests
+
+from leap.base import config as baseconfig
+from leap.base.constants import (DEFAULT_PROVIDER_DEFINITION,
+ DEFINITION_EXPECTED_PATH)
+from leap.eip import checks as eipchecks
+from leap.eip import specs as eipspecs
+from leap.eip import exceptions as eipexceptions
+from leap.eip.tests import data as testdata
+from leap.testing.basetest import BaseLeapTest
+from leap.testing.https_server import BaseHTTPSServerTestCase
+from leap.testing.https_server import where as where_cert
+
+
+class NoLogRequestHandler:
+ def log_message(self, *args):
+ # don't write log msg to stderr
+ pass
+
+ def read(self, n=None):
+ return ''
+
+
+class EIPCheckTest(BaseLeapTest):
+
+ __name__ = "eip_check_tests"
+ provider = "testprovider.example.org"
+ maxDiff = None
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ # test methods are there, and can be called from run_all
+
+ def test_checker_should_implement_check_methods(self):
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
+
+ self.assertTrue(hasattr(checker, "check_default_eipconfig"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "check_is_there_default_provider"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "fetch_definition"), "missing meth")
+ self.assertTrue(hasattr(checker, "fetch_eip_service_config"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "check_complete_eip_config"),
+ "missing meth")
+
+ def test_checker_should_actually_call_all_tests(self):
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
+
+ mc = Mock()
+ checker.run_all(checker=mc)
+ self.assertTrue(mc.check_default_eipconfig.called, "not called")
+ self.assertTrue(mc.check_is_there_default_provider.called,
+ "not called")
+ self.assertTrue(mc.fetch_definition.called,
+ "not called")
+ self.assertTrue(mc.fetch_eip_service_config.called,
+ "not called")
+ self.assertTrue(mc.check_complete_eip_config.called,
+ "not called")
+
+ # test individual check methods
+
+ def test_check_default_eipconfig(self):
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
+ # no eip config (empty home)
+ eipconfig_path = checker.eipconfig.filename
+ self.assertFalse(os.path.isfile(eipconfig_path))
+ checker.check_default_eipconfig()
+ # we've written one, so it should be there.
+ self.assertTrue(os.path.isfile(eipconfig_path))
+ with open(eipconfig_path, 'rb') as fp:
+ deserialized = json.load(fp)
+
+ # force re-evaluation of the paths
+ # small workaround for evaluating home dirs correctly
+ EIP_SAMPLE_CONFIG = copy.copy(testdata.EIP_SAMPLE_CONFIG)
+ EIP_SAMPLE_CONFIG['openvpn_client_certificate'] = \
+ eipspecs.client_cert_path(self.provider)
+ EIP_SAMPLE_CONFIG['openvpn_ca_certificate'] = \
+ eipspecs.provider_ca_path(self.provider)
+ self.assertEqual(deserialized, EIP_SAMPLE_CONFIG)
+
+ # TODO: shold ALSO run validation methods.
+
+ def test_check_is_there_default_provider(self):
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
+ # we do dump a sample eip config, but lacking a
+ # default provider entry.
+ # This error will be possible catched in a different
+ # place, when JSONConfig does validation of required fields.
+
+ # passing direct config
+ with self.assertRaises(eipexceptions.EIPMissingDefaultProvider):
+ checker.check_is_there_default_provider(config={})
+
+ # ok. now, messing with real files...
+ # blank out default_provider
+ sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG)
+ sampleconfig['provider'] = None
+ eipcfg_path = checker.eipconfig.filename
+ with open(eipcfg_path, 'w') as fp:
+ json.dump(sampleconfig, fp)
+ #with self.assertRaises(eipexceptions.EIPMissingDefaultProvider):
+ # XXX we should catch this as one of our errors, but do not
+ # see how to do it quickly.
+ with self.assertRaises(jsonschema.ValidationError):
+ #import ipdb;ipdb.set_trace()
+ checker.eipconfig.load(fromfile=eipcfg_path)
+ checker.check_is_there_default_provider()
+
+ sampleconfig = testdata.EIP_SAMPLE_CONFIG
+ #eipcfg_path = checker._get_default_eipconfig_path()
+ with open(eipcfg_path, 'w') as fp:
+ json.dump(sampleconfig, fp)
+ checker.eipconfig.load()
+ self.assertTrue(checker.check_is_there_default_provider())
+
+ def test_fetch_definition(self):
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.return_value.status_code = 200
+ mocked_get.return_value.json = DEFAULT_PROVIDER_DEFINITION
+ checker = eipchecks.EIPConfigChecker(fetcher=requests)
+ sampleconfig = testdata.EIP_SAMPLE_CONFIG
+ checker.fetch_definition(config=sampleconfig)
+
+ fn = os.path.join(baseconfig.get_default_provider_path(),
+ DEFINITION_EXPECTED_PATH)
+ with open(fn, 'r') as fp:
+ deserialized = json.load(fp)
+ self.assertEqual(DEFAULT_PROVIDER_DEFINITION, deserialized)
+
+ # XXX TODO check for ConnectionError, HTTPError, InvalidUrl
+ # (and proper EIPExceptions are raised).
+ # Look at base.test_config.
+
+ def test_fetch_eip_service_config(self):
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.return_value.status_code = 200
+ mocked_get.return_value.json = testdata.EIP_SAMPLE_SERVICE
+ checker = eipchecks.EIPConfigChecker(fetcher=requests)
+ sampleconfig = testdata.EIP_SAMPLE_CONFIG
+ checker.fetch_eip_service_config(config=sampleconfig)
+
+ def test_check_complete_eip_config(self):
+ checker = eipchecks.EIPConfigChecker()
+ with self.assertRaises(eipexceptions.EIPConfigurationError):
+ sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG)
+ sampleconfig['provider'] = None
+ checker.check_complete_eip_config(config=sampleconfig)
+ with self.assertRaises(eipexceptions.EIPConfigurationError):
+ sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG)
+ del sampleconfig['provider']
+ checker.check_complete_eip_config(config=sampleconfig)
+
+ # normal case
+ sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG)
+ checker.check_complete_eip_config(config=sampleconfig)
+
+
+class ProviderCertCheckerTest(BaseLeapTest):
+
+ __name__ = "provider_cert_checker_tests"
+ provider = "testprovider.example.org"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ # test methods are there, and can be called from run_all
+
+ def test_checker_should_implement_check_methods(self):
+ checker = eipchecks.ProviderCertChecker()
+
+ # For MVS+
+ self.assertTrue(hasattr(checker, "download_ca_cert"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "download_ca_signature"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "get_ca_signatures"), "missing meth")
+ self.assertTrue(hasattr(checker, "is_there_trust_path"),
+ "missing meth")
+
+ # For MVS
+ self.assertTrue(hasattr(checker, "is_there_provider_ca"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "is_https_working"), "missing meth")
+ self.assertTrue(hasattr(checker, "check_new_cert_needed"),
+ "missing meth")
+
+ def test_checker_should_actually_call_all_tests(self):
+ checker = eipchecks.ProviderCertChecker()
+
+ mc = Mock()
+ checker.run_all(checker=mc)
+ # XXX MVS+
+ #self.assertTrue(mc.download_ca_cert.called, "not called")
+ #self.assertTrue(mc.download_ca_signature.called, "not called")
+ #self.assertTrue(mc.get_ca_signatures.called, "not called")
+ #self.assertTrue(mc.is_there_trust_path.called, "not called")
+
+ # For MVS
+ self.assertTrue(mc.is_there_provider_ca.called, "not called")
+ self.assertTrue(mc.is_https_working.called,
+ "not called")
+ self.assertTrue(mc.check_new_cert_needed.called,
+ "not called")
+
+ # test individual check methods
+
+ @unittest.skip
+ def test_is_there_provider_ca(self):
+ # XXX commenting out this test.
+ # With the generic client this does not make sense,
+ # we should dump one there.
+ # or test conductor logic.
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(
+ checker.is_there_provider_ca())
+
+
+class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase, BaseLeapTest):
+ provider = "testprovider.example.org"
+
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ responses = {
+ '/': ['OK', ''],
+ '/client.cert': [
+ # XXX get sample cert
+ '-----BEGIN CERTIFICATE-----',
+ '-----END CERTIFICATE-----'],
+ '/badclient.cert': [
+ 'BADCERT']}
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def test_is_https_working(self):
+ fetcher = requests
+ uri = "https://%s/" % (self.get_server())
+ # bare requests call. this should just pass (if there is
+ # an https service there).
+ fetcher.get(uri, verify=False)
+ checker = eipchecks.ProviderCertChecker(fetcher=fetcher)
+ self.assertTrue(checker.is_https_working(uri=uri, verify=False))
+
+ # for local debugs, when in doubt
+ #self.assertTrue(checker.is_https_working(uri="https://github.com",
+ #verify=True))
+
+ # for the two checks below, I know they fail because no ca
+ # cert is passed to them, and I know that's the error that
+ # requests return with our implementation.
+ # We're receiving this because our
+ # server is dying prematurely when the handshake is interrupted on the
+ # client side.
+ # Since we have access to the server, we could check that
+ # the error raised has been:
+ # SSL23_READ_BYTES: alert bad certificate
+ with self.assertRaises(requests.exceptions.SSLError) as exc:
+ fetcher.get(uri, verify=True)
+ self.assertTrue(
+ "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)
+
+ # XXX FIXME! Uncomment after #638 is done
+ #with self.assertRaises(eipexceptions.EIPBadCertError) as exc:
+ #checker.is_https_working(uri=uri, verify=True)
+ #self.assertTrue(
+ #"cert verification failed" in exc.message)
+
+ # get cacert from testing.https_server
+ cacert = where_cert('cacert.pem')
+ fetcher.get(uri, verify=cacert)
+ self.assertTrue(checker.is_https_working(uri=uri, verify=cacert))
+
+ # same, but get cacert from leap.custom
+ # XXX TODO!
+
+ @unittest.skip
+ def test_download_new_client_cert(self):
+ # FIXME
+ # Magick srp decorator broken right now...
+ # Have to mock the decorator and inject something that
+ # can bypass the authentication
+
+ uri = "https://%s/client.cert" % (self.get_server())
+ cacert = where_cert('cacert.pem')
+ checker = eipchecks.ProviderCertChecker(domain=self.provider)
+ credentials = "testuser", "testpassword"
+ self.assertTrue(checker.download_new_client_cert(
+ credentials=credentials, uri=uri, verify=cacert))
+
+ # now download a malformed cert
+ uri = "https://%s/badclient.cert" % (self.get_server())
+ cacert = where_cert('cacert.pem')
+ checker = eipchecks.ProviderCertChecker()
+ with self.assertRaises(ValueError):
+ self.assertTrue(checker.download_new_client_cert(
+ credentials=credentials, uri=uri, verify=cacert))
+
+ # did we write cert to its path?
+ clientcertfile = eipspecs.client_cert_path()
+ self.assertTrue(os.path.isfile(clientcertfile))
+ certfile = eipspecs.client_cert_path()
+ with open(certfile, 'r') as cf:
+ certcontent = cf.read()
+ self.assertEqual(certcontent,
+ '\n'.join(
+ self.request_handler.responses['/client.cert']))
+ os.remove(clientcertfile)
+
+ def test_is_cert_valid(self):
+ checker = eipchecks.ProviderCertChecker()
+ # TODO: better exception catching
+ # should raise eipexceptions.BadClientCertificate, and give reasons
+ # on msg.
+ with self.assertRaises(Exception) as exc:
+ self.assertFalse(checker.is_cert_valid())
+ exc.message = "missing cert"
+
+ def test_bad_validity_certs(self):
+ checker = eipchecks.ProviderCertChecker()
+ certfile = where_cert('leaptestscert.pem')
+ self.assertFalse(checker.is_cert_not_expired(
+ certfile=certfile,
+ now=lambda: time.mktime((2038, 1, 1, 1, 1, 1, 1, 1, 1))))
+ self.assertFalse(checker.is_cert_not_expired(
+ certfile=certfile,
+ now=lambda: time.mktime((1970, 1, 1, 1, 1, 1, 1, 1, 1))))
+
+ def test_check_new_cert_needed(self):
+ # check: missing cert
+ checker = eipchecks.ProviderCertChecker(domain=self.provider)
+ self.assertTrue(checker.check_new_cert_needed(skip_download=True))
+ # TODO check: malformed cert
+ # TODO check: expired cert
+ # TODO check: pass test server uri instead of skip
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/eip/tests/test_config.py b/src/leap/eip/tests/test_config.py
new file mode 100644
index 00000000..50538240
--- /dev/null
+++ b/src/leap/eip/tests/test_config.py
@@ -0,0 +1,153 @@
+import json
+import os
+import platform
+import stat
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+#from leap.base import constants
+#from leap.eip import config as eip_config
+from leap import __branding as BRANDING
+from leap.eip import config as eipconfig
+from leap.eip.tests.data import EIP_SAMPLE_CONFIG, EIP_SAMPLE_SERVICE
+from leap.testing.basetest import BaseLeapTest
+from leap.util.fileutil import mkdir_p
+
+_system = platform.system()
+
+#PROVIDER = BRANDING.get('provider_domain')
+#PROVIDER_SHORTNAME = BRANDING.get('short_name')
+
+
+class EIPConfigTest(BaseLeapTest):
+
+ __name__ = "eip_config_tests"
+ provider = "testprovider.example.org"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ #
+ # helpers
+ #
+
+ def touch_exec(self):
+ path = os.path.join(
+ self.tempdir, 'bin')
+ mkdir_p(path)
+ tfile = os.path.join(
+ path,
+ 'openvpn')
+ open(tfile, 'wb').close()
+ os.chmod(tfile, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
+
+ def write_sample_eipservice(self):
+ conf = eipconfig.EIPServiceConfig()
+ folder, f = os.path.split(conf.filename)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ with open(conf.filename, 'w') as fd:
+ fd.write(json.dumps(EIP_SAMPLE_SERVICE))
+
+ def write_sample_eipconfig(self):
+ conf = eipconfig.EIPConfig()
+ folder, f = os.path.split(conf.filename)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ with open(conf.filename, 'w') as fd:
+ fd.write(json.dumps(EIP_SAMPLE_CONFIG))
+
+ def get_expected_openvpn_args(self):
+ args = []
+ username = self.get_username()
+ groupname = self.get_groupname()
+
+ args.append('--client')
+ args.append('--dev')
+ #does this have to be tap for win??
+ args.append('tun')
+ args.append('--persist-tun')
+ args.append('--persist-key')
+ args.append('--remote')
+ args.append('%s' % eipconfig.get_eip_gateway(
+ provider=self.provider))
+ # XXX get port!?
+ args.append('1194')
+ # XXX get proto
+ args.append('udp')
+ args.append('--tls-client')
+ args.append('--remote-cert-tls')
+ args.append('server')
+
+ args.append('--user')
+ args.append(username)
+ args.append('--group')
+ args.append(groupname)
+ args.append('--management-client-user')
+ args.append(username)
+ args.append('--management-signal')
+
+ args.append('--management')
+ #XXX hey!
+ #get platform switches here!
+ args.append('/tmp/test.socket')
+ args.append('unix')
+
+ # certs
+ # XXX get values from specs?
+ args.append('--cert')
+ args.append(os.path.join(
+ self.home,
+ '.config', 'leap', 'providers',
+ '%s' % self.provider,
+ 'keys', 'client',
+ 'openvpn.pem'))
+ args.append('--key')
+ args.append(os.path.join(
+ self.home,
+ '.config', 'leap', 'providers',
+ '%s' % self.provider,
+ 'keys', 'client',
+ 'openvpn.pem'))
+ args.append('--ca')
+ args.append(os.path.join(
+ self.home,
+ '.config', 'leap', 'providers',
+ '%s' % self.provider,
+ 'keys', 'ca',
+ 'cacert.pem'))
+ return args
+
+ # build command string
+ # these tests are going to have to check
+ # many combinations. we should inject some
+ # params in the function call, to disable
+ # some checks.
+
+ def test_build_ovpn_command_empty_config(self):
+ self.touch_exec()
+ self.write_sample_eipservice()
+ self.write_sample_eipconfig()
+
+ from leap.eip import config as eipconfig
+ from leap.util.fileutil import which
+ path = os.environ['PATH']
+ vpnbin = which('openvpn', path=path)
+ print 'path =', path
+ print 'vpnbin = ', vpnbin
+ command, args = eipconfig.build_ovpn_command(
+ do_pkexec_check=False, vpnbin=vpnbin,
+ socket_path="/tmp/test.socket",
+ provider=self.provider)
+ self.assertEqual(command, self.home + '/bin/openvpn')
+ self.assertEqual(args, self.get_expected_openvpn_args())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/eip/tests/test_eipconnection.py b/src/leap/eip/tests/test_eipconnection.py
new file mode 100644
index 00000000..aefca36f
--- /dev/null
+++ b/src/leap/eip/tests/test_eipconnection.py
@@ -0,0 +1,191 @@
+import logging
+import platform
+import os
+
+logging.basicConfig()
+logger = logging.getLogger(name=__name__)
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from mock import Mock, patch # MagicMock
+
+from leap.eip.eipconnection import EIPConnection
+from leap.eip.exceptions import ConnectionRefusedError
+from leap.eip import specs as eipspecs
+from leap.testing.basetest import BaseLeapTest
+
+_system = platform.system()
+
+PROVIDER = "testprovider.example.org"
+
+
+class NotImplementedError(Exception):
+ pass
+
+
+@patch('OpenVPNConnection._get_or_create_config')
+@patch('OpenVPNConnection._set_ovpn_command')
+class MockedEIPConnection(EIPConnection):
+
+ def _set_ovpn_command(self):
+ self.command = "mock_command"
+ self.args = [1, 2, 3]
+
+
+class EIPConductorTest(BaseLeapTest):
+
+ __name__ = "eip_conductor_tests"
+ provider = PROVIDER
+
+ def setUp(self):
+ # XXX there's a conceptual/design
+ # mistake here.
+ # If we're testing just attrs after init,
+ # init shold not be doing so much side effects.
+
+ # for instance:
+ # We have to TOUCH a keys file because
+ # we're triggerig the key checks FROM
+ # the constructor. me not like that,
+ # key checker should better be called explicitelly.
+
+ # XXX change to keys_checker invocation
+ # (see config_checker)
+
+ keyfiles = (eipspecs.provider_ca_path(domain=self.provider),
+ eipspecs.client_cert_path(domain=self.provider))
+ for filepath in keyfiles:
+ self.touch(filepath)
+ self.chmod600(filepath)
+
+ # we init the manager with only
+ # some methods mocked
+ self.manager = Mock(name="openvpnmanager_mock")
+ self.con = MockedEIPConnection()
+ self.con.provider = self.provider
+ self.con.run_openvpn_checks()
+
+ def tearDown(self):
+ del self.con
+
+ #
+ # tests
+ #
+
+ def test_vpnconnection_defaults(self):
+ """
+ default attrs as expected
+ """
+ con = self.con
+ self.assertEqual(con.autostart, True)
+
+ def test_ovpn_command(self):
+ """
+ set_ovpn_command called
+ """
+ self.assertEqual(self.con.command,
+ "mock_command")
+ self.assertEqual(self.con.args,
+ [1, 2, 3])
+
+ # config checks
+
+ def test_config_checked_called(self):
+ # XXX this single test is taking half of the time
+ # needed to run tests. (roughly 3 secs for this only)
+ # We should modularize and inject Mocks on more places.
+
+ del(self.con)
+ config_checker = Mock()
+ self.con = MockedEIPConnection(config_checker=config_checker)
+ self.assertTrue(config_checker.called)
+ self.con.run_checks()
+ self.con.config_checker.run_all.assert_called_with(
+ skip_download=False)
+
+ # XXX test for cert_checker also
+
+ # connect/disconnect calls
+
+ def test_disconnect(self):
+ """
+ disconnect method calls private and changes status
+ """
+ self.con._disconnect = Mock(
+ name="_disconnect")
+
+ # first we set status to connected
+ self.con.status.set_current(self.con.status.CONNECTED)
+ self.assertEqual(self.con.status.current,
+ self.con.status.CONNECTED)
+
+ # disconnect
+ self.con.cleanup = Mock()
+ self.con.disconnect()
+ self.con.cleanup.assert_called_once_with()
+
+ # new status should be disconnected
+ # XXX this should evolve and check no errors
+ # during disconnection
+ self.assertEqual(self.con.status.current,
+ self.con.status.DISCONNECTED)
+
+ def test_connect(self):
+ """
+ connect calls _launch_openvpn private
+ """
+ self.con._launch_openvpn = Mock()
+ self.con.connect()
+ self.con._launch_openvpn.assert_called_once_with()
+
+ # XXX tests breaking here ...
+
+ def test_good_poll_connection_state(self):
+ """
+ """
+ #@patch --
+ # self.manager.get_connection_state
+
+ #XXX review this set of poll_state tests
+ #they SHOULD NOT NEED TO MOCK ANYTHING IN THE
+ #lower layers!! -- status, vpn_manager..
+ #right now we're testing implementation, not
+ #behavior!!!
+ good_state = ["1345466946", "unknown_state", "ok",
+ "192.168.1.1", "192.168.1.100"]
+ self.con.get_connection_state = Mock(return_value=good_state)
+ self.con.status.set_vpn_state = Mock()
+
+ state = self.con.poll_connection_state()
+ good_state[1] = "disconnected"
+ final_state = tuple(good_state)
+ self.con.status.set_vpn_state.assert_called_with("unknown_state")
+ self.assertEqual(state, final_state)
+
+ # TODO between "good" and "bad" (exception raised) cases,
+ # we can still test for malformed states and see that only good
+ # states do have a change (and from only the expected transition
+ # states).
+
+ def test_bad_poll_connection_state(self):
+ """
+ get connection state raises ConnectionRefusedError
+ state is None
+ """
+ self.con.get_connection_state = Mock(
+ side_effect=ConnectionRefusedError('foo!'))
+ state = self.con.poll_connection_state()
+ self.assertEqual(state, None)
+
+
+ # XXX more things to test:
+ # - called config routines during initz.
+ # - raising proper exceptions with no config
+ # - called proper checks on config / permissions
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/eip/tests/test_openvpnconnection.py b/src/leap/eip/tests/test_openvpnconnection.py
new file mode 100644
index 00000000..0f27facf
--- /dev/null
+++ b/src/leap/eip/tests/test_openvpnconnection.py
@@ -0,0 +1,147 @@
+import logging
+import os
+import platform
+import psutil
+import shutil
+#import socket
+
+logging.basicConfig()
+logger = logging.getLogger(name=__name__)
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from mock import Mock, patch # MagicMock
+
+from leap.eip import config as eipconfig
+from leap.eip import openvpnconnection
+from leap.eip import exceptions as eipexceptions
+from leap.eip.udstelnet import UDSTelnet
+from leap.testing.basetest import BaseLeapTest
+
+_system = platform.system()
+
+
+class NotImplementedError(Exception):
+ pass
+
+
+mock_UDSTelnet = Mock(spec=UDSTelnet)
+# XXX cautious!!!
+# this might be fragile right now (counting a global
+# reference of calls I think.
+# investigate this other form instead:
+# http://www.voidspace.org.uk/python/mock/patch.html#start-and-stop
+
+# XXX redo after merge-refactor
+
+
+@patch('openvpnconnection.OpenVPNConnection.connect_to_management')
+class MockedOpenVPNConnection(openvpnconnection.OpenVPNConnection):
+ def __init__(self, *args, **kwargs):
+ self.mock_UDSTelnet = Mock()
+ super(MockedOpenVPNConnection, self).__init__(
+ *args, **kwargs)
+ self.tn = self.mock_UDSTelnet(self.host, self.port)
+
+ def connect_to_management(self):
+ #print 'patched connect'
+ self.tn = mock_UDSTelnet(self.host, port=self.port)
+
+
+class OpenVPNConnectionTest(BaseLeapTest):
+
+ __name__ = "vpnconnection_tests"
+
+ def setUp(self):
+ # XXX this will have to change for win, host=localhost
+ host = eipconfig.get_socket_path()
+ self.manager = MockedOpenVPNConnection(host=host)
+
+ def tearDown(self):
+ # remove the socket folder.
+ # XXX only if posix. in win, host is localhost, so nothing
+ # has to be done.
+ if self.manager.host:
+ folder, fpath = os.path.split(self.manager.host)
+ assert folder.startswith('/tmp/leap-tmp') # safety check
+ shutil.rmtree(folder)
+
+ del self.manager
+
+ #
+ # tests
+ #
+
+ def test_detect_vpn(self):
+ # XXX review, not sure if captured all the logic
+ # while fixing. kali.
+ openvpn_connection = openvpnconnection.OpenVPNConnection()
+
+ with patch.object(psutil, "get_process_list") as mocked_psutil:
+ mocked_process = Mock()
+ mocked_process.name = "openvpn"
+ mocked_psutil.return_value = [mocked_process]
+ with self.assertRaises(eipexceptions.OpenVPNAlreadyRunning):
+ openvpn_connection._check_if_running_instance()
+
+ openvpn_connection._check_if_running_instance()
+
+ @unittest.skipIf(_system == "Windows", "lin/mac only")
+ def test_lin_mac_default_init(self):
+ """
+ check default host for management iface
+ """
+ self.assertTrue(self.manager.host.startswith('/tmp/leap-tmp'))
+ self.assertEqual(self.manager.port, 'unix')
+
+ @unittest.skipUnless(_system == "Windows", "win only")
+ def test_win_default_init(self):
+ """
+ check default host for management iface
+ """
+ # XXX should we make the platform specific switch
+ # here or in the vpn command string building?
+ self.assertEqual(self.manager.host, 'localhost')
+ self.assertEqual(self.manager.port, 7777)
+
+ def test_port_types_init(self):
+ self.manager = MockedOpenVPNConnection(port="42")
+ self.assertEqual(self.manager.port, 42)
+ self.manager = MockedOpenVPNConnection()
+ self.assertEqual(self.manager.port, "unix")
+ self.manager = MockedOpenVPNConnection(port="bad")
+ self.assertEqual(self.manager.port, None)
+
+ def test_uds_telnet_called_on_connect(self):
+ self.manager.connect_to_management()
+ mock_UDSTelnet.assert_called_with(
+ self.manager.host,
+ port=self.manager.port)
+
+ @unittest.skip
+ def test_connect(self):
+ raise NotImplementedError
+ # XXX calls close
+ # calls UDSTelnet mock.
+
+ # XXX
+ # tests to write:
+ # UDSTelnetTest (for real?)
+ # HAVE A LOOK AT CORE TESTS FOR TELNETLIB.
+ # very illustrative instead...
+
+ # - raise MissingSocket
+ # - raise ConnectionRefusedError
+ # - test send command
+ # - tries connect
+ # - ... tries?
+ # - ... calls _seek_to_eof
+ # - ... read_until --> return value
+ # - ...
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/eip/udstelnet.py b/src/leap/eip/udstelnet.py
new file mode 100644
index 00000000..18e927c2
--- /dev/null
+++ b/src/leap/eip/udstelnet.py
@@ -0,0 +1,38 @@
+import os
+import socket
+import telnetlib
+
+from leap.eip import exceptions as eip_exceptions
+
+
+class UDSTelnet(telnetlib.Telnet):
+ """
+ a telnet-alike class, that can listen
+ on unix domain sockets
+ """
+
+ def open(self, host, port=23, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+ """Connect to a host. If port is 'unix', it
+ will open a connection over unix docmain sockets.
+
+ The optional second argument is the port number, which
+ defaults to the standard telnet port (23).
+
+ Don't try to reopen an already connected instance.
+ """
+ self.eof = 0
+ self.host = host
+ self.port = port
+ self.timeout = timeout
+
+ if self.port == "unix":
+ # unix sockets spoken
+ if not os.path.exists(self.host):
+ raise eip_exceptions.MissingSocketError
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ self.sock.connect(self.host)
+ except socket.error:
+ raise eip_exceptions.ConnectionRefusedError
+ else:
+ self.sock = socket.create_connection((host, port), timeout)
diff --git a/src/leap/eip/vpnmanager.py b/src/leap/eip/vpnmanager.py
deleted file mode 100644
index 78777cfb..00000000
--- a/src/leap/eip/vpnmanager.py
+++ /dev/null
@@ -1,262 +0,0 @@
-from __future__ import (print_function)
-import logging
-import os
-import socket
-import telnetlib
-import time
-
-logger = logging.getLogger(name=__name__)
-
-TELNET_PORT = 23
-
-
-class MissingSocketError(Exception):
- pass
-
-
-class ConnectionRefusedError(Exception):
- pass
-
-
-class UDSTelnet(telnetlib.Telnet):
-
- def open(self, host, port=0, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
- """Connect to a host. If port is 'unix', it
- will open a connection over unix docmain sockets.
-
- The optional second argument is the port number, which
- defaults to the standard telnet port (23).
-
- Don't try to reopen an already connected instance.
- """
- self.eof = 0
- if not port:
- port = TELNET_PORT
- self.host = host
- self.port = port
- self.timeout = timeout
-
- if self.port == "unix":
- # unix sockets spoken
- if not os.path.exists(self.host):
- raise MissingSocketError
- self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- try:
- self.sock.connect(self.host)
- except socket.error:
- raise ConnectionRefusedError
- else:
- self.sock = socket.create_connection((host, port), timeout)
-
-
-# this class based in code from cube-routed project
-
-class OpenVPNManager(object):
- """
- Run commands over OpenVPN management interface
- and parses the output.
- """
- # XXX might need a lock to avoid
- # race conditions here...
-
- def __init__(self, host="/tmp/.eip.sock", port="unix", password=None):
- #XXX hardcoded host here. change.
- self.host = host
- if isinstance(port, str) and port.isdigit():
- port = int(port)
- self.port = port
- self.password = password
- self.tn = None
-
- #XXX workaround for signaling
- #the ui that we don't know how to
- #manage a connection error
- self.with_errors = False
-
- def forget_errors(self):
- print('forgetting errors')
- self.with_errors = False
-
- def connect(self):
- """Connect to openvpn management interface"""
- try:
- self.close()
- except:
- #XXX don't like this general
- #catch here.
- pass
- if self.connected():
- return True
- self.tn = UDSTelnet(self.host, self.port)
-
- # XXX make password optional
- # specially for win plat. we should generate
- # the pass on the fly when invoking manager
- # from conductor
-
- #self.tn.read_until('ENTER PASSWORD:', 2)
- #self.tn.write(self.password + '\n')
- #self.tn.read_until('SUCCESS:', 2)
-
- self._seek_to_eof()
- self.forget_errors()
- return True
-
- def _seek_to_eof(self):
- """
- Read as much as available. Position seek pointer to end of stream
- """
- b = self.tn.read_eager()
- while b:
- b = self.tn.read_eager()
-
- def connected(self):
- """
- Returns True if connected
- rtype: bool
- """
- #return bool(getattr(self, 'tn', None))
- try:
- assert self.tn
- return True
- except:
- #XXX get rid of
- #this pokemon exception!!!
- return False
-
- def close(self, announce=True):
- """
- Close connection to openvpn management interface
- """
- if announce:
- self.tn.write("quit\n")
- self.tn.read_all()
- self.tn.get_socket().close()
- del self.tn
-
- def _send_command(self, cmd, tries=0):
- """
- Send a command to openvpn and return response as list
- """
- if tries > 3:
- return []
- if not self.connected():
- try:
- self.connect()
- except MissingSocketError:
- #XXX capture more helpful error
- #messages
- #pass
- return self.make_error()
- try:
- self.tn.write(cmd + "\n")
- except socket.error:
- logger.error('socket error')
- print('socket error!')
- self.close(announce=False)
- self._send_command(cmd, tries=tries + 1)
- return []
- buf = self.tn.read_until(b"END", 2)
- self._seek_to_eof()
- blist = buf.split('\r\n')
- if blist[-1].startswith('END'):
- del blist[-1]
- return blist
- else:
- return []
-
- def _send_short_command(self, cmd):
- """
- parse output from commands that are
- delimited by "success" instead
- """
- if not self.connected():
- self.connect()
- self.tn.write(cmd + "\n")
- # XXX not working?
- buf = self.tn.read_until(b"SUCCESS", 2)
- self._seek_to_eof()
- blist = buf.split('\r\n')
- return blist
-
- #
- # useful vpn commands
- #
-
- def pid(self):
- #XXX broken
- return self._send_short_command("pid")
-
- def make_error(self):
- """
- capture error and wrap it in an
- understandable format
- """
- #XXX get helpful error codes
- self.with_errors = True
- now = int(time.time())
- return '%s,LAUNCHER ERROR,ERROR,-,-' % now
-
- def state(self):
- """
- OpenVPN command: state
- """
- state = self._send_command("state")
- if not state:
- return None
- if isinstance(state, str):
- return state
- if isinstance(state, list):
- if len(state) == 1:
- return state[0]
- else:
- return state[-1]
-
- def status(self):
- """
- OpenVPN command: status
- """
- status = self._send_command("status")
- return status
-
- def status2(self):
- """
- OpenVPN command: last 2 statuses
- """
- return self._send_command("status 2")
-
- #
- # parse info
- #
-
- def get_status_io(self):
- status = self.status()
- if isinstance(status, str):
- lines = status.split('\n')
- if isinstance(status, list):
- lines = status
- try:
- (header, when, tun_read, tun_write,
- tcp_read, tcp_write, auth_read) = tuple(lines)
- except ValueError:
- return None
-
- when_ts = time.strptime(when.split(',')[1], "%a %b %d %H:%M:%S %Y")
- sep = ','
- # XXX cleanup!
- tun_read = tun_read.split(sep)[1]
- tun_write = tun_write.split(sep)[1]
- tcp_read = tcp_read.split(sep)[1]
- tcp_write = tcp_write.split(sep)[1]
- auth_read = auth_read.split(sep)[1]
-
- # XXX this could be a named tuple. prettier.
- return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read)
-
- def get_connection_state(self):
- state = self.state()
- if state is not None:
- ts, status_step, ok, ip, remote = state.split(',')
- ts = time.gmtime(float(ts))
- # XXX this could be a named tuple. prettier.
- return ts, status_step, ok, ip, remote
diff --git a/src/leap/eip/vpnwatcher.py b/src/leap/eip/vpnwatcher.py
deleted file mode 100644
index 09bd5811..00000000
--- a/src/leap/eip/vpnwatcher.py
+++ /dev/null
@@ -1,169 +0,0 @@
-"""generic watcher object that keeps track of connection status"""
-# This should be deprecated in favor of daemon mode + management
-# interface. But we can leave it here for debug purposes.
-
-
-class EIPConnectionStatus(object):
- """
- Keep track of client (gui) and openvpn
- states.
-
- These are the OpenVPN states:
- CONNECTING -- OpenVPN's initial state.
- WAIT -- (Client only) Waiting for initial response
- from server.
- AUTH -- (Client only) Authenticating with server.
- GET_CONFIG -- (Client only) Downloading configuration options
- from server.
- ASSIGN_IP -- Assigning IP address to virtual network
- interface.
- ADD_ROUTES -- Adding routes to system.
- CONNECTED -- Initialization Sequence Completed.
- RECONNECTING -- A restart has occurred.
- EXITING -- A graceful exit is in progress.
-
- We add some extra states:
-
- DISCONNECTED -- GUI initial state.
- UNRECOVERABLE -- An unrecoverable error has been raised
- while invoking openvpn service.
- """
- CONNECTING = 1
- WAIT = 2
- AUTH = 3
- GET_CONFIG = 4
- ASSIGN_IP = 5
- ADD_ROUTES = 6
- CONNECTED = 7
- RECONNECTING = 8
- EXITING = 9
-
- # gui specific states:
- UNRECOVERABLE = 11
- DISCONNECTED = 0
-
- def __init__(self, callbacks=None):
- """
- EIPConnectionStatus is initialized with a tuple
- of signals to be triggered.
- :param callbacks: a tuple of (callable) observers
- :type callbacks: tuple
- """
- # (callbacks to connect to signals in Qt-land)
- self.current = self.DISCONNECTED
- self.previous = None
- self.callbacks = callbacks
-
- def get_readable_status(self):
- # XXX DRY status / labels a little bit.
- # think we'll want to i18n this.
- human_status = {
- 0: 'disconnected',
- 1: 'connecting',
- 2: 'waiting',
- 3: 'authenticating',
- 4: 'getting config',
- 5: 'assigning ip',
- 6: 'adding routes',
- 7: 'connected',
- 8: 'reconnecting',
- 9: 'exiting',
- 11: 'unrecoverable error',
- }
- return human_status[self.current]
-
- def get_state_icon(self):
- """
- returns the high level icon
- for each fine-grain openvpn state
- """
- connecting = (self.CONNECTING,
- self.WAIT,
- self.AUTH,
- self.GET_CONFIG,
- self.ASSIGN_IP,
- self.ADD_ROUTES)
- connected = (self.CONNECTED,)
- disconnected = (self.DISCONNECTED,
- self.UNRECOVERABLE)
-
- # this can be made smarter,
- # but it's like it'll change,
- # so +readability.
-
- if self.current in connecting:
- return "connecting"
- if self.current in connected:
- return "connected"
- if self.current in disconnected:
- return "disconnected"
-
- def set_vpn_state(self, status):
- """
- accepts a state string from the management
- interface, and sets the internal state.
- :param status: openvpn STATE (uppercase).
- :type status: str
- """
- if hasattr(self, status):
- self.change_to(getattr(self, status))
-
- def set_current(self, to):
- """
- setter for the 'current' property
- :param to: destination state
- :type to: int
- """
- self.current = to
-
- def change_to(self, to):
- """
- :param to: destination state
- :type to: int
- """
- if to == self.current:
- return
- changed = False
- from_ = self.current
- self.current = to
-
- # We can add transition restrictions
- # here to ensure no transitions are
- # allowed outside the fsm.
-
- self.set_current(to)
- changed = True
-
- #trigger signals (as callbacks)
- #print('current state: %s' % self.current)
- if changed:
- self.previous = from_
- if self.callbacks:
- for cb in self.callbacks:
- if callable(cb):
- cb(self)
-
-
-def status_watcher(cs, line):
- """
- a wrapper that calls to ConnectionStatus object
- :param cs: a EIPConnectionStatus instance
- :type cs: EIPConnectionStatus object
- :param line: a single line of the watched output
- :type line: str
- """
- #print('status watcher watching')
-
- # from the mullvad code, should watch for
- # things like:
- # "Initialization Sequence Completed"
- # "With Errors"
- # "Tap-Win32"
-
- if "Completed" in line:
- cs.change_to(cs.CONNECTED)
- return
-
- if "Initial packet from" in line:
- cs.change_to(cs.CONNECTING)
- return
diff --git a/src/leap/gui/__init__.py b/src/leap/gui/__init__.py
index e69de29b..9b8f8746 100644
--- a/src/leap/gui/__init__.py
+++ b/src/leap/gui/__init__.py
@@ -0,0 +1,10 @@
+try:
+ import sip
+ sip.setapi('QString', 2)
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
+
+import firstrun
+
+__all__ = ['firstrun']
diff --git a/src/leap/gui/constants.py b/src/leap/gui/constants.py
new file mode 100644
index 00000000..277f3540
--- /dev/null
+++ b/src/leap/gui/constants.py
@@ -0,0 +1,13 @@
+import time
+
+APP_LOGO = ':/images/leap-color-small.png'
+
+# bare is the username portion of a JID
+# full includes the "at" and some extra chars
+# that can be allowed for fqdn
+
+BARE_USERNAME_REGEX = r"^[A-Za-z\d_]+$"
+FULL_USERNAME_REGEX = r"^[A-Za-z\d_@.-]+$"
+
+GUI_PAUSE_FOR_USER_SECONDS = 1
+pause_for_user = lambda: time.sleep(GUI_PAUSE_FOR_USER_SECONDS)
diff --git a/src/leap/gui/firstrun/__init__.py b/src/leap/gui/firstrun/__init__.py
new file mode 100644
index 00000000..8a70d90e
--- /dev/null
+++ b/src/leap/gui/firstrun/__init__.py
@@ -0,0 +1,29 @@
+try:
+ import sip
+ sip.setapi('QString', 2)
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
+
+import connect
+import intro
+import last
+import login
+import mixins
+import providerinfo
+import providerselect
+import providersetup
+import register
+import regvalidation
+
+__all__ = [
+ 'connect',
+ 'intro',
+ 'last',
+ 'login',
+ 'mixins',
+ 'providerinfo',
+ 'providerselect',
+ 'providersetup',
+ 'register',
+ 'regvalidation']
diff --git a/src/leap/gui/firstrun/connect.py b/src/leap/gui/firstrun/connect.py
new file mode 100644
index 00000000..a0fe021c
--- /dev/null
+++ b/src/leap/gui/firstrun/connect.py
@@ -0,0 +1,231 @@
+"""
+Connecting Page, used in First Run Wizard
+"""
+# XXX FIXME
+# DEPRECATED. All functionality moved to regvalidation
+# This file should be removed after checking that one is ok.
+# XXX
+
+import logging
+
+from PyQt4 import QtGui
+
+logger = logging.getLogger(__name__)
+
+from leap.base import auth
+
+from leap.gui.constants import APP_LOGO
+from leap.gui.styles import ErrorLabelStyleSheet
+
+
+class ConnectingPage(QtGui.QWizardPage):
+
+ # XXX change to a ValidationPage
+
+ def __init__(self, parent=None):
+ super(ConnectingPage, self).__init__(parent)
+
+ self.setTitle("Connecting")
+ self.setSubTitle('Connecting to provider.')
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.status = QtGui.QLabel("")
+ self.status.setWordWrap(True)
+ self.progress = QtGui.QProgressBar()
+ self.progress.setMaximum(100)
+ self.progress.hide()
+
+ # for pre-checks
+ self.status_line_1 = QtGui.QLabel()
+ self.status_line_2 = QtGui.QLabel()
+ self.status_line_3 = QtGui.QLabel()
+ self.status_line_4 = QtGui.QLabel()
+
+ # for connecting signals...
+ self.status_line_5 = QtGui.QLabel()
+
+ layout = QtGui.QGridLayout()
+ layout.addWidget(self.status, 0, 1)
+ layout.addWidget(self.progress, 5, 1)
+ layout.addWidget(self.status_line_1, 8, 1)
+ layout.addWidget(self.status_line_2, 9, 1)
+ layout.addWidget(self.status_line_3, 10, 1)
+ layout.addWidget(self.status_line_4, 11, 1)
+
+ # XXX to be used?
+ #self.validation_status = QtGui.QLabel("")
+ #self.validation_status.setStyleSheet(
+ #ErrorLabelStyleSheet)
+ #self.validation_msg = QtGui.QLabel("")
+
+ self.setLayout(layout)
+
+ self.goto_login_again = False
+
+ def set_status(self, status):
+ self.status.setText(status)
+ self.status.setWordWrap(True)
+
+ def set_status_line(self, line, status):
+ line = getattr(self, 'status_line_%s' % line)
+ if line:
+ line.setText(status)
+
+ def set_validation_status(self, status):
+ # Do not remember if we're using
+ # status lines > 3 now...
+ # if we are, move below
+ self.status_line_3.setStyleSheet(
+ ErrorLabelStyleSheet)
+ self.status_line_3.setText(status)
+
+ def set_validation_message(self, message):
+ self.status_line_4.setText(message)
+ self.status_line_4.setWordWrap(True)
+
+ def get_donemsg(self, msg):
+ return "%s ... done" % msg
+
+ def run_eip_checks_for_provider_and_connect(self, domain):
+ wizard = self.wizard()
+ conductor = wizard.conductor
+ start_eip_signal = getattr(
+ wizard,
+ 'start_eipconnection_signal', None)
+
+ if conductor:
+ conductor.set_provider_domain(domain)
+ conductor.run_checks()
+ self.conductor = conductor
+ errors = self.eip_error_check()
+ if not errors and start_eip_signal:
+ start_eip_signal.emit()
+
+ else:
+ logger.warning(
+ "No conductor found. This means that "
+ "probably the wizard has been launched "
+ "in an stand-alone way")
+
+ def eip_error_check(self):
+ """
+ a version of the main app error checker,
+ but integrated within the connecting page of the wizard.
+ consumes the conductor error queue.
+ pops errors, and add those to the wizard page
+ """
+ logger.debug('eip error check from connecting page')
+ errq = self.conductor.error_queue
+ # XXX missing!
+
+ def fetch_and_validate(self):
+ # XXX MOVE TO validate function in register-validation
+ import time
+ domain = self.field('provider_domain')
+ wizard = self.wizard()
+ #pconfig = wizard.providerconfig
+ eipconfigchecker = wizard.eipconfigchecker()
+ pCertChecker = wizard.providercertchecker(
+ domain=domain)
+
+ # username and password are in different fields
+ # if they were stored in log_in or sign_up pages.
+ from_login = self.wizard().from_login
+ unamek_base = 'userName'
+ passwk_base = 'userPassword'
+ unamek = 'login_%s' % unamek_base if from_login else unamek_base
+ passwk = 'login_%s' % passwk_base if from_login else passwk_base
+
+ username = self.field(unamek)
+ password = self.field(passwk)
+ credentials = username, password
+
+ self.progress.show()
+
+ fetching_eip_conf_msg = 'Fetching eip service configuration'
+ self.set_status(fetching_eip_conf_msg)
+ self.progress.setValue(30)
+
+ # Fetching eip service
+ eipconfigchecker.fetch_eip_service_config(
+ domain=domain)
+
+ self.status_line_1.setText(
+ self.get_donemsg(fetching_eip_conf_msg))
+
+ getting_client_cert_msg = 'Getting client certificate'
+ self.set_status(getting_client_cert_msg)
+ self.progress.setValue(66)
+
+ # Download cert
+ try:
+ pCertChecker.download_new_client_cert(
+ credentials=credentials,
+ # FIXME FIXME FIXME
+ # XXX FIX THIS!!!!!
+ # BUG #638. remove verify
+ # FIXME FIXME FIXME
+ verify=False)
+ except auth.SRPAuthenticationError as exc:
+ self.set_validation_status(
+ "Authentication error: %s" % exc.message)
+ return False
+
+ time.sleep(2)
+ self.status_line_2.setText(
+ self.get_donemsg(getting_client_cert_msg))
+
+ validating_clientcert_msg = 'Validating client certificate'
+ self.set_status(validating_clientcert_msg)
+ self.progress.setValue(90)
+ time.sleep(2)
+ self.status_line_3.setText(
+ self.get_donemsg(validating_clientcert_msg))
+
+ self.progress.setValue(100)
+ time.sleep(3)
+
+ # here we go! :)
+ self.run_eip_checks_for_provider_and_connect(domain)
+
+ #self.validation_block = self.wait_for_validation_block()
+
+ # XXX signal timeout!
+ return True
+
+ #
+ # wizardpage methods
+ #
+
+ def nextId(self):
+ wizard = self.wizard()
+ # XXX this does not work because
+ # page login has already been met
+ #if self.goto_login_again:
+ #next_ = "login"
+ #else:
+ #next_ = "lastpage"
+ next_ = "lastpage"
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ # XXX if we're coming from signup page
+ # we could say something like
+ # 'registration successful!'
+ self.status.setText(
+ "We have "
+ "all we need to connect with the provider.<br><br> "
+ "Click <i>next</i> to continue. ")
+ self.progress.setValue(0)
+ self.progress.hide()
+ self.status_line_1.setText('')
+ self.status_line_2.setText('')
+ self.status_line_3.setText('')
+
+ def validatePage(self):
+ # XXX remove
+ validated = self.fetch_and_validate()
+ return validated
diff --git a/src/leap/gui/firstrun/constants.py b/src/leap/gui/firstrun/constants.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/gui/firstrun/constants.py
diff --git a/src/leap/gui/firstrun/intro.py b/src/leap/gui/firstrun/intro.py
new file mode 100644
index 00000000..4bb008c7
--- /dev/null
+++ b/src/leap/gui/firstrun/intro.py
@@ -0,0 +1,68 @@
+"""
+Intro page used in first run wizard
+"""
+
+from PyQt4 import QtGui
+
+from leap.gui.constants import APP_LOGO
+
+
+class IntroPage(QtGui.QWizardPage):
+ def __init__(self, parent=None):
+ super(IntroPage, self).__init__(parent)
+
+ self.setTitle("First run wizard.")
+
+ #self.setPixmap(
+ #QtGui.QWizard.WatermarkPixmap,
+ #QtGui.QPixmap(':/images/watermark1.png'))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ label = QtGui.QLabel(
+ "Now we will guide you through "
+ "some configuration that is needed before you "
+ "can connect for the first time.<br><br>"
+ "If you ever need to modify these options again, "
+ "you can find the wizard in the '<i>Settings</i>' menu from the "
+ "main window.<br><br>"
+ "Do you want to <b>sign up</b> for a new account, or <b>log "
+ "in</b> with an already existing username?<br>")
+ label.setWordWrap(True)
+
+ radiobuttonGroup = QtGui.QGroupBox()
+
+ self.sign_up = QtGui.QRadioButton(
+ "Sign up for a new account.")
+ self.sign_up.setChecked(True)
+ self.log_in = QtGui.QRadioButton(
+ "Log In with my credentials.")
+
+ radiobLayout = QtGui.QVBoxLayout()
+ radiobLayout.addWidget(self.sign_up)
+ radiobLayout.addWidget(self.log_in)
+ radiobuttonGroup.setLayout(radiobLayout)
+
+ layout = QtGui.QVBoxLayout()
+ layout.addWidget(label)
+ layout.addWidget(radiobuttonGroup)
+ self.setLayout(layout)
+
+ self.registerField('is_signup', self.sign_up)
+
+ def validatePage(self):
+ return True
+
+ def nextId(self):
+ """
+ returns next id
+ in a non-linear wizard
+ """
+ if self.sign_up.isChecked():
+ next_ = 'providerselection'
+ if self.log_in.isChecked():
+ next_ = 'login'
+ wizard = self.wizard()
+ return wizard.get_page_index(next_)
diff --git a/src/leap/gui/firstrun/last.py b/src/leap/gui/firstrun/last.py
new file mode 100644
index 00000000..d33d2e77
--- /dev/null
+++ b/src/leap/gui/firstrun/last.py
@@ -0,0 +1,92 @@
+"""
+Last Page, used in First Run Wizard
+"""
+import logging
+
+from PyQt4 import QtGui
+
+from leap.util.coroutines import coroutine
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class LastPage(QtGui.QWizardPage):
+ def __init__(self, parent=None):
+ super(LastPage, self).__init__(parent)
+
+ self.setTitle("Connecting to Encrypted Internet Proxy service...")
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ #self.setPixmap(
+ #QtGui.QWizard.WatermarkPixmap,
+ #QtGui.QPixmap(':/images/watermark2.png'))
+
+ self.label = QtGui.QLabel()
+ self.label.setWordWrap(True)
+
+ # XXX REFACTOR to a Validating Page...
+ self.status_line_1 = QtGui.QLabel()
+ self.status_line_2 = QtGui.QLabel()
+ self.status_line_3 = QtGui.QLabel()
+ self.status_line_4 = QtGui.QLabel()
+
+ layout = QtGui.QVBoxLayout()
+ layout.addWidget(self.label)
+
+ # make loop
+ layout.addWidget(self.status_line_1)
+ layout.addWidget(self.status_line_2)
+ layout.addWidget(self.status_line_3)
+ layout.addWidget(self.status_line_4)
+
+ self.setLayout(layout)
+
+ def set_status_line(self, line, status):
+ statusline = getattr(self, 'status_line_%s' % line)
+ if statusline:
+ statusline.setText(status)
+
+ def set_finished_status(self):
+ self.setTitle('You are now using an encrypted connection!')
+ finishText = self.wizard().buttonText(
+ QtGui.QWizard.FinishButton)
+ finishText = finishText.replace('&', '')
+ self.label.setText(
+ "Click '<i>%s</i>' to end the wizard and "
+ "save your settings." % finishText)
+
+ @coroutine
+ def eip_status_handler(self):
+ # XXX this can be changed to use
+ # signals. See progress.py
+ logger.debug('logging status in last page')
+ self.validation_done = False
+ status_count = 0
+ try:
+ while True:
+ status = (yield)
+ status_count += 1
+ # XXX add to line...
+ logger.debug('status --> %s', status)
+ self.set_status_line(status_count, status)
+ if status == "connected":
+ self.set_finished_status()
+ break
+ except GeneratorExit:
+ pass
+ except StopIteration:
+ pass
+
+ def initializePage(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ eip_status_handler = self.eip_status_handler()
+ eip_statuschange_signal = wizard.eip_statuschange_signal
+ if eip_statuschange_signal:
+ eip_statuschange_signal.connect(
+ lambda status: eip_status_handler.send(status))
diff --git a/src/leap/gui/firstrun/login.py b/src/leap/gui/firstrun/login.py
new file mode 100644
index 00000000..02bace86
--- /dev/null
+++ b/src/leap/gui/firstrun/login.py
@@ -0,0 +1,330 @@
+"""
+LogIn Page, used inf First Run Wizard
+"""
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+import requests
+
+from leap.base import auth
+from leap.gui.firstrun.mixins import UserFormMixIn
+from leap.gui.progress import InlineValidationPage
+from leap.gui import styles
+
+from leap.gui.constants import APP_LOGO, FULL_USERNAME_REGEX
+
+
+class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage
+
+ def __init__(self, parent=None):
+
+ super(LogInPage, self).__init__(parent)
+ self.current_page = "login"
+
+ self.setTitle("Log In")
+ self.setSubTitle("Log in with your credentials.")
+ self.current_page = "login"
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.setupSteps()
+ self.setupUI()
+
+ self.do_confirm_next = False
+
+ def setupUI(self):
+ userNameLabel = QtGui.QLabel("User &name:")
+ userNameLineEdit = QtGui.QLineEdit()
+ userNameLineEdit.cursorPositionChanged.connect(
+ self.reset_validation_status)
+ userNameLabel.setBuddy(userNameLineEdit)
+
+ # let's add regex validator
+ usernameRe = QtCore.QRegExp(FULL_USERNAME_REGEX)
+ userNameLineEdit.setValidator(
+ QtGui.QRegExpValidator(usernameRe, self))
+
+ #userNameLineEdit.setPlaceholderText(
+ #'username@provider.example.org')
+ self.userNameLineEdit = userNameLineEdit
+
+ userPasswordLabel = QtGui.QLabel("&Password:")
+ self.userPasswordLineEdit = QtGui.QLineEdit()
+ self.userPasswordLineEdit.setEchoMode(
+ QtGui.QLineEdit.Password)
+ userPasswordLabel.setBuddy(self.userPasswordLineEdit)
+
+ self.registerField('login_userName*', self.userNameLineEdit)
+ self.registerField('login_userPassword*', self.userPasswordLineEdit)
+
+ layout = QtGui.QGridLayout()
+ layout.setColumnMinimumWidth(0, 20)
+
+ validationMsg = QtGui.QLabel("")
+ validationMsg.setStyleSheet(styles.ErrorLabelStyleSheet)
+ self.validationMsg = validationMsg
+
+ layout.addWidget(validationMsg, 0, 3)
+ layout.addWidget(userNameLabel, 1, 0)
+ layout.addWidget(self.userNameLineEdit, 1, 3)
+ layout.addWidget(userPasswordLabel, 2, 0)
+ layout.addWidget(self.userPasswordLineEdit, 2, 3)
+
+ # add validation frame
+ self.setupValidationFrame()
+ layout.addWidget(self.valFrame, 4, 2, 4, 2)
+ self.valFrame.hide()
+
+ self.nextText("Log in")
+ self.setLayout(layout)
+
+ #self.registerField('is_login_wizard')
+
+ def nextText(self, text):
+ self.setButtonText(
+ QtGui.QWizard.NextButton, text)
+
+ def nextFocus(self):
+ self.wizard().button(
+ QtGui.QWizard.NextButton).setFocus()
+
+ def disableNextButton(self):
+ self.wizard().button(
+ QtGui.QWizard.NextButton).setDisabled(True)
+
+ def onUserNameEdit(self, *args):
+ if self.initial_username_sample:
+ self.userNameLineEdit.setText('')
+ # XXX set regular color
+ self.initial_username_sample = None
+
+ def disableFields(self):
+ for field in (self.userNameLineEdit,
+ self.userPasswordLineEdit):
+ field.setDisabled(True)
+
+ def populateErrors(self):
+ # XXX could move this to ValidationMixin
+ # used in providerselect and register too
+
+ errors = self.wizard().get_validation_error(
+ self.current_page)
+ #prev_er = getattr(self, 'prevalidation_error', None)
+ showerr = self.validationMsg.setText
+
+ #if not errors and prev_er:
+ #showerr(prev_er)
+ #return
+#
+ if errors:
+ bad_str = getattr(self, 'bad_string', None)
+ cur_str = self.userNameLineEdit.text()
+
+ if bad_str is None:
+ # first time we fall here.
+ # save the current bad_string value
+ self.bad_string = cur_str
+ showerr(errors)
+ else:
+ #if prev_er:
+ #showerr(prev_er)
+ #return
+ # not the first time
+ if cur_str == bad_str:
+ showerr(errors)
+ else:
+ self.focused_field = False
+ showerr('')
+
+ def cleanup_errormsg(self):
+ """
+ we reset bad_string to None
+ should be called before leaving the page
+ """
+ self.bad_string = None
+
+ def paintEvent(self, event):
+ """
+ we hook our populate errors
+ on paintEvent because we need it to catch
+ when user enters the page coming from next,
+ and initializePage does not cover that case.
+ Maybe there's a better event to hook upon.
+ """
+ super(LogInPage, self).paintEvent(event)
+ self.populateErrors()
+
+ def set_prevalidation_error(self, error):
+ self.prevalidation_error = error
+
+ # pagewizard methods
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ if wizard.is_provider_setup is False:
+ next_ = 'providersetupvalidation'
+ if wizard.is_provider_setup is True:
+ # XXX bad name, ok, gonna change that
+ next_ = 'signupvalidation'
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ super(LogInPage, self).initializePage()
+ username = self.userNameLineEdit
+ username.setText('username@provider.example.org')
+ username.cursorPositionChanged.connect(
+ self.onUserNameEdit)
+ self.initial_username_sample = True
+ self.validationMsg.setText('')
+ self.valFrame.hide()
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ and clean the inline validation widget.
+ """
+ self.validationMsg.setText('')
+ self.steps.removeAllSteps()
+ self.clearTable()
+
+ def validatePage(self):
+ """
+ if not register done, do checks.
+ if done, wait for click.
+ """
+ self.disableNextButton()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+
+ if self.do_confirm_next:
+ full_username = self.userNameLineEdit.text()
+ password = self.userPasswordLineEdit.text()
+ username, domain = full_username.split('@')
+ self.setField('provider_domain', domain)
+ self.setField('login_userName', username)
+ self.setField('login_userPassword', password)
+
+ return True
+
+ if not self.is_done():
+ self.reset_validation_status()
+ self.do_checks()
+
+ return self.is_done()
+
+ def _do_checks(self):
+ # XXX convert this to inline
+
+ full_username = self.userNameLineEdit.text()
+ ###########################
+ # 0) check user@domain form
+ ###########################
+
+ def checkusername():
+ if full_username.count('@') != 1:
+ return self.fail(
+ self.tr(
+ "Username must be in the username@provider form."))
+ else:
+ return True
+
+ yield(("head_sentinel", 0), checkusername)
+
+ # XXX I think this is not needed
+ # since we're also checking for the is_signup field.
+ #self.wizard().from_login = True
+
+ username, domain = full_username.split('@')
+ password = self.userPasswordLineEdit.text()
+
+ # We try a call to an authenticated
+ # page here as a mean to catch
+ # srp authentication errors while
+ wizard = self.wizard()
+ eipconfigchecker = wizard.eipconfigchecker()
+
+ ########################
+ # 1) try name resolution
+ ########################
+ # show the frame before going on...
+ QtCore.QMetaObject.invokeMethod(
+ self, "showStepsFrame")
+
+ # Able to contact domain?
+ # can get definition?
+ # two-by-one
+ def resolvedomain():
+ try:
+ eipconfigchecker.fetch_definition(domain=domain)
+
+ # we're using requests here for all
+ # the possible error cases that it catches.
+ except requests.exceptions.ConnectionError as exc:
+ return self.fail(exc.message[1])
+ except requests.exceptions.HTTPError as exc:
+ return self.fail(exc.message)
+ except Exception as exc:
+ # XXX get catchall error msg
+ return self.fail(
+ exc.message)
+
+ yield((self.tr("resolving domain name"), 20), resolvedomain)
+
+ wizard.set_providerconfig(
+ eipconfigchecker.defaultprovider.config)
+
+ ########################
+ # 2) do authentication
+ ########################
+ credentials = username, password
+ pCertChecker = wizard.providercertchecker(
+ domain=domain)
+
+ def validate_credentials():
+ #################
+ # FIXME #BUG #638
+ verify = False
+
+ try:
+ pCertChecker.download_new_client_cert(
+ credentials=credentials,
+ verify=verify)
+
+ except auth.SRPAuthenticationError as exc:
+ return self.fail(
+ self.tr("Authentication error: %s" % exc.message))
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ yield(('Validating credentials', 20), validate_credentials)
+
+ self.set_done()
+ yield(("end_sentinel", 0), lambda: None)
+
+ def green_validation_status(self):
+ val = self.validationMsg
+ val.setText(self.tr('Credentials validated.'))
+ val.setStyleSheet(styles.GreenLineEdit)
+
+ def on_checks_validation_ready(self):
+ """
+ after checks
+ """
+ if self.is_done():
+ self.disableFields()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+ # make the user confirm the transition
+ # to next page.
+ self.nextText('&Next')
+ self.nextFocus()
+ self.green_validation_status()
+ self.do_confirm_next = True
diff --git a/src/leap/gui/firstrun/mixins.py b/src/leap/gui/firstrun/mixins.py
new file mode 100644
index 00000000..c4731893
--- /dev/null
+++ b/src/leap/gui/firstrun/mixins.py
@@ -0,0 +1,18 @@
+"""
+mixins used in First Run Wizard
+"""
+
+
+class UserFormMixIn(object):
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ """
+ self.validationMsg.setText('')
+
+ def set_validation_status(self, msg):
+ """
+ set generic validation status
+ """
+ self.validationMsg.setText(msg)
diff --git a/src/leap/gui/firstrun/providerinfo.py b/src/leap/gui/firstrun/providerinfo.py
new file mode 100644
index 00000000..c5b2984c
--- /dev/null
+++ b/src/leap/gui/firstrun/providerinfo.py
@@ -0,0 +1,98 @@
+"""
+Provider Info Page, used in First run Wizard
+"""
+import logging
+
+from PyQt4 import QtGui
+
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class ProviderInfoPage(QtGui.QWizardPage):
+
+ def __init__(self, parent=None):
+ super(ProviderInfoPage, self).__init__(parent)
+
+ self.setTitle(self.tr("Provider Info"))
+ self.setSubTitle(self.tr(
+ "This is what provider says."))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.create_info_panel()
+
+ def create_info_panel(self):
+ # Use stacked widget instead
+ # of reparenting the layout.
+
+ infoWidget = QtGui.QStackedWidget()
+
+ info = QtGui.QWidget()
+ layout = QtGui.QVBoxLayout()
+
+ displayName = QtGui.QLabel("")
+ description = QtGui.QLabel("")
+ enrollment_policy = QtGui.QLabel("")
+
+ # XXX set stylesheet...
+ # prettify a little bit.
+ # bigger fonts and so on...
+
+ # We could use a QFrame here
+
+ layout.addWidget(displayName)
+ layout.addWidget(description)
+ layout.addWidget(enrollment_policy)
+ layout.addStretch(1)
+
+ info.setLayout(layout)
+ infoWidget.addWidget(info)
+
+ pageLayout = QtGui.QVBoxLayout()
+ pageLayout.addWidget(infoWidget)
+ self.setLayout(pageLayout)
+
+ # add refs to self to allow for
+ # updates.
+ # Watch out! Have to get rid of these references!
+ # this should be better handled with signals !!
+ self.displayName = displayName
+ self.description = description
+ self.enrollment_policy = enrollment_policy
+
+ def show_provider_info(self):
+
+ # XXX get multilingual objects
+ # directly from the config object
+
+ lang = "en"
+ pconfig = self.wizard().providerconfig
+
+ dn = pconfig.get('display_name')
+ display_name = dn[lang] if dn else ''
+ domain_name = self.field('provider_domain')
+
+ self.displayName.setText(
+ "<b>%s</b> https://%s" % (display_name, domain_name))
+
+ desc = pconfig.get('description')
+ description_text = desc[lang] if desc else ''
+ self.description.setText(
+ "<i>%s</i>" % description_text)
+
+ enroll = pconfig.get('enrollment_policy')
+ if enroll:
+ self.enrollment_policy.setText(
+ 'enrollment policy: %s' % enroll)
+
+ def nextId(self):
+ wizard = self.wizard()
+ next_ = "providersetupvalidation"
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ self.show_provider_info()
diff --git a/src/leap/gui/firstrun/providerselect.py b/src/leap/gui/firstrun/providerselect.py
new file mode 100644
index 00000000..a4be51a9
--- /dev/null
+++ b/src/leap/gui/firstrun/providerselect.py
@@ -0,0 +1,472 @@
+"""
+Select Provider Page, used in First Run Wizard
+"""
+import logging
+
+import requests
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.base import exceptions as baseexceptions
+#from leap.crypto import certs
+from leap.eip import exceptions as eipexceptions
+from leap.gui.progress import InlineValidationPage
+from leap.gui import styles
+from leap.gui.utils import delay
+from leap.util.web import get_https_domain_and_port
+
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class SelectProviderPage(InlineValidationPage):
+
+ launchChecks = QtCore.pyqtSignal()
+
+ def __init__(self, parent=None, providers=None):
+ super(SelectProviderPage, self).__init__(parent)
+ self.current_page = 'providerselection'
+
+ self.setTitle(self.tr("Enter Provider"))
+ self.setSubTitle(self.tr(
+ "Please enter the domain of the provider you want "
+ "to use for your connection.")
+ )
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.did_cert_check = False
+
+ self.is_done = False
+
+ self.setupSteps()
+ self.setupUI()
+
+ self.launchChecks.connect(
+ self.launch_checks)
+
+ self.providerNameEdit.editingFinished.connect(
+ lambda: self.providerCheckButton.setFocus(True))
+
+ def setupUI(self):
+ """
+ initializes the UI
+ """
+ providerNameLabel = QtGui.QLabel("h&ttps://")
+ # note that we expect the bare domain name
+ # we will add the scheme later
+ providerNameEdit = QtGui.QLineEdit()
+ providerNameEdit.cursorPositionChanged.connect(
+ self.reset_validation_status)
+ providerNameLabel.setBuddy(providerNameEdit)
+
+ # add regex validator
+ providerDomainRe = QtCore.QRegExp(r"^[a-z\d_-.]+$")
+ providerNameEdit.setValidator(
+ QtGui.QRegExpValidator(providerDomainRe, self))
+ self.providerNameEdit = providerNameEdit
+
+ # Eventually we will seed a list of
+ # well known providers here.
+
+ #providercombo = QtGui.QComboBox()
+ #if providers:
+ #for provider in providers:
+ #providercombo.addItem(provider)
+ #providerNameSelect = providercombo
+
+ self.registerField("provider_domain*", self.providerNameEdit)
+ #self.registerField('provider_name_index', providerNameSelect)
+
+ validationMsg = QtGui.QLabel("")
+ validationMsg.setStyleSheet(styles.ErrorLabelStyleSheet)
+ self.validationMsg = validationMsg
+ providerCheckButton = QtGui.QPushButton(self.tr("chec&k!"))
+ self.providerCheckButton = providerCheckButton
+
+ # cert info
+
+ # this is used in the callback
+ # for the checkbox changes.
+ # tricky, since the first time came
+ # from the exception message.
+ # should get string from exception too!
+ self.bad_cert_status = self.tr(
+ "Server certificate could not be verified.")
+
+ self.certInfo = QtGui.QLabel("")
+ self.certInfo.setWordWrap(True)
+ self.certWarning = QtGui.QLabel("")
+ self.trustProviderCertCheckBox = QtGui.QCheckBox(
+ "&Trust this provider certificate.")
+
+ self.trustProviderCertCheckBox.stateChanged.connect(
+ self.onTrustCheckChanged)
+ self.providerNameEdit.textChanged.connect(
+ self.onProviderChanged)
+ self.providerCheckButton.clicked.connect(
+ self.onCheckButtonClicked)
+
+ layout = QtGui.QGridLayout()
+ layout.addWidget(validationMsg, 0, 2)
+ layout.addWidget(providerNameLabel, 1, 1)
+ layout.addWidget(providerNameEdit, 1, 2)
+ layout.addWidget(providerCheckButton, 1, 3)
+
+ # add certinfo group
+ # XXX not shown now. should move to validation box.
+ #layout.addWidget(certinfoGroup, 4, 1, 4, 2)
+ #self.certinfoGroup = certinfoGroup
+ #self.certinfoGroup.hide()
+
+ # add validation frame
+ self.setupValidationFrame()
+ layout.addWidget(self.valFrame, 4, 2, 4, 2)
+ self.valFrame.hide()
+
+ self.setLayout(layout)
+
+ # certinfo
+
+ def setupCertInfoGroup(self):
+ # XXX not used now.
+ certinfoGroup = QtGui.QGroupBox(
+ self.tr("Certificate validation"))
+ certinfoLayout = QtGui.QVBoxLayout()
+ certinfoLayout.addWidget(self.certInfo)
+ certinfoLayout.addWidget(self.certWarning)
+ certinfoLayout.addWidget(self.trustProviderCertCheckBox)
+ certinfoGroup.setLayout(certinfoLayout)
+ self.certinfoGroup = self.certinfoGroup
+
+ # progress frame
+
+ def setupValidationFrame(self):
+ qframe = QtGui.QFrame
+ valFrame = qframe()
+ valFrame.setFrameStyle(qframe.NoFrame)
+ valframeLayout = QtGui.QVBoxLayout()
+ zeros = (0, 0, 0, 0)
+ valframeLayout.setContentsMargins(*zeros)
+
+ valframeLayout.addWidget(self.stepsTableWidget)
+ valFrame.setLayout(valframeLayout)
+ self.valFrame = valFrame
+
+ @QtCore.pyqtSlot()
+ def onDisableCheckButton(self):
+ #print 'CHECK BUTTON DISABLED!!!'
+ self.providerCheckButton.setDisabled(True)
+
+ @QtCore.pyqtSlot()
+ def launch_checks(self):
+ self.do_checks()
+
+ def onCheckButtonClicked(self):
+ QtCore.QMetaObject.invokeMethod(
+ self, "onDisableCheckButton")
+
+ QtCore.QMetaObject.invokeMethod(
+ self, "showStepsFrame")
+
+ delay(self, "launch_checks")
+
+ def _do_checks(self):
+ """
+ generator that yields actual checks
+ that are executed in a separate thread
+ """
+
+ wizard = self.wizard()
+ full_domain = self.providerNameEdit.text()
+
+ # we check if we have a port in the domain string.
+ domain, port = get_https_domain_and_port(full_domain)
+ _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain)
+
+ netchecker = wizard.netchecker()
+
+ providercertchecker = wizard.providercertchecker()
+ eipconfigchecker = wizard.eipconfigchecker(domain=_domain)
+
+ yield(("head_sentinel", 0), lambda: None)
+
+ ########################
+ # 1) try name resolution
+ ########################
+
+ def namecheck():
+ """
+ in which we check if
+ we are able to name resolve
+ this domain
+ """
+ try:
+ netchecker.check_name_resolution(
+ domain)
+
+ except baseexceptions.LeapException as exc:
+ logger.error(exc.message)
+ return self.fail(exc.usermessage)
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ logger.debug('checking name resolution')
+ yield((self.tr("checking domain name"), 20), namecheck)
+
+ #########################
+ # 2) try https connection
+ #########################
+
+ def httpscheck():
+ """
+ in which we check
+ if the provider
+ is offering service over
+ https
+ """
+ try:
+ providercertchecker.is_https_working(
+ "https://%s" % _domain,
+ verify=True)
+
+ except eipexceptions.HttpsBadCertError as exc:
+ logger.debug('exception')
+ return self.fail(exc.usermessage)
+ # XXX skipping for now...
+ ##############################################
+ # We had this validation logic
+ # in the provider selection page before
+ ##############################################
+ #if self.trustProviderCertCheckBox.isChecked():
+ #pass
+ #else:
+ #fingerprint = certs.get_cert_fingerprint(
+ #domain=domain, sep=" ")
+
+ # it's ok if we've trusted this fgprt before
+ #trustedcrts = wizard.trusted_certs
+ #if trustedcrts and \
+ # fingerprint.replace(' ', '') in trustedcrts:
+ #pass
+ #else:
+ # let your user face panick :P
+ #self.add_cert_info(fingerprint)
+ #self.did_cert_check = True
+ #self.completeChanged.emit()
+ #return False
+
+ except baseexceptions.LeapException as exc:
+ return self.fail(exc.usermessage)
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ logger.debug('checking https connection')
+ yield((self.tr("checking https connection"), 40), httpscheck)
+
+ ##################################
+ # 3) try download provider info...
+ ##################################
+
+ def fetchinfo():
+ try:
+ # XXX we already set _domain in the initialization
+ # so it should not be needed here.
+ eipconfigchecker.fetch_definition(domain=_domain)
+ wizard.set_providerconfig(
+ eipconfigchecker.defaultprovider.config)
+ except requests.exceptions.SSLError:
+ # XXX we should have catched this before.
+ # but cert checking is broken.
+ return self.fail(self.tr(
+ "Could not get info from provider."))
+ except requests.exceptions.ConnectionError:
+ return self.fail(self.tr(
+ "Could not download provider info "
+ "(refused conn.)."))
+
+ except Exception as exc:
+ return self.fail(
+ self.tr(exc.message))
+ else:
+ return True
+
+ yield((self.tr("fetching provider info"), 80), fetchinfo)
+
+ # done!
+
+ self.is_done = True
+ yield(("end_sentinel", 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ called after _do_checks has finished.
+ """
+ self.domain_checked = True
+ self.completeChanged.emit()
+ # let's set focus...
+ if self.is_done:
+ self.wizard().clean_validation_error(self.current_page)
+ nextbutton = self.wizard().button(QtGui.QWizard.NextButton)
+ nextbutton.setFocus()
+ else:
+ self.providerNameEdit.setFocus()
+
+ # cert trust verification
+ # (disabled for now)
+
+ def is_insecure_cert_trusted(self):
+ return self.trustProviderCertCheckBox.isChecked()
+
+ def onTrustCheckChanged(self, state):
+ checked = False
+ if state == 2:
+ checked = True
+
+ if checked:
+ self.reset_validation_status()
+ else:
+ self.set_validation_status(self.bad_cert_status)
+
+ # trigger signal to redraw next button
+ self.completeChanged.emit()
+
+ def add_cert_info(self, certinfo):
+ self.certWarning.setText(
+ "Do you want to <b>trust this provider certificate?</b>")
+ self.certInfo.setText(
+ 'SHA-256 fingerprint: <i>%s</i><br>' % certinfo)
+ self.certInfo.setWordWrap(True)
+ self.certinfoGroup.show()
+
+ def onProviderChanged(self, text):
+ self.is_done = False
+ provider = self.providerNameEdit.text()
+ if provider:
+ self.providerCheckButton.setDisabled(False)
+ else:
+ self.providerCheckButton.setDisabled(True)
+ self.completeChanged.emit()
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ and clean the inline validation widget.
+ """
+ self.validationMsg.setText('')
+ self.steps.removeAllSteps()
+ self.clearTable()
+ self.domain_checked = False
+
+ # pagewizard methods
+
+ def isComplete(self):
+ provider = self.providerNameEdit.text()
+
+ if not self.is_done:
+ return False
+
+ if not provider:
+ return False
+ else:
+ if self.is_insecure_cert_trusted():
+ return True
+ if not self.did_cert_check:
+ if self.is_done:
+ # XXX sure?
+ return True
+ return False
+
+ def populateErrors(self):
+ # XXX could move this to ValidationMixin
+ # with some defaults for the validating fields
+ # (now it only allows one field, manually specified)
+
+ #logger.debug('getting errors')
+ errors = self.wizard().get_validation_error(
+ self.current_page)
+ if errors:
+ bad_str = getattr(self, 'bad_string', None)
+ cur_str = self.providerNameEdit.text()
+ showerr = self.validationMsg.setText
+ markred = lambda: self.providerNameEdit.setStyleSheet(
+ styles.ErrorLineEdit)
+ umarkrd = lambda: self.providerNameEdit.setStyleSheet(
+ styles.RegularLineEdit)
+ if bad_str is None:
+ # first time we fall here.
+ # save the current bad_string value
+ self.bad_string = cur_str
+ showerr(errors)
+ markred()
+ else:
+ # not the first time
+ # XXX hey, this is getting convoluted.
+ # roll out this.
+ # but be careful about all the possibilities
+ # with going back and forth once you
+ # enter a domain.
+ if cur_str == bad_str:
+ showerr(errors)
+ markred()
+ else:
+ if not getattr(self, 'domain_checked', None):
+ showerr('')
+ umarkrd()
+ else:
+ self.bad_string = cur_str
+ showerr(errors)
+
+ def cleanup_errormsg(self):
+ """
+ we reset bad_string to None
+ should be called before leaving the page
+ """
+ self.bad_string = None
+ self.domain_checked = False
+
+ def paintEvent(self, event):
+ """
+ we hook our populate errors
+ on paintEvent because we need it to catch
+ when user enters the page coming from next,
+ and initializePage does not cover that case.
+ Maybe there's a better event to hook upon.
+ """
+ super(SelectProviderPage, self).paintEvent(event)
+ self.populateErrors()
+
+ def initializePage(self):
+ self.validationMsg.setText('')
+ if hasattr(self, 'certinfoGroup'):
+ # XXX remove ?
+ self.certinfoGroup.hide()
+ self.is_done = False
+ self.providerCheckButton.setDisabled(True)
+ self.valFrame.hide()
+ self.steps.removeAllSteps()
+ self.clearTable()
+
+ def validatePage(self):
+ # some cleanup before we leave the page
+ self.cleanup_errormsg()
+
+ # go
+ return True
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ return wizard.get_page_index('providerinfo')
diff --git a/src/leap/gui/firstrun/providersetup.py b/src/leap/gui/firstrun/providersetup.py
new file mode 100644
index 00000000..1a362794
--- /dev/null
+++ b/src/leap/gui/firstrun/providersetup.py
@@ -0,0 +1,174 @@
+"""
+Provider Setup Validation Page,
+used if First Run Wizard
+"""
+import logging
+
+from PyQt4 import QtGui
+
+from leap.base import exceptions as baseexceptions
+from leap.gui.progress import ValidationPage
+
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class ProviderSetupValidationPage(ValidationPage):
+ def __init__(self, parent=None):
+ super(ProviderSetupValidationPage, self).__init__(parent)
+ self.current_page = "providersetupvalidation"
+
+ # XXX needed anymore?
+ is_signup = self.field("is_signup")
+ self.is_signup = is_signup
+
+ self.setTitle(self.tr("Provider setup"))
+ self.setSubTitle(
+ self.tr("Doing autoconfig."))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ def _do_checks(self):
+ """
+ generator that yields actual checks
+ that are executed in a separate thread
+ """
+
+ full_domain = self.field('provider_domain')
+ wizard = self.wizard()
+ pconfig = wizard.providerconfig
+
+ #pCertChecker = wizard.providercertchecker
+ #certchecker = pCertChecker(domain=full_domain)
+ pCertChecker = wizard.providercertchecker(
+ domain=full_domain)
+
+ yield(("head_sentinel", 0), lambda: None)
+
+ ########################
+ # 1) fetch ca cert
+ ########################
+
+ def fetchcacert():
+ if pconfig:
+ ca_cert_uri = pconfig.get('ca_cert_uri').geturl()
+ else:
+ ca_cert_uri = None
+
+ # XXX check scheme == "https"
+ # XXX passing verify == False because
+ # we have trusted right before.
+ # We should check it's the same domain!!!
+ # (Check with the trusted fingerprints dict
+ # or something smart)
+ try:
+ pCertChecker.download_ca_cert(
+ uri=ca_cert_uri,
+ verify=False)
+
+ except baseexceptions.LeapException as exc:
+ logger.error(exc.message)
+ # XXX this should be _ method
+ return self.fail(self.tr(exc.usermessage))
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ yield((self.tr('Fetching CA certificate'), 30),
+ fetchcacert)
+
+ #########################
+ # 2) check CA fingerprint
+ #########################
+
+ def checkcafingerprint():
+ # XXX get the real thing!!!
+ pass
+ #ca_cert_fingerprint = pconfig.get('ca_cert_fingerprint', None)
+
+ # XXX get fingerprint dict (types)
+ #sha256_fpr = ca_cert_fingerprint.split('=')[1]
+
+ #validate_fpr = pCertChecker.check_ca_cert_fingerprint(
+ #fingerprint=sha256_fpr)
+ #if not validate_fpr:
+ # XXX update validationMsg
+ # should catch exception
+ #return False
+
+ yield((self.tr("Checking CA fingerprint"), 60),
+ checkcafingerprint)
+
+ #########################
+ # 2) check CA fingerprint
+ #########################
+
+ def validatecacert():
+ pass
+ #api_uri = pconfig.get('api_uri', None)
+ #try:
+ #api_cert_verified = pCertChecker.verify_api_https(api_uri)
+ #except requests.exceptions.SSLError as exc:
+ #logger.error('BUG #638. %s' % exc.message)
+ # XXX RAISE! See #638
+ # bypassing until the hostname is fixed.
+ # We probably should raise yet-another-warning
+ # here saying user that the hostname "XX.XX.XX.XX' does not
+ # match 'foo.bar.baz'
+ #api_cert_verified = True
+
+ #if not api_cert_verified:
+ # XXX update validationMsg
+ # should catch exception
+ #return False
+
+ #???
+ #ca_cert_path = checker.ca_cert_path
+
+ yield((self.tr('Validating api certificate'), 90), validatecacert)
+
+ self.set_done()
+ yield(('end_sentinel', 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ called after _do_checks has finished
+ (connected to checker thread finished signal)
+ """
+ prevpage = "providerselection" if self.is_signup else "login"
+ wizard = self.wizard()
+
+ if self.errors:
+ logger.debug('going back with errors')
+ name, first_error = self.pop_first_error()
+ wizard.set_validation_error(
+ prevpage,
+ first_error)
+ # XXX don't go back, signal error
+ #self.go_back()
+ else:
+ logger.debug('should be going next, wait on user')
+ #self.go_next()
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ is_signup = self.field('is_signup')
+ if is_signup is True:
+ next_ = 'signup'
+ if is_signup is False:
+ # XXX bad name. change to connect again.
+ next_ = 'signupvalidation'
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ super(ProviderSetupValidationPage, self).initializePage()
+ self.set_undone()
+ self.completeChanged.emit()
diff --git a/src/leap/gui/firstrun/register.py b/src/leap/gui/firstrun/register.py
new file mode 100644
index 00000000..e85723cb
--- /dev/null
+++ b/src/leap/gui/firstrun/register.py
@@ -0,0 +1,368 @@
+"""
+Register User Page, used in First Run Wizard
+"""
+import json
+import logging
+import socket
+
+import requests
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.gui.firstrun.mixins import UserFormMixIn
+
+logger = logging.getLogger(__name__)
+
+from leap.base import auth
+from leap.gui import styles
+from leap.gui.constants import APP_LOGO, BARE_USERNAME_REGEX
+from leap.gui.progress import InlineValidationPage
+from leap.gui.styles import ErrorLabelStyleSheet
+
+
+class RegisterUserPage(InlineValidationPage, UserFormMixIn):
+
+ def __init__(self, parent=None):
+
+ super(RegisterUserPage, self).__init__(parent)
+ self.current_page = "signup"
+
+ self.setTitle(self.tr("Sign Up"))
+ # subtitle is set in the initializePage
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ # commit page means there's no way back after this...
+ # XXX should change the text on the "commit" button...
+ self.setCommitPage(True)
+
+ self.setupSteps()
+ self.setupUI()
+ self.do_confirm_next = False
+ self.focused_field = False
+
+ def setupUI(self):
+ userNameLabel = QtGui.QLabel("User &name:")
+ userNameLineEdit = QtGui.QLineEdit()
+ userNameLineEdit.cursorPositionChanged.connect(
+ self.reset_validation_status)
+ userNameLabel.setBuddy(userNameLineEdit)
+
+ # let's add regex validator
+ usernameRe = QtCore.QRegExp(BARE_USERNAME_REGEX)
+ userNameLineEdit.setValidator(
+ QtGui.QRegExpValidator(usernameRe, self))
+ self.userNameLineEdit = userNameLineEdit
+
+ userPasswordLabel = QtGui.QLabel("&Password:")
+ self.userPasswordLineEdit = QtGui.QLineEdit()
+ self.userPasswordLineEdit.setEchoMode(
+ QtGui.QLineEdit.Password)
+ userPasswordLabel.setBuddy(self.userPasswordLineEdit)
+
+ userPassword2Label = QtGui.QLabel("Password (again):")
+ self.userPassword2LineEdit = QtGui.QLineEdit()
+ self.userPassword2LineEdit.setEchoMode(
+ QtGui.QLineEdit.Password)
+ userPassword2Label.setBuddy(self.userPassword2LineEdit)
+
+ rememberPasswordCheckBox = QtGui.QCheckBox(
+ "&Remember username and password.")
+ rememberPasswordCheckBox.setChecked(True)
+
+ self.registerField('userName*', self.userNameLineEdit)
+ self.registerField('userPassword*', self.userPasswordLineEdit)
+ self.registerField('userPassword2*', self.userPassword2LineEdit)
+
+ # XXX missing password confirmation
+ # XXX validator!
+
+ self.registerField('rememberPassword', rememberPasswordCheckBox)
+
+ layout = QtGui.QGridLayout()
+ layout.setColumnMinimumWidth(0, 20)
+
+ validationMsg = QtGui.QLabel("")
+ validationMsg.setStyleSheet(ErrorLabelStyleSheet)
+
+ self.validationMsg = validationMsg
+
+ layout.addWidget(validationMsg, 0, 3)
+ layout.addWidget(userNameLabel, 1, 0)
+ layout.addWidget(self.userNameLineEdit, 1, 3)
+ layout.addWidget(userPasswordLabel, 2, 0)
+ layout.addWidget(userPassword2Label, 3, 0)
+ layout.addWidget(self.userPasswordLineEdit, 2, 3)
+ layout.addWidget(self.userPassword2LineEdit, 3, 3)
+ layout.addWidget(rememberPasswordCheckBox, 4, 3, 4, 4)
+
+ # add validation frame
+ self.setupValidationFrame()
+ layout.addWidget(self.valFrame, 5, 2, 5, 2)
+ self.valFrame.hide()
+
+ self.setLayout(layout)
+ self.commitText("Sign up!")
+
+ # commit button
+
+ def commitText(self, text):
+ # change "commit" button text
+ self.setButtonText(
+ QtGui.QWizard.CommitButton, text)
+
+ @property
+ def commitButton(self):
+ return self.wizard().button(QtGui.QWizard.CommitButton)
+
+ def commitFocus(self):
+ self.commitButton.setFocus()
+
+ def disableCommitButton(self):
+ self.commitButton.setDisabled(True)
+
+ def disableFields(self):
+ for field in (self.userNameLineEdit,
+ self.userPasswordLineEdit,
+ self.userPassword2LineEdit):
+ field.setDisabled(True)
+
+ # error painting
+
+ def markRedAndGetFocus(self, field):
+ field.setStyleSheet(styles.ErrorLineEdit)
+ if not self.focused_field:
+ self.focused_field = True
+ field.setFocus(QtCore.Qt.OtherFocusReason)
+
+ def markRegular(self, field):
+ field.setStyleSheet(styles.RegularLineEdit)
+
+ def populateErrors(self):
+ def showerr(text):
+ self.validationMsg.setText(text)
+ err_lower = text.lower()
+ if "username" in err_lower:
+ self.markRedAndGetFocus(
+ self.userNameLineEdit)
+ if "password" in err_lower:
+ self.markRedAndGetFocus(
+ self.userPasswordLineEdit)
+
+ def unmarkred():
+ for field in (self.userNameLineEdit,
+ self.userPasswordLineEdit,
+ self.userPassword2LineEdit):
+ self.markRegular(field)
+
+ errors = self.wizard().get_validation_error(
+ self.current_page)
+ if errors:
+ bad_str = getattr(self, 'bad_string', None)
+ cur_str = self.userNameLineEdit.text()
+ #prev_er = getattr(self, 'prevalidation_error', None)
+
+ if bad_str is None:
+ # first time we fall here.
+ # save the current bad_string value
+ self.bad_string = cur_str
+ showerr(errors)
+ else:
+ #if prev_er:
+ #showerr(prev_er)
+ #return
+ # not the first time
+ if cur_str == bad_str:
+ showerr(errors)
+ else:
+ self.focused_field = False
+ showerr('')
+ unmarkred()
+ else:
+ # no errors
+ self.focused_field = False
+ unmarkred()
+
+ def cleanup_errormsg(self):
+ """
+ we reset bad_string to None
+ should be called before leaving the page
+ """
+ self.bad_string = None
+
+ def paintEvent(self, event):
+ """
+ we hook our populate errors
+ on paintEvent because we need it to catch
+ when user enters the page coming from next,
+ and initializePage does not cover that case.
+ Maybe there's a better event to hook upon.
+ """
+ super(RegisterUserPage, self).paintEvent(event)
+ self.populateErrors()
+
+ def _do_checks(self):
+ """
+ generator that yields actual checks
+ that are executed in a separate thread
+ """
+ provider = self.field('provider_domain')
+ username = self.userNameLineEdit.text()
+ password = self.userPasswordLineEdit.text()
+ password2 = self.userPassword2LineEdit.text()
+
+ def checkpass():
+ # we better have here
+ # some call to a password checker...
+ # to assess strenght and avoid silly stuff.
+
+ if password != password2:
+ return self.fail(self.tr('Password does not match..'))
+
+ if len(password) < 6:
+ #self.set_prevalidation_error('Password too short.')
+ return self.fail(self.tr('Password too short.'))
+
+ if password == "123456":
+ # joking, but not too much.
+ #self.set_prevalidation_error('Password too obvious.')
+ return self.fail(self.tr('Password too obvious.'))
+
+ # go
+ return True
+
+ yield(("head_sentinel", 0), checkpass)
+
+ # XXX should emit signal for .show the frame!
+ # XXX HERE!
+
+ ##################################################
+ # 1) register user
+ ##################################################
+
+ # show the frame before going on...
+ QtCore.QMetaObject.invokeMethod(
+ self, "showStepsFrame")
+
+ def register():
+ # XXX FIXME!
+ verify = False
+
+ signup = auth.LeapSRPRegister(
+ schema="https",
+ provider=provider,
+ verify=verify)
+ try:
+ ok, req = signup.register_user(
+ username, password)
+
+ except socket.timeout:
+ return self.fail(
+ self.tr("Error connecting to provider (timeout)"))
+
+ except requests.exceptions.ConnectionError as exc:
+ logger.error(exc.message)
+ return self.fail(
+ self.tr('Error Connecting to provider (connerr).'))
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ # XXX check for != OK instead???
+
+ if req.status_code in (404, 500):
+ return self.fail(
+ self.tr(
+ "Error during registration (%s)") % req.status_code)
+
+ validation_msgs = json.loads(req.content)
+ errors = validation_msgs.get('errors', None)
+ logger.debug('validation errors: %s' % validation_msgs)
+
+ if errors and errors.get('login', None):
+ # XXX this sometimes catch the blank username
+ # but we're not allowing that (soon)
+ return self.fail(
+ self.tr('Username not available.'))
+
+ logger.debug('registering user')
+ yield(("registering with provider", 40), register)
+
+ self.set_done()
+ yield(("end_sentinel", 0), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ after checks
+ """
+ if self.is_done():
+ self.disableFields()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+ # make the user confirm the transition
+ # to next page.
+ self.commitText('Connect!')
+ self.commitFocus()
+ self.green_validation_status()
+ self.do_confirm_next = True
+
+ def green_validation_status(self):
+ val = self.validationMsg
+ val.setText(self.tr('Registration succeeded!'))
+ val.setStyleSheet(styles.GreenLineEdit)
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ and clean the inline validation widget.
+ """
+ self.validationMsg.setText('')
+ self.steps.removeAllSteps()
+ self.clearTable()
+
+ # pagewizard methods
+
+ def validatePage(self):
+ """
+ if not register done, do checks.
+ if done, wait for click.
+ """
+ self.disableCommitButton()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+
+ # After a successful validation
+ # (ie, success register with server)
+ # we change the commit button text
+ # and set this flag to True.
+ if self.do_confirm_next:
+ return True
+
+ if not self.is_done():
+ # calls checks, which after successful
+ # execution will call on_checks_validation_ready
+ self.reset_validation_status()
+ self.do_checks()
+
+ return self.is_done()
+
+ def initializePage(self):
+ """
+ inits wizard page
+ """
+ provider = self.field('provider_domain')
+ self.setSubTitle(
+ self.tr("Register a new user with provider %s.") %
+ provider)
+ self.validationMsg.setText('')
+ self.userPassword2LineEdit.setText('')
+ self.valFrame.hide()
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ # XXX this should be called connect
+ return wizard.get_page_index('signupvalidation')
diff --git a/src/leap/gui/firstrun/regvalidation.py b/src/leap/gui/firstrun/regvalidation.py
new file mode 100644
index 00000000..0e67834b
--- /dev/null
+++ b/src/leap/gui/firstrun/regvalidation.py
@@ -0,0 +1,204 @@
+"""
+Provider Setup Validation Page,
+used in First Run Wizard
+"""
+# XXX This page is called regvalidation
+# but it's implementing functionality in the former
+# connect page.
+# We should remame it to connect again, when we integrate
+# the login branch of the wizard.
+
+import logging
+#import json
+#import socket
+
+from PyQt4 import QtGui
+
+#import requests
+
+from leap.gui.progress import ValidationPage
+from leap.util.web import get_https_domain_and_port
+
+from leap.base import auth
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class RegisterUserValidationPage(ValidationPage):
+
+ def __init__(self, parent=None):
+ super(RegisterUserValidationPage, self).__init__(parent)
+
+ title = "Connecting..."
+ # XXX uh... really?
+ subtitle = "Checking connection with provider."
+
+ self.setTitle(title)
+ self.setSubTitle(subtitle)
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ def _do_checks(self, update_signal=None):
+ """
+ executes actual checks in a separate thread
+
+ we initialize the srp protocol register
+ and try to register user.
+ """
+ wizard = self.wizard()
+ full_domain = self.field('provider_domain')
+ domain, port = get_https_domain_and_port(full_domain)
+ _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain)
+
+ # FIXME #BUG 638 FIXME FIXME FIXME
+ verify = False # !!!!!!!!!!!!!!!!
+ # FIXME #BUG 638 FIXME FIXME FIXME
+
+ ###########################################
+ # Set Credentials.
+ # username and password are in different fields
+ # if they were stored in log_in or sign_up pages.
+ is_signup = self.field("is_signup")
+
+ unamek_base = 'userName'
+ passwk_base = 'userPassword'
+ unamek = 'login_%s' % unamek_base if not is_signup else unamek_base
+ passwk = 'login_%s' % passwk_base if not is_signup else passwk_base
+
+ username = self.field(unamek)
+ password = self.field(passwk)
+ credentials = username, password
+
+ eipconfigchecker = wizard.eipconfigchecker(domain=_domain)
+ #XXX change for _domain (sanitized)
+ pCertChecker = wizard.providercertchecker(
+ domain=full_domain)
+
+ yield(("head_sentinel", 0), lambda: None)
+
+ ##################################################
+ # 1) fetching eip service config
+ ##################################################
+ def fetcheipconf():
+ try:
+ eipconfigchecker.fetch_eip_service_config(
+ domain=full_domain)
+
+ # XXX get specific exception
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ yield((self.tr("Fetching provider config..."), 40),
+ fetcheipconf)
+
+ ##################################################
+ # 2) getting client certificate
+ ##################################################
+
+ def fetcheipcert():
+ try:
+ pCertChecker.download_new_client_cert(
+ credentials=credentials,
+ verify=verify)
+
+ except auth.SRPAuthenticationError as exc:
+ return self.fail(self.tr(
+ "Authentication error: %s" % exc.message))
+ else:
+ return True
+
+ yield((self.tr("Fetching eip certificate"), 80),
+ fetcheipcert)
+
+ ################
+ # end !
+ ################
+ self.set_done()
+ yield(("end_sentinel", 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ called after _do_checks has finished
+ (connected to checker thread finished signal)
+ """
+ # this should be called CONNECT PAGE AGAIN.
+ # here we go! :)
+ full_domain = self.field('provider_domain')
+ domain, port = get_https_domain_and_port(full_domain)
+ _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain)
+ self.run_eip_checks_for_provider_and_connect(_domain)
+
+ def run_eip_checks_for_provider_and_connect(self, domain):
+ wizard = self.wizard()
+ conductor = wizard.conductor
+ start_eip_signal = getattr(
+ wizard,
+ 'start_eipconnection_signal', None)
+
+ if conductor:
+ conductor.set_provider_domain(domain)
+ conductor.run_checks()
+ self.conductor = conductor
+ errors = self.eip_error_check()
+ if not errors and start_eip_signal:
+ start_eip_signal.emit()
+
+ else:
+ logger.warning(
+ "No conductor found. This means that "
+ "probably the wizard has been launched "
+ "in an stand-alone way.")
+
+ # XXX look for a better place to signal
+ # we are done.
+ # We could probably have a fake validatePage
+ # that checks if the domain transfer has been
+ # done to conductor object, triggers the start_signal
+ # and does the go_next()
+ self.set_done()
+
+ def eip_error_check(self):
+ """
+ a version of the main app error checker,
+ but integrated within the connecting page of the wizard.
+ consumes the conductor error queue.
+ pops errors, and add those to the wizard page
+ """
+ logger.debug('eip error check from connecting page')
+ errq = self.conductor.error_queue
+ # XXX missing!
+
+ def _do_validation(self):
+ """
+ called after _do_checks has finished
+ (connected to checker thread finished signal)
+ """
+ is_signup = self.field("is_signup")
+ prevpage = "signup" if is_signup else "login"
+
+ wizard = self.wizard()
+ if self.errors:
+ logger.debug('going back with errors')
+ logger.error(self.errors)
+ name, first_error = self.pop_first_error()
+ wizard.set_validation_error(
+ prevpage,
+ first_error)
+ self.go_back()
+ else:
+ logger.debug('should go next, wait for user to click next')
+ #self.go_next()
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ return wizard.get_page_index('lastpage')
+
+ def initializePage(self):
+ super(RegisterUserValidationPage, self).initializePage()
+ self.set_undone()
+ self.completeChanged.emit()
diff --git a/src/leap/gui/firstrun/tests/integration/fake_provider.py b/src/leap/gui/firstrun/tests/integration/fake_provider.py
new file mode 100755
index 00000000..33ee0ee6
--- /dev/null
+++ b/src/leap/gui/firstrun/tests/integration/fake_provider.py
@@ -0,0 +1,295 @@
+#!/usr/bin/env python
+"""A server faking some of the provider resources and apis,
+used for testing Leap Client requests
+
+It needs that you create a subfolder named 'certs',
+and that you place the following files:
+
+[ ] certs/leaptestscert.pem
+[ ] certs/leaptestskey.pem
+[ ] certs/cacert.pem
+[ ] certs/openvpn.pem
+
+[ ] provider.json
+[ ] eip-service.json
+"""
+# XXX NOTE: intended for manual debug.
+# I intend to include this as a regular test after 0.2.0 release
+# (so we can add twisted as a dep there)
+import binascii
+import json
+import os
+import sys
+
+# python SRP LIB (! important MUST be >=1.0.1 !)
+import srp
+
+# GnuTLS Example -- is not working as expected
+from gnutls import crypto
+from gnutls.constants import COMP_LZO, COMP_DEFLATE, COMP_NULL
+from gnutls.interfaces.twisted import X509Credentials
+
+# Going with OpenSSL as a workaround instead
+# But we DO NOT want to introduce this dependency.
+from OpenSSL import SSL
+
+from zope.interface import Interface, Attribute, implements
+
+from twisted.web.server import Site
+from twisted.web.static import File
+from twisted.web.resource import Resource
+from twisted.internet import reactor
+
+# See
+# http://twistedmatrix.com/documents/current/web/howto/web-in-60/index.htmln
+# for more examples
+
+"""
+Testing the FAKE_API:
+#####################
+
+ 1) register an user
+ >> curl -d "user[login]=me" -d "user[password_salt]=foo" \
+ -d "user[password_verifier]=beef" http://localhost:8000/1/users.json
+ << {"errors": null}
+
+ 2) check that if you try to register again, it will fail:
+ >> curl -d "user[login]=me" -d "user[password_salt]=foo" \
+ -d "user[password_verifier]=beef" http://localhost:8000/1/users.json
+ << {"errors": {"login": "already taken!"}}
+
+"""
+
+# Globals to mock user/sessiondb
+
+USERDB = {}
+SESSIONDB = {}
+
+
+safe_unhexlify = lambda x: binascii.unhexlify(x) \
+ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
+
+
+class IUser(Interface):
+ login = Attribute("User login.")
+ salt = Attribute("Password salt.")
+ verifier = Attribute("Password verifier.")
+ session = Attribute("Session.")
+ svr = Attribute("Server verifier.")
+
+
+class User(object):
+ implements(IUser)
+
+ def __init__(self, login, salt, verifier):
+ self.login = login
+ self.salt = salt
+ self.verifier = verifier
+ self.session = None
+
+ def set_server_verifier(self, svr):
+ self.svr = svr
+
+ def set_session(self, session):
+ SESSIONDB[session] = self
+ self.session = session
+
+
+class FakeUsers(Resource):
+ def __init__(self, name):
+ self.name = name
+
+ def render_POST(self, request):
+ args = request.args
+
+ login = args['user[login]'][0]
+ salt = args['user[password_salt]'][0]
+ verifier = args['user[password_verifier]'][0]
+
+ if login in USERDB:
+ return "%s\n" % json.dumps(
+ {'errors': {'login': 'already taken!'}})
+
+ print login, verifier, salt
+ user = User(login, salt, verifier)
+ USERDB[login] = user
+ return json.dumps({'errors': None})
+
+
+def get_user(request):
+ login = request.args.get('login')
+ if login:
+ user = USERDB.get(login[0], None)
+ if user:
+ return user
+
+ session = request.getSession()
+ user = SESSIONDB.get(session, None)
+ return user
+
+
+class FakeSession(Resource):
+ def __init__(self, name):
+ self.name = name
+
+ def render_GET(self, request):
+ return "%s\n" % json.dumps({'errors': None})
+
+ def render_POST(self, request):
+
+ user = get_user(request)
+
+ if not user:
+ # XXX get real error from demo provider
+ return json.dumps({'errors': 'no such user'})
+
+ A = request.args['A'][0]
+
+ _A = safe_unhexlify(A)
+ _salt = safe_unhexlify(user.salt)
+ _verifier = safe_unhexlify(user.verifier)
+
+ svr = srp.Verifier(
+ user.login,
+ _salt,
+ _verifier,
+ _A,
+ hash_alg=srp.SHA256,
+ ng_type=srp.NG_1024)
+
+ s, B = svr.get_challenge()
+
+ _B = binascii.hexlify(B)
+
+ print 'login = %s' % user.login
+ print 'salt = %s' % user.salt
+ print 'len(_salt) = %s' % len(_salt)
+ print 'vkey = %s' % user.verifier
+ print 'len(vkey) = %s' % len(_verifier)
+ print 's = %s' % binascii.hexlify(s)
+ print 'B = %s' % _B
+ print 'len(B) = %s' % len(_B)
+
+ session = request.getSession()
+ user.set_session(session)
+ user.set_server_verifier(svr)
+
+ # yep, this is tricky.
+ # some things are *already* unhexlified.
+ data = {
+ 'salt': user.salt,
+ 'B': _B,
+ 'errors': None}
+
+ return json.dumps(data)
+
+ def render_PUT(self, request):
+
+ # XXX check session???
+ user = get_user(request)
+
+ if not user:
+ print 'NO USER'
+ return json.dumps({'errors': 'no such user'})
+
+ data = request.content.read()
+ auth = data.split("client_auth=")
+ M = auth[1] if len(auth) > 1 else None
+ # if not H, return
+ if not M:
+ return json.dumps({'errors': 'no M proof passed by client'})
+
+ svr = user.svr
+ HAMK = svr.verify_session(binascii.unhexlify(M))
+ if HAMK is None:
+ print 'verification failed!!!'
+ raise Exception("Authentication failed!")
+ #import ipdb;ipdb.set_trace()
+
+ assert svr.authenticated()
+ print "***"
+ print 'server authenticated user SRP!'
+ print "***"
+
+ return json.dumps(
+ {'M2': binascii.hexlify(HAMK), 'errors': None})
+
+
+class API_Sessions(Resource):
+ def getChild(self, name, request):
+ return FakeSession(name)
+
+
+def get_certs_path():
+ script_path = os.path.realpath(os.path.dirname(sys.argv[0]))
+ certs_path = os.path.join(script_path, 'certs')
+ return certs_path
+
+
+def get_TLS_credentials():
+ # XXX this is giving errors
+ # XXX REview! We want to use gnutls!
+ certs_path = get_certs_path()
+
+ cert = crypto.X509Certificate(
+ open(certs_path + '/leaptestscert.pem').read())
+ key = crypto.X509PrivateKey(
+ open(certs_path + '/leaptestskey.pem').read())
+ ca = crypto.X509Certificate(
+ open(certs_path + '/cacert.pem').read())
+ #crl = crypto.X509CRL(open(certs_path + '/crl.pem').read())
+ #cred = crypto.X509Credentials(cert, key, [ca], [crl])
+ cred = X509Credentials(cert, key, [ca])
+ cred.verify_peer = True
+ cred.session_params.compressions = (COMP_LZO, COMP_DEFLATE, COMP_NULL)
+ return cred
+
+
+class OpenSSLServerContextFactory:
+ # XXX workaround for broken TLS interface
+ # from gnuTLS.
+
+ def getContext(self):
+ """Create an SSL context.
+ This is a sample implementation that loads a certificate from a file
+ called 'server.pem'."""
+ certs_path = get_certs_path()
+
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ ctx.use_certificate_file(certs_path + '/leaptestscert.pem')
+ ctx.use_privatekey_file(certs_path + '/leaptestskey.pem')
+ return ctx
+
+
+if __name__ == "__main__":
+
+ from twisted.python import log
+ log.startLogging(sys.stdout)
+
+ root = Resource()
+ root.putChild("provider.json", File("./provider.json"))
+ config = Resource()
+ config.putChild(
+ "eip-service.json",
+ File("./eip-service.json"))
+ apiv1 = Resource()
+ apiv1.putChild("config", config)
+ apiv1.putChild("sessions.json", API_Sessions())
+ apiv1.putChild("users.json", FakeUsers(None))
+ apiv1.putChild("cert", File(get_certs_path() + '/openvpn.pem'))
+ root.putChild("1", apiv1)
+
+ cred = get_TLS_credentials()
+
+ factory = Site(root)
+
+ # regular http (for debugging with curl)
+ reactor.listenTCP(8000, factory)
+
+ # TLS with gnutls --- seems broken :(
+ #reactor.listenTLS(8003, factory, cred)
+
+ # OpenSSL
+ reactor.listenSSL(8443, factory, OpenSSLServerContextFactory())
+
+ reactor.run()
diff --git a/src/leap/gui/firstrun/wizard.py b/src/leap/gui/firstrun/wizard.py
new file mode 100755
index 00000000..9b77b877
--- /dev/null
+++ b/src/leap/gui/firstrun/wizard.py
@@ -0,0 +1,286 @@
+#!/usr/bin/env python
+import logging
+
+import sip
+sip.setapi('QString', 2)
+sip.setapi('QVariant', 2)
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.base import checks as basechecks
+from leap.crypto import leapkeyring
+from leap.eip import checks as eipchecks
+
+from leap.gui import firstrun
+
+from leap.gui import mainwindow_rc
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+logger = logging.getLogger(__name__)
+
+"""
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+Work in progress!
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+This wizard still needs to be refactored out.
+
+TODO-ish:
+
+[X] Break file in wizard / pages files (and its own folder).
+[ ] Separate presentation from logic.
+[ ] Have a "manager" class for connections, that can be
+ dep-injected for testing.
+[ ] Document signals used / expected.
+[ ] Separate style from widgets.
+[ ] Fix TOFU Widget for provider cert.
+[X] Refactor widgets out.
+[ ] Follow more MVC style.
+[ ] Maybe separate "first run wizard" into different wizards
+ that share some of the pages?
+"""
+
+
+class FirstRunWizard(QtGui.QWizard):
+
+ def __init__(
+ self,
+ conductor_instance,
+ parent=None,
+ eip_username=None,
+ providers=None,
+ success_cb=None, is_provider_setup=False,
+ trusted_certs=None,
+ netchecker=basechecks.LeapNetworkChecker,
+ providercertchecker=eipchecks.ProviderCertChecker,
+ eipconfigchecker=eipchecks.EIPConfigChecker,
+ start_eipconnection_signal=None,
+ eip_statuschange_signal=None,
+ debug_server=None,
+ quitcallback=None):
+ super(FirstRunWizard, self).__init__(
+ parent,
+ QtCore.Qt.WindowStaysOnTopHint)
+
+ # we keep a reference to the conductor
+ # to be able to launch eip checks and connection
+ # in the connection page, before the wizard has ended.
+ self.conductor = conductor_instance
+
+ self.eip_username = eip_username
+ self.providers = providers
+
+ # success callback
+ self.success_cb = success_cb
+
+ # is provider setup?
+ self.is_provider_setup = is_provider_setup
+
+ # a dict with trusted fingerprints
+ # in the form {'nospacesfingerprint': ['host1', 'host2']}
+ self.trusted_certs = trusted_certs
+
+ # Checkers
+ self.netchecker = netchecker
+ self.providercertchecker = providercertchecker
+ self.eipconfigchecker = eipconfigchecker
+
+ # debug server
+ self.debug_server = debug_server
+
+ # Signals
+ # will be emitted in connecting page
+ self.start_eipconnection_signal = start_eipconnection_signal
+ self.eip_statuschange_signal = eip_statuschange_signal
+
+ if quitcallback is not None:
+ self.button(
+ QtGui.QWizard.CancelButton).clicked.connect(
+ quitcallback)
+
+ self.providerconfig = None
+ # previously registered
+ # if True, jumps to LogIn page.
+ # by setting 1st page??
+ #self.is_previously_registered = is_previously_registered
+ # XXX ??? ^v
+ self.is_previously_registered = bool(self.eip_username)
+ self.from_login = False
+
+ pages_dict = OrderedDict((
+ ('intro', firstrun.intro.IntroPage),
+ ('providerselection',
+ firstrun.providerselect.SelectProviderPage),
+ ('login', firstrun.login.LogInPage),
+ ('providerinfo', firstrun.providerinfo.ProviderInfoPage),
+ ('providersetupvalidation',
+ firstrun.providersetup.ProviderSetupValidationPage),
+ ('signup', firstrun.register.RegisterUserPage),
+ ('signupvalidation',
+ firstrun.regvalidation.RegisterUserValidationPage),
+ ('connecting', firstrun.connect.ConnectingPage),
+ ('lastpage', firstrun.last.LastPage)
+ ))
+ self.add_pages_from_dict(pages_dict)
+
+ self.validation_errors = {}
+
+ self.setPixmap(
+ QtGui.QWizard.BannerPixmap,
+ QtGui.QPixmap(':/images/banner.png'))
+ self.setPixmap(
+ QtGui.QWizard.BackgroundPixmap,
+ QtGui.QPixmap(':/images/background.png'))
+
+ # set options
+ self.setOption(QtGui.QWizard.IndependentPages, on=False)
+ self.setOption(QtGui.QWizard.NoBackButtonOnStartPage, on=True)
+
+ self.setWindowTitle("First Run Wizard")
+
+ # TODO: set style for MAC / windows ...
+ #self.setWizardStyle()
+
+ def add_pages_from_dict(self, pages_dict):
+ """
+ @param pages_dict: the dictionary with pages, where
+ values are a tuple of InstanceofWizardPage, kwargs.
+ @type pages_dict: dict
+ """
+ for name, page in pages_dict.items():
+ # XXX check for is_previously registered
+ # and skip adding the signup branch if so
+ self.addPage(page())
+ self.pages_dict = pages_dict
+
+ def get_page_index(self, page_name):
+ """
+ returns the index of the given page
+ @param page_name: the name of the desired page
+ @type page_name: str
+ @rparam: index of page in wizard
+ @rtype: int
+ """
+ return self.pages_dict.keys().index(page_name)
+
+ def set_validation_error(self, pagename, error):
+ self.validation_errors[pagename] = error
+
+ def clean_validation_error(self, pagename):
+ vald = self.validation_errors
+ if pagename in vald:
+ del vald[pagename]
+
+ def get_validation_error(self, pagename):
+ return self.validation_errors.get(pagename, None)
+
+ def set_providerconfig(self, providerconfig):
+ self.providerconfig = providerconfig
+
+ def setWindowFlags(self, flags):
+ logger.debug('setting window flags')
+ QtGui.QWizard.setWindowFlags(self, flags)
+
+ def focusOutEvent(self, event):
+ # needed ?
+ self.setFocus(True)
+ self.activateWindow()
+ self.raise_()
+ self.show()
+
+ def accept(self):
+ """
+ final step in the wizard.
+ gather the info, update settings
+ and call the success callback if any has been passed.
+ """
+ super(FirstRunWizard, self).accept()
+
+ # username and password are in different fields
+ # if they were stored in log_in or sign_up pages.
+ from_login = self.from_login
+ unamek_base = 'userName'
+ passwk_base = 'userPassword'
+ unamek = 'login_%s' % unamek_base if from_login else unamek_base
+ passwk = 'login_%s' % passwk_base if from_login else passwk_base
+
+ username = self.field(unamek)
+ password = self.field(passwk)
+ provider = self.field('provider_domain')
+ remember_pass = self.field('rememberPassword')
+
+ logger.debug('chosen provider: %s', provider)
+ logger.debug('username: %s', username)
+ logger.debug('remember password: %s', remember_pass)
+
+ # we are assuming here that we only remember one username
+ # in the form username@provider.domain
+ # We probably could extend this to support some form of
+ # profiles.
+
+ settings = QtCore.QSettings()
+
+ settings.setValue("FirstRunWizardDone", True)
+ settings.setValue("provider_domain", provider)
+ full_username = "%s@%s" % (username, provider)
+
+ settings.setValue("remember_user_and_pass", remember_pass)
+
+ if remember_pass:
+ settings.setValue("eip_username", full_username)
+ seed = self.get_random_str(10)
+ settings.setValue("%s_seed" % provider, seed)
+
+ # XXX #744: comment out for 0.2.0 release
+ # if we need to have a version of python-keyring < 0.9
+ leapkeyring.leap_set_password(
+ full_username, password, seed=seed)
+
+ logger.debug('First Run Wizard Done.')
+ cb = self.success_cb
+ if cb and callable(cb):
+ self.success_cb()
+
+ def get_provider_by_index(self):
+ provider = self.field('provider_index')
+ return self.providers[provider]
+
+ def get_random_str(self, n):
+ from string import (ascii_uppercase, ascii_lowercase, digits)
+ from random import choice
+ return ''.join(choice(
+ ascii_uppercase +
+ ascii_lowercase +
+ digits) for x in range(n))
+
+
+if __name__ == '__main__':
+ # standalone test
+ # it can be (somehow) run against
+ # gui/tests/integration/fake_user_signup.py
+
+ import sys
+ import logging
+ logging.basicConfig()
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG)
+
+ app = QtGui.QApplication(sys.argv)
+ server = sys.argv[1] if len(sys.argv) > 1 else None
+
+ trusted_certs = {
+ "3DF83F316BFA0186"
+ "0A11A5C9C7FC24B9"
+ "18C62B941192CC1A"
+ "49AE62218B2A4B7C": ['springbok']}
+
+ wizard = FirstRunWizard(
+ None, trusted_certs=trusted_certs,
+ debug_server=server)
+ wizard.show()
+ sys.exit(app.exec_())
diff --git a/src/leap/gui/mainwindow_rc.py b/src/leap/gui/mainwindow_rc.py
index e5a671f3..5bee35c7 100644
--- a/src/leap/gui/mainwindow_rc.py
+++ b/src/leap/gui/mainwindow_rc.py
@@ -2,7 +2,7 @@
# Resource object code
#
-# Created: Sun Jul 22 17:08:49 2012
+# Created: Wed Nov 21 04:25:36 2012
# by: The Resource Compiler for PyQt (Qt v4.8.2)
#
# WARNING! All changes made in this file will be lost!
@@ -236,6 +236,87 @@ qt_resource_data = "\
\x71\xa4\x40\xda\x14\x7a\xd1\x73\x1f\xf4\x7f\xb7\xf9\x1f\xc2\x26\
\x56\xd5\x70\x45\xfc\x8a\x00\x00\x00\x00\x49\x45\x4e\x44\xae\x42\
\x60\x82\
+\x00\x00\x04\xec\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x06\xec\x00\x00\x06\xec\
+\x01\x1e\x75\x38\x35\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x00\x13\x74\x45\
+\x58\x74\x41\x75\x74\x68\x6f\x72\x00\x52\x6f\x64\x6e\x65\x79\x20\
+\x44\x61\x77\x65\x73\x0e\xd8\x7e\x1d\x00\x00\x04\x4a\x49\x44\x41\
+\x54\x48\x89\x8d\x96\x5d\x6c\x53\x65\x18\xc7\x7f\xef\x39\x6b\xbb\
+\x7e\x9c\x75\x65\xad\x2b\x9b\xfb\xd0\x31\xdd\x14\xb6\x8c\x19\x44\
+\x90\x44\x63\x82\x42\x88\x5e\x90\x98\xcc\x19\x15\x13\xd4\x18\x76\
+\x61\xd4\x18\xe3\x85\x57\xca\x05\xe1\xc2\x0c\xa3\xa8\x51\xd0\x4c\
+\x12\xe3\x85\x31\x80\x26\x6a\xe2\x85\x23\xb0\x38\xb6\xc1\x1c\xce\
+\xb1\x40\x59\xf6\xe5\xca\xda\xae\xed\xfa\x75\x7a\x5e\x2f\x4e\xd7\
+\x59\xd6\x32\xfe\xc9\x7b\xf3\x9e\xe7\xf9\xff\x9f\xe7\xff\x9e\xf3\
+\x9c\x57\x48\x29\x59\x0f\xbd\x7b\x85\x0d\x17\xed\x1e\xbb\xb2\x07\
+\x20\x94\x30\x7e\x22\xc6\x48\xcf\x59\x99\x5a\x2f\x57\x94\x12\xf8\
+\xec\x55\x61\x71\x65\x6d\x47\xfc\xbe\xda\x47\x9d\x5a\xa5\xbf\xda\
+\x69\xaf\xda\xe0\x28\x2f\x07\x58\x5c\x4e\x26\xe7\xe3\x89\x9b\xf1\
+\x68\x78\x6e\x6e\x61\xfa\x8f\x98\x9a\x7a\xfb\x95\xe3\x32\x73\xc7\
+\x02\x9f\x76\x89\x8e\xba\xda\xda\x2f\xb7\x37\xdf\xdf\xe6\x2a\x13\
+\x8a\x94\x06\x82\xc2\x38\x89\x40\x08\x85\x98\x2e\x8d\xf3\x13\xe3\
+\x97\xa6\xa6\xa7\x5f\x7e\xed\x94\x1c\x5a\x57\xa0\xef\xa0\xfd\x70\
+\x5b\xf3\x96\x03\xcd\xde\x8a\x6a\x61\x64\xd7\x73\xc0\x14\x53\x54\
+\x26\x82\x4b\xf3\x97\x26\x2e\x7f\xd5\xfd\x79\xe2\xdd\x92\x02\x27\
+\x5f\x2a\x7b\xe1\x89\xce\x1d\xc7\xbc\x76\x55\x13\xc5\x98\xac\x4e\
+\x10\x0a\xa4\xa2\x6b\x45\x80\x60\x22\x1b\xfd\x6d\xf0\xdc\xa1\x17\
+\x4f\xe8\x5f\xaf\x11\x38\xfa\x9c\xf0\x6e\xdb\xf4\xc0\xf9\x6d\xf5\
+\xfe\x26\x30\xf2\x89\xca\xc6\x76\xd4\x07\xf7\xa3\xd4\x74\x80\xd5\
+\x65\x6e\xa6\xe3\x64\x03\xfd\x64\x2f\x9e\x40\x46\x67\xff\x27\xa3\
+\x30\x70\x63\x6e\x72\xe0\xea\xd8\xf6\x37\xbf\x95\x41\x73\x27\x87\
+\x06\x8f\xa7\x6f\x6b\x7d\x4d\x01\x39\x80\x52\xff\x08\x4a\xe3\xae\
+\x55\xf2\x5c\x27\x6a\xf3\x6e\x2c\x7b\x8f\x9a\x5d\xe5\x61\xb0\xb5\
+\xbe\xa6\xa9\xc1\xe3\xe9\x5b\x95\x04\x7a\xbb\x44\x47\x5b\x53\xcb\
+\x4e\x15\xbd\x98\x31\xc8\x70\x00\xfd\xfc\xc7\x64\xce\xbc\x81\x7e\
+\xe1\x13\xc8\x75\x2d\xb4\x8d\x28\xb5\x0f\x15\xc4\xaa\xe8\xb4\x35\
+\xb5\xec\xec\xed\x12\x1d\x00\x65\x00\xee\x72\x65\x9f\x5f\x73\x38\
+\x05\x6b\x0f\x35\x3b\xf6\x03\xfa\xc0\xf1\x3c\x29\xb3\xc3\xa8\xf7\
+\x3e\x8e\xf0\xb5\x98\x22\xf6\x0d\x05\xf1\x02\xf0\x6b\x0e\xa7\xbb\
+\x5c\xd9\x07\x0c\x29\x00\x9a\xc3\xd5\x69\x55\xd5\xe2\xd5\x47\xe7\
+\x56\xc9\x01\xe1\xbe\x1b\xe1\xb9\x67\xf5\x79\x70\x7c\x4d\x8e\x55\
+\x55\xd1\x1c\xae\xce\xbc\x45\x15\x6e\x5f\x9d\x90\xc5\xed\x29\xa8\
+\xae\xa2\x06\xcb\x53\x47\xa0\xcc\x66\x76\x37\xfa\x3d\xd9\xa9\x81\
+\xb5\x71\x52\xa7\xc2\xed\xab\x83\x9c\x45\x76\xbb\x56\x25\xa5\xa4\
+\xe8\xab\xb9\x02\x9b\x86\x65\xf7\x87\x08\xcd\x6f\x92\x8f\x9f\x21\
+\xf5\xdd\xf3\xa0\xa7\x10\xe5\x6e\x44\x45\x2d\x38\x7d\x08\x21\x90\
+\xd2\xe4\xcc\x0b\x24\x12\xd1\x9b\x42\xbd\xab\x81\x6c\xba\x28\xb7\
+\x94\x06\x65\xcd\x4f\x22\x2a\x1b\x00\x30\xa6\xff\x24\xd5\xb7\x1f\
+\x74\x73\x14\xc9\x64\x04\x99\x8c\x80\xc5\x8e\xe2\xae\x03\xab\x93\
+\x44\x22\x7a\x33\x6f\xd1\x52\x64\x61\x0a\xb5\xbc\x28\xb1\xb1\x34\
+\x83\x91\xb3\xc1\x98\x1d\xc1\x98\x1d\x41\x3f\xd7\x9b\x27\x2f\x40\
+\x26\x81\x11\xfc\x07\x99\x8a\x99\x9c\x2b\x1d\x44\x97\x63\x83\xc9\
+\xe8\xfc\x33\x36\x23\x05\xaa\x05\xd2\xcb\xc8\x74\xcc\xfc\x88\x72\
+\x5d\xa5\x7f\x3c\x74\x3b\x03\x0b\x90\x52\xed\x44\x97\x63\x83\x79\
+\x81\x48\xd2\x38\x3d\x1b\xcf\xbc\x53\x1f\xb9\xe4\x44\x1a\x45\x93\
+\xac\xcf\x7e\x83\xda\xb8\xcb\x2c\xf4\xd7\xf7\xd1\x2f\x9e\x2c\xce\
+\x2e\x14\xe6\xd2\x65\xf1\x48\xd2\x38\x0d\x39\x8b\x7a\x4e\xc9\xa1\
+\xd1\xc0\xb5\xfe\xac\xb7\xb5\x64\x55\xc2\xe5\x47\x54\x36\x98\xe7\
+\x60\xd3\x4a\xc6\x65\xbd\xad\x8c\x06\xae\xf5\xf7\xe4\x26\x6b\x7e\
+\x54\x04\x42\xa1\xee\xe1\x90\x31\x29\x1c\xde\xd2\xbd\xaf\x03\xe1\
+\xf0\x32\x1c\x32\x26\x03\xa1\x50\x77\x7e\xef\xd6\x69\xfa\x58\x7b\
+\xe7\x31\x5f\x78\x54\x23\xb3\x5c\x90\xac\xf8\xdb\x10\x0e\xf3\xab\
+\x35\x82\x13\xc8\xa5\xe9\x42\x76\x8b\x83\x85\xca\xcd\xd1\xdf\x47\
+\x06\x8b\x4f\xd3\x15\xf4\x1d\xb4\x1f\xde\xd2\xd4\x7a\x60\x93\x1a\
+\xaa\x26\x74\xfd\xce\x4a\xf7\x34\x72\x35\xeb\x99\xbf\x3c\x79\xe5\
+\xf6\xff\x83\x15\x7c\xf0\xb4\xd8\xbe\xb9\xa9\xe6\x8b\x1d\x0d\xd5\
+\xad\xae\xd8\x94\x22\x13\x21\x90\xb7\xcc\x29\xa1\x22\xec\x1e\x62\
+\xae\x3a\xa3\xff\xfa\xfc\xdf\xe7\xc6\x66\x5e\x3f\xf2\x0b\xfd\x52\
+\x16\x8e\x84\x02\x01\x21\x84\x0a\x54\x01\x95\x9a\x1d\xdf\x7b\x7b\
+\xac\x6f\xdd\x57\xb7\xb1\x6d\x83\xbb\xd2\x53\xe3\x10\x2e\x9f\xcd\
+\xb0\x00\xfc\x9b\x54\xf4\x99\x84\x8c\x2d\x86\xc3\xe1\x2b\x81\xd9\
+\xbf\x0e\xff\x9c\xfe\x28\x9e\x22\x08\x84\x80\xb0\x94\x32\x5c\xb2\
+\x03\x21\x84\x13\xf0\x00\xee\xdc\xd2\x5c\x56\x3c\x5b\xeb\x69\x79\
+\xb8\x51\x74\x18\x12\xe5\xc2\x75\x39\x3c\x74\x83\xc9\x78\x86\x10\
+\x10\x03\x96\x80\x48\x6e\x2d\x4a\xb9\x7a\x01\x28\x79\xab\xc8\x89\
+\x59\x00\x2b\x60\xcb\x2d\x0b\xa0\x02\x3a\x90\x02\xd2\x40\x12\xc8\
+\x48\x79\xab\x87\x26\xfe\x03\x26\x93\xd5\x41\x51\x76\x98\xdb\x00\
+\x00\x00\x00\x49\x45\x4e\x44\xae\x42\x60\x82\
\x00\x00\x0b\xd7\
\x89\
\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
@@ -744,6 +825,927 @@ qt_resource_data = "\
\x8f\xf3\x2f\x02\x93\x69\x3a\xed\x1c\xe8\xee\xee\x4e\xd2\xa7\x46\
\xff\xff\x67\x8f\x8f\x7b\xf9\x5f\x5a\xf1\x31\x65\xff\xe0\x15\x90\
\x00\x00\x00\x00\x49\x45\x4e\x44\xae\x42\x60\x82\
+\x00\x00\x2e\x85\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x80\x00\x00\x00\x65\x08\x06\x00\x00\x00\x85\xb7\xeb\xfa\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x02\x4b\x00\x00\x02\x4b\
+\x01\x08\x6c\xbf\x82\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x20\x00\x49\x44\
+\x41\x54\x78\x9c\xed\x9d\x77\x7c\x5c\xc5\xb9\xf7\xbf\x73\xb6\xaf\
+\x56\xbb\xd2\xaa\xf7\x2e\xcb\xb2\x6c\xcb\xdd\x32\x6e\x80\x4d\xc0\
+\x38\x40\xe8\x10\x7a\x0b\x24\xe4\xa6\x70\x79\x93\x10\x20\x84\xf4\
+\x0a\xc9\x0d\x29\x70\x81\x84\x40\x6e\x08\x3d\xc4\x74\x03\xae\x72\
+\xef\x2a\x96\x64\x15\xab\x97\x55\xd9\xae\x2d\x67\xde\x3f\x56\x92\
+\x6d\x59\xb6\x31\x2e\x58\xc0\xef\xf3\x39\x20\xcf\x99\x3e\xcf\xce\
+\x99\x79\xaa\x90\x52\xf2\x39\x0e\x40\x08\xa1\x03\x72\x81\x10\xf0\
+\x20\x30\x05\xc8\x06\x1e\x01\x9e\x90\x52\xb6\x7f\x72\xbd\x3b\xf9\
+\x10\x9f\x13\x00\x08\x21\x04\x30\x1d\xb8\x18\xb8\x09\xc8\x00\x82\
+\x80\x6e\x54\x56\x2f\xf0\x3f\x80\x16\x98\x0f\x94\xc9\x71\x3e\x81\
+\x9f\x69\x02\x10\x42\x24\x01\xff\x00\xa6\x01\xb1\x00\xf1\x68\x99\
+\x43\x14\xd3\x45\x14\x6f\xe6\x4f\xc5\xe9\x75\xe0\xf2\x39\x70\x7a\
+\x1d\x78\x07\x07\x18\x35\x5f\x39\x52\xca\xc6\xd3\xdf\xf3\x93\x07\
+\xed\x27\xdd\x81\xd3\x09\x21\xc4\x32\xc0\x0c\xec\x06\xfc\xc0\x7f\
+\x80\x92\xf3\xb1\x72\x03\x71\xcc\x21\x8a\x5c\x0c\x91\xcc\x12\xdc\
+\xd1\x79\xb8\xf3\xce\x1b\x29\x1f\x56\xc3\xd2\xed\xeb\x15\x1b\xaa\
+\x5f\x65\x7f\x77\x25\xc0\x54\xa0\xf1\xf4\x8e\xe2\xe4\x42\xf9\xa4\
+\x3b\x70\x9a\x31\x07\x78\x01\xa8\x06\x1a\x8b\x31\x96\xbc\x49\x01\
+\x6f\x52\xc0\x35\xd8\x0f\x2c\xfe\x10\xca\x7a\x3a\x77\x1f\xfc\x6f\
+\x8d\xa2\x11\xb6\xa8\x04\x96\x4c\xbb\x99\xb4\xb8\x42\x80\xc7\x85\
+\x10\xff\x75\xba\x3a\x7f\x2a\xf0\x99\x20\x00\x21\x84\x46\x08\xf1\
+\x04\xf0\xff\xec\x68\xb9\x06\x3b\xcf\x91\xc3\x4e\x8a\x39\x1f\xeb\
+\x11\xcb\x2d\xec\xed\x71\x8d\x95\xae\x51\xb4\x9c\x37\xfd\x56\x92\
+\x63\x73\x13\x81\x5f\x09\x21\x96\x9c\xa2\xae\x9f\x72\x7c\x26\x08\
+\x40\x4a\x19\x06\xd6\x00\xc6\x07\x16\x69\x79\x52\x97\xc1\xb5\xd8\
+\xd1\x22\x8e\x5a\x2e\xc1\xe3\xce\x3d\xd2\x3b\xad\x46\xcf\x17\x66\
+\xdc\x4e\x82\x2d\x53\x0f\xbc\x2e\x84\xf8\xab\x10\x62\xf9\xc9\xed\
+\xf9\xa9\xc7\x67\x82\x00\x86\xf0\x06\xe0\x34\x67\xf8\xd9\x73\x67\
+\x85\x44\xa7\xfa\x8f\x55\x40\x91\x6a\x72\x92\xc7\xd9\x7c\xa4\xf7\
+\x7a\xad\x91\x0b\x66\x7e\x05\x7b\x74\xaa\x11\xb8\x11\x78\x4a\x08\
+\x11\x73\x12\xfb\x3c\x26\x84\x10\x93\x84\x10\xdf\x15\x42\x4c\x3d\
+\xd1\xba\x3e\x33\x04\x20\xa5\xec\x01\x7a\xfa\xbc\x30\x75\x62\x48\
+\x04\x1e\xde\x53\x81\x56\x0e\x1e\xab\xdc\xac\x8e\xc6\xa6\xa3\xbd\
+\x37\xe8\xcc\x2c\x9b\x75\x17\x16\x53\x2c\x40\x02\xb0\xe8\xe4\xf4\
+\xf8\x50\x08\x21\x66\x0b\x21\xbe\x2f\x84\x58\x07\xec\xd1\xea\x0d\
+\x3f\x03\x76\x08\x21\xde\x14\x42\xcc\xfd\xb8\xf5\x7e\x66\x08\x60\
+\x08\x9b\x7a\x3d\x43\x7f\xd9\x82\x33\x82\x0f\xef\xde\x85\x56\x06\
+\x8e\x56\xa0\xb4\xa3\xe5\x98\x95\x86\xa5\x0e\xbd\x69\x0a\x8a\x62\
+\x00\xb8\x5a\x08\x91\x75\xe2\x5d\x8d\x40\x08\xa1\x08\x21\x7e\x06\
+\x6c\x44\xa3\xfd\x31\xb3\xaf\x9d\x97\x76\xe5\x0f\xf9\xde\xab\x2d\
+\x7c\xf5\xb7\xaf\x91\x98\x91\x7f\x3e\x50\x2e\x84\x78\x45\x08\x91\
+\x73\xbc\xf5\x7f\x66\x08\x40\x08\x61\x04\x96\xf5\x79\x0f\xa4\xc9\
+\x98\xe0\xac\xe0\x43\x7b\x76\xa0\x91\xc1\x23\x95\x4b\x73\xf7\x67\
+\x1e\xab\xee\xea\xbe\x68\x77\xce\x84\xdf\x90\x90\x74\x19\xc0\xd5\
+\xc0\x57\x4f\xbc\xc7\x20\x84\x88\x05\x56\x24\x19\x92\xbf\xfb\xe3\
+\x92\xff\x41\x77\xe7\xcb\xb0\xf4\xdb\x6a\xeb\xb9\x0f\xf2\xa3\xe6\
+\x78\x9e\x10\xf3\x43\x9a\x3b\x5f\x24\xf1\x86\xdf\x23\xd2\xa7\x5c\
+\x02\x54\x09\x21\xae\x3c\x9e\x36\x3e\x33\x04\x00\x7c\x01\xb0\x8e\
+\xec\x00\x43\x90\xf6\xc0\xec\xe0\x43\x7b\xb6\xa1\xc8\xd0\x58\x85\
+\x34\x52\xcd\xb4\xfb\x3d\x9d\x47\xaa\x74\x50\x9a\x1a\x15\xd3\x2c\
+\x13\x80\x46\x17\x83\xd6\x6c\x03\xf0\x9d\x68\x67\x85\x10\xc5\xc0\
+\x96\x7c\x73\xc1\xf9\x6f\xcf\x59\xcd\xa3\xf3\xb2\x43\xc1\xe4\x02\
+\x30\x59\xc2\xa8\xa1\x9d\x00\x41\xb3\x5d\xdb\x1e\x3f\x95\xae\xac\
+\x85\xc8\x70\x10\xa0\x1d\x78\xf7\x78\xda\xf9\x2c\x11\x80\x01\x60\
+\x34\x01\x00\xc8\xb8\xc0\x9c\xe0\x0f\x2a\xb7\x1c\x89\x08\xa6\x77\
+\xec\xaf\x3f\x52\xa5\xfb\xdc\x25\xad\xa0\x68\x00\xa2\xd2\x0b\x48\
+\x5d\x7c\x1d\xc0\x31\xcf\x16\x47\xc3\x10\xc3\xaa\x7c\xa6\x6d\x56\
+\xee\x87\xb3\xd7\x71\x45\xee\x5e\x6f\x4f\x46\x7e\x84\x69\x27\xa5\
+\x2e\xa3\xbd\x59\xbf\xbc\x7c\x93\x5a\xf8\xd2\x2f\xe0\xc7\xd3\xe1\
+\xa7\xb3\xa0\xbd\xca\x0b\x5c\x2a\xa5\xec\x3b\x9e\xb6\x3e\x31\x02\
+\x10\x42\x24\x0a\x21\xfe\x2d\x84\xf8\xa7\x10\xe2\x3b\x42\x88\x23\
+\x5e\xb9\x4e\x12\x06\x00\x0e\xfe\x04\x1c\x0c\x99\xe0\x9f\x1b\x7c\
+\xb0\x72\x33\x8a\x0c\x8f\x7e\x37\xbd\xa3\x79\xcc\x4f\x84\x3b\x6c\
+\xab\xf0\xab\xc6\x79\x00\x61\x5d\x6f\x63\xf2\xf5\x17\xa2\xb7\x25\
+\xc0\xc7\x20\x00\x21\x84\x41\x08\x71\xb1\x10\xe2\xdf\xc0\xeb\x5f\
+\x88\x5f\x66\x7d\x7b\xe6\x07\x5c\x9f\xb6\xb3\xa7\x36\x2f\xdf\xac\
+\x97\xc2\xb3\xa8\x4f\xbf\xf3\xeb\x4d\x51\x7d\xb7\x0d\xe4\x4c\x9c\
+\x65\x9b\xad\x74\x96\x3f\x06\xcd\xdb\x9f\x26\x1c\xbc\x1b\x98\x28\
+\xa5\xdc\x7e\xbc\xed\x7e\x22\xac\xe0\xa1\xed\x6d\x05\x3a\x63\x76\
+\xcc\xbc\x8b\x18\xd8\xf8\xc6\x55\xd2\xef\x7e\x48\x08\xf1\x6b\xe0\
+\x11\x29\x65\xef\x29\x68\xd6\x09\x63\xef\x00\xc3\x90\x89\xfe\xb2\
+\xe0\x03\x95\xeb\x75\x3f\x2a\x9e\x83\x2a\x34\xc3\xe9\x99\x4e\x47\
+\xea\x58\xf9\xf7\x79\x8a\x02\x80\x40\x2f\x76\x47\x5f\x91\x55\x2c\
+\x34\x5a\x74\xd6\x78\x80\xe2\x8f\xd2\x21\x21\x84\x1e\x38\x0f\xb8\
+\x12\xb8\xd8\x6e\x4a\xb2\x9e\x9b\x7b\x15\x4b\x73\xaf\xc5\xee\xf0\
+\x70\x5f\x42\x65\x73\x7d\x5e\x81\xfb\xd6\xde\xd8\xda\x94\xa0\x61\
+\xa6\x4e\x63\x98\xda\x13\xa8\x62\xe5\xae\xfb\xd9\x53\xf3\x3c\x03\
+\xae\xe6\x3e\xe0\xab\x52\xca\x63\x5e\x69\x8f\x84\xd3\x4e\x00\x43\
+\x57\x96\x37\x48\xc8\x8b\x15\x37\x3f\xc3\xec\x8b\xa7\xd2\xb1\x6b\
+\x3b\x55\xcf\xff\xd5\x18\xdc\xf0\xaf\xfb\xf1\xbb\xee\x11\x42\x3c\
+\x0b\x3c\x26\xa5\xdc\x79\x12\x9b\x76\x02\xf4\x1f\x61\x07\x18\x86\
+\x4c\xf2\xcf\x0b\x7e\xbf\x72\x9d\xee\xc7\x93\xca\x90\x91\x1d\x52\
+\xaf\x86\xf3\x2c\x81\x40\x9f\x5b\xaf\x8f\x1d\xce\xd7\x1b\x4c\xdc\
+\x1c\x96\xfa\x59\x68\x45\x9d\xf9\xf2\x98\x4c\xa1\x8f\x10\x8c\x3e\
+\x42\x00\xd7\x0b\x21\xbe\x23\xa5\x74\x8c\xae\x7f\x48\xdc\xbc\x84\
+\xc8\xa2\x5f\x62\xd0\x9a\x63\xe6\x67\x2e\x67\x69\xee\xb5\xcc\x48\
+\x3d\x07\x77\xb0\x1f\x67\xb0\xc6\xdd\x97\xe3\xdb\x65\x2b\x9a\x20\
+\x6e\xf7\x86\x0b\x31\xf6\x15\x69\xa2\x55\xb1\x76\xd3\x63\xbc\xf9\
+\xce\x23\xa8\xaa\xda\x49\x84\xaf\xf1\x9b\x13\x59\x7c\x38\xcd\x04\
+\x30\xc4\x32\x7d\x95\x69\x5f\x8a\xe2\xc6\xa7\xb1\xf8\xf6\x3b\x0c\
+\x96\xa8\xb8\xac\x79\xf3\x49\x98\x58\x2c\xcb\x5f\x38\x5b\x78\xab\
+\x56\x9a\x68\x7a\xee\x76\x1a\x02\xb7\x0b\x21\x3a\x89\x70\xf0\x86\
+\x9f\x9d\x52\x4a\xf5\x63\x36\xef\x04\x70\xfa\x41\x95\xa8\x8a\x38\
+\xf2\xe7\x4f\xa6\xf8\xcf\x0a\xde\x57\xb9\x56\xf7\xd3\xe2\x79\x43\
+\x44\x20\x4a\xbb\x9a\x6b\xd7\xa6\xe7\xcd\x06\x90\x52\xa8\xfb\xbd\
+\x05\x56\x04\x6d\xa6\x2f\xd9\xa2\x85\x41\xd8\x86\xcb\xea\xa3\xe3\
+\x01\xf4\x40\x14\xe0\x18\x1a\xb7\x06\x58\x4c\xe4\x86\x70\xa9\x40\
+\xd8\xa7\x26\xcf\xa7\x2c\xa7\x8c\x82\x84\x5c\xb4\xa6\x06\xec\x09\
+\x4f\xf8\xf4\x59\x5f\x0b\x25\xe8\x5b\xdb\xda\x5a\xca\x1a\x76\x0d\
+\xdc\xf8\x85\x50\xe3\xde\xc1\x9c\x99\x0b\x8c\x9b\x37\xbf\xcd\x8b\
+\x2f\x3e\x8a\xd3\xd9\xdb\x00\x7c\x4d\x4a\xf9\xe6\xc7\x9c\x83\xc3\
+\x70\x5a\xc4\xc1\x42\x08\x0b\xb0\x06\x45\x5b\xca\x65\xbf\x80\x25\
+\xdf\x06\xa0\xac\x98\xca\x78\xeb\xa1\xdb\xe5\xae\x37\xde\xf3\x36\
+\xbb\x3e\x30\xab\xf6\xb7\xa1\x66\x17\x54\x07\xa1\x0a\xe8\x86\xa1\
+\xff\xbe\x49\x44\x8a\xf7\xb6\x94\xd2\x39\x46\x5b\xd1\x80\x2a\xa5\
+\xf4\x8c\x4a\x7f\x08\xf8\x01\x40\xc7\xaf\x09\xc5\x9a\x8f\x4d\xfc\
+\xa2\x39\x6a\x8d\xee\xe7\x45\xf3\x91\x88\xaa\xb8\x94\x55\x8f\xcc\
+\x3a\x77\x11\x40\xe7\x60\xfa\xda\xd6\xc1\x9c\x62\xf3\xc5\xd6\x01\
+\x11\xab\x3d\xe4\xee\x1d\xf2\xb9\x58\xff\x8d\xa9\x3e\x20\x8e\x88\
+\x8e\xc1\xd5\xc0\x15\x40\x52\x3e\x30\xc7\x26\x58\x72\x4b\x3a\x39\
+\x19\xfd\x58\xa2\x0e\x11\x35\xb4\x35\xb7\x88\x7d\x9b\x5a\xbf\x18\
+\xe5\xd0\x5e\x38\x1d\xa0\xa7\xa7\x99\xcd\x9b\x5f\xa4\xb9\xb9\x3a\
+\x08\xfc\x06\x78\x58\x4a\x79\xc2\x37\x8c\x83\x71\xba\x76\x80\x07\
+\xb0\x26\x95\x72\xe7\xcb\x90\x37\x2f\xd2\xb0\x42\xcd\xe8\xc5\x07\
+\x98\xb2\x6c\x89\x39\xcb\x31\xbb\x6b\xb7\x23\x36\xd8\xb7\x68\x7a\
+\x12\xeb\xb6\x68\x31\xec\x86\xde\x10\x54\x91\x40\x25\x37\x50\xcd\
+\x0d\x78\x09\x0a\x21\xd6\x00\xcf\x10\xf9\xb5\x95\x01\x33\x81\x09\
+\x80\x14\x42\xd4\x00\xdb\x80\x5a\x22\xda\x3d\x0f\x0c\xb7\xd1\xed\
+\x42\x1b\x6b\x3e\x76\xa7\x65\x86\x67\x41\xf0\xde\xea\x35\xba\x5f\
+\x15\xcd\xcf\x19\xe8\x4e\x00\x50\x51\x06\xdb\x06\xb3\x52\x74\x8b\
+\xf0\x1c\xbc\xf8\x32\x14\xa4\xaf\x72\x0d\x5d\x9b\x5f\x07\xe8\x05\
+\x36\x01\x25\xd9\xc0\xd9\x43\x4f\x3a\x50\x87\x1c\x9c\x5c\xd4\x7c\
+\xb0\xd8\xd1\xdb\xdf\xc7\xa6\xba\x06\xa5\x60\x87\xfb\xd6\x89\x6a\
+\xd4\xcc\x78\xaf\xd7\xc9\xe6\xcd\xaf\x52\x53\xb3\x3e\x24\xa5\x7c\
+\x96\xc8\x56\xbf\xe7\xa3\x4d\xf5\xf1\xe1\x94\xef\x00\x42\x88\x89\
+\xc0\x4e\xbe\xf2\xb2\x8e\xe8\xac\x30\xdd\xbd\xb5\xf4\xf8\xba\x0c\
+\x83\x4e\xcd\x03\xbf\x9f\x2b\x73\x26\x58\x52\x14\xa4\x0e\x21\xb4\
+\x02\xa9\x97\xa0\x1f\x0c\x4b\xff\x53\x15\xad\x4d\xbb\xf4\x1f\x68\
+\xfa\x35\xfb\xa7\x33\x18\xec\x67\xfb\x56\x37\xbe\xed\x69\x88\xb0\
+\x40\x02\x4d\x44\x76\x86\x4a\xa0\x9e\xc8\x12\x03\x5a\x83\x82\x35\
+\x31\x16\xd5\x15\x8d\x46\xa3\x3f\xec\xb1\x5a\x2d\xdc\x7a\x61\x06\
+\x57\xc5\x3d\xae\xda\xa3\xe4\x47\xba\x05\x89\x7a\xcb\x6a\xdd\xaf\
+\x27\x94\x7d\x7d\xe9\x55\x81\xda\xc0\xc4\x4d\x7d\xd3\x12\x33\x8c\
+\xc5\x09\xf9\x52\x0d\xd3\x5f\x5d\x4e\xf7\xe6\xd7\xe9\xd9\xfe\x0e\
+\x21\xef\x00\x10\x51\x27\x5a\x4c\x64\xd1\xb3\x47\xd5\xa5\xc9\x66\
+\x4d\xd4\xf7\x58\x00\xe0\xf3\xb3\xbe\xae\x8e\x9c\x5e\x97\x5e\x54\
+\x84\xbe\x9d\x28\x0c\xe9\xca\xae\x5d\xef\xb1\x63\xc7\x9b\x04\x83\
+\x83\x1f\x02\x77\x4b\x29\x2b\x4e\x68\x01\x8e\x35\xb6\xd3\x40\x00\
+\x7f\x42\x28\x77\x92\x7f\x17\x58\xa7\x8c\x7e\xe9\xb9\xe2\x8e\x09\
+\x3b\x2e\xbb\xad\xf0\xac\xe1\xa4\xae\x3e\x57\xdf\x33\xbb\x5b\xc2\
+\xd2\x62\x89\x97\x52\xf5\xd6\x98\xdf\xde\xef\xd6\x76\x17\x01\x20\
+\x65\x1b\x75\x5b\x5b\x68\xdd\x34\x03\x54\x45\xe7\x56\x84\x31\xac\
+\xc5\xa6\x18\xc9\x8b\x8e\x25\x23\xdb\x16\x32\xdb\x75\x5a\x55\x4a\
+\x59\xf3\x58\x4a\xd8\xb3\x27\x6d\xd4\x0e\x17\x72\x5d\x7e\xf9\x32\
+\xc5\x62\x89\x8a\x12\xde\x7d\xc1\x0b\xe4\xbd\xee\x74\xab\x2f\x96\
+\x8f\x00\xa5\xce\xb2\xfa\x89\x95\x77\x2a\xff\x38\xe7\x7a\xe9\xcb\
+\xaa\x2d\xf5\xb2\x23\xba\xe5\xa9\x97\x70\xed\xaa\x02\x20\x85\x03\
+\xbf\xf4\xbc\xa3\xd4\xa3\x2b\xe5\x43\xc3\x1d\x24\xee\xdb\x87\xea\
+\x74\x51\xd2\xd0\x6b\xf3\x74\x47\xdf\x1f\xd5\xb8\xbf\x86\x8d\x1b\
+\x5f\xc6\xed\x76\xb4\x02\xf7\x4a\x29\xff\xef\xa3\xf4\xeb\x44\x71\
+\x3a\x08\x20\x16\x58\x83\xa2\x9f\x44\xfe\xd7\x21\xba\xf0\xb0\x3c\
+\xf1\xb6\x1e\xf7\x0f\xfe\x7a\x89\xa5\xa9\xbd\x8d\x77\x5d\x12\x6d\
+\x4c\x0c\x61\x97\x0b\xb5\xb7\x0b\xe9\x6f\x75\xf7\xe4\x6c\x6e\x8d\
+\xb1\xaa\x5e\x7b\xd8\x1b\x65\x93\xfe\xf8\x98\xc0\x80\x3d\x4a\x51\
+\x11\x47\x17\xe7\x06\xeb\xff\x37\xbb\xb9\x6f\x73\x5c\x2e\x80\x94\
+\x78\x97\x2f\x5f\xdc\x9d\x90\x10\x3b\xc2\xa7\xd7\x32\x18\xbc\x44\
+\xf9\xd6\x36\xbb\x68\x98\x33\x56\x05\x12\x42\xbe\x20\xad\xae\x00\
+\xdd\x2e\x3f\xde\xbf\x27\x5f\xd2\xfe\x5e\xe9\xd2\x45\x52\x90\x0c\
+\xe0\xd9\x55\x4f\xd6\xa6\xb7\xb8\xee\x5d\x17\x36\x6f\x84\x55\xa0\
+\x40\x20\x56\x88\x8e\x44\x8d\xa6\x43\x17\xd1\x2b\x94\xaa\xaa\x82\
+\x56\x84\xf5\x59\x7a\x77\x6b\x8e\x53\xdd\x61\xec\x59\x8e\x44\x6c\
+\x6f\xcf\xa1\x43\x73\x25\xe5\xe5\x2f\xd2\xd9\xb9\xcf\x4f\xe4\x3b\
+\xff\xb3\xd1\xe7\x97\x53\x89\xd3\x75\x08\xcc\x00\x56\xa2\xe8\x0a\
+\xd0\xc5\x82\xa2\x8b\x3c\x42\x07\x8a\x16\x84\x0e\xab\xdd\x80\x26\
+\x49\xc1\xd7\xd9\x49\xa0\xa3\x13\x8d\xea\x63\xe9\xcd\x82\x25\x37\
+\xe9\x90\x3a\x2d\xbe\x41\xbd\xcb\x10\x6f\xac\x0b\xea\x8c\x4e\x5f\
+\x50\x87\x3f\xa0\xd7\xfa\xc3\x1a\x53\x28\xac\xb1\xa8\x52\xc4\x48\
+\x29\xec\x1c\x7e\xa6\x09\x34\x3e\x95\x5d\xe3\xd8\x18\x37\x71\xe1\
+\xc2\x99\xbb\x73\x73\x33\x4a\x47\xbd\x97\x52\xd0\x3b\xd7\xf8\xfb\
+\x1d\x49\xbd\x6f\x46\x0f\xf8\xf1\xb9\x03\x08\x5f\x88\xa8\x50\x98\
+\x04\x55\x92\x3a\x5c\x67\x48\x23\xaa\x57\xcc\x89\xcf\xde\x9a\x79\
+\x7f\xb3\x8a\xb6\x40\x4a\xa9\xde\xbf\xe7\x81\x40\x5a\xc8\xa9\x9d\
+\xfd\xbf\x33\x35\x30\x06\x35\x6a\xf0\x91\x2a\x76\x91\x4d\x08\x8b\
+\x98\x8c\xc0\xfa\x56\x6b\x73\x43\xbd\xda\x92\xb3\xbe\x79\x1a\x9b\
+\xea\x0d\xd4\xd6\x6e\x04\xe4\x2b\xc0\x3d\x52\xca\x86\x93\x3e\xf9\
+\xc7\xc0\x69\x53\x0a\x15\x42\x14\x01\xbb\x88\xf0\xc9\x7f\x03\xd8\
+\x89\x68\xd6\x4e\x05\x7a\x80\xf5\xc0\x3a\x22\xea\x5a\xbd\x40\x0e\
+\x70\xb9\xce\xc0\x97\x2e\xbc\x1d\x71\xf5\x77\x21\x3e\xed\xc8\xf5\
+\x4b\x90\x83\x21\x6d\x9f\x3f\xa8\xeb\x73\x0f\x1a\x5c\xde\x80\xde\
+\xeb\xf3\x1b\x3d\x21\x4f\x74\xa8\xe2\x9d\xdb\xfd\xb6\xe2\xd2\x58\
+\xa7\x55\xab\x19\xb0\x69\x8c\x4e\xab\xce\xe2\xb4\x6a\xec\x1e\xb3\
+\x12\x27\x15\xb4\xdf\xee\xa9\x5e\x37\xf3\xf9\x87\x82\x9d\x8d\x55\
+\x8b\x8f\x50\x7d\x68\xcd\x14\x6b\x5d\xaf\x55\x57\x34\x60\xcc\xdb\
+\x53\x13\x7f\xe3\x24\x4b\xd8\xbd\xeb\xb7\xfb\xbe\x33\xd5\xd6\x18\
+\xb3\x3b\xef\xb5\x09\x93\x47\x72\x6a\xf1\x90\xc1\x2e\x32\x15\x88\
+\x62\x0a\x91\x03\xea\x08\xfe\xbc\xa7\x91\x67\xaa\xf5\x6c\xae\x6a\
+\x27\x14\x0a\x54\x00\xdf\x94\x52\xbe\xf7\x71\xe6\xf4\x64\xe0\xb4\
+\x6a\x05\x0b\x21\x2e\x05\xd6\x4a\x29\xbb\x0e\x4a\x33\x1e\x8d\x99\
+\x21\x84\x98\x06\xfc\x5a\x67\xe0\x9c\x0b\x6f\x87\xb1\x08\xc1\x08\
+\x9d\xd1\xd0\x66\x07\xa7\x15\x14\x13\xd8\x74\x90\xd9\xdc\x7c\xf6\
+\xfe\x77\xde\xfd\x83\xf1\xd5\x6b\x0d\x2d\x75\x46\xcf\x39\x63\xd5\
+\x3f\xd9\xdf\xbf\xe6\x5f\xfb\xd7\x2f\x90\x52\xaa\xab\xfe\xef\xe7\
+\xdb\xfc\x1e\xe7\xcc\xd1\x79\x3a\x6c\xba\x55\x1b\x27\x5b\x47\xe4\
+\xfc\x95\x89\x77\xac\x29\xf5\x35\xab\x37\x76\x3e\xbb\x28\x63\x55\
+\xf6\xea\x84\xca\xe4\xa9\x64\xb2\x87\x0c\x45\x83\x49\xa6\x23\x85\
+\x7b\x30\xac\x0e\xf6\xfb\xa4\xae\xd7\xab\xda\x5a\x9d\xc1\xe4\x3d\
+\x5d\x1e\xcd\xa6\xb6\x01\xde\x6d\xe8\xa3\xc7\x13\xec\x27\x72\x25\
+\xfd\xa3\x94\x63\xcb\x1f\x4e\x17\xc6\x8d\x5a\xb8\x10\xe2\x8b\xc0\
+\xcf\x75\x06\x8a\x97\xdf\x06\xb7\x5c\x06\x69\xf1\x90\x9f\x4d\x38\
+\x3a\x1a\xcd\xc1\x79\x03\x01\xdb\xc0\x8a\x15\x7f\xdb\xd9\xdd\x3d\
+\x79\x81\x6a\xd3\x6e\xdf\x70\x95\x5e\xdd\x48\xd7\x61\x0b\x6b\x92\
+\xe1\xea\xf2\xba\xf7\xb2\x0d\x32\x6c\x04\x08\xfa\x7d\x03\xef\x3f\
+\xf7\x93\x7e\xa9\x86\x47\xce\x09\xaa\x42\xd3\x8a\xb9\xf6\x44\x55\
+\x11\xa6\xe1\x34\x5d\x94\x69\xdf\x17\x06\x82\xc1\x94\x40\x67\x51\
+\xe1\xee\x59\x6f\x1b\xd4\x98\x9c\xb0\x14\x49\x7a\x45\x17\x3d\xcc\
+\x60\xaa\xe8\x71\xf1\x46\x5d\x37\x6f\xec\xeb\x62\x5d\x4b\x2f\x41\
+\x55\x42\x84\x31\xf4\x17\x22\xec\xee\x9e\x53\x32\x51\xc7\x89\x71\
+\x43\x00\x30\xc2\x51\xbb\x06\xb8\xfd\xbc\xf3\x58\xf8\xf2\xcb\x10\
+\x15\x75\x68\x9e\x3d\x7b\x6e\x28\x2f\x2f\x7f\x20\x4f\x4a\x25\x11\
+\x60\x60\xbe\xa5\xb2\xbb\x58\x9f\xf8\x3a\xfb\xe3\x0f\xa9\x0b\xe9\
+\xfc\x4f\xc3\xea\xfe\xdc\xa0\xe7\x10\x79\xbf\xd3\xd1\x5a\xb7\xfe\
+\xe5\x3f\xa4\x30\xb4\x75\x6f\x2c\xb6\xec\xe8\xb0\x1b\x4a\x01\x92\
+\xe2\x44\xe5\xac\xa9\xba\xbe\x04\x93\x46\xa3\x7f\xa9\xd5\xae\xaa\
+\xb2\xd0\xba\x6e\x61\x20\xdf\x1a\xad\xf7\x06\xc3\xac\x6c\xec\xe1\
+\x8d\x7d\xdd\xbc\x59\xdf\x45\xd3\xc0\x21\xfc\x9a\xbd\xc0\xa3\xc0\
+\xdf\x4e\x36\x23\xe7\x44\x31\xae\xec\x02\x86\x94\x3b\x9f\x05\x9e\
+\x15\x42\xdc\x71\xee\xb9\xfc\x65\xc5\x0a\x88\x8b\x03\x97\x2b\xbd\
+\xed\xf5\xd7\xff\xaf\xc5\xed\x4e\x2d\x3b\xa8\x88\xcf\x57\x68\x9a\
+\x68\x02\x21\x10\x1d\x12\x99\x3c\xfc\xe2\xc1\xae\x8a\xca\xdc\xa0\
+\xe7\x30\x55\x2a\x6b\x5c\x5a\xfe\xa4\x79\x17\x6d\xa8\x58\xff\xef\
+\xb9\x03\x16\xed\x9a\xce\x38\x43\xd9\xc4\x5c\xcd\xfa\x92\x09\x1a\
+\xab\x5e\x27\x4a\x00\x0a\x1d\xc6\x8a\x26\xad\xf0\x85\x7c\x52\xbe\
+\x5a\xd5\xa5\x7f\xb7\xa1\x8a\x55\xfb\x1d\x0c\x86\x55\x00\x17\xb0\
+\x83\x08\x13\x6a\xfb\xd0\xdf\xbb\xce\x54\x0b\xa2\x71\x45\x00\xa3\
+\xf0\xfa\xc6\x8d\xfc\x65\xe1\x42\x85\x7b\xee\xb9\xbd\x33\x14\xba\
+\xd7\x0a\x62\xf6\xc1\x19\x02\xf1\xda\xfa\xb0\x96\x49\x00\x46\x34\
+\xfb\x7d\x84\x92\x01\xe6\xfa\x1c\xab\xae\xee\xdf\x7f\x44\xdd\xbd\
+\x8c\x49\x65\x73\xbb\x3a\xea\xde\xee\x5d\xd2\xa3\xbd\x3a\x5f\xd3\
+\xa3\x08\x31\x6f\xf8\x9d\x4e\x55\xb6\xc6\x0c\x6a\x67\xd4\x86\x34\
+\x3b\xeb\xf6\xab\xe2\xde\xf7\xab\x21\xb2\xd0\xbf\x03\xca\x81\xda\
+\x33\x75\xb1\xc7\xc2\xb8\x25\x00\x29\x65\xbb\x10\xa2\xcd\x5c\xfa\
+\xa5\xd4\xee\xde\x1b\x63\x63\xad\x42\x3f\x3a\x8f\x67\xaa\x79\x64\
+\xdb\x8f\xc3\xe0\x6d\x21\x44\xb4\x1a\xdc\xf3\x44\xcb\xa6\x79\xa3\
+\xf3\x22\x19\xa0\xbd\xbf\x9a\x95\x95\x3e\x5e\xd9\x96\x36\xad\xba\
+\xf5\xdc\x60\x57\x72\x65\x48\x88\xe4\x83\xb3\x15\xf6\x19\xf5\x00\
+\x4e\xb7\x62\xad\x6e\x1c\x49\xfe\xa1\x94\xf2\xb5\x93\x38\xbc\xd3\
+\x86\x71\x4b\x00\x42\x88\x04\x20\xc5\x36\x3f\x9b\xba\xb3\x5b\x95\
+\x92\x6d\xb6\x75\x26\xb7\x71\x84\xa3\x88\x22\x3c\xbe\x6c\x7d\xd2\
+\xf0\x3f\x93\x30\x9b\x5b\xf1\xf4\xbe\xd2\xb4\xce\xae\x95\x52\x07\
+\xf8\x70\xb8\x2b\x58\x53\x33\x20\x5f\xde\x9a\xc0\x96\xc6\x62\xc2\
+\xea\x08\x43\x48\x00\x53\xa7\x77\x27\x6c\x6d\x48\xea\x43\x44\xec\
+\x06\xf5\xaa\xd8\x1e\x3b\xa8\x9d\x06\xe0\xf2\x2a\xb1\x43\x04\x30\
+\x08\x7c\x62\xd7\xb8\x13\xc5\xb8\x25\x00\xe0\x1c\x40\xd8\x72\x52\
+\x11\x51\x1a\x6d\xc5\x82\xbd\x67\xa5\xd6\x25\xad\x4d\xad\x4d\x9e\
+\x0e\x98\x03\x09\xba\x1a\xa9\x11\xd3\x86\x33\x27\x63\x4a\xf9\xc1\
+\x9e\xf2\x75\x69\x6b\xb7\x9a\xe4\xcb\x5b\x5a\xf8\x70\x6f\x09\xc1\
+\xf0\x61\x37\x83\x83\x61\x6c\x0a\xa7\xe4\x7f\xa5\x7f\x63\xdd\xe3\
+\x31\x73\x00\xf2\xfb\x4d\x23\xb2\x03\x7f\x48\xd8\xf6\x36\x02\xf0\
+\xe1\xe9\xe4\xdc\x9d\x6c\x8c\x67\x02\x98\x05\x60\xcb\x49\x19\x49\
+\x68\xcb\xef\x9c\xdf\x9f\xe8\xac\x2f\xda\x90\x1f\xf6\x96\x18\x74\
+\x9a\x40\xb0\x21\x75\xeb\xb6\xa6\xd2\xa7\x9f\xd3\x4d\x7a\xfe\xa5\
+\x22\xa3\xd3\xf5\xc5\xe3\xfd\x38\x27\x3d\xe1\x9d\xd3\x75\xa3\x79\
+\xb5\x6f\x9e\xc1\x66\xf7\x6b\x47\x0c\x31\x34\x1a\x8d\xa8\xd9\x0f\
+\xc0\x3b\x27\x61\x2c\x9f\x18\xc6\x33\x01\x4c\x35\x27\xd9\xd1\x9a\
+\x8d\x87\x24\x7a\xad\xbe\xdc\x95\xf3\x6b\x3f\x9c\xfd\x97\x8e\x81\
+\xab\xd7\xbe\x17\x6f\x72\x75\x0b\xbd\xa7\x37\x38\x68\x4c\xa8\x94\
+\x98\xf4\xda\x41\x8f\x59\x13\x0a\x58\x14\x35\x64\x43\xca\x58\xc6\
+\x62\xe1\x8e\x42\xd6\x12\xc7\xfc\x3f\x7c\xd7\x8e\x76\x8e\x8e\xa2\
+\x7c\x3d\x51\x66\x05\xa1\x35\xe0\x1f\x74\x01\x6c\x38\x35\xc3\x3b\
+\x3d\x18\xcf\x04\x30\xc5\x96\x7b\x98\xaa\xde\x40\x79\x87\xb5\x72\
+\x4f\x77\xd4\xac\x17\xad\x39\x51\x55\x0b\x72\xd6\x5c\xe1\xeb\x2e\
+\x23\xa2\xa1\x73\x18\x84\x54\xc3\x3a\x9f\xab\xdf\xe8\xe9\xeb\x37\
+\xba\x7b\xdd\x06\x77\x8f\xcf\xe4\xea\x0e\x9a\xdd\xbd\x61\x9d\xbb\
+\x57\x74\x78\x7a\x73\x3a\x7d\xce\xcc\x07\xfd\x1e\x65\xfd\x43\x0e\
+\xee\xc7\x81\x10\x90\x99\xa6\xc3\x16\xad\x40\x44\xd8\xb3\xed\xd4\
+\x0e\xf3\xd4\x62\x5c\x12\x80\x10\xc2\x0e\x24\xda\x72\x0e\x10\x80\
+\x3f\x24\x76\xbe\xd2\x10\x9f\xe0\x0e\x6a\xca\x08\xd1\x06\x44\x3d\
+\x1d\x9d\xbc\x60\x9b\x31\xba\xe2\x27\xbd\xf5\x71\xca\x90\x04\xef\
+\x60\x48\xa1\x68\x02\x66\x5b\x5c\xc0\x6c\x8b\x73\x26\x64\x8f\xa4\
+\xab\xd0\xbe\x4b\x9a\x6a\xfe\xee\x0c\x47\x55\xf5\xb6\xc3\x7f\xee\
+\x02\xf8\x37\xd0\x2a\x25\x8e\xa6\x96\xa0\x81\x08\xa3\xa8\xe7\x44\
+\x75\xf2\x3e\x69\x8c\x4b\x02\x60\xe8\x17\x6d\xcb\x49\x41\xaa\xaa\
+\x5a\xe7\xb4\xac\x5e\xd5\x1e\xbd\x50\x4a\x11\x39\xa4\x85\x35\xee\
+\xe1\x8c\x3b\x75\x51\x93\xae\x4f\x98\xd8\xf3\xa7\xde\xda\x1d\x56\
+\x35\x34\x5a\x1a\x38\x02\x89\xe8\xae\x52\x0d\x55\xff\x54\x63\x62\
+\xb6\x86\xcd\x93\x25\xa4\x60\x04\x9c\x5b\x22\xaf\x23\x3a\xf7\x87\
+\xa9\x8c\x8f\x77\x8c\x57\x02\x88\xc0\x62\xe5\xf5\x6d\xa2\xbf\xd3\
+\x6c\x5d\x7c\x48\x7a\xe0\x50\xcb\x9c\x3e\x45\x1b\xff\xe5\xf8\x89\
+\x31\x3f\xea\x6b\x58\x55\x1a\x74\x8f\x30\x80\xa4\xa4\xaf\x0e\xc3\
+\x9e\x7f\x85\x62\x2c\xe5\x6a\xd4\x14\x15\x16\x1e\xd6\x86\xb3\x05\
+\xa0\xff\xd3\xb8\xf8\x30\x0e\x09\x60\xc8\xa1\xd3\x03\x18\x75\x6c\
+\xdc\x0d\xd8\xfc\x46\x46\x9b\x94\x84\x74\x87\x6d\xcb\x61\xd0\xde\
+\x17\x9b\xb3\xe8\x52\x4f\xf7\xaa\x25\x1e\x87\xe6\xa5\x50\xac\x61\
+\x55\x38\xaa\x34\x84\x58\x70\xd4\x06\x23\x04\xd0\x7d\xb2\xfa\x7f\
+\xa6\x61\xdc\x11\x00\xf0\x4d\xe0\xab\x7c\x69\x3a\x2c\x30\xaa\xfc\
+\xaf\xdb\x4c\x69\xcb\x16\x2e\x4e\xb5\xa0\x57\x22\xaa\x63\x01\x46\
+\x5b\xf2\x78\x09\xb1\x03\xf7\xa0\xe6\x65\x77\xd4\x9c\x97\xe5\xa8\
+\xab\xc3\xd1\xe0\x6c\x85\x88\xe6\xe1\xa7\x12\xe3\x91\x00\xbe\x82\
+\xa2\xc0\x0d\xf3\x24\x69\xd1\x0a\xff\x54\x77\xb2\xa3\x6f\x26\x3b\
+\xfb\x24\x73\xe3\x36\x70\x7e\x4a\x22\x41\xc2\xc0\x20\x61\xb9\x03\
+\x57\x28\x8c\x37\x38\x95\xb0\x3c\x9c\xfd\x7b\x2c\x04\xbd\xd0\xdf\
+\x00\xf0\xe1\x49\x1e\xc3\x19\x83\x71\x65\x1c\x2a\x22\x7c\xf9\x09\
+\x2c\x99\x08\x69\xb1\x91\xfb\xfb\xa5\x9e\xc8\xf7\x5e\x22\x28\x77\
+\xcc\xe5\xa1\x3d\x59\xe6\x4e\x7f\x8f\xe8\x53\x37\xd2\xe6\x2d\xc5\
+\x15\x98\x47\x58\x46\x1d\xa5\xda\x23\xa3\xf6\x0d\x08\xf9\x7d\xc0\
+\x73\x27\x69\x08\x67\x1c\xc6\xdb\x0e\x10\x39\xa4\xdd\x34\xff\x40\
+\xca\x62\xef\x2c\x9e\x8d\x6e\x43\x25\x72\x27\x94\x68\xf2\x97\xe7\
+\x4d\x56\x52\x12\xf4\x7b\x56\xa8\x95\xa1\xf6\x1e\x3f\x5d\x5d\xe0\
+\x1f\x9c\x82\xe0\xf8\x08\xa1\xf2\x05\x80\x67\xce\x14\xe5\x8d\x53\
+\x81\x71\xb5\x03\x00\xd7\x92\x6f\x82\x44\xcb\x81\x14\x81\x86\x79\
+\x83\x35\x23\xff\xd4\x88\xce\xa8\xa2\x84\x42\x93\x8d\xec\x99\x57\
+\x2b\x25\x31\x53\x13\x03\x94\x94\xcc\x65\xfa\x34\x0d\xd9\x59\x1b\
+\x31\x99\xd6\x01\xfd\xc7\x6c\xa9\x63\x3b\xf4\xee\x03\x78\xf2\xe4\
+\x0f\xe3\xcc\xc1\xb8\x21\x80\x21\xf3\xf1\x8b\x98\xee\x83\xbd\x8f\
+\x85\xf0\x39\xca\x47\x5e\x5e\x35\x30\x85\x88\xe3\x47\x4c\x79\x71\
+\xb5\x23\x65\x14\x74\x13\x97\xb0\x68\xc2\x02\xb6\xa3\x51\x9c\xc4\
+\xc5\xcf\xa1\xb8\xf8\x2c\xa6\x4f\x8f\x22\x37\x77\x2b\x66\xf3\x1a\
+\x10\x63\x9f\xf0\x2b\x5e\x00\xd8\x22\xa5\xdc\x7c\x2a\xc7\xf5\x49\
+\x63\xdc\x10\x00\x70\x1b\x20\x28\x05\xb4\x7e\x2d\x3b\x7e\x57\x46\
+\xf3\xbb\xeb\x90\xd2\x8d\x45\xda\xc9\x0c\x6e\x01\x48\x3c\xbf\x50\
+\x33\xba\xa0\x3d\x9b\xe9\x33\x2f\x47\xa3\x8f\x22\xb2\x98\x42\xe8\
+\x88\x8d\x9d\xc1\xc4\x89\x0b\x98\x31\x3d\x8e\x82\xfc\x9d\x58\x2d\
+\xab\x40\xb4\x01\xe0\x75\x40\xc3\xfb\x00\x7f\x38\x5d\x83\xfb\xa4\
+\x30\x6e\x74\x02\x85\x10\x35\x64\x51\xc0\x7d\xa3\x5e\x18\xe3\x9a\
+\x99\x7c\x47\x3f\x8d\xd1\x06\x7e\x12\x97\x37\xf3\x8d\x1b\x5d\xba\
+\x68\xc3\x11\x5d\xb5\x35\x6c\x66\x55\x47\x35\x73\x61\x94\x5b\xd0\
+\x61\xec\xdb\xaa\xb2\xfa\x71\x85\x0d\x8f\xb7\x01\xb9\x52\x1e\xdb\
+\x93\xd8\x78\xc6\xb8\xd8\x01\x86\x4e\xff\xb9\x8c\xe5\x7b\xcb\xef\
+\xc8\x60\xf3\x2f\x8a\x89\xdd\xde\x66\x28\x31\xac\x3c\xda\xe2\x03\
+\xe4\xcc\x62\xd1\xe4\x0b\x68\x12\x0a\x63\xbb\x7d\xc9\x9b\xa1\xd0\
+\xb2\x11\xe0\xd7\x9f\xf6\xc5\x87\xf1\xb5\x03\xfc\x91\x68\xee\xe2\
+\xe7\x8c\x79\x77\x31\x6b\x94\xca\xff\x2e\x9c\x5e\x1b\x35\xe5\xf6\
+\x98\xad\xba\xb3\xf5\x6d\xe4\x64\x86\x85\xf6\x88\xa6\x24\x6a\x18\
+\x6f\xc5\x5b\x6c\x75\xf7\x72\x28\x27\xb0\xab\x0e\x1e\x28\xe8\x06\
+\xb2\xa5\x94\xc7\x70\x27\x31\xfe\x31\x9e\xae\x81\xf5\xb8\x88\xa8\
+\x5f\xce\x3a\x90\x68\x54\x94\xea\x3b\x33\x32\x9d\x4b\xe3\xe3\x66\
+\x83\x4c\x5c\x16\xf8\x4a\x1c\x81\x88\x8c\xbf\x47\xa4\x74\x6f\xd0\
+\x5e\xd0\xb0\x56\x7b\x91\x77\x8f\x52\x66\xed\x53\x12\xf2\xa5\x14\
+\x56\x00\x45\x83\x79\xf2\x85\x2c\xe8\xa8\x61\x43\xc3\x26\x8a\x90\
+\x44\x76\x8e\xed\x2f\x01\xfc\xcf\x89\x2c\xfe\x90\x8f\x82\xb9\x40\
+\x3c\xb0\x0f\xe8\x22\x72\xf3\x18\x38\xd3\x14\x46\xc7\xd3\x0e\xf0\
+\x20\xf0\x43\xce\x05\xae\x04\x83\x50\x6a\x6f\xcd\xc8\x70\x2c\x4b\
+\x88\x9f\x23\x0e\x52\xea\x98\x9b\xbc\xb7\xca\xae\x77\x4f\x1c\xab\
+\x0e\x89\x90\xfb\x34\x93\x1b\xd6\x68\x2f\x6a\xdb\xa0\x5c\xa0\xee\
+\x53\xa6\x24\x78\x85\x25\x7f\xd0\x4d\xf7\xce\x67\x9a\x0d\xe1\xe8\
+\x8c\x38\x7e\x36\x1b\x1a\x37\xdf\x02\x74\x00\xd1\xa3\x1e\x0b\x91\
+\xb3\xc3\xe8\x47\x47\x64\xb1\xe3\x89\x78\x0b\x8d\x61\x6c\x45\x13\
+\x95\x08\x21\x0c\x3f\x03\x43\xff\xff\x83\x94\xf2\xfd\x13\x9d\xa3\
+\x8f\x83\x71\x41\x00\x42\x88\x74\xa0\x86\xe9\x98\x74\x5f\x12\x5c\
+\x91\x95\x5c\x7f\x6d\x71\x6a\x8e\x18\x63\x92\x53\xa3\x7a\x3f\x2c\
+\x8d\x6b\x58\xfc\x51\xeb\x0e\x08\xa3\x7f\xa7\xb2\xa0\xee\xef\xdb\
+\x4a\x33\xdf\x7e\x3d\xc9\xea\x7d\xe3\x05\xa4\x94\x48\x35\x8c\x2a\
+\x55\xa4\x1a\x46\xaa\x2a\x72\xe8\xef\x91\x34\xa9\x0e\xa5\x87\x51\
+\xd5\x03\x69\x8a\xcd\x82\x2d\xa7\x10\x5b\x6e\x11\xae\xac\x74\x2e\
+\x2f\xbb\x8c\x58\x5f\x08\xa7\xa3\x13\x57\x6f\x27\x4e\x47\x27\xce\
+\xde\xc8\xe3\x6a\xac\xa2\x66\xdb\x1a\x86\xb6\x9a\xc7\xa4\x94\x77\
+\x9f\xac\x39\xfb\xa8\x38\xe3\x3f\x01\x42\x88\x42\xe0\xdf\x09\x73\
+\xf4\xa6\x39\x97\xc5\xf0\x95\x45\x19\x28\x9a\xc3\xe4\x7f\x23\xe8\
+\xf2\xd9\xec\x1f\xa9\x62\x95\x36\x3a\x45\x93\x7e\x57\x28\x30\x6b\
+\xfd\xaa\xb8\xf0\xa6\x2d\xe6\x58\xd3\x15\x74\xdc\xf4\xd8\xc7\xea\
+\xa7\x37\xd0\x1f\x6e\xd6\x35\xfa\x4d\x39\x39\x51\x00\xad\x22\xe8\
+\x68\x29\x48\xb7\x14\x28\x39\x06\x0d\x82\xa4\xcc\x43\xcd\xe2\x43\
+\x41\xbf\x6a\x7b\xf4\xd2\x40\x7f\x15\xc6\x9d\x3e\xf8\x07\x7c\x4d\
+\x08\xf1\x3b\x29\x65\xed\x98\x0d\x9c\x22\x9c\xf1\x04\x00\xdc\x17\
+\x6b\xd6\x4e\x78\xfb\xca\x22\x5a\xd4\x70\xd8\xb7\xc3\xdf\x6b\xcd\
+\xd6\xb5\x8a\x58\xad\x2e\x24\x64\x26\x82\xe8\x83\x33\x87\x54\xcd\
+\xa4\xa0\xaa\x19\xd0\x29\x61\xdb\x41\xc9\x3e\xdc\xd4\x52\xa7\xe9\
+\x63\xa3\x62\x60\xb3\x26\x87\x1e\x91\x0a\xa4\x02\xb2\x9b\x82\x35\
+\x01\xf3\xa5\x1a\xf7\xe4\xe2\x6e\x22\x5b\xf8\x71\xc1\x65\xf5\x97\
+\xf7\xd9\x99\x60\x22\xc7\x0e\xd0\x6d\x31\x36\xed\xce\x48\x4d\x32\
+\x2a\x4a\xab\x26\x3c\x86\xff\x43\x41\x38\x23\xd5\xb8\xf5\x82\xb6\
+\xb5\xd1\xd5\xe9\x18\x74\xb5\xe4\xbe\x02\x78\xe0\x5a\xe0\x87\xc7\
+\xdb\xfe\x89\xe0\x8c\x26\x80\x21\xd7\xeb\x57\x2d\x9d\x64\x63\x9a\
+\xcd\xc4\x34\xd0\x00\x09\x34\x93\xc0\x90\x13\x77\x9f\x5e\x76\x0e\
+\x98\x65\x5b\x9f\x19\x97\x33\x4a\xe2\xd6\x4b\x4b\x77\x77\x8c\x23\
+\x75\xb0\xd7\xcc\x0e\x4d\x98\x0d\x9a\x04\x6a\x94\x7c\x42\x4c\x39\
+\xac\x01\x29\x06\x6a\xb4\xcb\xf6\x76\x30\x79\xa1\xaa\x28\x4d\xee\
+\xd4\xdc\xe3\x72\xf2\x1c\xd6\xaa\x3d\x5d\xc9\xee\xba\xa0\x56\x96\
+\x0d\x7f\x8d\x1c\x51\x86\xb6\xad\x99\xf6\x78\xc0\x98\x28\x75\xbd\
+\x70\xf8\x6e\x15\x1f\xc3\x3a\x5b\xf3\x9e\x6c\xad\xd7\x95\x59\x62\
+\x65\x70\x93\x16\x7c\x11\x1b\xe1\x31\xcf\x2e\xa7\x12\x67\x34\x01\
+\x00\xe7\x03\xc6\x0b\x4a\x6c\x47\xcc\x60\x0a\x88\x24\x93\x57\x7a\
+\x92\xfb\x06\x5d\x34\xf4\x6b\x69\xe8\x4f\xc3\x13\x9c\x0e\x47\x17\
+\xf9\xab\x68\x6b\xb6\x69\x6f\x36\x7a\x89\x9b\x0d\xd0\x1a\x9f\xd0\
+\x00\x63\x72\x1a\xc6\x84\xcb\x3a\x58\xde\x67\xf7\x4d\x80\x03\xae\
+\xda\xfb\xcc\xfa\xae\xcd\x59\x76\x2b\x43\x86\xa5\x79\xc2\x7c\x98\
+\x4b\xbb\x18\x0b\xab\x0c\x3a\x16\x65\x57\xfe\xfd\x5d\x4c\x5a\x1b\
+\xbe\x90\xcd\x9e\x80\xf4\xb7\x1f\x5b\x3b\xf9\x54\xe0\x4c\x27\x00\
+\x23\xc0\xbf\x77\xf6\x71\xf5\xec\x38\x8c\x3a\x05\x40\x25\xa4\xd6\
+\xe1\xf0\xb5\x53\xdf\x6f\xa0\x71\xa0\x00\x5f\x28\x97\x31\x7e\x69\
+\x47\x82\x5b\x24\xae\xdb\xa1\x5c\x3f\x5d\x45\x37\x62\xf2\x5d\x91\
+\x57\x98\xfe\x51\xca\x86\x35\xd2\xd1\x99\xec\xae\x0d\xe9\xd4\xb2\
+\x83\xcf\xa0\x03\x26\x9d\x63\x53\x76\x82\x11\xe4\x88\xa4\xaa\x50\
+\x35\x1f\x12\x8f\x26\xca\xc4\x86\x28\x73\x44\xa2\x99\xa1\xac\x2a\
+\xe2\xfa\xa9\x4a\xb0\xc6\xb1\x61\xe6\x07\xcd\x73\x35\x1d\x61\xc2\
+\x92\x34\x21\x84\xe9\x74\x5a\x10\x9f\xe9\x04\xb0\x1b\xd8\xff\xca\
+\xb6\xbe\xcc\x2f\xfc\xae\x92\x6f\x96\x18\xb9\xb8\xc7\xef\x51\x02\
+\x6a\x21\x70\xb8\xb3\xa1\x63\x23\xd0\x22\xe6\x6e\xa8\x57\x16\x1f\
+\xa2\xfb\x17\x56\x34\x75\x0e\xab\x35\xff\x58\x85\x5d\xb6\x40\x79\
+\x5f\xac\xf7\x90\x5f\x3d\x80\xd3\xa0\x1b\xd8\x90\x13\xaf\x91\xc8\
+\x43\x16\x3c\x17\xd3\x08\x51\x19\xf4\xec\x89\x89\xa6\x14\x10\xf1\
+\x3d\x9b\x6a\x45\x38\x54\x00\xa0\x2b\x8c\x9b\x6b\x49\xb7\xc9\x07\
+\x92\x3b\xc5\x4f\xdf\xed\x9a\xef\x0b\xaa\x3b\x85\x10\x37\x4b\x29\
+\xd7\x7d\x8c\xf1\x1d\x37\xce\x68\x56\xb0\x94\x72\x2b\x50\x00\xbc\
+\xb4\xba\xda\x83\x66\xb9\x83\xb7\x26\xf9\xea\x7b\x0d\x72\x2d\xc7\
+\xed\x90\x59\x69\xdb\xa5\x5c\x5b\x3b\x7a\xf1\x01\xf6\x27\x26\x1f\
+\x35\x2a\x44\x58\x23\x1d\x6d\xe9\xee\xf2\xbe\x58\x5f\x19\x88\x43\
+\x6e\x19\x6e\x83\xce\x5d\x9e\x17\x1f\x96\x1c\x1a\x2a\x46\x81\xde\
+\x18\xa9\xb5\x00\x68\x35\x34\xc5\xdb\x48\x65\x68\x47\x33\xb4\x3d\
+\xdd\x14\x24\x38\xa2\x63\x10\x63\xd6\x8a\xfb\x2f\x4e\x63\xf7\x0f\
+\x8a\x39\x7b\x42\x74\x01\xb0\x5a\x08\xf1\xeb\xa1\x18\x07\xa7\x14\
+\x67\x34\x01\x00\x48\x29\x03\xc0\x8b\x00\xe5\xdb\xa0\xec\xbf\xd4\
+\xfc\x0f\x16\xf8\xd3\x9e\x29\xf2\xbb\xeb\x6c\xa1\x0f\xa5\xa0\xe3\
+\x58\x75\x04\x89\xda\x56\xae\xb9\xdb\xd8\x2f\x32\x27\x8d\xf5\xbe\
+\x22\x3f\xef\x88\x91\x36\x5c\xb6\x40\x79\x6b\xfa\x00\x21\x6d\xb8\
+\x6c\xf4\x3b\x8f\x41\xe7\x5b\x9f\x1b\x3f\x28\x47\x11\x05\x40\x34\
+\xda\x66\x00\x45\xa1\x37\xc9\x0e\x08\xec\x00\x01\xd5\x2f\xaf\x70\
+\x3d\x1f\x57\x9d\x55\xbb\x71\x97\x76\x5b\xfd\x2e\x76\xed\xa8\x13\
+\xb5\xeb\xda\xd4\xce\x9d\x66\x4b\xc0\xff\xde\xb7\x0a\x78\xfc\xba\
+\x2c\xc5\x66\xd2\xdc\x03\x6c\x17\x42\x8c\xe9\xc1\xec\xe3\x42\x08\
+\x91\x29\x84\xf8\x82\x10\xe2\x2c\x18\x07\x04\x30\x84\x4d\x00\xe1\
+\x30\xc4\xc6\x12\x75\xd9\xdd\xc4\xd8\x0b\x65\xd7\xfb\x19\xc1\xc5\
+\x4f\x16\xfb\xe2\xcb\x93\x83\xe5\x01\x0d\xbb\xc7\x28\x27\x1d\xa2\
+\xf0\xc3\x0d\x9a\xbb\x4b\x83\x98\xc7\xe4\x0f\x04\x75\xda\xaa\x81\
+\xa8\xe8\xc3\x0e\x7f\xaa\xa2\xf6\x8e\xfc\xea\x85\x88\x1b\xfd\x7e\
+\x50\x75\x07\x36\x78\xfe\xd6\xac\x0e\x54\xec\x66\x0c\x8d\xa1\x34\
+\x69\x70\x0a\x81\x3f\x39\x8e\x56\xc4\x81\xc3\xe5\xbb\xcd\x4f\x56\
+\xd7\x39\xfb\xa6\xfd\xd4\x92\x69\xb5\x94\x98\x1b\x13\x52\x82\xee\
+\x01\xe9\x9c\xd7\xae\xb4\x4c\x6d\x33\xd5\x18\xb7\xc9\x6d\xc1\xd4\
+\x89\x2d\xa1\xdf\x7f\x4d\xcb\xb5\xe7\x88\x22\x45\x61\x9d\x10\xe2\
+\xe7\x43\x9e\xc5\x4f\x08\x42\x88\xf3\x81\x46\x43\x94\xf6\x2d\xe0\
+\x35\x21\x44\xdc\x78\xe1\x04\xde\x0b\xfc\xb2\x72\x2d\x4c\x3c\xf0\
+\xe5\xf7\xbe\xf7\x1c\xd5\xf5\x55\x4c\x1f\x4e\x48\xf1\x2a\x95\xf3\
+\xdb\xf5\xbd\xb1\x3e\x31\x1b\x29\x7c\x43\x57\xbc\xd9\x63\xd5\x39\
+\x8c\xba\x8c\x8c\x55\x1b\x8b\xa7\x1c\xe2\x2c\xc2\x6d\x1d\xdc\xd0\
+\x1b\xeb\x2b\x18\x6b\xe1\x01\x82\x72\x30\xb4\xda\xfb\x97\xb6\x20\
+\x83\x11\xf7\x32\x42\x13\x20\x79\xf1\x66\x52\x96\xd8\xd0\x18\x4b\
+\x00\x2e\x26\x7e\xdd\xcd\xb1\xf1\x1a\xad\x86\x43\xce\x0b\xb7\x7e\
+\x90\x5d\xdd\xe9\x6f\x2a\xb2\xe8\x0d\xed\x5f\x9e\x5c\x62\xfd\x86\
+\xac\xd9\x52\x10\x76\xd9\x6b\xeb\x30\x7a\x3d\x14\x8c\x6e\xeb\xed\
+\xcd\xf0\xc0\x53\xa0\xaa\x6c\x02\x2e\x93\x52\x1e\x3b\x88\xd1\x18\
+\x10\x42\x64\x03\xdb\xa7\x5e\x94\x15\x73\xfe\xf7\xa6\xf1\x8b\xb2\
+\x57\x01\xbe\x7f\xa6\x1f\x02\x87\x71\xc7\x9c\xe9\x91\xc5\x97\x21\
+\x3a\xc5\x00\xd5\x78\x49\x3a\x77\x3e\xd9\xce\x6e\xca\x7b\x7a\x28\
+\x03\x68\x37\xab\xc5\x2f\xe4\xf9\x31\x85\x44\x77\x4a\xd7\xbc\xed\
+\x3a\x67\xc9\x54\x71\x74\xfa\x96\x15\xb9\x05\x23\x93\xae\x2a\x6a\
+\x6f\x47\xaa\x67\x6f\x48\x7b\xe8\x09\xff\x60\x84\x65\x48\x5d\xeb\
+\x7f\x72\x5f\x90\xc1\x09\x07\x6a\x09\xeb\x69\x5f\x79\x16\xed\x2b\
+\x21\x3a\xbb\x9a\xcc\xcb\x7a\x66\x24\x26\x0d\xe8\x34\x72\x3a\x82\
+\x2e\x40\x91\x52\x6a\xba\xfc\x4d\xb2\xd3\xdf\x34\x01\xc0\x1d\x18\
+\x4c\xa9\xef\xeb\x5f\xf5\x68\x6c\xe1\xa2\xeb\x75\x4d\xab\xe6\x17\
+\xf5\xcc\x1b\x70\xb2\xaa\xbe\x9e\x99\x6a\xf8\x80\xee\xe2\x79\x33\
+\xa1\xdf\x17\xcf\xa3\xcf\x0f\xcc\x0e\x86\x82\x5b\x85\x10\xd7\x4a\
+\x29\x57\x7e\x8c\x39\x7c\xdc\x12\x6f\x8c\xb9\xe5\xef\x67\x63\x88\
+\xd6\x63\x8e\xd1\xe3\xed\x0f\x4c\x38\xe3\x3f\x01\x43\x86\x20\x19\
+\x67\x95\x1a\xa8\xdb\x80\x53\xb4\x91\x80\x87\x45\x48\x8a\x84\xc0\
+\xbe\x68\x31\x01\x7b\x02\x6b\x0f\x2e\xe3\xd3\xca\x84\xfa\xd4\x75\
+\xe7\xd5\x4c\xf8\x91\xbd\x2b\xe9\xdd\xf5\x61\x8d\x7f\x4c\x7f\xbb\
+\x01\x9d\x6e\x8f\xdb\x68\x4a\x05\x70\x47\x07\x37\xb4\x66\xba\xd5\
+\xc8\xe2\x8f\x0d\x89\x94\xeb\x7c\x4f\xef\x1e\x54\x3d\x13\x8e\x94\
+\x27\x59\xd6\x1a\x6e\xd0\x7d\x2f\xc6\x37\x78\x4b\x91\x3d\xa6\xdc\
+\x91\x14\xbb\x2d\x31\x29\x76\x5b\x7c\xb2\x7d\x7b\x6c\x49\x4a\x8f\
+\xfd\xb2\xfc\xcb\xaa\x88\x08\x85\xf8\xb0\xa9\x61\xa6\x94\x74\xfe\
+\x5d\x66\x2d\x7a\x85\xd4\x4d\x36\x2b\x67\x95\x96\xe2\x4c\x8c\x3f\
+\x60\x71\x2c\x04\x5c\xb9\xa0\x87\xeb\x2e\x2c\x26\xc6\x96\x96\x08\
+\xbc\x23\x84\xf8\x95\x10\xe2\x23\xb8\xbb\x1e\x99\xc3\x2f\x02\x4b\
+\x97\x3f\x38\x1d\xa3\x55\x8f\x10\x90\x37\x3b\x07\xc0\x7c\xc6\x13\
+\xc0\x90\xf8\xb4\x71\x5b\xfd\x15\xac\xda\x59\xa2\x01\x0e\x31\xd1\
+\x8a\xb3\x53\x96\x92\x42\x66\x42\x0a\xab\x0f\x2b\x2b\xc2\xba\x5e\
+\xfb\xda\x79\xb5\x85\x3f\x9b\xd4\x9c\xf9\x74\x45\xc0\xd0\xbd\x0e\
+\x18\x09\x13\x57\x9f\x9a\xd1\xa7\x2a\x6a\x6f\x5b\xaa\xbb\xbc\x37\
+\xce\x3b\x57\x22\xe3\x47\xd7\x71\x30\x36\xf8\x9f\xdb\xe8\x95\xfd\
+\x63\x06\x6b\x34\xeb\x55\xcf\x25\x13\xfb\xb6\x5e\x34\xa1\x3f\xdb\
+\xa8\x91\x25\x8e\xc1\xe6\xbc\xbf\xd6\x7e\x63\x42\xbd\x73\xcb\xbb\
+\xc3\x6d\x2a\x42\xe1\x7b\x73\xbe\x57\xfc\xbb\xc5\xbf\x73\x68\xd1\
+\xba\x82\x61\x35\x6a\x5b\x6b\x7b\x03\xc0\x5b\xa4\x9c\xf5\x94\xcc\
+\xd9\x21\xc0\x96\x91\xc5\xdc\x29\x93\xd9\x62\x34\xb2\x1f\x22\x44\
+\xf0\xb5\x65\x3b\xb9\x70\x41\x3a\x39\xd9\x0b\x15\x8d\x46\xff\xdf\
+\x40\xa5\x10\xe2\xc2\x63\xcd\x9f\x10\xc2\x00\x3c\x9c\x58\x60\x63\
+\xce\x35\x93\x68\xfc\x67\xc9\xae\x35\x37\x2f\x0b\x87\x5b\xa6\x00\
+\x24\x8d\x97\x33\xc0\x8a\x79\xf3\xae\x5e\x56\x52\x72\x36\xf9\xc9\
+\xb5\x55\xdf\xbc\xf0\x11\x93\x4e\x13\xce\x1e\x7e\xef\xf3\xb1\x61\
+\x4d\x39\x73\x03\x83\xac\x6d\x6b\x61\x38\xc8\xc3\x98\xd0\x05\xa3\
+\xbb\x12\xbb\x96\x55\x5a\x5c\x13\x27\x3e\xb7\x7c\x61\x45\x6b\x0a\
+\x93\x15\x30\xeb\xc0\xad\x93\x8a\x57\x8f\xf0\xe9\x55\x06\x0d\x42\
+\x09\x68\x55\x42\x26\x94\xb0\x5e\x6a\xd4\xd7\xc2\xaf\x85\x76\xab\
+\xdb\x0f\x8b\x11\xac\x55\x24\x33\xa3\xfb\xdb\x4a\xb2\x82\x89\x8a\
+\x4e\x68\xa5\x2a\x19\x68\x1e\xc4\x9a\x6a\x40\xd1\x45\x3e\x23\x16\
+\x4f\x5a\xeb\x05\xd9\x5f\x53\x62\xe3\xe2\x46\xbc\x59\x74\xb8\x3a\
+\xb8\xfc\x8f\x97\x33\x98\xee\xe7\x9a\x9c\x29\x3d\x16\xbb\x2e\x1e\
+\xa0\x48\x3a\x2b\xbe\x25\x6a\x53\x88\x78\x52\xf5\xf7\xf5\xb1\xa1\
+\xb1\x91\x32\x55\x8d\xa8\xb0\x3d\xf6\xe6\xac\xe0\x96\x7d\x85\xba\
+\xfd\xfb\xd7\x30\x30\xb0\x1f\x22\x37\xa4\xff\x92\x52\xb6\x8f\x31\
+\x6f\xa5\xc0\xdf\xf4\x5a\x31\xe5\x6b\x5f\xca\xa3\xdf\xf6\x70\x97\
+\x4e\x29\x8c\x07\x94\x95\x2b\xef\x63\xdf\xbe\x77\x06\xc6\x0b\x01\
+\xfc\x6a\xf6\xec\x4b\xfe\xbb\xb4\xf4\x02\x00\xf4\xda\x80\xf7\xdb\
+\x17\xfe\x66\x6b\x4e\x52\xe3\x88\x36\xcf\xf6\x9d\xec\xec\x76\x30\
+\x35\x14\x62\x63\x6b\x13\xd3\xa4\x1c\xdb\x27\x80\x22\xe8\x88\xb7\
+\xb2\x77\x62\xaa\x2e\x66\x95\x79\xad\x1d\x34\x69\x1c\xe3\x36\x54\
+\xd7\xbb\xb7\xfb\xd7\xf6\x67\x13\x24\x87\xce\x55\xbe\xc6\x29\x67\
+\xa6\x78\x84\xaf\xd3\x4f\xfb\x36\x37\xed\xdb\x5c\x74\xec\xf0\x10\
+\xf4\x86\xd1\x1a\x15\x92\xa6\x58\x48\x9d\x69\x21\x75\x66\x34\xf6\
+\x1c\x0b\xa9\xad\x0b\xb9\x68\xf9\xd5\x00\x3c\x70\xdb\x03\xbc\xf3\
+\xf2\x3b\x70\x31\xa4\xcd\x8d\xe6\xc2\x73\x0f\x9c\x6e\xd3\xf1\xd5\
+\xdf\x47\x95\x51\x83\x4c\x05\x90\x2a\x8d\x0d\x8d\xf4\xf6\xf5\x45\
+\x0e\xbc\x7f\xfd\x60\xa6\xa3\xbc\x76\x7a\x9c\xc3\x51\x43\x73\x73\
+\x39\xe1\xf0\xa0\x13\xf8\x1e\xf0\x67\x29\xa5\x2a\x84\x98\x4e\x24\
+\x3e\xc2\xc5\xcb\x4a\x6c\xe2\x96\x04\x41\x49\xc0\x14\x5e\x75\xc3\
+\xeb\x1a\x97\x8b\xb6\x9d\x3b\xdb\x93\x3f\xf8\xe0\x27\x4a\x4b\xcb\
+\x86\x9a\xf1\x42\x00\x97\x4e\x9b\x76\xc1\x4b\xb3\x66\x5d\x72\x48\
+\xfa\xfc\xa2\xb5\x1b\xaf\x5b\xf0\x6c\x81\x10\xd2\x1e\x0a\x51\xfd\
+\xfe\x6a\x26\x00\x22\xac\xb2\xa3\xa5\x91\x7c\xa9\x62\x01\x10\xd0\
+\x15\x17\x4d\x75\x71\x3a\x31\xc9\x36\x26\x0f\xbb\x19\x7f\xa1\xed\
+\xfe\x56\x19\x7f\xc9\x98\x6a\x63\x0e\x47\x1b\xfb\xf7\x57\xd3\xe4\
+\x69\x62\xd5\xdc\x66\xbc\xf8\x41\x4a\x50\x41\x09\x86\x49\x6b\x6c\
+\x21\xbc\xa9\x83\xae\x9d\x6e\x42\xfe\x11\x96\x7f\x0b\x91\x00\x15\
+\x1d\x44\xa4\x8a\x33\x20\xe2\x60\x4a\x6b\x54\x48\x9e\x6a\x21\x3b\
+\xbb\x00\x06\x4c\xac\x7d\x6b\x6d\x08\xf8\x1b\x70\x39\x93\xb1\x9d\
+\x7f\x77\x3e\x99\x33\x0e\xc8\x3c\x62\xd4\x40\xf7\x8f\x94\x8a\x01\
+\x3d\xea\x08\x87\xd2\xef\xa7\xbc\xa6\x86\xdc\x60\x90\xa4\x97\x36\
+\x4c\xdf\xf7\xce\xae\x99\x79\xc1\xa0\x97\x8e\x8e\x1d\x38\x1c\xb5\
+\x84\xc3\x83\x0d\x44\x42\xe3\x4c\x2d\x4c\x32\xf2\xeb\xcb\xd3\xb1\
+\xed\x73\x1e\x67\x2a\x29\x00\x00\x0d\xf6\x49\x44\x41\x54\xb0\xa0\
+\xbb\x8f\xb6\x84\xc9\x75\x2b\xbe\xf4\x74\x3e\x40\x47\x47\x23\xbf\
+\xfd\xed\xdd\x0c\x0c\x74\x3c\x30\x5e\x08\x60\x49\x49\xc9\x39\xef\
+\xce\x9b\x77\xd5\x61\xef\x6c\xe6\x81\xae\xef\x5f\xfa\x93\x66\x9b\
+\x79\x60\xc6\xbe\x06\xd6\xee\x6b\x60\x3e\x80\x54\xa9\x34\x21\x7a\
+\x52\xa2\xb1\xc7\x47\xc9\x62\x31\xc6\x67\xa1\xd9\x9d\xce\x3a\xed\
+\xcb\xf4\xf4\xb4\xb0\x7f\x7f\x75\x64\xc1\x9b\xaa\x69\x6e\xae\xc6\
+\xe3\x39\x2c\x1a\x0d\x44\x3c\x82\x1c\xac\xd1\xd3\x4b\x84\x5d\xbd\
+\x11\x58\x39\x3a\x66\xdf\x90\x67\xd3\x62\xc0\x46\x44\xa3\xa8\x18\
+\xb8\x85\x08\x0b\xfe\x3e\x29\xe5\x4b\x42\x08\x2b\xf0\x55\x73\xa2\
+\xf6\x9e\xa2\x8b\x12\xe2\x0b\x2f\x8c\xc3\x9a\x11\x51\x58\x36\x0c\
+\x06\x03\x3f\x52\x2a\xf7\xd9\x74\xa1\x83\xa5\x84\xee\xee\x6e\xb6\
+\x36\xef\x67\xfe\x9b\x3b\x4a\xab\x5e\xd9\x34\x6b\x12\x08\xa1\xaa\
+\x61\xfa\xfa\xf6\x71\x7b\x69\x2d\x17\x4d\xb5\x51\x92\x6a\x62\xf3\
+\x4b\xb5\xce\x59\x0e\xa7\x15\xa0\x2e\xf7\x9c\x9d\xef\x2f\xf9\xe5\
+\xd4\x8e\x8e\x46\x7e\xf3\x9b\xaf\xe0\x74\x3a\xb6\x02\xe7\x8c\x17\
+\x02\xf8\xfa\x84\x09\x67\xfd\x7e\xd1\xa2\x1b\x8e\xf0\x5e\xca\xab\
+\xcf\x7a\x7e\xf5\xa2\x89\x1f\xe6\xbc\xbf\x5a\xc6\x85\xc3\x98\xcb\
+\xa6\xe9\xd6\x45\x5b\xc4\xb0\x0d\x59\x48\x86\xa4\x5f\xaa\xd2\x19\
+\x0e\x8a\x8e\x97\x56\x93\xfe\xda\x66\x53\xe2\xce\x5d\x4e\x5a\xfa\
+\xcd\xdd\x2e\x8f\xf7\x7d\x22\x26\xe0\x7d\x63\x3c\xbd\xc3\x7f\x9f\
+\x4a\x25\x51\x21\x84\x89\x08\x71\xdc\x9b\x34\xc5\x92\x55\xb8\x3c\
+\x8e\xbc\xf3\x62\x31\xea\x24\xb7\x75\x57\x76\xcd\x48\x0f\x26\x1e\
+\x9c\x5f\x55\xd9\xbb\x6f\x1f\xc1\x57\xd7\x97\xa8\xcf\xaf\x9f\x57\
+\x0c\x68\x2d\x3a\x5f\xf8\xc3\x9b\x2b\x35\x00\xdd\xaf\xee\xdd\x95\
+\xd0\xe5\x1e\x11\x81\xbf\x9c\x7e\xbe\x7b\xcf\xf4\xdb\x2c\x43\x8b\
+\x5f\x05\x2c\x94\x52\xf6\x7c\x2c\x02\x18\xe2\x51\x5f\xc9\x01\x57\
+\xef\xdb\xa4\x94\x6f\x7f\xec\xd1\x1f\xbb\xbd\x27\xf3\xf2\x66\xde\
+\x72\xee\xb9\xb7\x1f\x35\x5f\x56\x7c\x53\xdd\x79\x13\x9f\xae\xcb\
+\x4f\xe8\x36\x5a\x63\x74\x39\x1e\xad\xb5\xb3\x47\x9f\xea\x6a\x33\
+\xe7\x07\x9b\x2c\x45\xfa\x46\xcb\x24\x6b\x53\x28\x39\x4f\x5a\x6c\
+\xb1\x2b\x96\xde\x41\xdb\x07\x9b\xdf\x05\x96\x7d\xd2\x1e\xbb\x0f\
+\x86\x10\x42\x0b\xdc\x00\xdc\xaf\xd1\x2b\x39\xd9\x8b\x6d\x14\x9e\
+\x1f\xc7\xad\xd6\x66\x79\x59\xe9\xe0\x68\xe6\x84\xf4\x78\x58\xfb\
+\xd4\x6b\x13\x75\x7f\x5b\xb5\x60\xfa\xcc\x54\x57\xed\x9f\x2f\xac\
+\x99\xc8\xab\xd5\xeb\xe8\xf6\x1c\xa2\xed\xfc\x78\xd4\x2c\xee\xdd\
+\x51\x8f\xd3\xe9\xd8\x05\x2c\x1d\xf6\xd8\x7e\x5c\x04\x20\x84\xc8\
+\x07\xee\x04\x6e\x52\x14\x6d\x5c\x5a\xda\x2c\x1c\x8e\x3a\xbc\xde\
+\x6e\x09\xdc\x22\xa5\xfc\xeb\x09\x8c\xfd\x68\xed\x96\x27\x26\x16\
+\xcd\x3d\xf7\xdc\x6f\x06\x16\x2f\xbe\x63\xe0\x48\xf9\x54\x55\x06\
+\xde\xde\xb7\x69\x5b\xf2\x2f\x0d\xe7\xab\x42\xa3\x3b\x52\xbe\xda\
+\x67\x57\xf0\xc1\xf5\xf7\x39\x81\x49\x1f\x97\xb3\x76\xaa\x31\x44\
+\x08\xd7\x03\xdf\x07\xf2\xcc\x71\x3a\xae\x5c\x60\x27\x31\xf9\x3c\
+\x99\x9b\xea\xeb\x9c\x92\xdf\xdb\x93\x6e\x6f\x35\xd9\xa3\xdb\xe3\
+\xcd\x7a\x67\xf8\xb9\xb7\x27\x6e\xec\xaf\xcf\x33\xdf\xdd\xf8\x9e\
+\x8e\x01\xff\x61\xa6\xf0\xdf\x6f\x85\x9f\x76\xd0\x04\x4c\x3f\x38\
+\x30\xe7\x61\x04\x30\xb4\x15\x0d\x6b\xc1\xc6\x12\xb9\x8e\x4c\x05\
+\xae\x02\x66\x68\xb5\x26\x12\x13\x27\x91\x90\x50\x4c\x7a\x7a\x19\
+\x69\x69\xb3\x58\xbb\xf6\x67\x54\x56\xbe\x34\x00\x24\x9f\x0a\xa7\
+\x49\x42\x88\xf6\xb4\xb4\xd9\xc9\xc9\xc9\x53\x29\x28\x28\xfb\x70\
+\xf1\xe2\x9b\x16\x0d\x31\x88\x46\xa0\x4a\xd9\xf9\xbb\x15\x2b\x7a\
+\x1a\x3c\x5d\xa1\x1b\x57\xce\x1c\xf3\xae\x0e\x10\xe8\x77\xf1\x7c\
+\xd1\xc5\xf8\x3a\x1d\x77\x4b\x29\x3f\x9e\x02\xe0\x69\xc4\xd0\x39\
+\xe2\xcb\xc0\xfd\x40\x41\x5c\x5c\x21\x67\x9f\xfd\x30\x76\xfb\x21\
+\xd2\x6b\xaf\x56\x13\xec\xce\xab\x7a\x74\xc7\xf2\xc6\xe7\x63\x53\
+\xa3\x89\x8a\xd2\x93\x2f\x22\x67\x0f\x00\x26\x56\x40\xb5\x9f\xef\
+\x4a\x29\x7f\x71\x70\x41\xed\x50\x23\x93\x80\xff\x22\x12\xec\xea\
+\x50\x39\xbb\xd1\x06\x51\x49\x10\x95\x48\xba\x46\x4b\x62\x6c\x1e\
+\x62\x28\xaa\xaa\xd9\x6c\xaf\x04\x8a\x4b\x4b\x6f\xa6\xa6\xfa\x15\
+\x5b\x48\x55\xbf\x08\xbc\x70\x92\x27\x40\x0f\x24\x19\x8d\x11\x4d\
+\xeb\xda\xda\x0d\x8b\x3b\x3a\x6a\x37\x5e\x7e\xf9\x83\xc5\x3a\x9d\
+\x31\x1a\xc0\x17\x08\x54\x3c\xf4\xc2\x0b\xf1\xfd\x1e\xcf\xa4\xe8\
+\x54\xe3\xc6\xa3\xd5\xb7\xe9\xbe\xdf\xe3\xeb\x74\x54\x00\x7f\x3e\
+\x99\xfd\x3c\x55\x18\xf2\x4d\xf4\x8c\x10\xe2\x39\xe0\x1a\x87\xa3\
+\xe6\xfe\x97\x5f\xbe\x7e\xc2\x8c\x19\xb7\x53\x5a\x7a\x13\x22\xe2\
+\x17\xcb\x1c\x0a\xeb\xb2\xfe\x23\x8b\x1a\x4c\x4d\x4a\xbe\x46\xaa\
+\xa9\x00\x71\x26\x5a\x14\x0b\x51\x2e\x3d\xb1\xd5\x7e\x24\x70\x58\
+\x20\x2a\xed\x50\xa4\xea\x15\xe4\x96\xc1\x45\x0f\xc3\x40\x27\x78\
+\x3c\xa0\x4d\x04\x63\x22\x68\x0f\x98\xd0\xb5\x7a\x7b\x3c\xa6\xc6\
+\x0f\xfa\xac\xbe\xbe\xf4\x48\xe7\x84\x0b\xc0\xe6\xdd\xc6\x65\x13\
+\x92\x79\xbe\xaa\xed\x3a\x4e\x32\x01\x10\x51\xd3\x12\x11\x02\x88\
+\xc0\xe5\x72\xcc\x79\xe6\x99\x7b\x1a\x2e\xbe\xf8\xbb\x5d\x3e\x4c\
+\xed\x3f\x7f\xed\xb5\x59\xa1\x70\xd8\x00\x60\x88\x56\x8e\xf8\x3d\
+\xef\xde\x5c\x41\xd5\x5f\x5e\x04\xf8\xef\xf1\xe6\xf4\x69\xd8\x55\
+\xbe\x10\xe2\x1f\xaa\x1a\xbc\x6a\xf3\xe6\x3f\x3e\xd0\xd4\xb4\x6a\
+\xe2\xe2\xc5\x3f\x24\x26\x26\x1b\x80\x60\x78\x50\x57\x91\x58\xd2\
+\x32\xa5\x73\x57\x2a\x80\xc3\x47\x3a\xbe\x91\x55\xdf\x2e\xa5\xdc\
+\x3f\xba\x5e\x2d\xf0\x17\x4a\x2e\x80\xbb\x57\x44\x78\x8e\x07\xc3\
+\x3b\x30\xc0\x40\xaf\x03\xb7\x57\xc5\xaf\xc6\x49\xe2\x63\x6b\x8b\
+\xaf\xd0\x25\x74\x55\xac\xca\x6c\x5e\xb7\x28\x14\xf2\xab\x7a\x11\
+\xaa\xf9\xe5\x84\x2d\xf9\x1b\xa3\x32\x95\xe7\xab\xda\x2e\x10\x42\
+\x94\x9c\xe4\x20\x87\x4b\x0d\xd8\x98\x53\xf9\x1d\x4c\xd8\x1d\x5a\
+\x8c\x1e\x19\x09\x11\x24\xff\xdd\xb5\xa3\x7e\x63\x5c\xf3\xd2\x83\
+\x33\xeb\xad\x9a\x31\x17\x56\xaa\x2a\x6b\xee\xfc\x11\x52\x55\xdf\
+\x92\x52\xbe\x75\x12\xfb\x77\x5a\x31\x14\x3a\xf7\xff\x84\x10\x2f\
+\x76\x75\x55\xfc\xe2\xa5\x97\xae\xfd\xd6\xac\x59\x5f\x63\xf2\xe4\
+\x6b\x70\xb9\xda\xa6\x6e\x4d\x9d\x61\x9a\xdc\xb9\xbb\x5f\x20\x63\
+\x20\x62\x92\xb4\x26\x52\xf4\xc5\xb1\xea\xd3\x02\x41\xc2\x41\x08\
+\x78\xc1\x30\xca\x81\x86\xd9\x66\xc3\x6c\x3b\xc0\x9d\x90\xb2\x15\
+\xaf\xbf\xa9\x3b\x35\x59\xef\x4a\xce\x7f\x3f\xad\xa3\x92\xff\x67\
+\x7d\x52\x63\x52\xc2\xca\xe2\xcc\x38\xa6\x25\x59\x75\xdb\x3b\x9d\
+\x3b\x84\x10\x4f\x03\x0f\x49\x29\x5b\x4f\xc2\x98\x2f\xcb\x62\x31\
+\x56\x99\x01\x91\x70\xac\x71\x40\xdf\x9b\xb9\x9b\xea\x6b\x62\x5b\
+\x97\x8e\xce\x6c\xb6\x1b\xc6\x3c\xd5\x56\x3c\xf6\x3c\x3d\xdb\xaa\
+\xc2\xc0\x3d\x27\xa1\x4f\x9f\x38\xa4\x94\x41\xe0\xdb\x42\x88\x0f\
+\x36\x6c\x78\xe4\xaf\x8d\x8d\x1f\xda\x6d\xb6\x4c\x8b\xd6\x92\xc4\
+\x8e\xd8\x6c\xef\xb4\xbe\x86\x18\x80\xd7\x18\x11\x9e\x1c\x91\x00\
+\xae\xa1\xea\xbd\x35\xfc\x64\xba\x0e\x7b\x16\x89\xe9\x79\x14\xcd\
+\x5c\x4c\xfe\x94\x32\xf4\x06\x63\x48\x6b\x34\x79\x8c\x46\x8b\x47\
+\x28\x22\xac\x84\x85\x10\x3e\x53\xba\xf4\xa6\x75\x89\xac\x34\x61\
+\xaa\x75\x74\xcb\xfa\x8d\x15\x6a\x4a\xa2\x4b\xea\xf5\x53\x37\xdd\
+\x34\x5f\xbc\x5a\xd3\xa1\xf9\xc3\xd6\xa6\xdb\x56\xed\x77\x7c\x59\
+\x08\xf1\x14\xf0\xb8\x94\x72\xd7\xc7\x19\xa4\x88\xc8\xe3\x17\xe5\
+\x70\x60\x9d\x43\x42\xad\x7b\xae\x78\xa5\xbe\xcf\xe8\x9e\x31\x56\
+\x19\x53\x9c\xee\x30\x39\xae\xb7\xbd\x87\x2d\xf7\xff\x0f\xc0\x53\
+\x52\xca\x4f\x95\xc7\x2f\x29\xe5\xeb\x42\x88\xd2\x8e\x8e\xed\x4f\
+\x76\x74\x6c\x5f\x1a\x1d\x9d\x8a\x3f\x36\x2f\xb5\xa4\xaf\x81\x56\
+\x60\x6b\x24\xdb\xb7\x8e\x64\x70\x22\xa4\x94\xc3\x21\xdd\x7f\x4b\
+\xc4\xf6\x2d\x0d\xb0\x5b\x62\xe2\x29\x5b\x7e\x23\x8b\x97\xdd\x85\
+\xd1\x65\x53\x0d\xaa\x25\xa0\x53\x8c\x23\x3a\x6a\x52\x4a\xb6\x6f\
+\x7f\xb2\x73\x62\xe5\x0f\xf7\xda\x1b\x5b\xbc\x73\x4d\x86\x28\x6b\
+\x61\x4a\x42\xbb\x2d\x21\xab\x30\x25\xce\xb4\xaf\xdf\xcb\x1f\xb6\
+\x34\xf2\x6c\x45\x2b\xde\x60\x78\x33\xf0\x04\xf0\x4f\x29\xa5\x6b\
+\xac\x8e\x8c\xd9\x39\x21\x6e\x01\xf1\xe4\x1d\x54\x60\x26\x1e\xa7\
+\xde\xbf\xf1\xd9\xe2\x77\x27\x05\x34\x21\xcb\x11\x8a\xf4\xcd\xbb\
+\x2f\xb3\xa2\xe4\xd2\xd4\xf9\x07\x27\xae\xbc\xe6\x3b\xec\xfb\xe7\
+\x5b\x3e\xa0\xe0\x24\xed\x4a\x67\x24\x84\x10\x77\x02\xbf\x02\x2c\
+\xb1\x44\x58\x95\x12\x7e\x2a\xa5\xfc\xfe\x11\xcb\x8c\xc5\x07\x18\
+\x22\x88\x5b\x89\x44\xbe\xb6\x58\xad\xc9\xcc\x9d\x7b\x3d\xf3\xe7\
+\xdf\x2a\x93\x92\x26\x08\x80\xde\xde\xda\xba\x86\x86\x0f\xf2\x4d\
+\xfe\xb6\xc6\xe2\xaa\x1f\xa6\xd3\xcf\x3a\x63\x03\x96\xf3\xf2\x30\
+\x27\x44\x6b\xf2\x76\x45\x25\x48\x4f\x5c\x92\x21\x2d\x29\x96\x57\
+\x6a\x3a\x79\x6c\x6b\x13\xf5\xfd\x5e\x0f\xf0\x12\xf0\x14\xb0\xfa\
+\x58\x96\xb2\x42\x88\x7f\x27\x33\xfd\x8b\x57\xf3\x26\xb5\x31\xfb\
+\xd7\xae\xc8\xdb\x76\x16\x07\x34\x35\x82\xe8\xf4\x0d\xc4\x98\xfb\
+\x49\xb2\x6b\x49\x8e\x4d\x21\xfd\x9e\xbd\x67\x9f\x77\x9e\xa5\xa0\
+\x70\xf9\x48\x1c\x80\xd6\xf7\x36\xb0\x62\xe9\x57\x00\x7e\x29\xa5\
+\xfc\xce\x47\x9e\xcd\x71\x0a\x21\x44\x26\x30\x8f\x88\xc1\x6a\x40\
+\x4a\xf9\xfc\x51\xf3\x1f\x6d\x0d\x86\x2c\x73\xee\x06\xbe\x41\xc4\
+\xf2\x95\xdc\xdc\x32\xce\x3a\xeb\x16\xa4\x1c\x08\x99\xcd\xd1\x5a\
+\x80\xe2\xbd\x3f\x5e\x6b\xf2\x36\xcf\x17\x41\x36\xcb\xdd\x14\xc6\
+\xea\xa8\x38\x2f\x8f\x0c\x9b\x91\x0c\x55\x0a\xb9\xdb\x68\xa7\x2f\
+\x2e\x49\xec\x0f\x29\xfc\x63\x6f\x27\xef\xd4\x77\x23\x23\x42\x93\
+\x3f\x13\x89\xa4\xe5\x18\xba\xee\x15\x03\xa5\x43\x4f\x09\x30\x7f\
+\x86\xf8\x86\xc1\x9d\x36\xd5\xb3\x37\x4d\xba\xb0\x9a\x1d\xc4\xc7\
+\x2a\x24\xc7\x24\x63\x8b\x8e\x45\x33\xc4\xde\x17\x7e\x1f\xf6\x8b\
+\x76\xa2\x69\x9e\x7b\xe1\xb2\x3f\x57\xa4\xa5\x97\x4d\x02\x08\x0f\
+\x06\x78\x71\xca\xe5\x0c\xd4\x34\xb5\x02\x13\x8f\x67\xf7\xf9\xac\
+\xe0\x23\x71\x02\x87\xb4\x4f\x6e\x24\xc2\xab\x9e\x29\x84\xc2\x92\
+\x25\x77\x90\x93\x33\x14\x90\xc3\xb5\xdf\x3d\xa3\xee\x27\x06\x40\
+\x27\x24\x0d\xb2\x02\xf0\x91\x96\x1a\x4d\xf9\xd2\x5c\xa6\x18\xb4\
+\x8c\x04\x68\xae\xd6\xdb\xa8\xd0\x58\x78\xa3\xcb\xcf\x8b\x4d\x03\
+\x38\x07\x43\x3e\x22\xec\xe4\x74\x84\x10\xc4\x65\x43\x6a\x31\xa4\
+\x14\x43\x6a\x31\x7a\xf3\x04\x02\x76\x0b\xe8\x8f\xc0\xd8\xd3\x74\
+\x77\x61\xbf\xa0\x07\xe1\x29\x06\xb8\xfc\xf2\x17\x1a\xec\xf6\xc2\
+\x1c\x80\x6d\x0f\xff\x85\x2d\x3f\xf8\x23\x44\x74\xe9\x5e\xfe\x98\
+\x73\xf4\xa9\xc6\x71\xcb\x02\x84\x10\x25\x44\x58\x94\x97\xc7\xc7\
+\x67\xe6\xce\x9c\x79\x31\x99\x99\x25\x98\xb6\x3c\xec\x2d\xd6\xb4\
+\x0e\xab\x29\x39\xa9\x67\x2f\x0e\x66\x01\x03\x45\xf1\xec\x38\x2b\
+\x83\x39\x1a\xe5\x50\x7b\xad\x46\x0c\x3c\xd7\x0b\x0f\x4c\xfc\x16\
+\x32\xb5\x04\x92\x8b\x0e\xbf\x89\x1c\x0d\xba\x3d\xb5\xc4\x5e\x1d\
+\x05\xe1\x11\xbf\xf1\xd7\x5d\xb7\xb2\xc7\x6c\x8e\x8f\x77\xee\x6b\
+\xe6\x85\x92\xcb\x08\xfb\x07\xff\x23\xa5\xfc\xe2\x71\x0d\xf2\x33\
+\x84\x13\x92\x06\x0a\x21\xbe\x00\x3c\x9d\x94\x94\x9b\x32\x63\xe2\
+\x2c\x96\xb9\x9e\x47\x77\x40\xe8\xaa\xd2\xc3\x1a\x1a\x58\x04\xa0\
+\x08\xda\x66\xa4\xd0\x50\x9a\x42\x99\x38\x48\x01\x43\x05\x8a\xe6\
+\xfc\xc4\x55\x9b\x72\x69\xf4\x61\x0d\x1c\x0d\x86\xff\x6c\xc1\xfa\
+\xff\x8a\x10\xf2\x90\x03\xe1\x6d\xb7\x6e\x0e\x2a\x1a\xbd\xee\x8d\
+\xf3\xef\xa2\xe5\xed\xf5\x1e\x22\xfc\xfe\xa6\x8f\x3d\xc8\x4f\x39\
+\x4e\x48\x27\x70\x48\x02\x38\xa5\xb3\xb3\xfe\x5f\x6f\x7c\xf8\x3c\
+\x4f\xd4\x0b\x5c\x07\xf8\x70\x0a\xf1\x2c\x62\x32\xeb\x51\xf0\xa9\
+\x92\xd4\xcd\x6d\x9c\xf5\xf4\x76\xea\xf6\xf5\x0d\xdf\x4e\x22\x1d\
+\x78\x23\xed\xb7\x1f\x28\x42\xfd\xe8\xde\x38\x2d\xbf\x58\x8d\xed\
+\xde\x69\xa3\x17\x1f\xf0\x2a\x1a\xbd\xae\xfe\x85\x77\x68\x79\x7b\
+\x3d\x44\xc2\xb9\x7d\xbe\xf8\x47\xc1\x09\x2b\x85\x4a\x29\x7b\xa4\
+\x94\x57\x01\x97\xef\xec\x93\x5d\x0f\xef\x81\x8a\x83\xe5\x75\x46\
+\xe6\x31\x8d\x26\x61\xa0\x0d\x20\xa4\x52\xb8\xb2\x9e\x19\x7f\xdf\
+\xc5\xb6\x0e\x37\xd5\x00\xb1\x55\x0e\xcb\xf7\x73\xfe\xbc\xf3\xd8\
+\xad\xa9\x2a\x31\x37\xae\xc2\xfc\xd7\x85\xc0\x61\xfe\x00\x45\x50\
+\x09\x05\x5d\x1e\xca\xbf\xf5\x2b\x88\x28\x6a\x3c\x72\xa2\xe3\xfb\
+\xb4\xe3\xa4\x2a\x84\x0c\x31\x6e\x7e\x2f\xe0\xda\x73\x92\xe0\xd2\
+\x0c\xd0\x0e\x5f\xda\x24\x3d\xd4\xd2\xc6\xc0\x21\x76\xfa\x32\xd6\
+\xc8\xfa\xb3\x66\xe3\x35\x3f\x88\x2e\x6f\xeb\x06\x7b\x6f\xc0\x76\
+\xb8\x1d\x3f\x80\xf0\x78\xb0\x2f\xaf\x44\xd3\x31\x6b\xcc\xf7\x80\
+\xa6\x4f\x2f\x8b\x3b\xae\x14\xbb\x1f\x7d\xd6\x0f\xcc\x3a\xc9\x2c\
+\xe9\x4f\x25\x4e\x89\x46\x50\x84\x81\xc3\x9f\xd2\x4c\xe8\x6f\xcb\
+\x83\xd4\x11\x23\x6c\x82\xb2\x9d\x0d\xa2\xe5\x50\xd7\x6c\x42\xa1\
+\x65\xee\x53\x28\xfe\xf4\x49\xee\xd9\xe5\x2f\xe6\x32\xda\x6a\x59\
+\xd3\xda\x8e\x7d\xb9\x0b\xe1\x3f\xb2\x45\xb0\x9f\xb0\x76\xb7\x41\
+\x13\x7e\x3a\x88\x0c\xab\x5f\x95\x52\xfe\xe9\xa4\x0e\xea\x53\x8a\
+\x53\x62\x17\x20\xa5\x7c\x0a\x58\xd8\xea\xa3\xf5\xa7\x95\xf0\x7e\
+\x27\xc3\xfa\xb4\x3a\x91\xc2\x02\x26\xb2\x1a\x71\x20\xa8\x83\x54\
+\x49\xf7\x74\xa3\x9d\x6a\xae\xe8\x5b\x1a\xbf\xfe\x50\xb3\x68\xfd\
+\xd6\x2a\xec\xe7\x29\x63\x2c\x7e\x37\x83\x94\xd3\xc1\x6a\xb6\xd2\
+\xcf\x76\x34\xa1\x7f\x0d\x22\xc3\xea\x2b\x9f\x2f\xfe\x47\xc7\x29\
+\xd5\x09\x14\x42\x24\x11\x91\x46\x9e\x3d\xc9\x06\x37\xe5\x80\x75\
+\xf8\x3a\x1f\x62\xa7\xdc\x4d\xba\x08\x11\x07\x90\x75\x23\x9b\x0a\
+\x96\xa1\x4d\xc8\x32\x4f\x88\x7d\x77\x93\x2b\x8c\x26\x19\xd3\x0b\
+\x1b\x89\x7e\x70\x32\x60\x46\xd2\xce\x20\x0d\xf4\x13\xc6\x41\x3a\
+\x7e\x0e\xb5\xe8\x7d\x0b\xd8\x4f\x33\x30\x75\xb4\x72\xe6\xe7\x38\
+\x32\x4e\xb9\x52\xa8\x88\x68\x2c\x3c\x08\x3c\x10\xad\x45\xb9\x21\
+\x07\xa6\x0c\x89\xf6\xa5\xa4\x5d\x54\x33\x80\x9b\xa2\x98\x52\x56\
+\xe5\xdd\xc5\xc2\x69\xd3\x68\xfa\x63\xf3\x75\x9d\x5f\xaf\x8f\x75\
+\xa1\xfd\x83\x81\x7e\x29\x70\x90\x45\x80\x8c\x23\x36\xb2\x15\xd8\
+\x4a\x18\x38\x5b\x4a\xb9\xe6\x94\x0e\xe8\x53\x86\xd3\xa6\x15\x2c\
+\x84\x38\x87\x48\xe4\x8d\xe4\x45\x89\x70\x45\x06\xe8\x14\x90\xe0\
+\x13\xfb\xd9\xa1\xf3\x61\x98\xf2\x2b\xa6\x67\x66\xb2\x2a\x21\x81\
+\x45\x5f\x7e\x9f\x6d\xff\xa8\x3b\x60\xf9\x7b\x44\xec\x27\xf2\xeb\
+\x8f\x88\x9f\x4f\xab\x87\xad\x4f\x03\x4e\x9b\x6d\xe0\x90\x27\xcc\
+\x52\xe0\x9f\xab\xba\x90\x3f\xae\x80\x66\x2f\x08\x30\x91\x49\x59\
+\x28\x9d\x7e\x24\xb2\xb5\x95\x52\xc0\xfb\xf4\x22\xe2\x0c\x9a\x63\
+\x78\x01\x71\x02\x1f\x00\xb0\x02\xf8\xd1\x29\x1e\xc2\xa7\x12\xa7\
+\xd5\x38\x54\x4a\xd9\x29\xa5\xbc\x06\xc8\xe9\xf0\xf3\xc1\xcf\x2b\
+\xe1\x9d\x8e\xc8\x01\x51\x46\x71\x8e\xd7\x41\x73\x38\x8c\xcd\xe3\
+\x61\xab\x5e\x43\xd6\x33\x67\x1f\x25\x2e\x6f\x08\x78\x17\x18\xa4\
+\x06\xf8\xf2\x90\xa6\xcc\xe7\x38\x4e\x7c\x22\xd6\xc1\x43\xdc\xb9\
+\xa5\x21\xc9\xb7\x5f\x6a\xc6\xf3\xe8\x5e\xe8\x0f\x40\xe5\xf6\x88\
+\x29\x57\x4b\x33\x89\x00\x57\xe6\x32\x67\x82\x6d\xd8\x23\xe0\x28\
+\xac\x05\x1c\x74\x01\x17\x48\x29\x8f\xa8\x2a\xfe\x39\x8e\x8e\x4f\
+\xcc\x3c\x5c\x4a\x19\x96\x52\x3e\x02\x94\x56\x3b\x69\x7f\xb8\x02\
+\xaa\xbb\x22\x7e\x74\xdc\x1e\x26\x84\x54\x76\x03\xc6\xf7\x2f\xa2\
+\xf3\xb0\xc2\x95\x40\x0d\x7e\x60\xb9\x94\x72\x6c\xbf\xff\x9f\xe3\
+\x23\xe1\x13\xf7\x0f\x20\xa5\xac\x03\xce\xf7\x84\x18\x78\xa7\x29\
+\x62\x7f\x09\xd0\xd1\x86\x0b\x20\xd5\xc4\xcc\xbb\x8a\x39\xa0\xea\
+\xdd\x05\xac\x07\x22\x86\x28\x9f\xea\x78\x3e\xa7\x03\x9f\x38\x01\
+\x00\x0c\xe9\x0c\x5e\x5a\xd9\x44\xf0\xae\x47\xa0\xcd\x01\xed\xed\
+\xcc\x0d\x85\x22\xd1\xbd\x7e\x3f\x8f\xcc\x28\x85\x30\x3e\x22\xdf\
+\x7d\x95\xa7\xa5\x94\x87\xe9\xb8\x7f\x8e\xe3\xc7\x19\x41\x00\x30\
+\x72\x4b\xb8\x74\xcb\x5e\x06\xae\xfe\x21\xbc\xb2\x16\xe5\xad\x75\
+\x91\x20\x0e\x5a\x85\x94\x1f\xa4\xa1\x61\x25\xe0\xa1\x9e\x88\x11\
+\xcb\xe7\x38\x09\x38\xe3\xac\x83\x85\x10\x05\xc0\x9f\x80\x73\x67\
+\x4f\x82\xd9\x93\x60\x57\x1d\xec\xaa\x85\x7e\x17\x0d\x44\x0e\x7d\
+\x7b\x3f\xe1\x6e\x7e\x6a\x70\xc6\x11\xc0\x30\x84\x10\x4b\x89\xd8\
+\xc3\x25\x0d\x25\x35\x01\xd7\x0f\x5b\xb5\x7e\x8e\x93\x83\xff\x0f\
+\x92\x04\x28\x92\xfd\x58\xc9\xac\x00\x00\x00\x00\x49\x45\x4e\x44\
+\xae\x42\x60\x82\
+\x00\x00\x05\x24\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x06\xec\x00\x00\x06\xec\
+\x01\x1e\x75\x38\x35\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x00\x13\x74\x45\
+\x58\x74\x41\x75\x74\x68\x6f\x72\x00\x52\x6f\x64\x6e\x65\x79\x20\
+\x44\x61\x77\x65\x73\x0e\xd8\x7e\x1d\x00\x00\x04\x82\x49\x44\x41\
+\x54\x48\x89\x8d\x96\x6d\x88\x54\x55\x18\xc7\x7f\xe7\xbe\xcd\xdc\
+\xb9\xb3\xb3\x33\xfb\x1a\xeb\xbe\xe9\x6e\xad\x2f\x69\xea\x4a\x68\
+\x8a\x22\x24\x94\x58\x7e\xaa\x48\x23\x30\x41\xa1\xec\x5b\x51\x41\
+\xf8\x31\xa3\x3e\x44\x94\x44\x92\x09\xa5\x46\x54\x84\x21\x59\x44\
+\x18\x96\x19\xe8\xb2\x5a\x9b\xba\xea\xa4\xfb\x66\xee\xce\xce\xec\
+\xec\xcc\xdc\xbd\xf3\x76\xef\xe9\x43\xb3\x63\xe6\xbe\x3d\xf0\x7c\
+\xfb\x9f\xff\xef\x3c\xcf\x3d\xf7\x39\x47\x48\x29\x99\x2d\x36\xbf\
+\x27\x7c\x6a\x40\x7f\xc0\x6f\xaa\x8f\x02\x64\x1d\xf7\x84\x3b\x51\
+\xb8\xf0\xed\x8b\x32\x37\xdb\x5a\x31\x1d\x60\xf7\x01\xa1\x67\x2a\
+\x42\x6f\x37\xd7\x76\xac\xaf\x0a\xd6\xd5\x07\x2c\x7f\xb5\x69\x05\
+\x7c\x52\xba\x38\xb6\x93\xb3\xed\x6c\x62\xdc\x8e\xdd\xea\x8b\x5d\
+\x39\x15\x4c\xa7\x5e\xfe\x70\x97\x2c\xcc\x19\xb0\xf5\xa0\xb1\xa2\
+\xb9\x7e\xfe\xa1\x65\x1d\xab\x96\x7a\x5a\x56\xc9\xbb\x0e\x92\x3b\
+\x75\x02\x81\xa1\x9a\x50\x50\xbd\x9e\x2b\xe7\xff\xe8\x1f\xbe\xb1\
+\xe3\xd8\xce\x7c\xf7\xac\x80\x6d\x9f\x87\xf7\x2d\x69\x59\xf9\x5c\
+\x5d\x7d\x4d\x9d\xe3\xa6\x67\xeb\x00\x00\x01\x35\x44\x7c\x78\x74\
+\xe4\x42\x5f\xd7\xc7\x47\x9f\x4a\xbe\x36\x2d\xe0\x89\xc3\xe6\xb3\
+\x6b\x96\xae\xdf\x6f\x04\xf5\xa0\x27\xdd\x39\x99\x87\xb4\x7a\xaa\
+\xb4\x05\xa4\x8a\x43\x8c\x24\xaf\x65\xce\xf4\xfc\xfa\xc2\x17\xcf\
+\x38\x9f\xdc\x05\x78\xfc\x80\xa8\x59\xdc\xb6\xfc\xb7\xc6\xd6\x79\
+\x6d\x73\x33\x17\x6c\xa8\xdc\x43\x83\xb1\x0c\x9f\x08\x72\xc9\xf9\
+\x9e\x73\x99\x4f\x19\xb8\xd1\x17\xbd\x18\xed\x59\xfd\xcd\x2e\x39\
+\x0a\xa0\x4c\xca\xab\xab\xea\x8e\x34\x35\x37\xcc\xd1\x1c\x56\x5a\
+\x4f\xd2\xea\x5b\x8d\x5f\x54\x10\x2b\x46\x39\x97\x39\x8c\x2b\x5d\
+\x5a\x9a\x5b\xdb\xaa\xab\x6a\x8f\x4c\xea\x14\xf8\xf7\xa3\xde\xd7\
+\xd2\xb1\xd6\x15\x73\x33\xaf\xd1\xdb\xb9\xd7\xdc\x88\x82\x4a\xda\
+\x1d\xe1\x97\xd4\x07\xb8\xb2\x08\x40\x51\xb8\xb4\x37\xcf\x5f\xbb\
+\xf5\xa0\xb1\xa2\x0c\xf0\xfb\xd4\x2d\x56\xd0\xb4\x40\x20\xc4\x54\
+\xa9\x94\x53\x53\xfc\xac\x0b\xed\x22\xa0\x86\xc9\x4b\x87\x4b\xce\
+\x09\xd2\xee\x2d\x14\xa1\xa2\x08\x0d\x21\x54\x02\x96\xdf\xf2\xfb\
+\xd4\x2d\x00\x1a\x80\x65\x55\x76\xaa\x86\x86\xbc\xdd\xb1\x52\x97\
+\xc5\x5d\xbb\x5f\x57\xb1\x9b\x88\xd6\x02\xc0\xad\x42\x0f\x97\xb3\
+\x3f\xa2\x08\xed\x0e\x8d\x66\xe8\x58\x66\xa8\xb3\x0c\x08\x9b\x91\
+\x26\x4f\x4a\x54\x45\x9d\xd1\x7c\x81\xff\x21\x1a\x7d\xcb\xf1\x3c\
+\x8f\x91\x42\x2f\xc7\xe3\x7b\x31\xd5\x10\x86\x12\xb8\x43\x57\x14\
+\x0a\xe1\x40\xa8\xa9\x0c\x30\x0d\xb3\x5a\x11\x3a\x0a\x0a\x4c\x61\
+\x0c\x10\x50\x23\x2c\x31\x1f\x43\xb8\x3a\x49\xf7\x26\xc7\xe3\x7b\
+\x89\xe5\xa3\x00\x58\x6a\x15\x11\xa3\x11\x53\x0d\x23\x10\xe4\xbd\
+\x1c\xa6\xe1\xab\x2e\x03\x9c\xfc\x44\x5c\x57\x8c\x96\xff\x9a\x77\
+\x98\x9b\x08\x28\x61\xba\xed\x2f\xf1\x64\x91\x4e\x73\x1b\x16\xb5\
+\x14\x65\x8e\x0b\xf6\x57\x0c\xe5\xcf\x97\xb5\xb6\x9b\xc0\x76\x12\
+\xf8\x14\x8b\xb0\xde\x48\xde\x4b\xe1\xe4\xb3\xf1\x32\x20\x69\x8f\
+\x0d\x08\xe9\xad\x14\xc2\x00\x60\x51\xe0\x11\x96\x04\x36\x23\x50\
+\x18\x2b\x0c\x52\x70\x73\xdc\xa3\x2d\x06\xa0\x3f\x77\x96\x53\xa9\
+\xfd\x53\x56\x99\xf3\x6c\x86\x73\xbd\x54\xa9\xb5\x24\x27\xc6\x07\
+\xca\x00\xdb\x49\x75\x25\x9d\xa1\xad\xe8\x3a\xa6\x5a\x89\x81\x85\
+\xf0\x34\x3c\xcf\xa3\xcd\xd8\x80\x2e\x4c\x14\xa1\x13\x2b\x5c\xe5\
+\x58\xe2\x55\x60\xe6\x09\xac\xba\x3a\xb6\x93\xe9\x2a\x03\xb2\x39\
+\xf7\x78\x2e\x53\x7c\x65\x3c\xd0\x67\x49\x3c\xbc\xa2\xa4\xb5\x76\
+\x0d\x9a\xf0\x53\x55\x3a\x31\xb6\x1b\xe7\x64\xea\x1d\x26\xbc\xc4\
+\x8c\xe6\x02\x85\x82\x2d\xed\x6c\xce\x3d\x0e\xa5\xff\xe0\xd8\xce\
+\x7c\x77\x74\xa0\xff\x74\x8d\xd2\x00\xc0\x60\xbe\x9b\x44\xb1\xbf\
+\xbc\xc8\x93\x2e\x97\x9c\xef\xb8\xea\xfc\x34\xa3\x39\x40\x8d\x52\
+\x47\x74\x70\xf0\xf4\xe4\x64\x2d\x1f\xfc\x78\x22\xb6\x3d\x31\xe4\
+\x44\x4d\xc5\xc2\xc3\xe5\x77\xfb\x6b\x3c\x59\xa4\x20\x1d\x06\xf2\
+\xe7\xf8\x21\xf9\xe6\xac\xe6\xa6\x12\x20\x31\xe4\x44\xe3\x89\xf8\
+\xf6\x72\x45\xff\x9f\xa6\x9d\x0b\xef\x7f\x3f\xe9\xff\xbb\xc2\x95\
+\x1e\x0b\xcd\x87\xc9\xca\x34\xd7\xb3\x67\x98\xad\xef\x9a\xd0\x08\
+\x3b\xb5\xe9\xae\xde\xde\x3d\x53\x4e\xd3\xc9\xd8\xf6\x59\x78\x5f\
+\xc7\xbc\x96\x1d\x6a\x4d\xa1\x3e\xe9\xc6\x67\xdd\x35\x40\x58\x8d\
+\xe0\x8e\x6a\xc3\xbd\x43\xfd\x87\x8e\x3e\x9d\x9a\xfe\x3e\x98\x8c\
+\x4d\x6f\x19\xab\xdb\xda\xeb\x3f\x5a\xd0\x5e\xb7\x28\xad\x24\x94\
+\xac\x9c\xfa\x46\xf3\x0b\x3f\x15\x6e\xd8\xfb\xeb\xda\xc8\xe5\x8b\
+\x5d\xb1\xe7\x7f\xde\x57\x3c\x2d\x65\x69\xea\x4d\x05\x10\x42\xa8\
+\x40\x35\x10\x36\x2b\xa8\xdd\xf8\x7a\xf0\xa5\x96\xe6\xea\x65\xa1\
+\x50\x20\x12\x08\x2b\x41\xa3\x42\xd1\x01\x72\x29\xaf\x38\x91\x2c\
+\x66\xd2\xa9\x6c\xf2\xfa\xf5\xd8\x9f\x27\xdf\x98\x78\x37\x67\x33\
+\x0a\x8c\x01\x49\x29\x65\x72\xda\x0a\x84\x10\x16\x10\x01\x2a\x4b\
+\x59\x61\x04\x89\x34\x3c\xa8\x2c\x6c\x5a\xa5\xad\x90\x12\x65\xe0\
+\x6c\xf1\xfc\xcd\xb3\x5e\xb4\x60\x33\x06\x64\x80\x14\x30\x5e\xca\
+\x84\x94\xb7\x1f\x00\xd3\xbe\x2a\x4a\x30\x1d\x30\x00\x5f\x29\x75\
+\x40\x05\x8a\x40\x0e\xc8\x03\x59\xa0\x20\xe5\xd4\x37\xd5\x3f\x13\
+\x05\x02\x8c\xec\xcf\x7e\xae\x00\x00\x00\x00\x49\x45\x4e\x44\xae\
+\x42\x60\x82\
+\x00\x00\x05\x64\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x0d\xd7\x00\x00\x0d\xd7\
+\x01\x42\x28\x9b\x78\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x04\xe1\x49\x44\
+\x41\x54\x48\x89\xb5\x95\x4b\x6c\x54\x55\x18\xc7\x7f\xe7\x9c\x3b\
+\xd3\x99\x32\xb5\xf4\x41\x1f\x38\x02\xa5\x8d\x4c\x54\xb0\x25\x48\
+\xa2\x12\x62\x14\x17\x46\x12\x64\x81\x26\xb0\x31\x18\x12\x36\x04\
+\xd2\x45\xe9\xca\x84\x5d\x29\x1a\x43\x80\x6e\x0c\x2b\xdc\x88\x09\
+\xa0\x21\x04\x35\x18\xa2\x26\x3e\xa3\x62\xa8\x4f\x7c\x61\x95\x81\
+\xb6\x4c\xcb\xbc\x67\xee\x39\x9f\x8b\xdb\xde\x4e\x2b\x10\x37\xde\
+\xe4\x4b\xee\xf3\xf7\xff\xfe\xdf\xf9\xee\x77\x94\x88\xf0\x7f\x1e\
+\xde\xdd\x1e\x9e\x53\xaa\x23\x6a\xcc\x33\xf1\x58\xec\xc9\xc6\xae\
+\x15\x2b\xeb\x97\x76\x2e\xb5\xc5\x62\x29\xfb\xd7\x5f\x63\xb9\x6b\
+\xd7\x7f\x2c\x55\x2a\xe7\x2b\xd6\x7e\xb0\x59\xa4\x70\x27\x86\xba\
+\xad\x03\xa5\xd4\x87\xb1\xd8\xcb\xc9\xc7\x1e\xdd\xdd\xd9\xde\xd6\
+\x11\xab\x8b\xa3\x4a\x45\xc8\xe5\xc1\x18\x48\x2c\xc2\x79\x86\x5c\
+\x21\xef\x7e\x1f\xbd\x7c\x65\xe2\xa7\x5f\xf7\x3c\x59\xad\xbe\xf7\
+\x9f\x04\xce\x2b\xb5\xa2\x75\x69\xe7\x1b\xa9\x0d\x8f\xaf\x4f\x54\
+\x6d\x84\x52\x29\x7c\x26\xc0\xec\xfb\x32\x13\x34\x24\xb8\x91\x9f\
+\x9e\xfa\xe5\xe3\x4f\xdf\xb6\xb7\xb2\xbb\x9f\x10\x29\xd5\xf2\xe6\
+\x09\xbc\xa7\x54\x6a\xc5\xda\xde\xf7\x7b\xba\xbb\x93\x3a\x57\x98\
+\x83\x2c\x80\x2e\x14\x11\xcf\xc3\xc6\x23\xf2\xdd\xa7\x9f\x7d\x91\
+\x1f\xbb\xf6\xf8\x13\x22\xfe\x2c\x53\xcf\x9e\x1c\x50\x4a\xb7\x2e\
+\x5f\x76\xa2\x67\xf9\xb2\x24\xd9\x3c\x4e\x24\x0c\x0b\xd8\x9a\x6b\
+\x07\xc1\x3d\xc0\x01\xce\xf7\x21\x5b\x54\x3d\x0f\xaf\x5e\xa7\xeb\
+\xeb\x5f\xab\x75\x10\x0a\x6c\x8a\xc7\x87\xee\x7f\xe8\x81\x3e\x29\
+\x56\x02\x80\x08\x95\xed\xdb\xa9\x6e\xd9\x32\x27\x52\x03\xb5\x9e\
+\x87\x1d\x18\xc0\xef\xed\x0d\xc5\x8d\x55\x3a\xd9\xf7\xe0\x0b\xef\
+\x46\xa3\xeb\x67\xb9\x1e\xc0\x05\xa5\xee\x5f\xf9\xd8\xfa\x17\xeb\
+\x7c\x67\x1c\x0a\x01\xec\x8e\x1d\xf8\xcf\x3d\x17\x64\xa1\x14\xfa\
+\xf4\xe9\xb0\x3c\x2e\x12\x81\xfd\xfb\xa1\xb7\x17\xfa\xfa\x70\x43\
+\x43\xf0\xf5\xd7\x88\x15\x9a\x1b\x9b\x96\x34\x25\xdb\x5f\x47\xa9\
+\x5e\x44\x44\x03\x78\x9e\xf7\x52\xc7\xe2\xa6\x25\x0e\x15\x64\xe9\
+\x79\xd8\x9e\x9e\xd0\xa6\xdb\xbe\x1d\x7f\xeb\xd6\x20\xd3\x68\x14\
+\x06\x07\x51\x7d\x7d\x28\xa5\x50\x91\x08\x74\x77\x63\x95\xc2\x29\
+\x85\xad\x38\x5a\xdb\x5a\x97\x9d\x85\x9e\xd0\xc1\xa2\xd6\x96\x07\
+\x11\x70\x5a\x07\x59\x3a\x07\x43\x43\xa8\xc1\x41\x58\xbd\x3a\x50\
+\xd9\xb1\x03\x17\x89\xa0\x52\x29\xd4\x9a\x35\xa1\xb8\x3d\x79\x12\
+\xff\xcc\x19\xc4\x18\x10\x41\xb4\xa6\x2e\x16\x5b\x1c\xd1\xfa\x29\
+\xe0\x67\x0d\x10\x6b\x6d\xba\xcf\x39\xc1\x69\x3d\x17\xd6\xe2\x86\
+\x87\x91\xcb\x97\xe7\x16\xec\xf9\xe7\x43\xb8\x88\xe0\xbf\xf9\x26\
+\xd5\x53\xa7\x82\xc4\xb4\xc6\x79\x1e\x62\x0c\x26\x12\xa5\x2e\x11\
+\xdf\x00\xa0\xdf\x52\x2a\x51\xb7\x28\xde\xee\xbc\xc8\x7c\x01\xad\
+\x71\xce\xe1\xbf\xf2\x0a\x32\x3a\x3a\x57\x2e\xe7\xf0\x7d\x9f\xe2\
+\xa1\x43\x64\x07\x06\x28\x5f\xb9\x82\x2d\x95\x70\xc6\x04\xdf\x18\
+\x83\xad\x5a\xe2\x4b\x9a\x97\x87\x25\xc2\x98\x60\xe1\x6a\x7b\x7b\
+\x06\x28\xc6\x20\x4a\x81\x13\x9c\xb8\xf0\x1f\x70\x80\xb5\x16\x3f\
+\x9d\x46\xd2\x69\x74\x4b\x0b\x5e\x32\x89\x6e\x6e\xc6\x96\xcb\x88\
+\xb8\xa0\x8b\xb6\x89\xe4\xbe\x79\x68\xd5\x75\xb9\xd7\xb4\x8b\x95\
+\x00\xae\x54\x00\x8f\x44\x88\xec\xdb\x07\xa9\x14\xd6\xd9\xda\xf6\
+\x26\xd6\xdf\x8f\x03\x0a\x87\x0f\x07\x5d\x37\x31\x41\x65\x7c\x1c\
+\xe2\x71\xbc\xce\x26\xf2\x13\x53\x7f\x84\xff\x41\x71\x32\xf3\xa7\
+\x3f\x39\x81\xd3\x26\xb4\xea\x97\xca\xe8\x5d\xbb\x90\x55\xab\x70\
+\x2e\xc8\xa6\xf0\xea\xab\x54\x3f\xfa\x28\x14\xa9\xef\xef\x27\xb6\
+\x77\x2f\x56\x24\x08\xc0\x2f\x14\xa8\x7a\x9a\x72\xae\xf8\x71\x58\
+\xa2\xc2\x44\x66\xb4\x9c\xcd\x3c\xeb\x46\x7f\x40\x25\x12\x38\xdf\
+\x27\x3e\x32\x82\x5e\xbb\x36\x84\xe5\x87\x87\x29\x1c\x3b\x86\x44\
+\xa3\x34\x1e\x3f\x4e\xdd\xc6\x8d\x00\x24\xfa\xfb\x71\xbe\x4f\xf6\
+\xc8\x91\xb0\xb4\x3e\x4c\x39\xe7\x2e\x84\x0e\xf0\xfd\xe3\x19\xe7\
+\xc6\x2d\x8e\xca\xf8\x38\xd5\x4c\x06\xff\xd2\xa5\x10\x9e\x3b\x78\
+\x90\xdc\xd1\xa3\x41\x96\xe5\x32\x93\x3b\x77\x52\xba\x78\x31\x28\
+\x63\xb9\x4c\xe9\xdb\x6f\xf1\x67\x5c\xe8\xf6\x26\xa6\x27\x32\x57\
+\x37\xc3\x15\xa8\x19\x76\x17\xea\xeb\x87\xbb\xfa\x52\xfd\xf6\xf2\
+\x2f\x46\x24\x58\x8b\xc4\xe0\x20\xf6\xd6\x2d\x72\x23\x23\x41\x8f\
+\xd7\x36\x41\x34\x4a\xeb\xc8\x08\xd9\x13\x27\x28\x5c\xbc\x18\x0c\
+\xc4\xa8\x87\xac\xec\x18\xff\xfb\xab\x9f\x36\x6f\x16\xf9\x7c\x9e\
+\xc0\x01\xa5\xf4\xa6\x64\xfb\x67\x1d\x0d\x0d\xeb\x2a\x63\xd7\x91\
+\x05\xc0\xd9\xde\x9f\x27\x52\x73\x0d\x50\x97\xba\xcf\xa5\xbf\xff\
+\x7d\xe4\xe9\x42\x69\xcf\xac\xfb\x79\xe3\xfa\xac\x52\xa9\xb6\x54\
+\xd7\xfb\x8d\xc6\x24\x2b\x7f\xa4\xff\x05\x58\x08\x0d\x47\xb7\x67\
+\x88\x76\x77\xca\x74\xfa\xe6\x17\x92\x9e\x9c\x37\xae\x6f\xbb\xe1\
+\xc4\xda\x9a\xdf\x58\xb2\x62\xe9\x23\xfe\x6f\xd7\xa2\x36\x5f\x9c\
+\x2f\xb0\x40\xc4\x6b\x6b\x44\x9a\x13\x99\xc9\x1f\xc7\xde\x89\x15\
+\x4a\x77\xdf\x70\x00\x94\x52\x0d\x31\x68\x39\x12\x31\x03\xa9\x55\
+\xcb\xb7\x35\x18\xd3\xaa\x7c\x8b\x9d\xce\x53\x9d\xce\x83\xd1\x98\
+\xc5\x09\x74\x43\x0c\xab\x95\x4c\x66\x72\x63\x17\xae\x5e\x3f\x78\
+\x08\xde\x05\xa6\x80\x29\xb9\x93\x03\xa5\x54\x04\x68\x01\x16\x03\
+\x4d\xbd\xd0\xb5\x49\xeb\x8d\x5d\x9e\xe9\x6b\x69\x69\x68\x6b\x6c\
+\x5c\xd4\x58\xf5\xad\xbd\x99\xc9\x4d\xdf\x98\xce\x5f\x1b\xf5\xed\
+\x97\xe7\xe0\x93\x71\x48\xcf\xc2\x81\x8c\x88\x64\xef\xe8\x60\x46\
+\xa8\x1e\xb8\x67\x26\xea\x81\x38\x50\x07\x18\x40\x11\xec\x3b\x15\
+\xa0\x0c\xe4\x81\x1c\x70\x0b\xc8\xca\xec\x8c\x98\x39\xfe\x01\x76\
+\x95\xba\xf1\x06\x3a\xff\x81\x00\x00\x00\x00\x49\x45\x4e\x44\xae\
+\x42\x60\x82\
"
qt_resource_name = "\
@@ -755,6 +1757,11 @@ qt_resource_name = "\
\x05\xcd\xf4\xe7\
\x00\x63\
\x00\x6f\x00\x6e\x00\x6e\x00\x5f\x00\x65\x00\x72\x00\x72\x00\x6f\x00\x72\x00\x2e\x00\x70\x00\x6e\x00\x67\
+\x00\x13\
+\x09\xd2\x6c\x67\
+\x00\x45\
+\x00\x6d\x00\x62\x00\x6c\x00\x65\x00\x6d\x00\x2d\x00\x71\x00\x75\x00\x65\x00\x73\x00\x74\x00\x69\x00\x6f\x00\x6e\x00\x2e\x00\x70\
+\x00\x6e\x00\x67\
\x00\x12\
\x04\xe4\x91\x47\
\x00\x63\
@@ -769,15 +1776,33 @@ qt_resource_name = "\
\x00\x63\
\x00\x6f\x00\x6e\x00\x6e\x00\x5f\x00\x63\x00\x6f\x00\x6e\x00\x6e\x00\x65\x00\x63\x00\x74\x00\x69\x00\x6e\x00\x67\x00\x2e\x00\x70\
\x00\x6e\x00\x67\
+\x00\x14\
+\x00\xe9\x23\x87\
+\x00\x6c\
+\x00\x65\x00\x61\x00\x70\x00\x2d\x00\x63\x00\x6f\x00\x6c\x00\x6f\x00\x72\x00\x2d\x00\x73\x00\x6d\x00\x61\x00\x6c\x00\x6c\x00\x2e\
+\x00\x70\x00\x6e\x00\x67\
+\x00\x11\
+\x06\x1a\x44\xa7\
+\x00\x44\
+\x00\x69\x00\x61\x00\x6c\x00\x6f\x00\x67\x00\x2d\x00\x61\x00\x63\x00\x63\x00\x65\x00\x70\x00\x74\x00\x2e\x00\x70\x00\x6e\x00\x67\
+\
+\x00\x10\
+\x0f\xc3\x90\x67\
+\x00\x44\
+\x00\x69\x00\x61\x00\x6c\x00\x6f\x00\x67\x00\x2d\x00\x65\x00\x72\x00\x72\x00\x6f\x00\x72\x00\x2e\x00\x70\x00\x6e\x00\x67\
"
qt_resource_struct = "\
\x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\
-\x00\x00\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x02\
-\x00\x00\x00\x34\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xf7\
+\x00\x00\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x02\
+\x00\x00\x00\xd4\x00\x00\x00\x00\x00\x01\x00\x00\x32\x3e\
+\x00\x00\x00\x60\x00\x00\x00\x00\x00\x01\x00\x00\x12\xe7\
\x00\x00\x00\x12\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\
-\x00\x00\x00\x5e\x00\x00\x00\x00\x00\x01\x00\x00\x19\xd2\
-\x00\x00\x00\x7c\x00\x00\x00\x00\x00\x01\x00\x00\x20\xbd\
+\x00\x00\x01\x02\x00\x00\x00\x00\x00\x01\x00\x00\x60\xc7\
+\x00\x00\x00\x8a\x00\x00\x00\x00\x00\x01\x00\x00\x1e\xc2\
+\x00\x00\x00\x34\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xf7\
+\x00\x00\x00\xa8\x00\x00\x00\x00\x00\x01\x00\x00\x25\xad\
+\x00\x00\x01\x2a\x00\x00\x00\x00\x00\x01\x00\x00\x65\xef\
"
def qInitResources():
diff --git a/src/leap/gui/progress.py b/src/leap/gui/progress.py
new file mode 100644
index 00000000..64b87b2c
--- /dev/null
+++ b/src/leap/gui/progress.py
@@ -0,0 +1,448 @@
+"""
+classes used in progress pages
+from first run wizard
+"""
+try:
+ from collections import OrderedDict
+except ImportError:
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+import logging
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.gui.threads import FunThread
+
+from leap.gui import mainwindow_rc
+
+ICON_CHECKMARK = ":/images/Dialog-accept.png"
+ICON_FAILED = ":/images/Dialog-error.png"
+ICON_WAITING = ":/images/Emblem-question.png"
+
+logger = logging.getLogger(__name__)
+
+
+class ImgWidget(QtGui.QWidget):
+
+ # XXX move to widgets
+
+ def __init__(self, parent=None, img=None):
+ super(ImgWidget, self).__init__(parent)
+ self.pic = QtGui.QPixmap(img)
+
+ def paintEvent(self, event):
+ painter = QtGui.QPainter(self)
+ painter.drawPixmap(0, 0, self.pic)
+
+
+class ProgressStep(object):
+ """
+ Data model for sequential steps
+ to be used in a progress page in
+ connection wizard
+ """
+ NAME = 0
+ DONE = 1
+
+ def __init__(self, stepname, done, index=None):
+ """
+ @param step: the name of the step
+ @type step: str
+ @param done: whether is completed or not
+ @type done: bool
+ """
+ self.index = int(index) if index else 0
+ self.name = unicode(stepname)
+ self.done = bool(done)
+
+ @classmethod
+ def columns(self):
+ return ('name', 'done')
+
+
+class ProgressStepContainer(object):
+ """
+ a container for ProgressSteps objects
+ access data in the internal dict
+ """
+
+ def __init__(self):
+ self.dirty = False
+ self.steps = {}
+
+ def step(self, identity):
+ return self.step.get(identity)
+
+ def addStep(self, step):
+ self.steps[step.index] = step
+
+ def removeStep(self, step):
+ del self.steps[step.index]
+ del step
+ self.dirty = True
+
+ def removeAllSteps(self):
+ for item in iter(self):
+ self.removeStep(item)
+
+ @property
+ def columns(self):
+ return ProgressStep.columns()
+
+ def __len__(self):
+ return len(self.steps)
+
+ def __iter__(self):
+ for step in self.steps.values():
+ yield step
+
+
+class StepsTableWidget(QtGui.QTableWidget):
+ """
+ initializes a TableWidget
+ suitable for our display purposes, like removing
+ header info and grid display
+ """
+
+ def __init__(self, parent=None):
+ super(StepsTableWidget, self).__init__(parent)
+
+ # remove headers and all edit/select behavior
+ self.horizontalHeader().hide()
+ self.verticalHeader().hide()
+ self.setEditTriggers(
+ QtGui.QAbstractItemView.NoEditTriggers)
+ self.setSelectionMode(
+ QtGui.QAbstractItemView.NoSelection)
+ width = self.width()
+ # WTF? Here init width is 100...
+ # but on populating is 456... :(
+
+ # XXX do we need this initial?
+ logger.debug('init table. width=%s' % width)
+ self.horizontalHeader().resizeSection(0, width * 0.7)
+
+ # this disables the table grid.
+ # we should add alignment to the ImgWidget (it's top-left now)
+ self.setShowGrid(False)
+ self.setFocusPolicy(QtCore.Qt.NoFocus)
+ #self.setStyleSheet("QTableView{outline: 0;}")
+
+ # XXX change image for done to rc
+
+ # Note about the "done" status painting:
+ #
+ # XXX currently we are setting the CellWidget
+ # for the whole table on a per-row basis
+ # (on add_status_line method on ValidationPage).
+ # However, a more generic solution might be
+ # to implement a custom Delegate that overwrites
+ # the paint method (so it paints a checked tickmark if
+ # done is True and some other thing if checking or false).
+ # What we have now is quick and works because
+ # I'm supposing that on first fail we will
+ # go back to previous wizard page to signal the failure.
+ # A more generic solution could be used for
+ # some failing tests if they are not critical.
+
+
+class WithStepsMixIn(object):
+
+ # worker threads for checks
+
+ def setupStepsProcessingQueue(self):
+ self.steps_queue = Queue.Queue()
+ self.stepscheck_timer = QtCore.QTimer()
+ self.stepscheck_timer.timeout.connect(self.processStepsQueue)
+ self.stepscheck_timer.start(100)
+ # we need to keep a reference to child threads
+ self.threads = []
+
+ def do_checks(self):
+
+ # yo dawg, I heard you like checks
+ # so I put a __do_checks in your do_checks
+ # for calling others' _do_checks
+
+ def __do_checks(fun=None, queue=None):
+
+ for checkcase in fun():
+ checkmsg, checkfun = checkcase
+
+ queue.put(checkmsg)
+ if checkfun() is False:
+ queue.put("failed")
+ break
+
+ t = FunThread(fun=partial(
+ __do_checks,
+ fun=self._do_checks,
+ queue=self.steps_queue))
+ t.finished.connect(self.on_checks_validation_ready)
+ t.begin()
+ self.threads.append(t)
+
+ def fail(self, err=None):
+ """
+ return failed state
+ and send error notification as
+ a nice side effect
+ """
+ wizard = self.wizard()
+ senderr = lambda err: wizard.set_validation_error(
+ self.current_page, err)
+ self.set_undone()
+ if err:
+ senderr(err)
+ return False
+
+ @QtCore.pyqtSlot()
+ def launch_checks(self):
+ self.do_checks()
+
+ # slot
+ #@QtCore.pyqtSlot(str, int)
+ def onStepStatusChanged(self, status, progress=None):
+ if status not in ("head_sentinel", "end_sentinel"):
+ self.add_status_line(status)
+ if status in ("end_sentinel"):
+ self.checks_finished = True
+ self.set_checked_icon()
+ if progress and hasattr(self, 'progress'):
+ self.progress.setValue(progress)
+ self.progress.update()
+
+ def processStepsQueue(self):
+ """
+ consume steps queue
+ and pass messages
+ to the ui updater functions
+ """
+ while self.steps_queue.qsize():
+ try:
+ status = self.steps_queue.get(0)
+ if status == "failed":
+ self.set_failed_icon()
+ else:
+ self.onStepStatusChanged(*status)
+ except Queue.Empty:
+ pass
+
+ def setupSteps(self):
+ self.steps = ProgressStepContainer()
+ # steps table widget
+ self.stepsTableWidget = StepsTableWidget(self)
+ zeros = (0, 0, 0, 0)
+ self.stepsTableWidget.setContentsMargins(*zeros)
+ self.errors = OrderedDict()
+
+ def set_error(self, name, error):
+ self.errors[name] = error
+
+ def pop_first_error(self):
+ return list(reversed(self.errors.items())).pop()
+
+ def clean_errors(self):
+ self.errors = OrderedDict()
+
+ def clean_wizard_errors(self, pagename=None):
+ if pagename is None:
+ pagename = getattr(self, 'prev_page', None)
+ if pagename is None:
+ return
+ logger.debug('cleaning wizard errors for %s' % pagename)
+ self.wizard().set_validation_error(pagename, None)
+
+ def populateStepsTable(self):
+ # from examples,
+ # but I guess it's not needed to re-populate
+ # the whole table.
+ table = self.stepsTableWidget
+ table.setRowCount(len(self.steps))
+ columns = self.steps.columns
+ table.setColumnCount(len(columns))
+
+ for row, step in enumerate(self.steps):
+ item = QtGui.QTableWidgetItem(step.name)
+ item.setData(QtCore.Qt.UserRole,
+ long(id(step)))
+ table.setItem(row, columns.index('name'), item)
+ table.setItem(row, columns.index('done'),
+ QtGui.QTableWidgetItem(step.done))
+ self.resizeTable()
+ self.update()
+
+ def clearTable(self):
+ # ??? -- not sure what's the difference
+ #self.stepsTableWidget.clear()
+ self.stepsTableWidget.clearContents()
+
+ def resizeTable(self):
+ # resize first column to ~80%
+ table = self.stepsTableWidget
+ FIRST_COLUMN_PERCENT = 0.70
+ width = table.width()
+ logger.debug('populate table. width=%s' % width)
+ table.horizontalHeader().resizeSection(0, width * FIRST_COLUMN_PERCENT)
+
+ def set_item_icon(self, img=ICON_CHECKMARK, current=True):
+ """
+ mark the last item
+ as done
+ """
+ # setting cell widget.
+ # see note on StepsTableWidget about plans to
+ # change this for a better solution.
+ index = len(self.steps)
+ table = self.stepsTableWidget
+ _index = index - 1 if current else index - 2
+ table.setCellWidget(
+ _index,
+ ProgressStep.DONE,
+ ImgWidget(img=img))
+ table.update()
+
+ def set_failed_icon(self):
+ self.set_item_icon(img=ICON_FAILED, current=True)
+
+ def set_checking_icon(self):
+ self.set_item_icon(img=ICON_WAITING, current=True)
+
+ def set_checked_icon(self, current=True):
+ self.set_item_icon(current=current)
+
+ def add_status_line(self, message):
+ """
+ adds a new status line
+ and mark the next-to-last item
+ as done
+ """
+ index = len(self.steps)
+ step = ProgressStep(message, False, index=index)
+ self.steps.addStep(step)
+ self.populateStepsTable()
+ self.set_checking_icon()
+ self.set_checked_icon(current=False)
+
+ # Sets/unsets done flag
+ # for isComplete checks
+
+ def set_done(self):
+ self.done = True
+ self.completeChanged.emit()
+
+ def set_undone(self):
+ self.done = False
+ self.completeChanged.emit()
+
+ def is_done(self):
+ return self.done
+
+ def go_back(self):
+ self.wizard().back()
+
+ def go_next(self):
+ self.wizard().next()
+
+
+"""
+We will use one base class for the intermediate pages
+and another one for the in-page validations, both sharing the creation
+of the tablewidgets.
+The logic of this split comes from where I was trying to solve
+the ui update using signals, but now that it's working well with
+queues I could join them again.
+"""
+
+import Queue
+from functools import partial
+
+
+class InlineValidationPage(QtGui.QWizardPage, WithStepsMixIn):
+
+ def __init__(self, parent=None):
+ super(InlineValidationPage, self).__init__(parent)
+ self.setupStepsProcessingQueue()
+ self.done = False
+
+ # slot
+
+ @QtCore.pyqtSlot()
+ def showStepsFrame(self):
+ self.valFrame.show()
+ self.update()
+
+ # progress frame
+
+ def setupValidationFrame(self):
+ qframe = QtGui.QFrame
+ valFrame = qframe()
+ valFrame.setFrameStyle(qframe.NoFrame)
+ valframeLayout = QtGui.QVBoxLayout()
+ zeros = (0, 0, 0, 0)
+ valframeLayout.setContentsMargins(*zeros)
+
+ valframeLayout.addWidget(self.stepsTableWidget)
+ valFrame.setLayout(valframeLayout)
+ self.valFrame = valFrame
+
+
+class ValidationPage(QtGui.QWizardPage, WithStepsMixIn):
+ """
+ class to be used as an intermediate
+ between two pages in a wizard.
+ shows feedback to the user and goes back if errors,
+ goes forward if ok.
+ initializePage triggers a one shot timer
+ that calls do_checks.
+ Derived classes should implement
+ _do_checks and
+ _do_validation
+ """
+
+ # signals
+ stepChanged = QtCore.pyqtSignal([str, int])
+
+ def __init__(self, parent=None):
+ super(ValidationPage, self).__init__(parent)
+ self.setupSteps()
+ #self.connect_step_status()
+
+ layout = QtGui.QVBoxLayout()
+ self.progress = QtGui.QProgressBar(self)
+ layout.addWidget(self.progress)
+ layout.addWidget(self.stepsTableWidget)
+
+ self.setLayout(layout)
+ self.layout = layout
+
+ self.timer = QtCore.QTimer()
+ self.done = False
+
+ self.setupStepsProcessingQueue()
+
+ def isComplete(self):
+ return self.is_done()
+
+ ########################
+
+ def show_progress(self):
+ self.progress.show()
+ self.stepsTableWidget.show()
+
+ def hide_progress(self):
+ self.progress.hide()
+ self.stepsTableWidget.hide()
+
+ # pagewizard methods.
+ # if overriden, child classes should call super.
+
+ def initializePage(self):
+ self.clean_errors()
+ self.clean_wizard_errors()
+ self.steps.removeAllSteps()
+ self.clearTable()
+ self.resizeTable()
+ self.timer.singleShot(0, self.do_checks)
diff --git a/src/leap/gui/styles.py b/src/leap/gui/styles.py
new file mode 100644
index 00000000..b482922e
--- /dev/null
+++ b/src/leap/gui/styles.py
@@ -0,0 +1,16 @@
+GreenLineEdit = "QLabel {color: green; font-weight: bold}"
+ErrorLabelStyleSheet = """QLabel { color: red; font-weight: bold }"""
+ErrorLineEdit = """QLineEdit { border: 1px solid red; }"""
+
+
+# XXX this is bad.
+# and you should feel bad for it.
+# The original style has a sort of box color
+# white/beige left-top/right-bottom or something like
+# that.
+
+RegularLineEdit = """
+QLineEdit {
+ border: 1px solid black;
+}
+"""
diff --git a/src/leap/gui/test_mainwindow_rc.py b/src/leap/gui/test_mainwindow_rc.py
new file mode 100644
index 00000000..c5abb4aa
--- /dev/null
+++ b/src/leap/gui/test_mainwindow_rc.py
@@ -0,0 +1,29 @@
+import unittest
+import hashlib
+
+try:
+ import sip
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
+
+from leap.gui import mainwindow_rc
+
+# I have to admit that there's something
+# perverse in testing this.
+# Even though, I still think that it _is_ a good idea
+# to put a check to avoid non-updated resources files.
+
+# so, if you came here because an updated resource
+# did break a test, what you have to do is getting
+# the md5 hash of your qt_resource_data and change it here.
+
+# annoying? yep. try making a script for that :P
+
+
+class MainWindowResourcesTest(unittest.TestCase):
+
+ def test_mainwindow_resources_hash(self):
+ self.assertEqual(
+ hashlib.md5(mainwindow_rc.qt_resource_data).hexdigest(),
+ '53e196f29061d8f08f112e5a2e64eb53')
diff --git a/src/leap/gui/tests/integration/fake_user_signup.py b/src/leap/gui/tests/integration/fake_user_signup.py
new file mode 100644
index 00000000..78873749
--- /dev/null
+++ b/src/leap/gui/tests/integration/fake_user_signup.py
@@ -0,0 +1,84 @@
+"""
+simple server to test registration and
+authentication
+
+To test:
+
+curl -d login=python_test_user -d password_salt=54321\
+ -d password_verifier=12341234 \
+ http://localhost:8000/users.json
+
+"""
+from BaseHTTPServer import HTTPServer
+from BaseHTTPServer import BaseHTTPRequestHandler
+import cgi
+import json
+import urlparse
+
+HOST = "localhost"
+PORT = 8000
+
+LOGIN_ERROR = """{"errors":{"login":["has already been taken"]}}"""
+
+from leap.base.tests.test_providers import EXPECTED_DEFAULT_CONFIG
+
+
+class request_handler(BaseHTTPRequestHandler):
+ responses = {
+ '/': ['ok\n'],
+ '/users.json': ['ok\n'],
+ '/timeout': ['ok\n'],
+ '/provider.json': ['%s\n' % json.dumps(EXPECTED_DEFAULT_CONFIG)]
+ }
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(
+ self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def do_POST(self):
+ form = cgi.FieldStorage(
+ fp=self.rfile,
+ headers=self.headers,
+ environ={'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': self.headers['Content-Type'],
+ })
+ data = dict(
+ (key, form[key].value) for key in form.keys())
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(
+ self.responses.get(
+ path.path, ''))
+
+ login = data.get('login', None)
+ #password_salt = data.get('password_salt', None)
+ #password_verifier = data.get('password_verifier', None)
+
+ if path.geturl() == "/timeout":
+ print 'timeout'
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+ import time
+ time.sleep(10)
+ return
+
+ ok = True if (login == "python_test_user") else False
+ if ok:
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ else:
+ self.send_response(500)
+ self.end_headers()
+ self.wfile.write(LOGIN_ERROR)
+
+
+if __name__ == "__main__":
+ server = HTTPServer((HOST, PORT), request_handler)
+ server.serve_forever()
diff --git a/src/leap/gui/threads.py b/src/leap/gui/threads.py
new file mode 100644
index 00000000..8aad8866
--- /dev/null
+++ b/src/leap/gui/threads.py
@@ -0,0 +1,21 @@
+from PyQt4 import QtCore
+
+
+class FunThread(QtCore.QThread):
+
+ def __init__(self, fun=None, parent=None):
+
+ QtCore.QThread.__init__(self, parent)
+ self.exiting = False
+ self.fun = fun
+
+ def __del__(self):
+ self.exiting = True
+ self.wait()
+
+ def run(self):
+ if self.fun:
+ self.fun()
+
+ def begin(self):
+ self.start()
diff --git a/src/leap/gui/utils.py b/src/leap/gui/utils.py
new file mode 100644
index 00000000..f91ac3ef
--- /dev/null
+++ b/src/leap/gui/utils.py
@@ -0,0 +1,34 @@
+"""
+utility functions to work with gui objects
+"""
+from PyQt4 import QtCore
+
+
+def layout_widgets(layout):
+ """
+ return a generator with all widgets in a layout
+ """
+ return (layout.itemAt(i) for i in range(layout.count()))
+
+
+DELAY_MSECS = 50
+
+
+def delay(obj, method_str=None, call_args=None):
+ """
+ Triggers a function or slot with a small delay.
+ this is a mainly a hack to get responsiveness in the ui
+ in cases in which the event loop freezes and the task
+ is not heavy enough to setup a processing queue.
+ """
+ if callable(obj) and not method_str:
+ fun = lambda: obj()
+
+ if method_str:
+ invoke = QtCore.QMetaObject.invokeMethod
+ if call_args:
+ fun = lambda: invoke(obj, method_str, call_args)
+ else:
+ fun = lambda: invoke(obj, method_str)
+
+ QtCore.QTimer().singleShot(DELAY_MSECS, fun)
diff --git a/src/leap/soledad/README b/src/leap/soledad/README
new file mode 100644
index 00000000..dc448374
--- /dev/null
+++ b/src/leap/soledad/README
@@ -0,0 +1,13 @@
+Soledad -- Synchronization Of Locally Encrypted Data Among Devices
+==================================================================
+
+Dependencies
+------------
+
+Soledad uses the following python libraries:
+
+ * u1db 0.1.4 [1]
+ * python-swiftclient 1.1.1 [2]
+
+[1] http://pypi.python.org/pypi/u1db/0.1.4
+[2] https://launchpad.net/python-swiftclient
diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py
new file mode 100644
index 00000000..3d685635
--- /dev/null
+++ b/src/leap/soledad/__init__.py
@@ -0,0 +1,164 @@
+# License?
+
+"""A U1DB implementation that uses OpenStack Swift as its persistence layer."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db.backends import CommonBackend, CommonSyncTarget
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+
+from swiftclient import client
+
+
+class OpenStackDatabase(CommonBackend):
+ """A U1DB implementation that uses OpenStack as its persistence layer."""
+
+ def __init__(self, auth_url, user, auth_key):
+ """Create a new OpenStack data container."""
+ self._auth_url = auth_url
+ self._user = user
+ self._auth_key = auth_key
+ self.set_document_factory(LeapDocument)
+ self._connection = swiftclient.Connection(self._auth_url, self._user,
+ self._auth_key)
+
+ #-------------------------------------------------------------------------
+ # implemented methods from Database
+ #-------------------------------------------------------------------------
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def set_document_size_limit(self, limit):
+ raise NotImplementedError(self.set_document_size_limit)
+
+ def whats_changed(self, old_generation=0):
+ raise NotImplementedError(self.whats_changed)
+
+ def get_doc(self, doc_id, include_deleted=False):
+ raise NotImplementedError(self.get_doc)
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ raise NotImplementedError(self.get_all_docs)
+
+ def put_doc(self, doc):
+ raise NotImplementedError(self.put_doc)
+
+ def delete_doc(self, doc):
+ raise NotImplementedError(self.delete_doc)
+
+ # start of index-related methods: these are not supported by this backend.
+
+ def create_index(self, index_name, *index_expressions):
+ return False
+
+ def delete_index(self, index_name):
+ return False
+
+ def list_indexes(self):
+ return []
+
+ def get_from_index(self, index_name, *key_values):
+ return []
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ return []
+
+ def get_index_keys(self, index_name):
+ return []
+
+ # end of index-related methods: these are not supported by this backend.
+
+ def get_doc_conflicts(self, doc_id):
+ return []
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ raise NotImplementedError(self.resolve_doc)
+
+ def get_sync_target(self):
+ return OpenStackSyncTarget(self)
+
+ def close(self):
+ raise NotImplementedError(self.close)
+
+ def sync(self, url, creds=None, autocreate=True):
+ raise NotImplementedError(self.close)
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ raise NotImplementedError(self._get_replica_gen_and_trans_id)
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ raise NotImplementedError(self._set_replica_gen_and_trans_id)
+
+ #-------------------------------------------------------------------------
+ # implemented methods from CommonBackend
+ #-------------------------------------------------------------------------
+
+ def _get_generation(self):
+ raise NotImplementedError(self._get_generation)
+
+ def _get_generation_info(self):
+ raise NotImplementedError(self._get_generation_info)
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ raise NotImplementedError(self._get_doc)
+
+ def _has_conflicts(self, doc_id):
+ raise NotImplementedError(self._has_conflicts)
+
+ def _get_transaction_log(self):
+ raise NotImplementedError(self._get_transaction_log)
+
+ def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content):
+ raise NotImplementedError(self._put_and_update_indexes)
+
+
+ def _get_trans_id_for_gen(self, generation):
+ raise NotImplementedError(self._get_trans_id_for_gen)
+
+ #-------------------------------------------------------------------------
+ # OpenStack specific methods
+ #-------------------------------------------------------------------------
+
+ def _is_initialized(self, c):
+ raise NotImplementedError(self._is_initialized)
+
+ def _initialize(self, c):
+ raise NotImplementedError(self._initialize)
+
+ def _get_auth(self):
+ self._url, self._auth_token = self._connection.get_auth(self._auth_url,
+ self._user,
+ self._auth_key)
+ return self._url, self.auth_token
+
+
+class LeapDocument(Document):
+
+ def get_content_encrypted(self):
+ raise NotImplementedError(self.get_content_encrypted)
+
+ def set_content_encrypted(self):
+ raise NotImplementedError(self.set_content_encrypted)
+
+
+class OpenStackSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ raise NotImplementedError(self.get_sync_info)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ raise NotImplementedError(self.record_sync_info)
diff --git a/src/leap/soledad/swiftclient/__init__.py b/src/leap/soledad/swiftclient/__init__.py
new file mode 100644
index 00000000..ba0b41a3
--- /dev/null
+++ b/src/leap/soledad/swiftclient/__init__.py
@@ -0,0 +1,5 @@
+# -*- encoding: utf-8 -*-
+""""
+OpenStack Swift Python client binding.
+"""
+from client import *
diff --git a/src/leap/soledad/swiftclient/client.py b/src/leap/soledad/swiftclient/client.py
new file mode 100644
index 00000000..79e6594f
--- /dev/null
+++ b/src/leap/soledad/swiftclient/client.py
@@ -0,0 +1,1056 @@
+# Copyright (c) 2010-2012 OpenStack, LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Cloud Files client library used internally
+"""
+
+import socket
+import os
+import logging
+import httplib
+
+from urllib import quote as _quote
+from urlparse import urlparse, urlunparse, urljoin
+
+try:
+ from eventlet.green.httplib import HTTPException, HTTPSConnection
+except ImportError:
+ from httplib import HTTPException, HTTPSConnection
+
+try:
+ from eventlet import sleep
+except ImportError:
+ from time import sleep
+
+try:
+ from swift.common.bufferedhttp \
+ import BufferedHTTPConnection as HTTPConnection
+except ImportError:
+ try:
+ from eventlet.green.httplib import HTTPConnection
+ except ImportError:
+ from httplib import HTTPConnection
+
+logger = logging.getLogger("swiftclient")
+
+
+def http_log(args, kwargs, resp, body):
+ if os.environ.get('SWIFTCLIENT_DEBUG', False):
+ ch = logging.StreamHandler()
+ logger.setLevel(logging.DEBUG)
+ logger.addHandler(ch)
+ elif not logger.isEnabledFor(logging.DEBUG):
+ return
+
+ string_parts = ['curl -i']
+ for element in args:
+ if element in ('GET', 'POST', 'PUT', 'HEAD'):
+ string_parts.append(' -X %s' % element)
+ else:
+ string_parts.append(' %s' % element)
+
+ if 'headers' in kwargs:
+ for element in kwargs['headers']:
+ header = ' -H "%s: %s"' % (element, kwargs['headers'][element])
+ string_parts.append(header)
+
+ logger.debug("REQ: %s\n" % "".join(string_parts))
+ if 'raw_body' in kwargs:
+ logger.debug("REQ BODY (RAW): %s\n" % (kwargs['raw_body']))
+ if 'body' in kwargs:
+ logger.debug("REQ BODY: %s\n" % (kwargs['body']))
+
+ logger.debug("RESP STATUS: %s\n", resp.status)
+ if body:
+ logger.debug("RESP BODY: %s\n", body)
+
+
+def quote(value, safe='/'):
+ """
+ Patched version of urllib.quote that encodes utf8 strings before quoting
+ """
+ if isinstance(value, unicode):
+ value = value.encode('utf8')
+ return _quote(value, safe)
+
+
+# look for a real json parser first
+try:
+ # simplejson is popular and pretty good
+ from simplejson import loads as json_loads
+ from simplejson import dumps as json_dumps
+except ImportError:
+ # 2.6 will have a json module in the stdlib
+ from json import loads as json_loads
+ from json import dumps as json_dumps
+
+
+class ClientException(Exception):
+
+ def __init__(self, msg, http_scheme='', http_host='', http_port='',
+ http_path='', http_query='', http_status=0, http_reason='',
+ http_device='', http_response_content=''):
+ Exception.__init__(self, msg)
+ self.msg = msg
+ self.http_scheme = http_scheme
+ self.http_host = http_host
+ self.http_port = http_port
+ self.http_path = http_path
+ self.http_query = http_query
+ self.http_status = http_status
+ self.http_reason = http_reason
+ self.http_device = http_device
+ self.http_response_content = http_response_content
+
+ def __str__(self):
+ a = self.msg
+ b = ''
+ if self.http_scheme:
+ b += '%s://' % self.http_scheme
+ if self.http_host:
+ b += self.http_host
+ if self.http_port:
+ b += ':%s' % self.http_port
+ if self.http_path:
+ b += self.http_path
+ if self.http_query:
+ b += '?%s' % self.http_query
+ if self.http_status:
+ if b:
+ b = '%s %s' % (b, self.http_status)
+ else:
+ b = str(self.http_status)
+ if self.http_reason:
+ if b:
+ b = '%s %s' % (b, self.http_reason)
+ else:
+ b = '- %s' % self.http_reason
+ if self.http_device:
+ if b:
+ b = '%s: device %s' % (b, self.http_device)
+ else:
+ b = 'device %s' % self.http_device
+ if self.http_response_content:
+ if len(self.http_response_content) <= 60:
+ b += ' %s' % self.http_response_content
+ else:
+ b += ' [first 60 chars of response] %s' \
+ % self.http_response_content[:60]
+ return b and '%s: %s' % (a, b) or a
+
+
+def http_connection(url, proxy=None):
+ """
+ Make an HTTPConnection or HTTPSConnection
+
+ :param url: url to connect to
+ :param proxy: proxy to connect through, if any; None by default; str of the
+ format 'http://127.0.0.1:8888' to set one
+ :returns: tuple of (parsed url, connection object)
+ :raises ClientException: Unable to handle protocol scheme
+ """
+ parsed = urlparse(url)
+ proxy_parsed = urlparse(proxy) if proxy else None
+ if parsed.scheme == 'http':
+ conn = HTTPConnection((proxy_parsed if proxy else parsed).netloc)
+ elif parsed.scheme == 'https':
+ conn = HTTPSConnection((proxy_parsed if proxy else parsed).netloc)
+ else:
+ raise ClientException('Cannot handle protocol scheme %s for url %s' %
+ (parsed.scheme, repr(url)))
+ if proxy:
+ conn._set_tunnel(parsed.hostname, parsed.port)
+ return parsed, conn
+
+
+def json_request(method, url, **kwargs):
+ """Takes a request in json parse it and return in json"""
+ kwargs.setdefault('headers', {})
+ if 'body' in kwargs:
+ kwargs['headers']['Content-Type'] = 'application/json'
+ kwargs['body'] = json_dumps(kwargs['body'])
+ parsed, conn = http_connection(url)
+ conn.request(method, parsed.path, **kwargs)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log((url, method,), kwargs, resp, body)
+ if body:
+ try:
+ body = json_loads(body)
+ except ValueError:
+ body = None
+ if not body or resp.status < 200 or resp.status >= 300:
+ raise ClientException('Auth GET failed', http_scheme=parsed.scheme,
+ http_host=conn.host,
+ http_port=conn.port,
+ http_path=parsed.path,
+ http_status=resp.status,
+ http_reason=resp.reason)
+ return resp, body
+
+
+def _get_auth_v1_0(url, user, key, snet):
+ parsed, conn = http_connection(url)
+ method = 'GET'
+ conn.request(method, parsed.path, '',
+ {'X-Auth-User': user, 'X-Auth-Key': key})
+ resp = conn.getresponse()
+ body = resp.read()
+ url = resp.getheader('x-storage-url')
+ http_log((url, method,), {}, resp, body)
+
+ # There is a side-effect on current Rackspace 1.0 server where a
+ # bad URL would get you that document page and a 200. We error out
+ # if we don't have a x-storage-url header and if we get a body.
+ if resp.status < 200 or resp.status >= 300 or (body and not url):
+ raise ClientException('Auth GET failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=parsed.path, http_status=resp.status,
+ http_reason=resp.reason)
+ if snet:
+ parsed = list(urlparse(url))
+ # Second item in the list is the netloc
+ netloc = parsed[1]
+ parsed[1] = 'snet-' + netloc
+ url = urlunparse(parsed)
+ return url, resp.getheader('x-storage-token',
+ resp.getheader('x-auth-token'))
+
+
+def _get_auth_v2_0(url, user, tenant_name, key, snet):
+ body = {'auth':
+ {'passwordCredentials': {'password': key, 'username': user},
+ 'tenantName': tenant_name}}
+ token_url = urljoin(url, "tokens")
+ resp, body = json_request("POST", token_url, body=body)
+ token_id = None
+ try:
+ url = None
+ catalogs = body['access']['serviceCatalog']
+ for service in catalogs:
+ if service['type'] == 'object-store':
+ url = service['endpoints'][0]['publicURL']
+ token_id = body['access']['token']['id']
+ if not url:
+ raise ClientException("There is no object-store endpoint "
+ "on this auth server.")
+ except(KeyError, IndexError):
+ raise ClientException("Error while getting answers from auth server")
+
+ if snet:
+ parsed = list(urlparse(url))
+ # Second item in the list is the netloc
+ parsed[1] = 'snet-' + parsed[1]
+ url = urlunparse(parsed)
+
+ return url, token_id
+
+
+def get_auth(url, user, key, snet=False, tenant_name=None, auth_version="1.0"):
+ """
+ Get authentication/authorization credentials.
+
+ The snet parameter is used for Rackspace's ServiceNet internal network
+ implementation. In this function, it simply adds *snet-* to the beginning
+ of the host name for the returned storage URL. With Rackspace Cloud Files,
+ use of this network path causes no bandwidth charges but requires the
+ client to be running on Rackspace's ServiceNet network.
+
+ :param url: authentication/authorization URL
+ :param user: user to authenticate as
+ :param key: key or password for authorization
+ :param snet: use SERVICENET internal network (see above), default is False
+ :param auth_version: OpenStack auth version, default is 1.0
+ :param tenant_name: The tenant/account name, required when connecting
+ to a auth 2.0 system.
+ :returns: tuple of (storage URL, auth token)
+ :raises: ClientException: HTTP GET request to auth URL failed
+ """
+ if auth_version in ["1.0", "1"]:
+ return _get_auth_v1_0(url, user, key, snet)
+ elif auth_version in ["2.0", "2"]:
+ if not tenant_name and ':' in user:
+ (tenant_name, user) = user.split(':')
+ if not tenant_name:
+ raise ClientException('No tenant specified')
+ return _get_auth_v2_0(url, user, tenant_name, key, snet)
+ else:
+ raise ClientException('Unknown auth_version %s specified.'
+ % auth_version)
+
+
+def get_account(url, token, marker=None, limit=None, prefix=None,
+ http_conn=None, full_listing=False):
+ """
+ Get a listing of containers for the account.
+
+ :param url: storage URL
+ :param token: auth token
+ :param marker: marker query
+ :param limit: limit query
+ :param prefix: prefix query
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :param full_listing: if True, return a full listing, else returns a max
+ of 10000 listings
+ :returns: a tuple of (response headers, a list of containers) The response
+ headers will be a dict and all header names will be lowercase.
+ :raises ClientException: HTTP GET request failed
+ """
+ if not http_conn:
+ http_conn = http_connection(url)
+ if full_listing:
+ rv = get_account(url, token, marker, limit, prefix, http_conn)
+ listing = rv[1]
+ while listing:
+ marker = listing[-1]['name']
+ listing = \
+ get_account(url, token, marker, limit, prefix, http_conn)[1]
+ if listing:
+ rv[1].extend(listing)
+ return rv
+ parsed, conn = http_conn
+ qs = 'format=json'
+ if marker:
+ qs += '&marker=%s' % quote(marker)
+ if limit:
+ qs += '&limit=%d' % limit
+ if prefix:
+ qs += '&prefix=%s' % quote(prefix)
+ full_path = '%s?%s' % (parsed.path, qs)
+ headers = {'X-Auth-Token': token}
+ conn.request('GET', full_path, '',
+ headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(("%s?%s" % (url, qs), 'GET',), {'headers': headers}, resp, body)
+
+ resp_headers = {}
+ for header, value in resp.getheaders():
+ resp_headers[header.lower()] = value
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Account GET failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=parsed.path, http_query=qs,
+ http_status=resp.status, http_reason=resp.reason,
+ http_response_content=body)
+ if resp.status == 204:
+ body
+ return resp_headers, []
+ return resp_headers, json_loads(body)
+
+
+def head_account(url, token, http_conn=None):
+ """
+ Get account stats.
+
+ :param url: storage URL
+ :param token: auth token
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :returns: a dict containing the response's headers (all header names will
+ be lowercase)
+ :raises ClientException: HTTP HEAD request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ method = "HEAD"
+ headers = {'X-Auth-Token': token}
+ conn.request(method, parsed.path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log((url, method,), {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Account HEAD failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=parsed.path, http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+ resp_headers = {}
+ for header, value in resp.getheaders():
+ resp_headers[header.lower()] = value
+ return resp_headers
+
+
+def post_account(url, token, headers, http_conn=None):
+ """
+ Update an account's metadata.
+
+ :param url: storage URL
+ :param token: auth token
+ :param headers: additional headers to include in the request
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :raises ClientException: HTTP POST request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ method = 'POST'
+ headers['X-Auth-Token'] = token
+ conn.request(method, parsed.path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log((url, method,), {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Account POST failed',
+ http_scheme=parsed.scheme,
+ http_host=conn.host,
+ http_port=conn.port,
+ http_path=parsed.path,
+ http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+
+
+def get_container(url, token, container, marker=None, limit=None,
+ prefix=None, delimiter=None, http_conn=None,
+ full_listing=False):
+ """
+ Get a listing of objects for the container.
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name to get a listing for
+ :param marker: marker query
+ :param limit: limit query
+ :param prefix: prefix query
+ :param delimeter: string to delimit the queries on
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :param full_listing: if True, return a full listing, else returns a max
+ of 10000 listings
+ :returns: a tuple of (response headers, a list of objects) The response
+ headers will be a dict and all header names will be lowercase.
+ :raises ClientException: HTTP GET request failed
+ """
+ if not http_conn:
+ http_conn = http_connection(url)
+ if full_listing:
+ rv = get_container(url, token, container, marker, limit, prefix,
+ delimiter, http_conn)
+ listing = rv[1]
+ while listing:
+ if not delimiter:
+ marker = listing[-1]['name']
+ else:
+ marker = listing[-1].get('name', listing[-1].get('subdir'))
+ listing = get_container(url, token, container, marker, limit,
+ prefix, delimiter, http_conn)[1]
+ if listing:
+ rv[1].extend(listing)
+ return rv
+ parsed, conn = http_conn
+ path = '%s/%s' % (parsed.path, quote(container))
+ qs = 'format=json'
+ if marker:
+ qs += '&marker=%s' % quote(marker)
+ if limit:
+ qs += '&limit=%d' % limit
+ if prefix:
+ qs += '&prefix=%s' % quote(prefix)
+ if delimiter:
+ qs += '&delimiter=%s' % quote(delimiter)
+ headers = {'X-Auth-Token': token}
+ method = 'GET'
+ conn.request(method, '%s?%s' % (path, qs), '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, qs), method,), {'headers': headers}, resp, body)
+
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Container GET failed',
+ http_scheme=parsed.scheme, http_host=conn.host,
+ http_port=conn.port, http_path=path,
+ http_query=qs, http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+ resp_headers = {}
+ for header, value in resp.getheaders():
+ resp_headers[header.lower()] = value
+ if resp.status == 204:
+ return resp_headers, []
+ return resp_headers, json_loads(body)
+
+
+def head_container(url, token, container, http_conn=None, headers=None):
+ """
+ Get container stats.
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name to get stats for
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :returns: a dict containing the response's headers (all header names will
+ be lowercase)
+ :raises ClientException: HTTP HEAD request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s' % (parsed.path, quote(container))
+ method = 'HEAD'
+ req_headers = {'X-Auth-Token': token}
+ if headers:
+ req_headers.update(headers)
+ conn.request(method, path, '', req_headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), method,),
+ {'headers': req_headers}, resp, body)
+
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Container HEAD failed',
+ http_scheme=parsed.scheme, http_host=conn.host,
+ http_port=conn.port, http_path=path,
+ http_status=resp.status, http_reason=resp.reason,
+ http_response_content=body)
+ resp_headers = {}
+ for header, value in resp.getheaders():
+ resp_headers[header.lower()] = value
+ return resp_headers
+
+
+def put_container(url, token, container, headers=None, http_conn=None):
+ """
+ Create a container
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name to create
+ :param headers: additional headers to include in the request
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :raises ClientException: HTTP PUT request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s' % (parsed.path, quote(container))
+ method = 'PUT'
+ if not headers:
+ headers = {}
+ headers['X-Auth-Token'] = token
+ conn.request(method, path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), method,),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Container PUT failed',
+ http_scheme=parsed.scheme, http_host=conn.host,
+ http_port=conn.port, http_path=path,
+ http_status=resp.status, http_reason=resp.reason,
+ http_response_content=body)
+
+
+def post_container(url, token, container, headers, http_conn=None):
+ """
+ Update a container's metadata.
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name to update
+ :param headers: additional headers to include in the request
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :raises ClientException: HTTP POST request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s' % (parsed.path, quote(container))
+ method = 'POST'
+ headers['X-Auth-Token'] = token
+ conn.request(method, path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), method,),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Container POST failed',
+ http_scheme=parsed.scheme, http_host=conn.host,
+ http_port=conn.port, http_path=path,
+ http_status=resp.status, http_reason=resp.reason,
+ http_response_content=body)
+
+
+def delete_container(url, token, container, http_conn=None):
+ """
+ Delete a container
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name to delete
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :raises ClientException: HTTP DELETE request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s' % (parsed.path, quote(container))
+ headers = {'X-Auth-Token': token}
+ method = 'DELETE'
+ conn.request(method, path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), method,),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Container DELETE failed',
+ http_scheme=parsed.scheme, http_host=conn.host,
+ http_port=conn.port, http_path=path,
+ http_status=resp.status, http_reason=resp.reason,
+ http_response_content=body)
+
+
+def get_object(url, token, container, name, http_conn=None,
+ resp_chunk_size=None):
+ """
+ Get an object
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name that the object is in
+ :param name: object name to get
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :param resp_chunk_size: if defined, chunk size of data to read. NOTE: If
+ you specify a resp_chunk_size you must fully read
+ the object's contents before making another
+ request.
+ :returns: a tuple of (response headers, the object's contents) The response
+ headers will be a dict and all header names will be lowercase.
+ :raises ClientException: HTTP GET request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s/%s' % (parsed.path, quote(container), quote(name))
+ method = 'GET'
+ headers = {'X-Auth-Token': token}
+ conn.request(method, path, '', headers)
+ resp = conn.getresponse()
+ if resp.status < 200 or resp.status >= 300:
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), 'POST',),
+ {'headers': headers}, resp, body)
+ raise ClientException('Object GET failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=path, http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+ if resp_chunk_size:
+
+ def _object_body():
+ buf = resp.read(resp_chunk_size)
+ while buf:
+ yield buf
+ buf = resp.read(resp_chunk_size)
+ object_body = _object_body()
+ else:
+ object_body = resp.read()
+ resp_headers = {}
+ for header, value in resp.getheaders():
+ resp_headers[header.lower()] = value
+ http_log(('%s?%s' % (url, path), 'POST',),
+ {'headers': headers}, resp, object_body)
+ return resp_headers, object_body
+
+
+def head_object(url, token, container, name, http_conn=None):
+ """
+ Get object info
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name that the object is in
+ :param name: object name to get info for
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :returns: a dict containing the response's headers (all header names will
+ be lowercase)
+ :raises ClientException: HTTP HEAD request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s/%s' % (parsed.path, quote(container), quote(name))
+ method = 'HEAD'
+ headers = {'X-Auth-Token': token}
+ conn.request(method, path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), 'POST',),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Object HEAD failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=path, http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+ resp_headers = {}
+ for header, value in resp.getheaders():
+ resp_headers[header.lower()] = value
+ return resp_headers
+
+
+def put_object(url, token=None, container=None, name=None, contents=None,
+ content_length=None, etag=None, chunk_size=65536,
+ content_type=None, headers=None, http_conn=None, proxy=None):
+ """
+ Put an object
+
+ :param url: storage URL
+ :param token: auth token; if None, no token will be sent
+ :param container: container name that the object is in; if None, the
+ container name is expected to be part of the url
+ :param name: object name to put; if None, the object name is expected to be
+ part of the url
+ :param contents: a string or a file like object to read object data from;
+ if None, a zero-byte put will be done
+ :param content_length: value to send as content-length header; also limits
+ the amount read from contents; if None, it will be
+ computed via the contents or chunked transfer
+ encoding will be used
+ :param etag: etag of contents; if None, no etag will be sent
+ :param chunk_size: chunk size of data to write; default 65536
+ :param content_type: value to send as content-type header; if None, no
+ content-type will be set (remote end will likely try
+ to auto-detect it)
+ :param headers: additional headers to include in the request, if any
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :param proxy: proxy to connect through, if any; None by default; str of the
+ format 'http://127.0.0.1:8888' to set one
+ :returns: etag from server response
+ :raises ClientException: HTTP PUT request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url, proxy=proxy)
+ path = parsed.path
+ if container:
+ path = '%s/%s' % (path.rstrip('/'), quote(container))
+ if name:
+ path = '%s/%s' % (path.rstrip('/'), quote(name))
+ if headers:
+ headers = dict(headers)
+ else:
+ headers = {}
+ if token:
+ headers['X-Auth-Token'] = token
+ if etag:
+ headers['ETag'] = etag.strip('"')
+ if content_length is not None:
+ headers['Content-Length'] = str(content_length)
+ else:
+ for n, v in headers.iteritems():
+ if n.lower() == 'content-length':
+ content_length = int(v)
+ if content_type is not None:
+ headers['Content-Type'] = content_type
+ if not contents:
+ headers['Content-Length'] = '0'
+ if hasattr(contents, 'read'):
+ conn.putrequest('PUT', path)
+ for header, value in headers.iteritems():
+ conn.putheader(header, value)
+ if content_length is None:
+ conn.putheader('Transfer-Encoding', 'chunked')
+ conn.endheaders()
+ chunk = contents.read(chunk_size)
+ while chunk:
+ conn.send('%x\r\n%s\r\n' % (len(chunk), chunk))
+ chunk = contents.read(chunk_size)
+ conn.send('0\r\n\r\n')
+ else:
+ conn.endheaders()
+ left = content_length
+ while left > 0:
+ size = chunk_size
+ if size > left:
+ size = left
+ chunk = contents.read(size)
+ conn.send(chunk)
+ left -= len(chunk)
+ else:
+ conn.request('PUT', path, contents, headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ headers = {'X-Auth-Token': token}
+ http_log(('%s?%s' % (url, path), 'PUT',),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Object PUT failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=path, http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+ return resp.getheader('etag', '').strip('"')
+
+
+def post_object(url, token, container, name, headers, http_conn=None):
+ """
+ Update object metadata
+
+ :param url: storage URL
+ :param token: auth token
+ :param container: container name that the object is in
+ :param name: name of the object to update
+ :param headers: additional headers to include in the request
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :raises ClientException: HTTP POST request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url)
+ path = '%s/%s/%s' % (parsed.path, quote(container), quote(name))
+ headers['X-Auth-Token'] = token
+ conn.request('POST', path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), 'POST',),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Object POST failed', http_scheme=parsed.scheme,
+ http_host=conn.host, http_port=conn.port,
+ http_path=path, http_status=resp.status,
+ http_reason=resp.reason,
+ http_response_content=body)
+
+
+def delete_object(url, token=None, container=None, name=None, http_conn=None,
+ headers=None, proxy=None):
+ """
+ Delete object
+
+ :param url: storage URL
+ :param token: auth token; if None, no token will be sent
+ :param container: container name that the object is in; if None, the
+ container name is expected to be part of the url
+ :param name: object name to delete; if None, the object name is expected to
+ be part of the url
+ :param http_conn: HTTP connection object (If None, it will create the
+ conn object)
+ :param headers: additional headers to include in the request
+ :param proxy: proxy to connect through, if any; None by default; str of the
+ format 'http://127.0.0.1:8888' to set one
+ :raises ClientException: HTTP DELETE request failed
+ """
+ if http_conn:
+ parsed, conn = http_conn
+ else:
+ parsed, conn = http_connection(url, proxy=proxy)
+ path = parsed.path
+ if container:
+ path = '%s/%s' % (path.rstrip('/'), quote(container))
+ if name:
+ path = '%s/%s' % (path.rstrip('/'), quote(name))
+ if headers:
+ headers = dict(headers)
+ else:
+ headers = {}
+ if token:
+ headers['X-Auth-Token'] = token
+ conn.request('DELETE', path, '', headers)
+ resp = conn.getresponse()
+ body = resp.read()
+ http_log(('%s?%s' % (url, path), 'POST',),
+ {'headers': headers}, resp, body)
+ if resp.status < 200 or resp.status >= 300:
+ raise ClientException('Object DELETE failed',
+ http_scheme=parsed.scheme, http_host=conn.host,
+ http_port=conn.port, http_path=path,
+ http_status=resp.status, http_reason=resp.reason,
+ http_response_content=body)
+
+
+class Connection(object):
+ """Convenience class to make requests that will also retry the request"""
+
+ def __init__(self, authurl, user, key, retries=5, preauthurl=None,
+ preauthtoken=None, snet=False, starting_backoff=1,
+ tenant_name=None,
+ auth_version="1"):
+ """
+ :param authurl: authentication URL
+ :param user: user name to authenticate as
+ :param key: key/password to authenticate with
+ :param retries: Number of times to retry the request before failing
+ :param preauthurl: storage URL (if you have already authenticated)
+ :param preauthtoken: authentication token (if you have already
+ authenticated)
+ :param snet: use SERVICENET internal network default is False
+ :param auth_version: OpenStack auth version, default is 1.0
+ :param tenant_name: The tenant/account name, required when connecting
+ to a auth 2.0 system.
+ """
+ self.authurl = authurl
+ self.user = user
+ self.key = key
+ self.retries = retries
+ self.http_conn = None
+ self.url = preauthurl
+ self.token = preauthtoken
+ self.attempts = 0
+ self.snet = snet
+ self.starting_backoff = starting_backoff
+ self.auth_version = auth_version
+ self.tenant_name = tenant_name
+
+ def get_auth(self):
+ return get_auth(self.authurl, self.user,
+ self.key, snet=self.snet,
+ tenant_name=self.tenant_name,
+ auth_version=self.auth_version)
+
+ def http_connection(self):
+ return http_connection(self.url)
+
+ def _retry(self, reset_func, func, *args, **kwargs):
+ self.attempts = 0
+ backoff = self.starting_backoff
+ while self.attempts <= self.retries:
+ self.attempts += 1
+ try:
+ if not self.url or not self.token:
+ self.url, self.token = self.get_auth()
+ self.http_conn = None
+ if not self.http_conn:
+ self.http_conn = self.http_connection()
+ kwargs['http_conn'] = self.http_conn
+ rv = func(self.url, self.token, *args, **kwargs)
+ return rv
+ except (socket.error, HTTPException):
+ if self.attempts > self.retries:
+ raise
+ self.http_conn = None
+ except ClientException, err:
+ if self.attempts > self.retries:
+ raise
+ if err.http_status == 401:
+ self.url = self.token = None
+ if self.attempts > 1:
+ raise
+ elif err.http_status == 408:
+ self.http_conn = None
+ elif 500 <= err.http_status <= 599:
+ pass
+ else:
+ raise
+ sleep(backoff)
+ backoff *= 2
+ if reset_func:
+ reset_func(func, *args, **kwargs)
+
+ def head_account(self):
+ """Wrapper for :func:`head_account`"""
+ return self._retry(None, head_account)
+
+ def get_account(self, marker=None, limit=None, prefix=None,
+ full_listing=False):
+ """Wrapper for :func:`get_account`"""
+ # TODO(unknown): With full_listing=True this will restart the entire
+ # listing with each retry. Need to make a better version that just
+ # retries where it left off.
+ return self._retry(None, get_account, marker=marker, limit=limit,
+ prefix=prefix, full_listing=full_listing)
+
+ def post_account(self, headers):
+ """Wrapper for :func:`post_account`"""
+ return self._retry(None, post_account, headers)
+
+ def head_container(self, container):
+ """Wrapper for :func:`head_container`"""
+ return self._retry(None, head_container, container)
+
+ def get_container(self, container, marker=None, limit=None, prefix=None,
+ delimiter=None, full_listing=False):
+ """Wrapper for :func:`get_container`"""
+ # TODO(unknown): With full_listing=True this will restart the entire
+ # listing with each retry. Need to make a better version that just
+ # retries where it left off.
+ return self._retry(None, get_container, container, marker=marker,
+ limit=limit, prefix=prefix, delimiter=delimiter,
+ full_listing=full_listing)
+
+ def put_container(self, container, headers=None):
+ """Wrapper for :func:`put_container`"""
+ return self._retry(None, put_container, container, headers=headers)
+
+ def post_container(self, container, headers):
+ """Wrapper for :func:`post_container`"""
+ return self._retry(None, post_container, container, headers)
+
+ def delete_container(self, container):
+ """Wrapper for :func:`delete_container`"""
+ return self._retry(None, delete_container, container)
+
+ def head_object(self, container, obj):
+ """Wrapper for :func:`head_object`"""
+ return self._retry(None, head_object, container, obj)
+
+ def get_object(self, container, obj, resp_chunk_size=None):
+ """Wrapper for :func:`get_object`"""
+ return self._retry(None, get_object, container, obj,
+ resp_chunk_size=resp_chunk_size)
+
+ def put_object(self, container, obj, contents, content_length=None,
+ etag=None, chunk_size=65536, content_type=None,
+ headers=None):
+ """Wrapper for :func:`put_object`"""
+
+ def _default_reset(*args, **kwargs):
+ raise ClientException('put_object(%r, %r, ...) failure and no '
+ 'ability to reset contents for reupload.'
+ % (container, obj))
+
+ reset_func = _default_reset
+ tell = getattr(contents, 'tell', None)
+ seek = getattr(contents, 'seek', None)
+ if tell and seek:
+ orig_pos = tell()
+ reset_func = lambda *a, **k: seek(orig_pos)
+ elif not contents:
+ reset_func = lambda *a, **k: None
+
+ return self._retry(reset_func, put_object, container, obj, contents,
+ content_length=content_length, etag=etag,
+ chunk_size=chunk_size, content_type=content_type,
+ headers=headers)
+
+ def post_object(self, container, obj, headers):
+ """Wrapper for :func:`post_object`"""
+ return self._retry(None, post_object, container, obj, headers)
+
+ def delete_object(self, container, obj):
+ """Wrapper for :func:`delete_object`"""
+ return self._retry(None, delete_object, container, obj)
diff --git a/src/leap/soledad/swiftclient/openstack/__init__.py b/src/leap/soledad/swiftclient/openstack/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/soledad/swiftclient/openstack/__init__.py
diff --git a/src/leap/soledad/swiftclient/openstack/common/__init__.py b/src/leap/soledad/swiftclient/openstack/common/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/soledad/swiftclient/openstack/common/__init__.py
diff --git a/src/leap/soledad/swiftclient/openstack/common/setup.py b/src/leap/soledad/swiftclient/openstack/common/setup.py
new file mode 100644
index 00000000..caf06fa5
--- /dev/null
+++ b/src/leap/soledad/swiftclient/openstack/common/setup.py
@@ -0,0 +1,342 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Utilities with minimum-depends for use in setup.py
+"""
+
+import datetime
+import os
+import re
+import subprocess
+import sys
+
+from setuptools.command import sdist
+
+
+def parse_mailmap(mailmap='.mailmap'):
+ mapping = {}
+ if os.path.exists(mailmap):
+ fp = open(mailmap, 'r')
+ for l in fp:
+ l = l.strip()
+ if not l.startswith('#') and ' ' in l:
+ canonical_email, alias = l.split(' ')
+ mapping[alias] = canonical_email
+ return mapping
+
+
+def canonicalize_emails(changelog, mapping):
+ """Takes in a string and an email alias mapping and replaces all
+ instances of the aliases in the string with their real email.
+ """
+ for alias, email in mapping.iteritems():
+ changelog = changelog.replace(alias, email)
+ return changelog
+
+
+# Get requirements from the first file that exists
+def get_reqs_from_files(requirements_files):
+ reqs_in = []
+ for requirements_file in requirements_files:
+ if os.path.exists(requirements_file):
+ return open(requirements_file, 'r').read().split('\n')
+ return []
+
+
+def parse_requirements(requirements_files=['requirements.txt',
+ 'tools/pip-requires']):
+ requirements = []
+ for line in get_reqs_from_files(requirements_files):
+ # For the requirements list, we need to inject only the portion
+ # after egg= so that distutils knows the package it's looking for
+ # such as:
+ # -e git://github.com/openstack/nova/master#egg=nova
+ if re.match(r'\s*-e\s+', line):
+ requirements.append(re.sub(r'\s*-e\s+.*#egg=(.*)$', r'\1',
+ line))
+ # such as:
+ # http://github.com/openstack/nova/zipball/master#egg=nova
+ elif re.match(r'\s*https?:', line):
+ requirements.append(re.sub(r'\s*https?:.*#egg=(.*)$', r'\1',
+ line))
+ # -f lines are for index locations, and don't get used here
+ elif re.match(r'\s*-f\s+', line):
+ pass
+ # argparse is part of the standard library starting with 2.7
+ # adding it to the requirements list screws distro installs
+ elif line == 'argparse' and sys.version_info >= (2, 7):
+ pass
+ else:
+ requirements.append(line)
+
+ return requirements
+
+
+def parse_dependency_links(requirements_files=['requirements.txt',
+ 'tools/pip-requires']):
+ dependency_links = []
+ # dependency_links inject alternate locations to find packages listed
+ # in requirements
+ for line in get_reqs_from_files(requirements_files):
+ # skip comments and blank lines
+ if re.match(r'(\s*#)|(\s*$)', line):
+ continue
+ # lines with -e or -f need the whole line, minus the flag
+ if re.match(r'\s*-[ef]\s+', line):
+ dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line))
+ # lines that are only urls can go in unmolested
+ elif re.match(r'\s*https?:', line):
+ dependency_links.append(line)
+ return dependency_links
+
+
+def write_requirements():
+ venv = os.environ.get('VIRTUAL_ENV', None)
+ if venv is not None:
+ with open("requirements.txt", "w") as req_file:
+ output = subprocess.Popen(["pip", "-E", venv, "freeze", "-l"],
+ stdout=subprocess.PIPE)
+ requirements = output.communicate()[0].strip()
+ req_file.write(requirements)
+
+
+def _run_shell_command(cmd):
+ output = subprocess.Popen(["/bin/sh", "-c", cmd],
+ stdout=subprocess.PIPE)
+ out = output.communicate()
+ if len(out) == 0:
+ return None
+ if len(out[0].strip()) == 0:
+ return None
+ return out[0].strip()
+
+
+def _get_git_next_version_suffix(branch_name):
+ datestamp = datetime.datetime.now().strftime('%Y%m%d')
+ if branch_name == 'milestone-proposed':
+ revno_prefix = "r"
+ else:
+ revno_prefix = ""
+ _run_shell_command("git fetch origin +refs/meta/*:refs/remotes/meta/*")
+ milestone_cmd = "git show meta/openstack/release:%s" % branch_name
+ milestonever = _run_shell_command(milestone_cmd)
+ if not milestonever:
+ milestonever = ""
+ post_version = _get_git_post_version()
+ revno = post_version.split(".")[-1]
+ return "%s~%s.%s%s" % (milestonever, datestamp, revno_prefix, revno)
+
+
+def _get_git_current_tag():
+ return _run_shell_command("git tag --contains HEAD")
+
+
+def _get_git_tag_info():
+ return _run_shell_command("git describe --tags")
+
+
+def _get_git_post_version():
+ current_tag = _get_git_current_tag()
+ if current_tag is not None:
+ return current_tag
+ else:
+ tag_info = _get_git_tag_info()
+ if tag_info is None:
+ base_version = "0.0"
+ cmd = "git --no-pager log --oneline"
+ out = _run_shell_command(cmd)
+ revno = len(out.split("\n"))
+ else:
+ tag_infos = tag_info.split("-")
+ base_version = "-".join(tag_infos[:-2])
+ revno = tag_infos[-2]
+ return "%s.%s" % (base_version, revno)
+
+
+def write_git_changelog():
+ """Write a changelog based on the git changelog."""
+ if os.path.isdir('.git'):
+ git_log_cmd = 'git log --stat'
+ changelog = _run_shell_command(git_log_cmd)
+ mailmap = parse_mailmap()
+ with open("ChangeLog", "w") as changelog_file:
+ changelog_file.write(canonicalize_emails(changelog, mailmap))
+
+
+def generate_authors():
+ """Create AUTHORS file using git commits."""
+ jenkins_email = 'jenkins@review.openstack.org'
+ old_authors = 'AUTHORS.in'
+ new_authors = 'AUTHORS'
+ if os.path.isdir('.git'):
+ # don't include jenkins email address in AUTHORS file
+ git_log_cmd = ("git log --format='%aN <%aE>' | sort -u | "
+ "grep -v " + jenkins_email)
+ changelog = _run_shell_command(git_log_cmd)
+ mailmap = parse_mailmap()
+ with open(new_authors, 'w') as new_authors_fh:
+ new_authors_fh.write(canonicalize_emails(changelog, mailmap))
+ if os.path.exists(old_authors):
+ with open(old_authors, "r") as old_authors_fh:
+ new_authors_fh.write('\n' + old_authors_fh.read())
+
+_rst_template = """%(heading)s
+%(underline)s
+
+.. automodule:: %(module)s
+ :members:
+ :undoc-members:
+ :show-inheritance:
+"""
+
+
+def read_versioninfo(project):
+ """Read the versioninfo file. If it doesn't exist, we're in a github
+ zipball, and there's really know way to know what version we really
+ are, but that should be ok, because the utility of that should be
+ just about nil if this code path is in use in the first place."""
+ versioninfo_path = os.path.join(project, 'versioninfo')
+ if os.path.exists(versioninfo_path):
+ with open(versioninfo_path, 'r') as vinfo:
+ version = vinfo.read().strip()
+ else:
+ version = "0.0.0"
+ return version
+
+
+def write_versioninfo(project, version):
+ """Write a simple file containing the version of the package."""
+ open(os.path.join(project, 'versioninfo'), 'w').write("%s\n" % version)
+
+
+def get_cmdclass():
+ """Return dict of commands to run from setup.py."""
+
+ cmdclass = dict()
+
+ def _find_modules(arg, dirname, files):
+ for filename in files:
+ if filename.endswith('.py') and filename != '__init__.py':
+ arg["%s.%s" % (dirname.replace('/', '.'),
+ filename[:-3])] = True
+
+ class LocalSDist(sdist.sdist):
+ """Builds the ChangeLog and Authors files from VC first."""
+
+ def run(self):
+ write_git_changelog()
+ generate_authors()
+ # sdist.sdist is an old style class, can't use super()
+ sdist.sdist.run(self)
+
+ cmdclass['sdist'] = LocalSDist
+
+ # If Sphinx is installed on the box running setup.py,
+ # enable setup.py to build the documentation, otherwise,
+ # just ignore it
+ try:
+ from sphinx.setup_command import BuildDoc
+
+ class LocalBuildDoc(BuildDoc):
+ def generate_autoindex(self):
+ print "**Autodocumenting from %s" % os.path.abspath(os.curdir)
+ modules = {}
+ option_dict = self.distribution.get_option_dict('build_sphinx')
+ source_dir = os.path.join(option_dict['source_dir'][1], 'api')
+ if not os.path.exists(source_dir):
+ os.makedirs(source_dir)
+ for pkg in self.distribution.packages:
+ if '.' not in pkg:
+ os.path.walk(pkg, _find_modules, modules)
+ module_list = modules.keys()
+ module_list.sort()
+ autoindex_filename = os.path.join(source_dir, 'autoindex.rst')
+ with open(autoindex_filename, 'w') as autoindex:
+ autoindex.write(""".. toctree::
+ :maxdepth: 1
+
+""")
+ for module in module_list:
+ output_filename = os.path.join(source_dir,
+ "%s.rst" % module)
+ heading = "The :mod:`%s` Module" % module
+ underline = "=" * len(heading)
+ values = dict(module=module, heading=heading,
+ underline=underline)
+
+ print "Generating %s" % output_filename
+ with open(output_filename, 'w') as output_file:
+ output_file.write(_rst_template % values)
+ autoindex.write(" %s.rst\n" % module)
+
+ def run(self):
+ if not os.getenv('SPHINX_DEBUG'):
+ self.generate_autoindex()
+
+ for builder in ['html', 'man']:
+ self.builder = builder
+ self.finalize_options()
+ self.project = self.distribution.get_name()
+ self.version = self.distribution.get_version()
+ self.release = self.distribution.get_version()
+ BuildDoc.run(self)
+ cmdclass['build_sphinx'] = LocalBuildDoc
+ except ImportError:
+ pass
+
+ return cmdclass
+
+
+def get_git_branchname():
+ for branch in _run_shell_command("git branch --color=never").split("\n"):
+ if branch.startswith('*'):
+ _branch_name = branch.split()[1].strip()
+ if _branch_name == "(no":
+ _branch_name = "no-branch"
+ return _branch_name
+
+
+def get_pre_version(projectname, base_version):
+ """Return a version which is based"""
+ if os.path.isdir('.git'):
+ current_tag = _get_git_current_tag()
+ if current_tag is not None:
+ version = current_tag
+ else:
+ branch_name = os.getenv('BRANCHNAME',
+ os.getenv('GERRIT_REFNAME',
+ get_git_branchname()))
+ version_suffix = _get_git_next_version_suffix(branch_name)
+ version = "%s~%s" % (base_version, version_suffix)
+ write_versioninfo(projectname, version)
+ return version.split('~')[0]
+ else:
+ version = read_versioninfo(projectname)
+ return version.split('~')[0]
+
+
+def get_post_version(projectname):
+ """Return a version which is equal to the tag that's on the current
+ revision if there is one, or tag plus number of additional revisions
+ if the current revision has no tag."""
+
+ if os.path.isdir('.git'):
+ version = _get_git_post_version()
+ write_versioninfo(projectname, version)
+ return version
+ return read_versioninfo(projectname)
diff --git a/src/leap/soledad/swiftclient/versioninfo b/src/leap/soledad/swiftclient/versioninfo
new file mode 100644
index 00000000..524cb552
--- /dev/null
+++ b/src/leap/soledad/swiftclient/versioninfo
@@ -0,0 +1 @@
+1.1.1
diff --git a/src/leap/soledad/u1db/__init__.py b/src/leap/soledad/u1db/__init__.py
new file mode 100644
index 00000000..ed41bb03
--- /dev/null
+++ b/src/leap/soledad/u1db/__init__.py
@@ -0,0 +1,697 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""U1DB"""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db.errors import InvalidJSON, InvalidContent
+
+__version_info__ = (0, 1, 4)
+__version__ = '.'.join(map(str, __version_info__))
+
+
+def open(path, create, document_factory=None):
+ """Open a database at the given location.
+
+ Will raise u1db.errors.DatabaseDoesNotExist if create=False and the
+ database does not already exist.
+
+ :param path: The filesystem path for the database to open.
+ :param create: True/False, should the database be created if it doesn't
+ already exist?
+ :param document_factory: A function that will be called with the same
+ parameters as Document.__init__.
+ :return: An instance of Database.
+ """
+ from u1db.backends import sqlite_backend
+ return sqlite_backend.SQLiteDatabase.open_database(
+ path, create=create, document_factory=document_factory)
+
+
+# constraints on database names (relevant for remote access, as regex)
+DBNAME_CONSTRAINTS = r"[a-zA-Z0-9][a-zA-Z0-9.-]*"
+
+# constraints on doc ids (as regex)
+# (no slashes, and no characters outside the ascii range)
+DOC_ID_CONSTRAINTS = r"[a-zA-Z0-9.%_-]+"
+
+
+class Database(object):
+ """A JSON Document data store.
+
+ This data store can be synchronized with other u1db.Database instances.
+ """
+
+ def set_document_factory(self, factory):
+ """Set the document factory that will be used to create objects to be
+ returned as documents by the database.
+
+ :param factory: A function that returns an object which at minimum must
+ satisfy the same interface as does the class DocumentBase.
+ Subclassing that class is the easiest way to create such
+ a function.
+ """
+ raise NotImplementedError(self.set_document_factory)
+
+ def set_document_size_limit(self, limit):
+ """Set the maximum allowed document size for this database.
+
+ :param limit: Maximum allowed document size in bytes.
+ """
+ raise NotImplementedError(self.set_document_size_limit)
+
+ def whats_changed(self, old_generation=0):
+ """Return a list of documents that have changed since old_generation.
+ This allows APPS to only store a db generation before going
+ 'offline', and then when coming back online they can use this
+ data to update whatever extra data they are storing.
+
+ :param old_generation: The generation of the database in the old
+ state.
+ :return: (generation, trans_id, [(doc_id, generation, trans_id),...])
+ The current generation of the database, its associated transaction
+ id, and a list of of changed documents since old_generation,
+ represented by tuples with for each document its doc_id and the
+ generation and transaction id corresponding to the last intervening
+ change and sorted by generation (old changes first)
+ """
+ raise NotImplementedError(self.whats_changed)
+
+ def get_doc(self, doc_id, include_deleted=False):
+ """Get the JSON string for the given document.
+
+ :param doc_id: The unique document identifier
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise asking for a deleted
+ document will return None.
+ :return: a Document object.
+ """
+ raise NotImplementedError(self.get_doc)
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ """Get the JSON content for many documents.
+
+ :param doc_ids: A list of document identifiers.
+ :param check_for_conflicts: If set to False, then the conflict check
+ will be skipped, and 'None' will be returned instead of True/False.
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise deleted documents will not
+ be included in the results.
+ :return: iterable giving the Document object for each document id
+ in matching doc_ids order.
+ """
+ raise NotImplementedError(self.get_docs)
+
+ def get_all_docs(self, include_deleted=False):
+ """Get the JSON content for all documents in the database.
+
+ :param include_deleted: If set to True, deleted documents will be
+ returned with empty content. Otherwise deleted documents will not
+ be included in the results.
+ :return: (generation, [Document])
+ The current generation of the database, followed by a list of all
+ the documents in the database.
+ """
+ raise NotImplementedError(self.get_all_docs)
+
+ def create_doc(self, content, doc_id=None):
+ """Create a new document.
+
+ You can optionally specify the document identifier, but the document
+ must not already exist. See 'put_doc' if you want to override an
+ existing document.
+ If the database specifies a maximum document size and the document
+ exceeds it, create will fail and raise a DocumentTooBig exception.
+
+ :param content: A Python dictionary.
+ :param doc_id: An optional identifier specifying the document id.
+ :return: Document
+ """
+ raise NotImplementedError(self.create_doc)
+
+ def create_doc_from_json(self, json, doc_id=None):
+ """Create a new document.
+
+ You can optionally specify the document identifier, but the document
+ must not already exist. See 'put_doc' if you want to override an
+ existing document.
+ If the database specifies a maximum document size and the document
+ exceeds it, create will fail and raise a DocumentTooBig exception.
+
+ :param json: The JSON document string
+ :param doc_id: An optional identifier specifying the document id.
+ :return: Document
+ """
+ raise NotImplementedError(self.create_doc_from_json)
+
+ def put_doc(self, doc):
+ """Update a document.
+ If the document currently has conflicts, put will fail.
+ If the database specifies a maximum document size and the document
+ exceeds it, put will fail and raise a DocumentTooBig exception.
+
+ :param doc: A Document with new content.
+ :return: new_doc_rev - The new revision identifier for the document.
+ The Document object will also be updated.
+ """
+ raise NotImplementedError(self.put_doc)
+
+ def delete_doc(self, doc):
+ """Mark a document as deleted.
+ Will abort if the current revision doesn't match doc.rev.
+ This will also set doc.content to None.
+ """
+ raise NotImplementedError(self.delete_doc)
+
+ def create_index(self, index_name, *index_expressions):
+ """Create an named index, which can then be queried for future lookups.
+ Creating an index which already exists is not an error, and is cheap.
+ Creating an index which does not match the index_expressions of the
+ existing index is an error.
+ Creating an index will block until the expressions have been evaluated
+ and the index generated.
+
+ :param index_name: A unique name which can be used as a key prefix
+ :param index_expressions: index expressions defining the index
+ information.
+
+ Examples:
+
+ "fieldname", or "fieldname.subfieldname" to index alphabetically
+ sorted on the contents of a field.
+
+ "number(fieldname, width)", "lower(fieldname)"
+ """
+ raise NotImplementedError(self.create_index)
+
+ def delete_index(self, index_name):
+ """Remove a named index.
+
+ :param index_name: The name of the index we are removing
+ """
+ raise NotImplementedError(self.delete_index)
+
+ def list_indexes(self):
+ """List the definitions of all known indexes.
+
+ :return: A list of [('index-name', ['field', 'field2'])] definitions.
+ """
+ raise NotImplementedError(self.list_indexes)
+
+ def get_from_index(self, index_name, *key_values):
+ """Return documents that match the keys supplied.
+
+ You must supply exactly the same number of values as have been defined
+ in the index. It is possible to do a prefix match by using '*' to
+ indicate a wildcard match. You can only supply '*' to trailing entries,
+ (eg 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.)
+ It is also possible to append a '*' to the last supplied value (eg
+ 'val*', '*', '*' or 'val', 'val*', '*', but not 'val*', 'val', '*')
+
+ :param index_name: The index to query
+ :param key_values: values to match. eg, if you have
+ an index with 3 fields then you would have:
+ get_from_index(index_name, val1, val2, val3)
+ :return: List of [Document]
+ """
+ raise NotImplementedError(self.get_from_index)
+
+ def get_range_from_index(self, index_name, start_value, end_value):
+ """Return documents that fall within the specified range.
+
+ Both ends of the range are inclusive. For both start_value and
+ end_value, one must supply exactly the same number of values as have
+ been defined in the index, or pass None. In case of a single column
+ index, a string is accepted as an alternative for a tuple with a single
+ value. It is possible to do a prefix match by using '*' to indicate
+ a wildcard match. You can only supply '*' to trailing entries, (eg
+ 'val', '*', '*' is allowed, but '*', 'val', 'val' is not.) It is also
+ possible to append a '*' to the last supplied value (eg 'val*', '*',
+ '*' or 'val', 'val*', '*', but not 'val*', 'val', '*')
+
+ :param index_name: The index to query
+ :param start_values: tuples of values that define the lower bound of
+ the range. eg, if you have an index with 3 fields then you would
+ have: (val1, val2, val3)
+ :param end_values: tuples of values that define the upper bound of the
+ range. eg, if you have an index with 3 fields then you would have:
+ (val1, val2, val3)
+ :return: List of [Document]
+ """
+ raise NotImplementedError(self.get_range_from_index)
+
+ def get_index_keys(self, index_name):
+ """Return all keys under which documents are indexed in this index.
+
+ :param index_name: The index to query
+ :return: [] A list of tuples of indexed keys.
+ """
+ raise NotImplementedError(self.get_index_keys)
+
+ def get_doc_conflicts(self, doc_id):
+ """Get the list of conflicts for the given document.
+
+ The order of the conflicts is such that the first entry is the value
+ that would be returned by "get_doc".
+
+ :return: [doc] A list of the Document entries that are conflicted.
+ """
+ raise NotImplementedError(self.get_doc_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ """Mark a document as no longer conflicted.
+
+ We take the list of revisions that the client knows about that it is
+ superseding. This may be a different list from the actual current
+ conflicts, in which case only those are removed as conflicted. This
+ may fail if the conflict list is significantly different from the
+ supplied information. (sync could have happened in the background from
+ the time you GET_DOC_CONFLICTS until the point where you RESOLVE)
+
+ :param doc: A Document with the new content to be inserted.
+ :param conflicted_doc_revs: A list of revisions that the new content
+ supersedes.
+ """
+ raise NotImplementedError(self.resolve_doc)
+
+ def get_sync_target(self):
+ """Return a SyncTarget object, for another u1db to synchronize with.
+
+ :return: An instance of SyncTarget.
+ """
+ raise NotImplementedError(self.get_sync_target)
+
+ def close(self):
+ """Release any resources associated with this database."""
+ raise NotImplementedError(self.close)
+
+ def sync(self, url, creds=None, autocreate=True):
+ """Synchronize documents with remote replica exposed at url.
+
+ :param url: the url of the target replica to sync with.
+ :param creds: optional dictionary giving credentials
+ to authorize the operation with the server. For using OAuth
+ the form of creds is:
+ {'oauth': {
+ 'consumer_key': ...,
+ 'consumer_secret': ...,
+ 'token_key': ...,
+ 'token_secret': ...
+ }}
+ :param autocreate: ask the target to create the db if non-existent.
+ :return: local_gen_before_sync The local generation before the
+ synchronisation was performed. This is useful to pass into
+ whatschanged, if an application wants to know which documents were
+ affected by a synchronisation.
+ """
+ from u1db.sync import Synchronizer
+ from u1db.remote.http_target import HTTPSyncTarget
+ return Synchronizer(self, HTTPSyncTarget(url, creds=creds)).sync(
+ autocreate=autocreate)
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ """Return the last known generation and transaction id for the other db
+ replica.
+
+ When you do a synchronization with another replica, the Database keeps
+ track of what generation the other database replica was at, and what
+ the associated transaction id was. This is used to determine what data
+ needs to be sent, and if two databases are claiming to be the same
+ replica.
+
+ :param other_replica_uid: The identifier for the other replica.
+ :return: (gen, trans_id) The generation and transaction id we
+ encountered during synchronization. If we've never synchronized
+ with the replica, this is (0, '').
+ """
+ raise NotImplementedError(self._get_replica_gen_and_trans_id)
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ """Set the last-known generation and transaction id for the other
+ database replica.
+
+ We have just performed some synchronization, and we want to track what
+ generation the other replica was at. See also
+ _get_replica_gen_and_trans_id.
+ :param other_replica_uid: The U1DB identifier for the other replica.
+ :param other_generation: The generation number for the other replica.
+ :param other_transaction_id: The transaction id associated with the
+ generation.
+ """
+ raise NotImplementedError(self._set_replica_gen_and_trans_id)
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen,
+ replica_trans_id=''):
+ """Insert/update document into the database with a given revision.
+
+ This api is used during synchronization operations.
+
+ If a document would conflict and save_conflict is set to True, the
+ content will be selected as the 'current' content for doc.doc_id,
+ even though doc.rev doesn't supersede the currently stored revision.
+ The currently stored document will be added to the list of conflict
+ alternatives for the given doc_id.
+
+ This forces the new content to be 'current' so that we get convergence
+ after synchronizing, even if people don't resolve conflicts. Users can
+ then notice that their content is out of date, update it, and
+ synchronize again. (The alternative is that users could synchronize and
+ think the data has propagated, but their local copy looks fine, and the
+ remote copy is never updated again.)
+
+ :param doc: A Document object
+ :param save_conflict: If this document is a conflict, do you want to
+ save it as a conflict, or just ignore it.
+ :param replica_uid: A unique replica identifier.
+ :param replica_gen: The generation of the replica corresponding to the
+ this document. The replica arguments are optional, but are used
+ during synchronization.
+ :param replica_trans_id: The transaction_id associated with the
+ generation.
+ :return: (state, at_gen) - If we don't have doc_id already,
+ or if doc_rev supersedes the existing document revision,
+ then the content will be inserted, and state is 'inserted'.
+ If doc_rev is less than or equal to the existing revision,
+ then the put is ignored and state is respecitvely 'superseded'
+ or 'converged'.
+ If doc_rev is not strictly superseded or supersedes, then
+ state is 'conflicted'. The document will not be inserted if
+ save_conflict is False.
+ For 'inserted' or 'converged', at_gen is the insertion/current
+ generation.
+ """
+ raise NotImplementedError(self._put_doc_if_newer)
+
+
+class DocumentBase(object):
+ """Container for handling a single document.
+
+ :ivar doc_id: Unique identifier for this document.
+ :ivar rev: The revision identifier of the document.
+ :ivar json_string: The JSON string for this document.
+ :ivar has_conflicts: Boolean indicating if this document has conflicts
+ """
+
+ def __init__(self, doc_id, rev, json_string, has_conflicts=False):
+ self.doc_id = doc_id
+ self.rev = rev
+ if json_string is not None:
+ try:
+ value = json.loads(json_string)
+ except json.JSONDecodeError:
+ raise InvalidJSON
+ if not isinstance(value, dict):
+ raise InvalidJSON
+ self._json = json_string
+ self.has_conflicts = has_conflicts
+
+ def same_content_as(self, other):
+ """Compare the content of two documents."""
+ if self._json:
+ c1 = json.loads(self._json)
+ else:
+ c1 = None
+ if other._json:
+ c2 = json.loads(other._json)
+ else:
+ c2 = None
+ return c1 == c2
+
+ def __repr__(self):
+ if self.has_conflicts:
+ extra = ', conflicted'
+ else:
+ extra = ''
+ return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id,
+ self.rev, extra, self.get_json())
+
+ def __hash__(self):
+ raise NotImplementedError(self.__hash__)
+
+ def __eq__(self, other):
+ if not isinstance(other, Document):
+ return NotImplemented
+ return (
+ self.doc_id == other.doc_id and self.rev == other.rev and
+ self.same_content_as(other) and self.has_conflicts ==
+ other.has_conflicts)
+
+ def __lt__(self, other):
+ """This is meant for testing, not part of the official api.
+
+ It is implemented so that sorted([Document, Document]) can be used.
+ It doesn't imply that users would want their documents to be sorted in
+ this order.
+ """
+ # Since this is just for testing, we don't worry about comparing
+ # against things that aren't a Document.
+ return ((self.doc_id, self.rev, self.get_json())
+ < (other.doc_id, other.rev, other.get_json()))
+
+ def get_json(self):
+ """Get the json serialization of this document."""
+ if self._json is not None:
+ return self._json
+ return None
+
+ def get_size(self):
+ """Calculate the total size of the document."""
+ size = 0
+ json = self.get_json()
+ if json:
+ size += len(json)
+ if self.rev:
+ size += len(self.rev)
+ if self.doc_id:
+ size += len(self.doc_id)
+ return size
+
+ def set_json(self, json_string):
+ """Set the json serialization of this document."""
+ if json_string is not None:
+ try:
+ value = json.loads(json_string)
+ except json.JSONDecodeError:
+ raise InvalidJSON
+ if not isinstance(value, dict):
+ raise InvalidJSON
+ self._json = json_string
+
+ def make_tombstone(self):
+ """Make this document into a tombstone."""
+ self._json = None
+
+ def is_tombstone(self):
+ """Return True if the document is a tombstone, False otherwise."""
+ if self._json is not None:
+ return False
+ return True
+
+
+class Document(DocumentBase):
+ """Container for handling a single document.
+
+ :ivar doc_id: Unique identifier for this document.
+ :ivar rev: The revision identifier of the document.
+ :ivar json: The JSON string for this document.
+ :ivar has_conflicts: Boolean indicating if this document has conflicts
+ """
+
+ # The following part of the API is optional: no implementation is forced to
+ # have it but if the language supports dictionaries/hashtables, it makes
+ # Documents a lot more user friendly.
+
+ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False):
+ # TODO: We convert the json in the superclass to check its validity so
+ # we might as well set _content here directly since the price is
+ # already being paid.
+ super(Document, self).__init__(doc_id, rev, json, has_conflicts)
+ self._content = None
+
+ def same_content_as(self, other):
+ """Compare the content of two documents."""
+ if self._json:
+ c1 = json.loads(self._json)
+ else:
+ c1 = self._content
+ if other._json:
+ c2 = json.loads(other._json)
+ else:
+ c2 = other._content
+ return c1 == c2
+
+ def get_json(self):
+ """Get the json serialization of this document."""
+ json_string = super(Document, self).get_json()
+ if json_string is not None:
+ return json_string
+ if self._content is not None:
+ return json.dumps(self._content)
+ return None
+
+ def set_json(self, json):
+ """Set the json serialization of this document."""
+ self._content = None
+ super(Document, self).set_json(json)
+
+ def make_tombstone(self):
+ """Make this document into a tombstone."""
+ self._content = None
+ super(Document, self).make_tombstone()
+
+ def is_tombstone(self):
+ """Return True if the document is a tombstone, False otherwise."""
+ if self._content is not None:
+ return False
+ return super(Document, self).is_tombstone()
+
+ def _get_content(self):
+ """Get the dictionary representing this document."""
+ if self._json is not None:
+ self._content = json.loads(self._json)
+ self._json = None
+ if self._content is not None:
+ return self._content
+ return None
+
+ def _set_content(self, content):
+ """Set the dictionary representing this document."""
+ try:
+ tmp = json.dumps(content)
+ except TypeError:
+ raise InvalidContent(
+ "Can not be converted to JSON: %r" % (content,))
+ if not tmp.startswith('{'):
+ raise InvalidContent(
+ "Can not be converted to a JSON object: %r." % (content,))
+ # We might as well store the JSON at this point since we did the work
+ # of encoding it, and it doesn't lose any information.
+ self._json = tmp
+ self._content = None
+
+ content = property(
+ _get_content, _set_content, doc="Content of the Document.")
+
+ # End of optional part.
+
+
+class SyncTarget(object):
+ """Functionality for using a Database as a synchronization target."""
+
+ def get_sync_info(self, source_replica_uid):
+ """Return information about known state.
+
+ Return the replica_uid and the current database generation of this
+ database, and the last-seen database generation for source_replica_uid
+
+ :param source_replica_uid: Another replica which we might have
+ synchronized with in the past.
+ :return: (target_replica_uid, target_replica_generation,
+ target_trans_id, source_replica_last_known_generation,
+ source_replica_last_known_transaction_id)
+ """
+ raise NotImplementedError(self.get_sync_info)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ """Record tip information for another replica.
+
+ After sync_exchange has been processed, the caller will have
+ received new content from this replica. This call allows the
+ source replica instigating the sync to inform us what their
+ generation became after applying the documents we returned.
+
+ This is used to allow future sync operations to not need to repeat data
+ that we just talked about. It also means that if this is called at the
+ wrong time, there can be database records that will never be
+ synchronized.
+
+ :param source_replica_uid: The identifier for the source replica.
+ :param source_replica_generation:
+ The database generation for the source replica.
+ :param source_replica_transaction_id: The transaction id associated
+ with the source replica generation.
+ """
+ raise NotImplementedError(self.record_sync_info)
+
+ def sync_exchange(self, docs_by_generation, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ """Incorporate the documents sent from the source replica.
+
+ This is not meant to be called by client code directly, but is used as
+ part of sync().
+
+ This adds docs to the local store, and determines documents that need
+ to be returned to the source replica.
+
+ Documents must be supplied in docs_by_generation paired with
+ the generation of their latest change in order from the oldest
+ change to the newest, that means from the oldest generation to
+ the newest.
+
+ Documents are also returned paired with the generation of
+ their latest change in order from the oldest change to the
+ newest.
+
+ :param docs_by_generation: A list of [(Document, generation,
+ transaction_id)] tuples indicating documents which should be
+ updated on this replica paired with the generation and transaction
+ id of their latest change.
+ :param source_replica_uid: The source replica's identifier
+ :param last_known_generation: The last generation that the source
+ replica knows about this target replica
+ :param last_known_trans_id: The last transaction id that the source
+ replica knows about this target replica
+ :param: return_doc_cb(doc, gen): is a callback
+ used to return documents to the source replica, it will
+ be invoked in turn with Documents that have changed since
+ last_known_generation together with the generation of
+ their last change.
+ :param: ensure_callback(replica_uid): if set the target may create
+ the target db if not yet existent, the callback can then
+ be used to inform of the created db replica uid.
+ :return: new_generation - After applying docs_by_generation, this is
+ the current generation for this replica
+ """
+ raise NotImplementedError(self.sync_exchange)
+
+ def _set_trace_hook(self, cb):
+ """Set a callback that will be invoked to trace database actions.
+
+ The callback will be passed a string indicating the current state, and
+ the sync target object. Implementations do not have to implement this
+ api, it is used by the test suite.
+
+ :param cb: A callable that takes cb(state)
+ """
+ raise NotImplementedError(self._set_trace_hook)
+
+ def _set_trace_hook_shallow(self, cb):
+ """Set a callback that will be invoked to trace database actions.
+
+ Similar to _set_trace_hook, for implementations that don't offer
+ state changes from the inner working of sync_exchange().
+
+ :param cb: A callable that takes cb(state)
+ """
+ self._set_trace_hook(cb)
diff --git a/src/leap/soledad/u1db/backends/__init__.py b/src/leap/soledad/u1db/backends/__init__.py
new file mode 100644
index 00000000..c8e5adc6
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/__init__.py
@@ -0,0 +1,211 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Abstract classes and common implementations for the backends."""
+
+import re
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import uuid
+
+import u1db
+from u1db import (
+ errors,
+)
+import u1db.sync
+from u1db.vectorclock import VectorClockRev
+
+
+check_doc_id_re = re.compile("^" + u1db.DOC_ID_CONSTRAINTS + "$", re.UNICODE)
+
+
+class CommonSyncTarget(u1db.sync.LocalSyncTarget):
+ pass
+
+
+class CommonBackend(u1db.Database):
+
+ document_size_limit = 0
+
+ def _allocate_doc_id(self):
+ """Generate a unique identifier for this document."""
+ return 'D-' + uuid.uuid4().hex # 'D-' stands for document
+
+ def _allocate_transaction_id(self):
+ return 'T-' + uuid.uuid4().hex # 'T-' stands for transaction
+
+ def _allocate_doc_rev(self, old_doc_rev):
+ vcr = VectorClockRev(old_doc_rev)
+ vcr.increment(self._replica_uid)
+ return vcr.as_str()
+
+ def _check_doc_id(self, doc_id):
+ if not check_doc_id_re.match(doc_id):
+ raise errors.InvalidDocId()
+
+ def _check_doc_size(self, doc):
+ if not self.document_size_limit:
+ return
+ if doc.get_size() > self.document_size_limit:
+ raise errors.DocumentTooBig
+
+ def _get_generation(self):
+ """Return the current generation.
+
+ """
+ raise NotImplementedError(self._get_generation)
+
+ def _get_generation_info(self):
+ """Return the current generation and transaction id.
+
+ """
+ raise NotImplementedError(self._get_generation_info)
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Extract the document from storage.
+
+ This can return None if the document doesn't exist.
+ """
+ raise NotImplementedError(self._get_doc)
+
+ def _has_conflicts(self, doc_id):
+ """Return True if the doc has conflicts, False otherwise."""
+ raise NotImplementedError(self._has_conflicts)
+
+ def create_doc(self, content, doc_id=None):
+ json_string = json.dumps(content)
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ doc = self._factory(doc_id, None, json_string)
+ self.put_doc(doc)
+ return doc
+
+ def create_doc_from_json(self, json, doc_id=None):
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ doc = self._factory(doc_id, None, json)
+ self.put_doc(doc)
+ return doc
+
+ def _get_transaction_log(self):
+ """This is only for the test suite, it is not part of the api."""
+ raise NotImplementedError(self._get_transaction_log)
+
+ def _put_and_update_indexes(self, doc_id, old_doc, new_rev, content):
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ for doc_id in doc_ids:
+ doc = self._get_doc(
+ doc_id, check_for_conflicts=check_for_conflicts)
+ if doc.is_tombstone() and not include_deleted:
+ continue
+ yield doc
+
+ def _get_trans_id_for_gen(self, generation):
+ """Get the transaction id corresponding to a particular generation.
+
+ Raises an InvalidGeneration when the generation does not exist.
+
+ """
+ raise NotImplementedError(self._get_trans_id_for_gen)
+
+ def validate_gen_and_trans_id(self, generation, trans_id):
+ """Validate the generation and transaction id.
+
+ Raises an InvalidGeneration when the generation does not exist, and an
+ InvalidTransactionId when it does but with a different transaction id.
+
+ """
+ if generation == 0:
+ return
+ known_trans_id = self._get_trans_id_for_gen(generation)
+ if known_trans_id != trans_id:
+ raise errors.InvalidTransactionId
+
+ def _validate_source(self, other_replica_uid, other_generation,
+ other_transaction_id):
+ """Validate the new generation and transaction id.
+
+ other_generation must be greater than what we have stored for this
+ replica, *or* it must be the same and the transaction_id must be the
+ same as well.
+ """
+ (old_generation,
+ old_transaction_id) = self._get_replica_gen_and_trans_id(
+ other_replica_uid)
+ if other_generation < old_generation:
+ raise errors.InvalidGeneration
+ if other_generation > old_generation:
+ return
+ if other_transaction_id == old_transaction_id:
+ return
+ raise errors.InvalidTransactionId
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid, replica_gen,
+ replica_trans_id=''):
+ cur_doc = self._get_doc(doc.doc_id)
+ doc_vcr = VectorClockRev(doc.rev)
+ if cur_doc is None:
+ cur_vcr = VectorClockRev(None)
+ else:
+ cur_vcr = VectorClockRev(cur_doc.rev)
+ self._validate_source(replica_uid, replica_gen, replica_trans_id)
+ if doc_vcr.is_newer(cur_vcr):
+ rev = doc.rev
+ self._prune_conflicts(doc, doc_vcr)
+ if doc.rev != rev:
+ # conflicts have been autoresolved
+ state = 'superseded'
+ else:
+ state = 'inserted'
+ self._put_and_update_indexes(cur_doc, doc)
+ elif doc.rev == cur_doc.rev:
+ # magical convergence
+ state = 'converged'
+ elif cur_vcr.is_newer(doc_vcr):
+ # Don't add this to seen_ids, because we have something newer,
+ # so we should send it back, and we should not generate a
+ # conflict
+ state = 'superseded'
+ elif cur_doc.same_content_as(doc):
+ # the documents have been edited to the same thing at both ends
+ doc_vcr.maximize(cur_vcr)
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._put_and_update_indexes(cur_doc, doc)
+ state = 'superseded'
+ else:
+ state = 'conflicted'
+ if save_conflict:
+ self._force_doc_sync_conflict(doc)
+ if replica_uid is not None and replica_gen is not None:
+ self._do_set_replica_gen_and_trans_id(
+ replica_uid, replica_gen, replica_trans_id)
+ return state, self._get_generation()
+
+ def _ensure_maximal_rev(self, cur_rev, extra_revs):
+ vcr = VectorClockRev(cur_rev)
+ for rev in extra_revs:
+ vcr.maximize(VectorClockRev(rev))
+ vcr.increment(self._replica_uid)
+ return vcr.as_str()
+
+ def set_document_size_limit(self, limit):
+ self.document_size_limit = limit
diff --git a/src/leap/soledad/u1db/backends/dbschema.sql b/src/leap/soledad/u1db/backends/dbschema.sql
new file mode 100644
index 00000000..ae027fc5
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/dbschema.sql
@@ -0,0 +1,42 @@
+-- Database schema
+CREATE TABLE transaction_log (
+ generation INTEGER PRIMARY KEY AUTOINCREMENT,
+ doc_id TEXT NOT NULL,
+ transaction_id TEXT NOT NULL
+);
+CREATE TABLE document (
+ doc_id TEXT PRIMARY KEY,
+ doc_rev TEXT NOT NULL,
+ content TEXT
+);
+CREATE TABLE document_fields (
+ doc_id TEXT NOT NULL,
+ field_name TEXT NOT NULL,
+ value TEXT
+);
+CREATE INDEX document_fields_field_value_doc_idx
+ ON document_fields(field_name, value, doc_id);
+
+CREATE TABLE sync_log (
+ replica_uid TEXT PRIMARY KEY,
+ known_generation INTEGER,
+ known_transaction_id TEXT
+);
+CREATE TABLE conflicts (
+ doc_id TEXT,
+ doc_rev TEXT,
+ content TEXT,
+ CONSTRAINT conflicts_pkey PRIMARY KEY (doc_id, doc_rev)
+);
+CREATE TABLE index_definitions (
+ name TEXT,
+ offset INT,
+ field TEXT,
+ CONSTRAINT index_definitions_pkey PRIMARY KEY (name, offset)
+);
+create index index_definitions_field on index_definitions(field);
+CREATE TABLE u1db_config (
+ name TEXT PRIMARY KEY,
+ value TEXT
+);
+INSERT INTO u1db_config VALUES ('sql_schema', '0');
diff --git a/src/leap/soledad/u1db/backends/inmemory.py b/src/leap/soledad/u1db/backends/inmemory.py
new file mode 100644
index 00000000..a271bb37
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/inmemory.py
@@ -0,0 +1,469 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The in-memory Database class for U1DB."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+from u1db.backends import CommonBackend, CommonSyncTarget
+
+
+def get_prefix(value):
+ key_prefix = '\x01'.join(value)
+ return key_prefix.rstrip('*')
+
+
+class InMemoryDatabase(CommonBackend):
+ """A database that only stores the data internally."""
+
+ def __init__(self, replica_uid, document_factory=None):
+ self._transaction_log = []
+ self._docs = {}
+ # Map from doc_id => [(doc_rev, doc)] conflicts beyond 'winner'
+ self._conflicts = {}
+ self._other_generations = {}
+ self._indexes = {}
+ self._replica_uid = replica_uid
+ self._factory = document_factory or Document
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ self._replica_uid = replica_uid
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def close(self):
+ # This is a no-op, We don't want to free the data because one client
+ # may be closing it, while another wants to inspect the results.
+ pass
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ return self._other_generations.get(other_replica_uid, (0, ''))
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ # TODO: to handle race conditions, we may want to check if the current
+ # value is greater than this new value.
+ self._other_generations[other_replica_uid] = (other_generation,
+ other_transaction_id)
+
+ def get_sync_target(self):
+ return InMemorySyncTarget(self)
+
+ def _get_transaction_log(self):
+ # snapshot!
+ return self._transaction_log[:]
+
+ def _get_generation(self):
+ return len(self._transaction_log)
+
+ def _get_generation_info(self):
+ if not self._transaction_log:
+ return 0, ''
+ return len(self._transaction_log), self._transaction_log[-1][1]
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ if generation > len(self._transaction_log):
+ raise errors.InvalidGeneration
+ return self._transaction_log[generation - 1][1]
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ for index in self._indexes.itervalues():
+ if old_doc is not None and not old_doc.is_tombstone():
+ index.remove_json(old_doc.doc_id, old_doc.get_json())
+ if not doc.is_tombstone():
+ index.add_json(doc.doc_id, doc.get_json())
+ trans_id = self._allocate_transaction_id()
+ self._docs[doc.doc_id] = (doc.rev, doc.get_json())
+ self._transaction_log.append((doc.doc_id, trans_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ try:
+ doc_rev, content = self._docs[doc_id]
+ except KeyError:
+ return None
+ doc = self._factory(doc_id, doc_rev, content)
+ if check_for_conflicts:
+ doc.has_conflicts = (doc.doc_id in self._conflicts)
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ return doc_id in self._conflicts
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Return all documents in the database."""
+ generation = self._get_generation()
+ results = []
+ for doc_id, (doc_rev, content) in self._docs.items():
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = self._has_conflicts(doc_id)
+ results.append(doc)
+ return (generation, results)
+
+ def get_doc_conflicts(self, doc_id):
+ if doc_id not in self._conflicts:
+ return []
+ result = [self._get_doc(doc_id)]
+ result[0].has_conflicts = True
+ result.extend([self._factory(doc_id, rev, content)
+ for rev, content in self._conflicts[doc_id]])
+ return result
+
+ def _replace_conflicts(self, doc, conflicts):
+ if not conflicts:
+ del self._conflicts[doc.doc_id]
+ else:
+ self._conflicts[doc.doc_id] = conflicts
+ doc.has_conflicts = bool(conflicts)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ c_vcr = vectorclock.VectorClockRev(c_rev)
+ if doc_vcr.is_newer(c_vcr):
+ continue
+ if doc.same_content_as(Document(doc.doc_id, c_rev, c_doc)):
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ cur_doc = self._get_doc(doc.doc_id)
+ if cur_doc is None:
+ cur_rev = None
+ else:
+ cur_rev = cur_doc.rev
+ new_rev = self._ensure_maximal_rev(cur_rev, conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ remaining_conflicts = []
+ cur_conflicts = self._conflicts[doc.doc_id]
+ for c_rev, c_doc in cur_conflicts:
+ if c_rev in superseded_revs:
+ continue
+ remaining_conflicts.append((c_rev, c_doc))
+ doc.rev = new_rev
+ if cur_rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ remaining_conflicts.append((new_rev, doc.get_json()))
+ self._replace_conflicts(doc, remaining_conflicts)
+
+ def delete_doc(self, doc):
+ if doc.doc_id not in self._docs:
+ raise errors.DocumentDoesNotExist
+ if self._docs[doc.doc_id][1] in ('null', None):
+ raise errors.DocumentAlreadyDeleted
+ doc.make_tombstone()
+ self.put_doc(doc)
+
+ def create_index(self, index_name, *index_expressions):
+ if index_name in self._indexes:
+ if self._indexes[index_name]._definition == list(
+ index_expressions):
+ return
+ raise errors.IndexNameTakenError
+ index = InMemoryIndex(index_name, list(index_expressions))
+ for doc_id, (doc_rev, doc) in self._docs.iteritems():
+ if doc is not None:
+ index.add_json(doc_id, doc)
+ self._indexes[index_name] = index
+
+ def delete_index(self, index_name):
+ del self._indexes[index_name]
+
+ def list_indexes(self):
+ definitions = []
+ for idx in self._indexes.itervalues():
+ definitions.append((idx._name, idx._definition))
+ return definitions
+
+ def get_from_index(self, index_name, *key_values):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ doc_ids = index.lookup(key_values)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ doc_ids = index.lookup_range(start_value, end_value)
+ result = []
+ for doc_id in doc_ids:
+ result.append(self._get_doc(doc_id, check_for_conflicts=True))
+ return result
+
+ def get_index_keys(self, index_name):
+ try:
+ index = self._indexes[index_name]
+ except KeyError:
+ raise errors.IndexDoesNotExist
+ keys = index.keys()
+ # XXX inefficiency warning
+ return list(set([tuple(key.split('\x01')) for key in keys]))
+
+ def whats_changed(self, old_generation=0):
+ changes = []
+ relevant_tail = self._transaction_log[old_generation:]
+ # We don't use len(self._transaction_log) because _transaction_log may
+ # get mutated by a concurrent operation.
+ cur_generation = old_generation + len(relevant_tail)
+ last_trans_id = ''
+ if relevant_tail:
+ last_trans_id = relevant_tail[-1][1]
+ elif self._transaction_log:
+ last_trans_id = self._transaction_log[-1][1]
+ seen = set()
+ generation = cur_generation
+ for doc_id, trans_id in reversed(relevant_tail):
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ generation -= 1
+ changes.reverse()
+ return (cur_generation, last_trans_id, changes)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._conflicts.setdefault(doc.doc_id, []).append(
+ (my_doc.rev, my_doc.get_json()))
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+
+class InMemoryIndex(object):
+ """Interface for managing an Index."""
+
+ def __init__(self, index_name, index_definition):
+ self._name = index_name
+ self._definition = index_definition
+ self._values = {}
+ parser = query_parser.Parser()
+ self._getters = parser.parse_all(self._definition)
+
+ def evaluate_json(self, doc):
+ """Determine the 'key' after applying this index to the doc."""
+ raw = json.loads(doc)
+ return self.evaluate(raw)
+
+ def evaluate(self, obj):
+ """Evaluate a dict object, applying this definition."""
+ all_rows = [[]]
+ for getter in self._getters:
+ new_rows = []
+ keys = getter.get(obj)
+ if not keys:
+ return []
+ for key in keys:
+ new_rows.extend([row + [key] for row in all_rows])
+ all_rows = new_rows
+ all_rows = ['\x01'.join(row) for row in all_rows]
+ return all_rows
+
+ def add_json(self, doc_id, doc):
+ """Add this json doc to the index."""
+ keys = self.evaluate_json(doc)
+ if not keys:
+ return
+ for key in keys:
+ self._values.setdefault(key, []).append(doc_id)
+
+ def remove_json(self, doc_id, doc):
+ """Remove this json doc from the index."""
+ keys = self.evaluate_json(doc)
+ if keys:
+ for key in keys:
+ doc_ids = self._values[key]
+ doc_ids.remove(doc_id)
+ if not doc_ids:
+ del self._values[key]
+
+ def _find_non_wildcards(self, values):
+ """Check if this should be a wildcard match.
+
+ Further, this will raise an exception if the syntax is improperly
+ defined.
+
+ :return: The offset of the last value we need to match against.
+ """
+ if len(values) != len(self._definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ last = 0
+ for idx, val in enumerate(values):
+ if val.endswith('*'):
+ if val != '*':
+ # We have an 'x*' style wildcard
+ if is_wildcard:
+ # We were already in wildcard mode, so this is invalid
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ # We were in wildcard mode, we can't follow that with
+ # non-wildcard
+ raise errors.InvalidGlobbing
+ last = idx + 1
+ if not is_wildcard:
+ return -1
+ return last
+
+ def lookup(self, values):
+ """Find docs that match the values."""
+ last = self._find_non_wildcards(values)
+ if last == -1:
+ return self._lookup_exact(values)
+ else:
+ return self._lookup_prefix(values[:last])
+
+ def lookup_range(self, start_values, end_values):
+ """Find docs within the range."""
+ # TODO: Wildly inefficient, which is unlikely to be a problem for the
+ # inmemory implementation.
+ if start_values:
+ self._find_non_wildcards(start_values)
+ start_values = get_prefix(start_values)
+ if end_values:
+ if self._find_non_wildcards(end_values) == -1:
+ exact = True
+ else:
+ exact = False
+ end_values = get_prefix(end_values)
+ found = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if start_values and start_values > key:
+ continue
+ if end_values and end_values < key:
+ if exact:
+ break
+ else:
+ if not key.startswith(end_values):
+ break
+ found.extend(doc_ids)
+ return found
+
+ def keys(self):
+ """Find the indexed keys."""
+ return self._values.keys()
+
+ def _lookup_prefix(self, value):
+ """Find docs that match the prefix string in values."""
+ # TODO: We need a different data structure to make prefix style fast,
+ # some sort of sorted list would work, but a plain dict doesn't.
+ key_prefix = get_prefix(value)
+ all_doc_ids = []
+ for key, doc_ids in sorted(self._values.iteritems()):
+ if key.startswith(key_prefix):
+ all_doc_ids.extend(doc_ids)
+ return all_doc_ids
+
+ def _lookup_exact(self, value):
+ """Find docs that match exactly."""
+ key = '\x01'.join(value)
+ if key in self._values:
+ return self._values[key]
+ return ()
+
+
+class InMemorySyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_transaction_id)
diff --git a/src/leap/soledad/u1db/backends/sqlite_backend.py b/src/leap/soledad/u1db/backends/sqlite_backend.py
new file mode 100644
index 00000000..773213b5
--- /dev/null
+++ b/src/leap/soledad/u1db/backends/sqlite_backend.py
@@ -0,0 +1,926 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A U1DB implementation that uses SQLite as its persistence layer."""
+
+import errno
+import os
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from sqlite3 import dbapi2
+import sys
+import time
+import uuid
+
+import pkg_resources
+
+from u1db.backends import CommonBackend, CommonSyncTarget
+from u1db import (
+ Document,
+ errors,
+ query_parser,
+ vectorclock,
+ )
+
+
+class SQLiteDatabase(CommonBackend):
+ """A U1DB implementation that uses SQLite as its persistence layer."""
+
+ _sqlite_registry = {}
+
+ def __init__(self, sqlite_file, document_factory=None):
+ """Create a new sqlite file."""
+ self._db_handle = dbapi2.connect(sqlite_file)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ def get_sync_target(self):
+ return SQLiteSyncTarget(self)
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'index_storage'")
+ except dbapi2.OperationalError, e:
+ # The table does not exist yet
+ return None, e
+ else:
+ return c.fetchone()[0], None
+
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.5
+
+ @classmethod
+ def _open_database(cls, sqlite_file, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLiteDatabase._sqlite_registry[v](
+ sqlite_file, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(
+ sqlite_file, document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLitePartialExpandDatabase
+ backend_cls = SQLitePartialExpandDatabase
+ return backend_cls(sqlite_file, document_factory=document_factory)
+
+ @staticmethod
+ def delete_database(sqlite_file):
+ try:
+ os.unlink(sqlite_file)
+ except OSError as ex:
+ if ex.errno == errno.ENOENT:
+ raise errors.DatabaseDoesNotExist()
+ raise
+
+ @staticmethod
+ def register_implementation(klass):
+ """Register that we implement an SQLiteDatabase.
+
+ The attribute _index_storage_value will be used as the lookup key.
+ """
+ SQLiteDatabase._sqlite_registry[klass._index_storage_value] = klass
+
+ def _get_sqlite_handle(self):
+ """Get access to the underlying sqlite database.
+
+ This should only be used by the test suite, etc, for examining the
+ state of the underlying database.
+ """
+ return self._db_handle
+
+ def _close_sqlite_handle(self):
+ """Release access to the underlying sqlite database."""
+ self._db_handle.close()
+
+ def close(self):
+ self._close_sqlite_handle()
+
+ def _is_initialized(self, c):
+ """Check if this database has been initialized."""
+ c.execute("PRAGMA case_sensitive_like=ON")
+ try:
+ c.execute("SELECT value FROM u1db_config"
+ " WHERE name = 'sql_schema'")
+ except dbapi2.OperationalError:
+ # The table does not exist yet
+ val = None
+ else:
+ val = c.fetchone()
+ if val is not None:
+ return True
+ return False
+
+ def _initialize(self, c):
+ """Create the schema in the database."""
+ #read the script with sql commands
+ # TODO: Change how we set up the dependency. Most likely use something
+ # like lp:dirspec to grab the file from a common resource
+ # directory. Doesn't specifically need to be handled until we get
+ # to the point of packaging this.
+ schema_content = pkg_resources.resource_string(
+ __name__, 'dbschema.sql')
+ # Note: We'd like to use c.executescript() here, but it seems that
+ # executescript always commits, even if you set
+ # isolation_level = None, so if we want to properly handle
+ # exclusive locking and rollbacks between processes, we need
+ # to execute it line-by-line
+ for line in schema_content.split(';'):
+ if not line:
+ continue
+ c.execute(line)
+ #add extra fields
+ self._extra_schema_init(c)
+ # A unique identifier should be set for this replica. Implementations
+ # don't have to strictly use uuid here, but we do want the uid to be
+ # unique amongst all databases that will sync with each other.
+ # We might extend this to using something with hostname for easier
+ # debugging.
+ self._set_replica_uid_in_transaction(uuid.uuid4().hex)
+ c.execute("INSERT INTO u1db_config VALUES" " ('index_storage', ?)",
+ (self._index_storage_value,))
+
+ def _ensure_schema(self):
+ """Ensure that the database schema has been created."""
+ old_isolation_level = self._db_handle.isolation_level
+ c = self._db_handle.cursor()
+ if self._is_initialized(c):
+ return
+ try:
+ # autocommit/own mgmt of transactions
+ self._db_handle.isolation_level = None
+ with self._db_handle:
+ # only one execution path should initialize the db
+ c.execute("begin exclusive")
+ if self._is_initialized(c):
+ return
+ self._initialize(c)
+ finally:
+ self._db_handle.isolation_level = old_isolation_level
+
+ def _extra_schema_init(self, c):
+ """Add any extra fields, etc to the basic table definitions."""
+
+ def _parse_index_definition(self, index_field):
+ """Parse a field definition for an index, returning a Getter."""
+ # Note: We may want to keep a Parser object around, and cache the
+ # Getter objects for a greater length of time. Specifically, if
+ # you create a bunch of indexes, and then insert 50k docs, you'll
+ # re-parse the indexes between puts. The time to insert the docs
+ # is still likely to dominate put_doc time, though.
+ parser = query_parser.Parser()
+ getter = parser.parse(index_field)
+ return getter
+
+ def _update_indexes(self, doc_id, raw_doc, getters, db_cursor):
+ """Update document_fields for a single document.
+
+ :param doc_id: Identifier for this document
+ :param raw_doc: The python dict representation of the document.
+ :param getters: A list of [(field_name, Getter)]. Getter.get will be
+ called to evaluate the index definition for this document, and the
+ results will be inserted into the db.
+ :param db_cursor: An sqlite Cursor.
+ :return: None
+ """
+ values = []
+ for field_name, getter in getters:
+ for idx_value in getter.get(raw_doc):
+ values.append((doc_id, field_name, idx_value))
+ if values:
+ db_cursor.executemany(
+ "INSERT INTO document_fields VALUES (?, ?, ?)", values)
+
+ def _set_replica_uid(self, replica_uid):
+ """Force the replica_uid to be set."""
+ with self._db_handle:
+ self._set_replica_uid_in_transaction(replica_uid)
+
+ def _set_replica_uid_in_transaction(self, replica_uid):
+ """Set the replica_uid. A transaction should already be held."""
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO u1db_config"
+ " VALUES ('replica_uid', ?)",
+ (replica_uid,))
+ self._real_replica_uid = replica_uid
+
+ def _get_replica_uid(self):
+ if self._real_replica_uid is not None:
+ return self._real_replica_uid
+ c = self._db_handle.cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name = 'replica_uid'")
+ val = c.fetchone()
+ if val is None:
+ return None
+ self._real_replica_uid = val[0]
+ return self._real_replica_uid
+
+ _replica_uid = property(_get_replica_uid)
+
+ def _get_generation(self):
+ c = self._db_handle.cursor()
+ c.execute('SELECT max(generation) FROM transaction_log')
+ val = c.fetchone()[0]
+ if val is None:
+ return 0
+ return val
+
+ def _get_generation_info(self):
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT max(generation), transaction_id FROM transaction_log ')
+ val = c.fetchone()
+ if val[0] is None:
+ return(0, '')
+ return val
+
+ def _get_trans_id_for_gen(self, generation):
+ if generation == 0:
+ return ''
+ c = self._db_handle.cursor()
+ c.execute(
+ 'SELECT transaction_id FROM transaction_log WHERE generation = ?',
+ (generation,))
+ val = c.fetchone()
+ if val is None:
+ raise errors.InvalidGeneration
+ return val[0]
+
+ def _get_transaction_log(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, transaction_id FROM transaction_log"
+ " ORDER BY generation")
+ return c.fetchall()
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling."""
+ c = self._db_handle.cursor()
+ if check_for_conflicts:
+ c.execute(
+ "SELECT document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN "
+ "conflicts ON conflicts.doc_id = document.doc_id WHERE "
+ "document.doc_id = ? GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;", (doc_id,))
+ else:
+ c.execute(
+ "SELECT doc_rev, content, 0 FROM document WHERE doc_id = ?",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return None
+ doc_rev, content, conflicts = val
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ return doc
+
+ def _has_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT 1 FROM conflicts WHERE doc_id = ? LIMIT 1",
+ (doc_id,))
+ val = c.fetchone()
+ if val is None:
+ return False
+ else:
+ return True
+
+ def get_doc(self, doc_id, include_deleted=False):
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc is None:
+ return None
+ if doc.is_tombstone() and not include_deleted:
+ return None
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ c = self._db_handle.cursor()
+ c.execute(
+ "SELECT document.doc_id, document.doc_rev, document.content, "
+ "count(conflicts.doc_rev) FROM document LEFT OUTER JOIN conflicts "
+ "ON conflicts.doc_id = document.doc_id GROUP BY document.doc_id, "
+ "document.doc_rev, document.content;")
+ rows = c.fetchall()
+ for doc_id, doc_rev, content, conflicts in rows:
+ if content is None and not include_deleted:
+ continue
+ doc = self._factory(doc_id, doc_rev, content)
+ doc.has_conflicts = conflicts > 0
+ results.append(doc)
+ return (generation, results)
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ self._check_doc_id(doc.doc_id)
+ self._check_doc_size(doc)
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc and old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ if old_doc and doc.rev is None and old_doc.is_tombstone():
+ new_rev = self._allocate_doc_rev(old_doc.rev)
+ else:
+ if old_doc is not None:
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ else:
+ if doc.rev is not None:
+ raise errors.RevisionConflict()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _expand_to_fields(self, doc_id, base_field, raw_doc, save_none):
+ """Convert a dict representation into named fields.
+
+ So something like: {'key1': 'val1', 'key2': 'val2'}
+ gets converted into: [(doc_id, 'key1', 'val1', 0)
+ (doc_id, 'key2', 'val2', 0)]
+ :param doc_id: Just added to every record.
+ :param base_field: if set, these are nested keys, so each field should
+ be appropriately prefixed.
+ :param raw_doc: The python dictionary.
+ """
+ # TODO: Handle lists
+ values = []
+ for field_name, value in raw_doc.iteritems():
+ if value is None and not save_none:
+ continue
+ if base_field:
+ full_name = base_field + '.' + field_name
+ else:
+ full_name = field_name
+ if value is None or isinstance(value, (int, float, basestring)):
+ values.append((doc_id, full_name, value, len(values)))
+ else:
+ subvalues = self._expand_to_fields(doc_id, full_name, value,
+ save_none)
+ for _, subfield_name, val, _ in subvalues:
+ values.append((doc_id, subfield_name, val, len(values)))
+ return values
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ """Actually insert a document into the database.
+
+ This both updates the existing documents content, and any indexes that
+ refer to this document.
+ """
+ raise NotImplementedError(self._put_and_update_indexes)
+
+ def whats_changed(self, old_generation=0):
+ c = self._db_handle.cursor()
+ c.execute("SELECT generation, doc_id, transaction_id"
+ " FROM transaction_log"
+ " WHERE generation > ? ORDER BY generation DESC",
+ (old_generation,))
+ results = c.fetchall()
+ cur_gen = old_generation
+ seen = set()
+ changes = []
+ newest_trans_id = ''
+ for generation, doc_id, trans_id in results:
+ if doc_id not in seen:
+ changes.append((doc_id, generation, trans_id))
+ seen.add(doc_id)
+ if changes:
+ cur_gen = changes[0][1] # max generation
+ newest_trans_id = changes[0][2]
+ changes.reverse()
+ else:
+ c.execute("SELECT generation, transaction_id"
+ " FROM transaction_log ORDER BY generation DESC LIMIT 1")
+ results = c.fetchone()
+ if not results:
+ cur_gen = 0
+ newest_trans_id = ''
+ else:
+ cur_gen, newest_trans_id = results
+
+ return cur_gen, newest_trans_id, changes
+
+ def delete_doc(self, doc):
+ with self._db_handle:
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ def _get_conflicts(self, doc_id):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_rev, content FROM conflicts WHERE doc_id = ?",
+ (doc_id,))
+ return [self._factory(doc_id, doc_rev, content)
+ for doc_rev, content in c.fetchall()]
+
+ def get_doc_conflicts(self, doc_id):
+ with self._db_handle:
+ conflict_docs = self._get_conflicts(doc_id)
+ if not conflict_docs:
+ return []
+ this_doc = self._get_doc(doc_id)
+ this_doc.has_conflicts = True
+ return [this_doc] + conflict_docs
+
+ def _get_replica_gen_and_trans_id(self, other_replica_uid):
+ c = self._db_handle.cursor()
+ c.execute("SELECT known_generation, known_transaction_id FROM sync_log"
+ " WHERE replica_uid = ?",
+ (other_replica_uid,))
+ val = c.fetchone()
+ if val is None:
+ other_gen = 0
+ trans_id = ''
+ else:
+ other_gen = val[0]
+ trans_id = val[1]
+ return other_gen, trans_id
+
+ def _set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation, other_transaction_id):
+ with self._db_handle:
+ self._do_set_replica_gen_and_trans_id(
+ other_replica_uid, other_generation, other_transaction_id)
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ c = self._db_handle.cursor()
+ c.execute("INSERT OR REPLACE INTO sync_log VALUES (?, ?, ?)",
+ (other_replica_uid, other_generation,
+ other_transaction_id))
+
+ def _put_doc_if_newer(self, doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ with self._db_handle:
+ return super(SQLiteDatabase, self)._put_doc_if_newer(doc,
+ save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+
+ def _add_conflict(self, c, doc_id, my_doc_rev, my_content):
+ c.execute("INSERT INTO conflicts VALUES (?, ?, ?)",
+ (doc_id, my_doc_rev, my_content))
+
+ def _delete_conflicts(self, c, doc, conflict_revs):
+ deleting = [(doc.doc_id, c_rev) for c_rev in conflict_revs]
+ c.executemany("DELETE FROM conflicts"
+ " WHERE doc_id=? AND doc_rev=?", deleting)
+ doc.has_conflicts = self._has_conflicts(doc.doc_id)
+
+ def _prune_conflicts(self, doc, doc_vcr):
+ if self._has_conflicts(doc.doc_id):
+ autoresolved = False
+ c_revs_to_prune = []
+ for c_doc in self._get_conflicts(doc.doc_id):
+ c_vcr = vectorclock.VectorClockRev(c_doc.rev)
+ if doc_vcr.is_newer(c_vcr):
+ c_revs_to_prune.append(c_doc.rev)
+ elif doc.same_content_as(c_doc):
+ c_revs_to_prune.append(c_doc.rev)
+ doc_vcr.maximize(c_vcr)
+ autoresolved = True
+ if autoresolved:
+ doc_vcr.increment(self._replica_uid)
+ doc.rev = doc_vcr.as_str()
+ c = self._db_handle.cursor()
+ self._delete_conflicts(c, doc, c_revs_to_prune)
+
+ def _force_doc_sync_conflict(self, doc):
+ my_doc = self._get_doc(doc.doc_id)
+ c = self._db_handle.cursor()
+ self._prune_conflicts(doc, vectorclock.VectorClockRev(doc.rev))
+ self._add_conflict(c, doc.doc_id, my_doc.rev, my_doc.get_json())
+ doc.has_conflicts = True
+ self._put_and_update_indexes(my_doc, doc)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ with self._db_handle:
+ cur_doc = self._get_doc(doc.doc_id)
+ # TODO: https://bugs.launchpad.net/u1db/+bug/928274
+ # I think we have a logic bug in resolve_doc
+ # Specifically, cur_doc.rev is always in the final vector
+ # clock of revisions that we supersede, even if it wasn't in
+ # conflicted_doc_revs. We still add it as a conflict, but the
+ # fact that _put_doc_if_newer propagates resolutions means I
+ # think that conflict could accidentally be resolved. We need
+ # to add a test for this case first. (create a rev, create a
+ # conflict, create another conflict, resolve the first rev
+ # and first conflict, then make sure that the resolved
+ # rev doesn't supersede the second conflict rev.) It *might*
+ # not matter, because the superseding rev is in as a
+ # conflict, but it does seem incorrect
+ new_rev = self._ensure_maximal_rev(cur_doc.rev,
+ conflicted_doc_revs)
+ superseded_revs = set(conflicted_doc_revs)
+ c = self._db_handle.cursor()
+ doc.rev = new_rev
+ if cur_doc.rev in superseded_revs:
+ self._put_and_update_indexes(cur_doc, doc)
+ else:
+ self._add_conflict(c, doc.doc_id, new_rev, doc.get_json())
+ # TODO: Is there some way that we could construct a rev that would
+ # end up in superseded_revs, such that we add a conflict, and
+ # then immediately delete it?
+ self._delete_conflicts(c, doc, superseded_revs)
+
+ def list_indexes(self):
+ """Return the list of indexes and their definitions."""
+ c = self._db_handle.cursor()
+ # TODO: How do we test the ordering?
+ c.execute("SELECT name, field FROM index_definitions"
+ " ORDER BY name, offset")
+ definitions = []
+ cur_name = None
+ for name, field in c.fetchall():
+ if cur_name != name:
+ definitions.append((name, []))
+ cur_name = name
+ definitions[-1][-1].append(field)
+ return definitions
+
+ def _get_index_definition(self, index_name):
+ """Return the stored definition for a given index_name."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions"
+ " WHERE name = ? ORDER BY offset", (index_name,))
+ fields = [x[0] for x in c.fetchall()]
+ if not fields:
+ raise errors.IndexDoesNotExist
+ return fields
+
+ @staticmethod
+ def _strip_glob(value):
+ """Remove the trailing * from a value."""
+ assert value[-1] == '*'
+ return value[:-1]
+
+ def _format_query(self, definition, key_values):
+ # First, build the definition. We join the document_fields table
+ # against itself, as many times as the 'width' of our definition.
+ # We then do a query for each key_value, one-at-a-time.
+ # Note: All of these strings are static, we could cache them, etc.
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = ["d.doc_id = d%d.doc_id"
+ " AND d%d.field_name = ?"
+ % (i, i) for i in range(len(definition))]
+ wildcard_where = [novalue_where[i]
+ + (" AND d%d.value NOT NULL" % (i,))
+ for i in range(len(definition))]
+ exact_where = [novalue_where[i]
+ + (" AND d%d.value = ?" % (i,))
+ for i in range(len(definition))]
+ like_where = [novalue_where[i]
+ + (" AND d%d.value GLOB ?" % (i,))
+ for i in range(len(definition))]
+ is_wildcard = False
+ # Merge the lists together, so that:
+ # [field1, field2, field3], [val1, val2, val3]
+ # Becomes:
+ # (field1, val1, field2, val2, field3, val3)
+ args = []
+ where = []
+ for idx, (field, value) in enumerate(zip(definition, key_values)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(exact_where[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_from_index(self, index_name, *key_values):
+ definition = self._get_index_definition(index_name)
+ if len(key_values) != len(definition):
+ raise errors.InvalidValueForIndex()
+ statement, args = self._format_query(definition, key_values)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def _format_range_query(self, definition, start_value, end_value):
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ wildcard_where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ like_where = [
+ novalue_where[i] + (
+ " AND (d%d.value < ? OR d%d.value GLOB ?)" % (i, i)) for i in
+ range(len(definition))]
+ range_where_lower = [
+ novalue_where[i] + (" AND d%d.value >= ?" % (i,)) for i in
+ range(len(definition))]
+ range_where_upper = [
+ novalue_where[i] + (" AND d%d.value <= ?" % (i,)) for i in
+ range(len(definition))]
+ args = []
+ where = []
+ if start_value:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ if len(start_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, start_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(self._strip_glob(value))
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_lower[idx])
+ args.append(value)
+ if end_value:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ if len(end_value) != len(definition):
+ raise errors.InvalidValueForIndex()
+ is_wildcard = False
+ for idx, (field, value) in enumerate(zip(definition, end_value)):
+ args.append(field)
+ if value.endswith('*'):
+ if value == '*':
+ where.append(wildcard_where[idx])
+ else:
+ # This is a glob match
+ if is_wildcard:
+ # We can't have a partial wildcard following
+ # another wildcard
+ raise errors.InvalidGlobbing
+ where.append(like_where[idx])
+ args.append(self._strip_glob(value))
+ args.append(value)
+ is_wildcard = True
+ else:
+ if is_wildcard:
+ raise errors.InvalidGlobbing
+ where.append(range_where_upper[idx])
+ args.append(value)
+ statement = (
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, %s LEFT OUTER JOIN conflicts c ON c.doc_id = "
+ "d.doc_id WHERE %s GROUP BY d.doc_id, d.doc_rev, d.content ORDER "
+ "BY %s;" % (', '.join(tables), ' AND '.join(where), ', '.join(
+ ['d%d.value' % i for i in range(len(definition))])))
+ return statement, args
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ """Return all documents with key values in the specified range."""
+ definition = self._get_index_definition(index_name)
+ statement, args = self._format_range_query(
+ definition, start_value, end_value)
+ c = self._db_handle.cursor()
+ try:
+ c.execute(statement, tuple(args))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, args))
+ res = c.fetchall()
+ results = []
+ for row in res:
+ doc = self._factory(row[0], row[1], row[2])
+ doc.has_conflicts = row[3] > 0
+ results.append(doc)
+ return results
+
+ def get_index_keys(self, index_name):
+ c = self._db_handle.cursor()
+ definition = self._get_index_definition(index_name)
+ value_fields = ', '.join([
+ 'd%d.value' % i for i in range(len(definition))])
+ tables = ["document_fields d%d" % i for i in range(len(definition))]
+ novalue_where = [
+ "d.doc_id = d%d.doc_id AND d%d.field_name = ?" % (i, i) for i in
+ range(len(definition))]
+ where = [
+ novalue_where[i] + (" AND d%d.value NOT NULL" % (i,)) for i in
+ range(len(definition))]
+ statement = (
+ "SELECT %s FROM document d, %s WHERE %s GROUP BY %s;" % (
+ value_fields, ', '.join(tables), ' AND '.join(where),
+ value_fields))
+ try:
+ c.execute(statement, tuple(definition))
+ except dbapi2.OperationalError, e:
+ raise dbapi2.OperationalError(str(e) +
+ '\nstatement: %s\nargs: %s\n' % (statement, tuple(definition)))
+ return c.fetchall()
+
+ def delete_index(self, index_name):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ c.execute("DELETE FROM index_definitions WHERE name = ?",
+ (index_name,))
+ c.execute(
+ "DELETE FROM document_fields WHERE document_fields.field_name "
+ " NOT IN (SELECT field from index_definitions)")
+
+
+class SQLiteSyncTarget(CommonSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
+
+
+class SQLitePartialExpandDatabase(SQLiteDatabase):
+ """An SQLite Backend that expands documents into a document_field table.
+
+ It stores the original document text in document.doc. For fields that are
+ indexed, the data goes into document_fields.
+ """
+
+ _index_storage_value = 'expand referenced'
+
+ def _get_indexed_fields(self):
+ """Determine what fields are indexed."""
+ c = self._db_handle.cursor()
+ c.execute("SELECT field FROM index_definitions")
+ return set([x[0] for x in c.fetchall()])
+
+ def _evaluate_index(self, raw_doc, field):
+ parser = query_parser.Parser()
+ getter = parser.parse(field)
+ return getter.get(raw_doc)
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ c = self._db_handle.cursor()
+ if doc and not doc.is_tombstone():
+ raw_doc = json.loads(doc.get_json())
+ else:
+ raw_doc = {}
+ if old_doc is not None:
+ c.execute("UPDATE document SET doc_rev=?, content=?"
+ " WHERE doc_id = ?",
+ (doc.rev, doc.get_json(), doc.doc_id))
+ c.execute("DELETE FROM document_fields WHERE doc_id = ?",
+ (doc.doc_id,))
+ else:
+ c.execute("INSERT INTO document (doc_id, doc_rev, content)"
+ " VALUES (?, ?, ?)",
+ (doc.doc_id, doc.rev, doc.get_json()))
+ indexed_fields = self._get_indexed_fields()
+ if indexed_fields:
+ # It is expected that len(indexed_fields) is shorter than
+ # len(raw_doc)
+ getters = [(field, self._parse_index_definition(field))
+ for field in indexed_fields]
+ self._update_indexes(doc.doc_id, raw_doc, getters, c)
+ trans_id = self._allocate_transaction_id()
+ c.execute("INSERT INTO transaction_log(doc_id, transaction_id)"
+ " VALUES (?, ?)", (doc.doc_id, trans_id))
+
+ def create_index(self, index_name, *index_expressions):
+ with self._db_handle:
+ c = self._db_handle.cursor()
+ cur_fields = self._get_indexed_fields()
+ definition = [(index_name, idx, field)
+ for idx, field in enumerate(index_expressions)]
+ try:
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ definition)
+ except dbapi2.IntegrityError as e:
+ stored_def = self._get_index_definition(index_name)
+ if stored_def == [x[-1] for x in definition]:
+ return
+ raise errors.IndexNameTakenError, e, sys.exc_info()[2]
+ new_fields = set(
+ [f for f in index_expressions if f not in cur_fields])
+ if new_fields:
+ self._update_all_indexes(new_fields)
+
+ def _iter_all_docs(self):
+ c = self._db_handle.cursor()
+ c.execute("SELECT doc_id, content FROM document")
+ while True:
+ next_rows = c.fetchmany()
+ if not next_rows:
+ break
+ for row in next_rows:
+ yield row
+
+ def _update_all_indexes(self, new_fields):
+ """Iterate all the documents, and add content to document_fields.
+
+ :param new_fields: The index definitions that need to be added.
+ """
+ getters = [(field, self._parse_index_definition(field))
+ for field in new_fields]
+ c = self._db_handle.cursor()
+ for doc_id, doc in self._iter_all_docs():
+ if doc is None:
+ continue
+ raw_doc = json.loads(doc)
+ self._update_indexes(doc_id, raw_doc, getters, c)
+
+SQLiteDatabase.register_implementation(SQLitePartialExpandDatabase)
diff --git a/src/leap/soledad/u1db/commandline/__init__.py b/src/leap/soledad/u1db/commandline/__init__.py
new file mode 100644
index 00000000..3f32e381
--- /dev/null
+++ b/src/leap/soledad/u1db/commandline/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
diff --git a/src/leap/soledad/u1db/commandline/client.py b/src/leap/soledad/u1db/commandline/client.py
new file mode 100644
index 00000000..15bf8561
--- /dev/null
+++ b/src/leap/soledad/u1db/commandline/client.py
@@ -0,0 +1,497 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Commandline bindings for the u1db-client program."""
+
+import argparse
+import os
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import sys
+
+from u1db import (
+ Document,
+ open as u1db_open,
+ sync,
+ errors,
+ )
+from u1db.commandline import command
+from u1db.remote import (
+ http_database,
+ http_target,
+ )
+
+
+client_commands = command.CommandGroup()
+
+
+def set_oauth_credentials(client):
+ keys = os.environ.get('OAUTH_CREDENTIALS', None)
+ if keys is not None:
+ consumer_key, consumer_secret, \
+ token_key, token_secret = keys.split(":")
+ client.set_oauth_credentials(consumer_key, consumer_secret,
+ token_key, token_secret)
+
+
+class OneDbCmd(command.Command):
+ """Base class for commands operating on one local or remote database."""
+
+ def _open(self, database, create):
+ if database.startswith(('http://', 'https://')):
+ db = http_database.HTTPDatabase(database)
+ set_oauth_credentials(db)
+ db.open(create)
+ return db
+ else:
+ return u1db_open(database, create)
+
+
+class CmdCreate(OneDbCmd):
+ """Create a new document from scratch"""
+
+ name = 'create'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url')
+ parser.add_argument('infile', nargs='?', default=None,
+ help='The file to read content from.')
+ parser.add_argument('--id', dest='doc_id', default=None,
+ help='Set the document identifier')
+
+ def run(self, database, infile, doc_id):
+ if infile is None:
+ infile = self.stdin
+ db = self._open(database, create=False)
+ doc = db.create_doc_from_json(infile.read(), doc_id=doc_id)
+ self.stderr.write('id: %s\nrev: %s\n' % (doc.doc_id, doc.rev))
+
+client_commands.register(CmdCreate)
+
+
+class CmdDelete(OneDbCmd):
+ """Delete a document from the database"""
+
+ name = 'delete'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url')
+ parser.add_argument('doc_id', help='The document id to retrieve')
+ parser.add_argument('doc_rev',
+ help='The revision of the document (which is being superseded.)')
+
+ def run(self, database, doc_id, doc_rev):
+ db = self._open(database, create=False)
+ doc = Document(doc_id, doc_rev, None)
+ db.delete_doc(doc)
+ self.stderr.write('rev: %s\n' % (doc.rev,))
+
+client_commands.register(CmdDelete)
+
+
+class CmdGet(OneDbCmd):
+ """Extract a document from the database"""
+
+ name = 'get'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to query',
+ metavar='database-path-or-url')
+ parser.add_argument('doc_id', help='The document id to retrieve.')
+ parser.add_argument('outfile', nargs='?', default=None,
+ help='The file to write the document to',
+ type=argparse.FileType('wb'))
+
+ def run(self, database, doc_id, outfile):
+ if outfile is None:
+ outfile = self.stdout
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ doc = db.get_doc(doc_id)
+ if doc is None:
+ self.stderr.write('Document not found (id: %s)\n' % (doc_id,))
+ return 1 # failed
+ if doc.is_tombstone():
+ outfile.write('[document deleted]\n')
+ else:
+ outfile.write(doc.get_json() + '\n')
+ self.stderr.write('rev: %s\n' % (doc.rev,))
+ if doc.has_conflicts:
+ self.stderr.write("Document has conflicts.\n")
+
+client_commands.register(CmdGet)
+
+
+class CmdGetDocConflicts(OneDbCmd):
+ """Get the conflicts from a document"""
+
+ name = 'get-doc-conflicts'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local database to query',
+ metavar='database-path')
+ parser.add_argument('doc_id', help='The document id to retrieve.')
+
+ def run(self, database, doc_id):
+ try:
+ db = self._open(database, False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ conflicts = db.get_doc_conflicts(doc_id)
+ if not conflicts:
+ if db.get_doc(doc_id) is None:
+ self.stderr.write("Document does not exist.\n")
+ return 1
+ self.stdout.write("[")
+ for i, doc in enumerate(conflicts):
+ if i:
+ self.stdout.write(",")
+ self.stdout.write(
+ json.dumps(dict(rev=doc.rev, content=doc.content), indent=4))
+ self.stdout.write("]\n")
+
+client_commands.register(CmdGetDocConflicts)
+
+
+class CmdInitDB(OneDbCmd):
+ """Create a new database"""
+
+ name = 'init-db'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to create',
+ metavar='database-path-or-url')
+ parser.add_argument('--replica-uid', default=None,
+ help='The unique identifier for this database (not for remote)')
+
+ def run(self, database, replica_uid):
+ db = self._open(database, create=True)
+ if replica_uid is not None:
+ db._set_replica_uid(replica_uid)
+
+client_commands.register(CmdInitDB)
+
+
+class CmdPut(OneDbCmd):
+ """Add a document to the database"""
+
+ name = 'put'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url'),
+ parser.add_argument('doc_id', help='The document id to retrieve')
+ parser.add_argument('doc_rev',
+ help='The revision of the document (which is being superseded.)')
+ parser.add_argument('infile', nargs='?', default=None,
+ help='The filename of the document that will be used for content',
+ type=argparse.FileType('rb'))
+
+ def run(self, database, doc_id, doc_rev, infile):
+ if infile is None:
+ infile = self.stdin
+ try:
+ db = self._open(database, create=False)
+ doc = Document(doc_id, doc_rev, infile.read())
+ doc_rev = db.put_doc(doc)
+ self.stderr.write('rev: %s\n' % (doc_rev,))
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ except errors.RevisionConflict:
+ if db.get_doc(doc_id) is None:
+ self.stderr.write("Document does not exist.\n")
+ else:
+ self.stderr.write("Given revision is not current.\n")
+ except errors.ConflictedDoc:
+ self.stderr.write(
+ "Document has conflicts.\n"
+ "Inspect with get-doc-conflicts, then resolve.\n")
+ else:
+ return
+ return 1
+
+client_commands.register(CmdPut)
+
+
+class CmdResolve(OneDbCmd):
+ """Resolve a conflicted document"""
+
+ name = 'resolve-doc'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database',
+ help='The local or remote database to update',
+ metavar='database-path-or-url'),
+ parser.add_argument('doc_id', help='The conflicted document id')
+ parser.add_argument('doc_revs', metavar="doc-rev", nargs="+",
+ help='The revisions that the new content supersedes')
+ parser.add_argument('--infile', nargs='?', default=None,
+ help='The filename of the document that will be used for content',
+ type=argparse.FileType('rb'))
+
+ def run(self, database, doc_id, doc_revs, infile):
+ if infile is None:
+ infile = self.stdin
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ doc = db.get_doc(doc_id)
+ if doc is None:
+ self.stderr.write("Document does not exist.\n")
+ return 1
+ doc.set_json(infile.read())
+ db.resolve_doc(doc, doc_revs)
+ self.stderr.write("rev: %s\n" % db.get_doc(doc_id).rev)
+ if doc.has_conflicts:
+ self.stderr.write("Document still has conflicts.\n")
+
+client_commands.register(CmdResolve)
+
+
+class CmdSync(command.Command):
+ """Synchronize two databases"""
+
+ name = 'sync'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('source', help='database to sync from')
+ parser.add_argument('target', help='database to sync to')
+
+ def _open_target(self, target):
+ if target.startswith(('http://', 'https://')):
+ st = http_target.HTTPSyncTarget.connect(target)
+ set_oauth_credentials(st)
+ else:
+ db = u1db_open(target, create=True)
+ st = db.get_sync_target()
+ return st
+
+ def run(self, source, target):
+ """Start a Sync request."""
+ source_db = u1db_open(source, create=False)
+ st = self._open_target(target)
+ syncer = sync.Synchronizer(source_db, st)
+ syncer.sync()
+ source_db.close()
+
+client_commands.register(CmdSync)
+
+
+class CmdCreateIndex(OneDbCmd):
+ """Create an index"""
+
+ name = "create-index"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to update',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+ parser.add_argument('expression', help='an index expression',
+ nargs='+')
+
+ def run(self, database, index, expression):
+ try:
+ db = self._open(database, create=False)
+ db.create_index(index, *expression)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ except errors.IndexNameTakenError:
+ self.stderr.write("There is already a different index named %r.\n"
+ % (index,))
+ return 1
+ except errors.IndexDefinitionParseError:
+ self.stderr.write("Bad index expression.\n")
+ return 1
+
+client_commands.register(CmdCreateIndex)
+
+
+class CmdListIndexes(OneDbCmd):
+ """List existing indexes"""
+
+ name = "list-indexes"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to query',
+ metavar='database-path')
+
+ def run(self, database):
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ for (index, expression) in db.list_indexes():
+ self.stdout.write("%s: %s\n" % (index, ", ".join(expression)))
+
+client_commands.register(CmdListIndexes)
+
+
+class CmdDeleteIndex(OneDbCmd):
+ """Delete an index"""
+
+ name = "delete-index"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to update',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+
+ def run(self, database, index):
+ try:
+ db = self._open(database, create=False)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ return 1
+ db.delete_index(index)
+
+client_commands.register(CmdDeleteIndex)
+
+
+class CmdGetIndexKeys(OneDbCmd):
+ """Get the index's keys"""
+
+ name = "get-index-keys"
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to query',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+
+ def run(self, database, index):
+ try:
+ db = self._open(database, create=False)
+ for key in db.get_index_keys(index):
+ self.stdout.write("%s\n" % (", ".join(
+ [i.encode('utf-8') for i in key],)))
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ except errors.IndexDoesNotExist:
+ self.stderr.write("Index does not exist.\n")
+ else:
+ return
+ return 1
+
+client_commands.register(CmdGetIndexKeys)
+
+
+class CmdGetFromIndex(OneDbCmd):
+ """Find documents by searching an index"""
+
+ name = "get-from-index"
+ argv = None
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('database', help='The local database to query',
+ metavar='database-path')
+ parser.add_argument('index', help='the name of the index')
+ parser.add_argument('values', metavar="value",
+ help='the value to look up (one per index column)',
+ nargs="+")
+
+ def run(self, database, index, values):
+ try:
+ db = self._open(database, create=False)
+ docs = db.get_from_index(index, *values)
+ except errors.DatabaseDoesNotExist:
+ self.stderr.write("Database does not exist.\n")
+ except errors.IndexDoesNotExist:
+ self.stderr.write("Index does not exist.\n")
+ except errors.InvalidValueForIndex:
+ index_def = db._get_index_definition(index)
+ len_diff = len(index_def) - len(values)
+ if len_diff == 0:
+ # can't happen (HAH)
+ raise
+ argv = self.argv if self.argv is not None else sys.argv
+ self.stderr.write(
+ "Invalid query: "
+ "index %r requires %d query expression%s%s.\n"
+ "For example, the following would be valid:\n"
+ " %s %s %r %r %s\n"
+ % (index,
+ len(index_def),
+ "s" if len(index_def) > 1 else "",
+ ", not %d" % len(values) if len(values) else "",
+ argv[0], argv[1], database, index,
+ " ".join(map(repr,
+ values[:len(index_def)]
+ + ["*" for i in range(len_diff)])),
+ ))
+ except errors.InvalidGlobbing:
+ argv = self.argv if self.argv is not None else sys.argv
+ fixed = []
+ for (i, v) in enumerate(values):
+ fixed.append(v)
+ if v.endswith('*'):
+ break
+ # values has at least one element, so i is defined
+ fixed.extend('*' * (len(values) - i - 1))
+ self.stderr.write(
+ "Invalid query: a star can only be followed by stars.\n"
+ "For example, the following would be valid:\n"
+ " %s %s %r %r %s\n"
+ % (argv[0], argv[1], database, index,
+ " ".join(map(repr, fixed))))
+
+ else:
+ self.stdout.write("[")
+ for i, doc in enumerate(docs):
+ if i:
+ self.stdout.write(",")
+ self.stdout.write(
+ json.dumps(
+ dict(id=doc.doc_id, rev=doc.rev, content=doc.content),
+ indent=4))
+ self.stdout.write("]\n")
+ return
+ return 1
+
+client_commands.register(CmdGetFromIndex)
+
+
+def main(args):
+ return client_commands.run_argv(args, sys.stdin, sys.stdout, sys.stderr)
diff --git a/src/leap/soledad/u1db/commandline/command.py b/src/leap/soledad/u1db/commandline/command.py
new file mode 100644
index 00000000..eace0560
--- /dev/null
+++ b/src/leap/soledad/u1db/commandline/command.py
@@ -0,0 +1,80 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Command infrastructure for u1db"""
+
+import argparse
+import inspect
+
+
+class CommandGroup(object):
+ """A collection of commands."""
+
+ def __init__(self, description=None):
+ self.commands = {}
+ self.description = description
+
+ def register(self, cmd):
+ """Register a new command to be incorporated with this group."""
+ self.commands[cmd.name] = cmd
+
+ def make_argparser(self):
+ """Create an argparse.ArgumentParser"""
+ parser = argparse.ArgumentParser(description=self.description)
+ subs = parser.add_subparsers(title='commands')
+ for name, cmd in sorted(self.commands.iteritems()):
+ sub = subs.add_parser(name, help=cmd.__doc__)
+ sub.set_defaults(subcommand=cmd)
+ cmd._populate_subparser(sub)
+ return parser
+
+ def run_argv(self, argv, stdin, stdout, stderr):
+ """Run a command, from a sys.argv[1:] style input."""
+ parser = self.make_argparser()
+ args = parser.parse_args(argv)
+ cmd = args.subcommand(stdin, stdout, stderr)
+ params, _, _, _ = inspect.getargspec(cmd.run)
+ vals = []
+ for param in params[1:]:
+ vals.append(getattr(args, param))
+ return cmd.run(*vals)
+
+
+class Command(object):
+ """Definition of a Command that can be run.
+
+ :cvar name: The name of the command, so that you can run
+ 'u1db-client <name>'.
+ """
+
+ name = None
+
+ def __init__(self, stdin, stdout, stderr):
+ self.stdin = stdin
+ self.stdout = stdout
+ self.stderr = stderr
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ """Child classes should override this to provide their arguments."""
+ raise NotImplementedError(cls._populate_subparser)
+
+ def run(self, *args):
+ """This is where the magic happens.
+
+ Subclasses should implement this, requesting their specific arguments.
+ """
+ raise NotImplementedError(self.run)
diff --git a/src/leap/soledad/u1db/commandline/serve.py b/src/leap/soledad/u1db/commandline/serve.py
new file mode 100644
index 00000000..0bb0e641
--- /dev/null
+++ b/src/leap/soledad/u1db/commandline/serve.py
@@ -0,0 +1,34 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Build server for u1db-serve."""
+
+from paste import httpserver
+
+from u1db.remote import (
+ http_app,
+ server_state,
+ )
+
+
+def make_server(host, port, working_dir):
+ """Make a server on host and port exposing dbs living in working_dir."""
+ state = server_state.ServerState()
+ state.set_workingdir(working_dir)
+ application = http_app.HTTPApp(state)
+ server = httpserver.WSGIServer(application, (host, port),
+ httpserver.WSGIHandler)
+ return server
diff --git a/src/leap/soledad/u1db/errors.py b/src/leap/soledad/u1db/errors.py
new file mode 100644
index 00000000..967c7c38
--- /dev/null
+++ b/src/leap/soledad/u1db/errors.py
@@ -0,0 +1,189 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A list of errors that u1db can raise."""
+
+
+class U1DBError(Exception):
+ """Generic base class for U1DB errors."""
+
+ # description/tag for identifying the error during transmission (http,...)
+ wire_description = "error"
+
+ def __init__(self, message=None):
+ self.message = message
+
+
+class RevisionConflict(U1DBError):
+ """The document revisions supplied does not match the current version."""
+
+ wire_description = "revision conflict"
+
+
+class InvalidJSON(U1DBError):
+ """Content was not valid json."""
+
+
+class InvalidContent(U1DBError):
+ """Content was not a python dictionary."""
+
+
+class InvalidDocId(U1DBError):
+ """A document was requested with an invalid document identifier."""
+
+ wire_description = "invalid document id"
+
+
+class MissingDocIds(U1DBError):
+ """Needs document ids."""
+
+ wire_description = "missing document ids"
+
+
+class DocumentTooBig(U1DBError):
+ """Document exceeds the maximum document size for this database."""
+
+ wire_description = "document too big"
+
+
+class UserQuotaExceeded(U1DBError):
+ """Document exceeds the maximum document size for this database."""
+
+ wire_description = "user quota exceeded"
+
+
+class SubscriptionNeeded(U1DBError):
+ """User needs a subscription to be able to use this replica.."""
+
+ wire_description = "user needs subscription"
+
+
+class InvalidTransactionId(U1DBError):
+ """Invalid transaction for generation."""
+
+ wire_description = "invalid transaction id"
+
+
+class InvalidGeneration(U1DBError):
+ """Generation was previously synced with a different transaction id."""
+
+ wire_description = "invalid generation"
+
+
+class ConflictedDoc(U1DBError):
+ """The document is conflicted, you must call resolve before put()"""
+
+
+class InvalidValueForIndex(U1DBError):
+ """The values supplied does not match the index definition."""
+
+
+class InvalidGlobbing(U1DBError):
+ """Raised if wildcard matches are not strictly at the tail of the request.
+ """
+
+
+class DocumentDoesNotExist(U1DBError):
+ """The document does not exist."""
+
+ wire_description = "document does not exist"
+
+
+class DocumentAlreadyDeleted(U1DBError):
+ """The document was already deleted."""
+
+ wire_description = "document already deleted"
+
+
+class DatabaseDoesNotExist(U1DBError):
+ """The database does not exist."""
+
+ wire_description = "database does not exist"
+
+
+class IndexNameTakenError(U1DBError):
+ """The given index name is already taken."""
+
+
+class IndexDefinitionParseError(U1DBError):
+ """The index definition cannot be parsed."""
+
+
+class IndexDoesNotExist(U1DBError):
+ """No index of that name exists."""
+
+
+class Unauthorized(U1DBError):
+ """Request wasn't authorized properly."""
+
+ wire_description = "unauthorized"
+
+
+class HTTPError(U1DBError):
+ """Unspecific HTTP errror."""
+
+ wire_description = None
+
+ def __init__(self, status, message=None, headers={}):
+ self.status = status
+ self.message = message
+ self.headers = headers
+
+ def __str__(self):
+ if not self.message:
+ return "HTTPError(%d)" % self.status
+ else:
+ return "HTTPError(%d, %r)" % (self.status, self.message)
+
+
+class Unavailable(HTTPError):
+ """Server not available not serve request."""
+
+ wire_description = "unavailable"
+
+ def __init__(self, message=None, headers={}):
+ super(Unavailable, self).__init__(503, message, headers)
+
+ def __str__(self):
+ if not self.message:
+ return "Unavailable()"
+ else:
+ return "Unavailable(%r)" % self.message
+
+
+class BrokenSyncStream(U1DBError):
+ """Unterminated or otherwise broken sync exchange stream."""
+
+ wire_description = None
+
+
+class UnknownAuthMethod(U1DBError):
+ """Unknown auhorization method."""
+
+ wire_description = None
+
+
+# mapping wire (transimission) descriptions/tags for errors to the exceptions
+wire_description_to_exc = dict(
+ (x.wire_description, x) for x in globals().values()
+ if getattr(x, 'wire_description', None) not in (None, "error")
+)
+wire_description_to_exc["error"] = U1DBError
+
+
+#
+# wire error descriptions not corresponding to an exception
+DOCUMENT_DELETED = "document deleted"
diff --git a/src/leap/soledad/u1db/query_parser.py b/src/leap/soledad/u1db/query_parser.py
new file mode 100644
index 00000000..f564821f
--- /dev/null
+++ b/src/leap/soledad/u1db/query_parser.py
@@ -0,0 +1,370 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Code for parsing Index definitions."""
+
+import re
+from u1db import (
+ errors,
+ )
+
+
+class Getter(object):
+ """Get values from a document based on a specification."""
+
+ def get(self, raw_doc):
+ """Get a value from the document.
+
+ :param raw_doc: a python dictionary to get the value from.
+ :return: A list of values that match the description.
+ """
+ raise NotImplementedError(self.get)
+
+
+class StaticGetter(Getter):
+ """A getter that returns a defined value (independent of the doc)."""
+
+ def __init__(self, value):
+ """Create a StaticGetter.
+
+ :param value: the value to return when get is called.
+ """
+ if value is None:
+ self.value = []
+ elif isinstance(value, list):
+ self.value = value
+ else:
+ self.value = [value]
+
+ def get(self, raw_doc):
+ return self.value
+
+
+def extract_field(raw_doc, subfields, index=0):
+ if not isinstance(raw_doc, dict):
+ return []
+ val = raw_doc.get(subfields[index])
+ if val is None:
+ return []
+ if index < len(subfields) - 1:
+ if isinstance(val, list):
+ results = []
+ for item in val:
+ results.extend(extract_field(item, subfields, index + 1))
+ return results
+ if isinstance(val, dict):
+ return extract_field(val, subfields, index + 1)
+ return []
+ if isinstance(val, dict):
+ return []
+ if isinstance(val, list):
+ # Strip anything in the list that isn't a simple type
+ return [v for v in val if not isinstance(v, (dict, list))]
+ return [val]
+
+
+class ExtractField(Getter):
+ """Extract a field from the document."""
+
+ def __init__(self, field):
+ """Create an ExtractField object.
+
+ When a document is passed to get() this will return a value
+ from the document based on the field specifier passed to
+ the constructor.
+
+ None will be returned if the field is nonexistant, or refers to an
+ object, rather than a simple type or list of simple types.
+
+ :param field: a specifier for the field to return.
+ This is either a field name, or a dotted field name.
+ """
+ self.field = field.split('.')
+
+ def get(self, raw_doc):
+ return extract_field(raw_doc, self.field)
+
+
+class Transformation(Getter):
+ """A transformation on a value from another Getter."""
+
+ name = None
+ arity = 1
+ args = ['expression']
+
+ def __init__(self, inner):
+ """Create a transformation.
+
+ :param inner: the argument(s) to the transformation.
+ """
+ self.inner = inner
+
+ def get(self, raw_doc):
+ inner_values = self.inner.get(raw_doc)
+ assert isinstance(inner_values, list),\
+ 'get() should always return a list'
+ return self.transform(inner_values)
+
+ def transform(self, values):
+ """Transform the values.
+
+ This should be implemented by subclasses to transform the
+ value when get() is called.
+
+ :param values: the values from the other Getter
+ :return: the transformed values.
+ """
+ raise NotImplementedError(self.transform)
+
+
+class Lower(Transformation):
+ """Lowercase a string.
+
+ This transformation will return None for non-string inputs. However,
+ it will lowercase any strings in a list, dropping any elements
+ that are not strings.
+ """
+
+ name = "lower"
+
+ def _can_transform(self, val):
+ return isinstance(val, basestring)
+
+ def transform(self, values):
+ if not values:
+ return []
+ return [val.lower() for val in values if self._can_transform(val)]
+
+
+class Number(Transformation):
+ """Convert an integer to a zero padded string.
+
+ This transformation will return None for non-integer inputs. However, it
+ will transform any integers in a list, dropping any elements that are not
+ integers.
+ """
+
+ name = 'number'
+ arity = 2
+ args = ['expression', int]
+
+ def __init__(self, inner, number):
+ super(Number, self).__init__(inner)
+ self.padding = "%%0%sd" % number
+
+ def _can_transform(self, val):
+ return isinstance(val, int) and not isinstance(val, bool)
+
+ def transform(self, values):
+ """Transform any integers in values into zero padded strings."""
+ if not values:
+ return []
+ return [self.padding % (v,) for v in values if self._can_transform(v)]
+
+
+class Bool(Transformation):
+ """Convert bool to string."""
+
+ name = "bool"
+ args = ['expression']
+
+ def _can_transform(self, val):
+ return isinstance(val, bool)
+
+ def transform(self, values):
+ """Transform any booleans in values into strings."""
+ if not values:
+ return []
+ return [('1' if v else '0') for v in values if self._can_transform(v)]
+
+
+class SplitWords(Transformation):
+ """Split a string on whitespace.
+
+ This Getter will return [] for non-string inputs. It will however
+ split any strings in an input list, discarding any elements that
+ are not strings.
+ """
+
+ name = "split_words"
+
+ def _can_transform(self, val):
+ return isinstance(val, basestring)
+
+ def transform(self, values):
+ if not values:
+ return []
+ result = set()
+ for value in values:
+ if self._can_transform(value):
+ for word in value.split():
+ result.add(word)
+ return list(result)
+
+
+class Combine(Transformation):
+ """Combine multiple expressions into a single index."""
+
+ name = "combine"
+ # variable number of args
+ arity = -1
+
+ def __init__(self, *inner):
+ super(Combine, self).__init__(inner)
+
+ def get(self, raw_doc):
+ inner_values = []
+ for inner in self.inner:
+ inner_values.extend(inner.get(raw_doc))
+ return self.transform(inner_values)
+
+ def transform(self, values):
+ return values
+
+
+class IsNull(Transformation):
+ """Indicate whether the input is None.
+
+ This Getter returns a bool indicating whether the input is nil.
+ """
+
+ name = "is_null"
+
+ def transform(self, values):
+ return [len(values) == 0]
+
+
+def check_fieldname(fieldname):
+ if fieldname.endswith('.'):
+ raise errors.IndexDefinitionParseError(
+ "Fieldname cannot end in '.':%s^" % (fieldname,))
+
+
+class Parser(object):
+ """Parse an index expression into a sequence of transformations."""
+
+ _transformations = {}
+ _delimiters = re.compile("\(|\)|,")
+
+ def __init__(self):
+ self._tokens = []
+
+ def _set_expression(self, expression):
+ self._open_parens = 0
+ self._tokens = []
+ expression = expression.strip()
+ while expression:
+ delimiter = self._delimiters.search(expression)
+ if delimiter:
+ idx = delimiter.start()
+ if idx == 0:
+ result, expression = (expression[:1], expression[1:])
+ self._tokens.append(result)
+ else:
+ result, expression = (expression[:idx], expression[idx:])
+ result = result.strip()
+ if result:
+ self._tokens.append(result)
+ else:
+ expression = expression.strip()
+ if expression:
+ self._tokens.append(expression)
+ expression = None
+
+ def _get_token(self):
+ if self._tokens:
+ return self._tokens.pop(0)
+
+ def _peek_token(self):
+ if self._tokens:
+ return self._tokens[0]
+
+ @staticmethod
+ def _to_getter(term):
+ if isinstance(term, Getter):
+ return term
+ check_fieldname(term)
+ return ExtractField(term)
+
+ def _parse_op(self, op_name):
+ self._get_token() # '('
+ op = self._transformations.get(op_name, None)
+ if op is None:
+ raise errors.IndexDefinitionParseError(
+ "Unknown operation: %s" % op_name)
+ args = []
+ while True:
+ args.append(self._parse_term())
+ sep = self._get_token()
+ if sep == ')':
+ break
+ if sep != ',':
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' in parentheses." % (sep,))
+ parsed = []
+ for i, arg in enumerate(args):
+ arg_type = op.args[i % len(op.args)]
+ if arg_type == 'expression':
+ inner = self._to_getter(arg)
+ else:
+ try:
+ inner = arg_type(arg)
+ except ValueError, e:
+ raise errors.IndexDefinitionParseError(
+ "Invalid value %r for argument type %r "
+ "(%r)." % (arg, arg_type, e))
+ parsed.append(inner)
+ return op(*parsed)
+
+ def _parse_term(self):
+ term = self._get_token()
+ if term is None:
+ raise errors.IndexDefinitionParseError(
+ "Unexpected end of index definition.")
+ if term in (',', ')', '('):
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' at start of expression." % (term,))
+ next_token = self._peek_token()
+ if next_token == '(':
+ return self._parse_op(term)
+ return term
+
+ def parse(self, expression):
+ self._set_expression(expression)
+ term = self._to_getter(self._parse_term())
+ if self._peek_token():
+ raise errors.IndexDefinitionParseError(
+ "Unexpected token '%s' after end of expression."
+ % (self._peek_token(),))
+ return term
+
+ def parse_all(self, fields):
+ return [self.parse(field) for field in fields]
+
+ @classmethod
+ def register_transormation(cls, transform):
+ assert transform.name not in cls._transformations, (
+ "Transform %s already registered for %s"
+ % (transform.name, cls._transformations[transform.name]))
+ cls._transformations[transform.name] = transform
+
+
+Parser.register_transormation(SplitWords)
+Parser.register_transormation(Lower)
+Parser.register_transormation(Number)
+Parser.register_transormation(Bool)
+Parser.register_transormation(IsNull)
+Parser.register_transormation(Combine)
diff --git a/src/leap/soledad/u1db/remote/__init__.py b/src/leap/soledad/u1db/remote/__init__.py
new file mode 100644
index 00000000..3f32e381
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
diff --git a/src/leap/soledad/u1db/remote/basic_auth_middleware.py b/src/leap/soledad/u1db/remote/basic_auth_middleware.py
new file mode 100644
index 00000000..a2cbff62
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/basic_auth_middleware.py
@@ -0,0 +1,68 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+"""U1DB Basic Auth authorisation WSGI middleware."""
+import httplib
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from wsgiref.util import shift_path_info
+
+
+class Unauthorized(Exception):
+ """User authorization failed."""
+
+
+class BasicAuthMiddleware(object):
+ """U1DB Basic Auth Authorisation WSGI middleware."""
+
+ def __init__(self, app, prefix):
+ self.app = app
+ self.prefix = prefix
+
+ def _error(self, start_response, status, description, message=None):
+ start_response("%d %s" % (status, httplib.responses[status]),
+ [('content-type', 'application/json')])
+ err = {"error": description}
+ if message:
+ err['message'] = message
+ return [json.dumps(err)]
+
+ def __call__(self, environ, start_response):
+ if self.prefix and not environ['PATH_INFO'].startswith(self.prefix):
+ return self._error(start_response, 400, "bad request")
+ auth = environ.get('HTTP_AUTHORIZATION')
+ if not auth:
+ return self._error(start_response, 401, "unauthorized",
+ "Missing Basic Authentication.")
+ scheme, encoded = auth.split(None, 1)
+ if scheme.lower() != 'basic':
+ return self._error(
+ start_response, 401, "unauthorized",
+ "Missing Basic Authentication")
+ user, password = encoded.decode('base64').split(':', 1)
+ try:
+ self.verify_user(environ, user, password)
+ except Unauthorized:
+ return self._error(
+ start_response, 401, "unauthorized",
+ "Incorrect password or login.")
+ del environ['HTTP_AUTHORIZATION']
+ shift_path_info(environ)
+ return self.app(environ, start_response)
+
+ def verify_user(self, environ, username, password):
+ raise NotImplementedError(self.verify_user)
diff --git a/src/leap/soledad/u1db/remote/http_app.py b/src/leap/soledad/u1db/remote/http_app.py
new file mode 100644
index 00000000..3d7d4248
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/http_app.py
@@ -0,0 +1,629 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""HTTP Application exposing U1DB."""
+
+import functools
+import httplib
+import inspect
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import sys
+import urlparse
+
+import routes.mapper
+
+from u1db import (
+ __version__ as _u1db_version,
+ DBNAME_CONSTRAINTS,
+ Document,
+ errors,
+ sync,
+ )
+from u1db.remote import (
+ http_errors,
+ utils,
+ )
+
+
+def parse_bool(expression):
+ """Parse boolean querystring parameter."""
+ if expression == 'true':
+ return True
+ return False
+
+
+def parse_list(expression):
+ if expression is None:
+ return []
+ return [t.strip() for t in expression.split(',')]
+
+
+def none_or_str(expression):
+ if expression is None:
+ return None
+ return str(expression)
+
+
+class BadRequest(Exception):
+ """Bad request."""
+
+
+class _FencedReader(object):
+ """Read and get lines from a file but not past a given length."""
+
+ MAXCHUNK = 8192
+
+ def __init__(self, rfile, total, max_entry_size):
+ self.rfile = rfile
+ self.remaining = total
+ self.max_entry_size = max_entry_size
+ self._kept = None
+
+ def read_chunk(self, atmost):
+ if self._kept is not None:
+ # ignore atmost, kept data should be a subchunk anyway
+ kept, self._kept = self._kept, None
+ return kept
+ if self.remaining == 0:
+ return ''
+ data = self.rfile.read(min(self.remaining, atmost))
+ self.remaining -= len(data)
+ return data
+
+ def getline(self):
+ line_parts = []
+ size = 0
+ while True:
+ chunk = self.read_chunk(self.MAXCHUNK)
+ if chunk == '':
+ break
+ nl = chunk.find("\n")
+ if nl != -1:
+ size += nl + 1
+ if size > self.max_entry_size:
+ raise BadRequest
+ line_parts.append(chunk[:nl + 1])
+ rest = chunk[nl + 1:]
+ self._kept = rest or None
+ break
+ else:
+ size += len(chunk)
+ if size > self.max_entry_size:
+ raise BadRequest
+ line_parts.append(chunk)
+ return ''.join(line_parts)
+
+
+def http_method(**control):
+ """Decoration for handling of query arguments and content for a HTTP
+ method.
+
+ args and content here are the query arguments and body of the incoming
+ HTTP requests.
+
+ Match query arguments to python method arguments:
+ w = http_method()(f)
+ w(self, args, content) => args["content"]=content;
+ f(self, **args)
+
+ JSON deserialize content to arguments:
+ w = http_method(content_as_args=True,...)(f)
+ w(self, args, content) => args.update(json.loads(content));
+ f(self, **args)
+
+ Support conversions (e.g int):
+ w = http_method(Arg=Conv,...)(f)
+ w(self, args, content) => args["Arg"]=Conv(args["Arg"]);
+ f(self, **args)
+
+ Enforce no use of query arguments:
+ w = http_method(no_query=True,...)(f)
+ w(self, args, content) raises BadRequest if args is not empty
+
+ Argument mismatches, deserialisation failures produce BadRequest.
+ """
+ content_as_args = control.pop('content_as_args', False)
+ no_query = control.pop('no_query', False)
+ conversions = control.items()
+
+ def wrap(f):
+ argspec = inspect.getargspec(f)
+ assert argspec.args[0] == "self"
+ nargs = len(argspec.args)
+ ndefaults = len(argspec.defaults or ())
+ required_args = set(argspec.args[1:nargs - ndefaults])
+ all_args = set(argspec.args)
+
+ @functools.wraps(f)
+ def wrapper(self, args, content):
+ if no_query and args:
+ raise BadRequest()
+ if content is not None:
+ if content_as_args:
+ try:
+ args.update(json.loads(content))
+ except ValueError:
+ raise BadRequest()
+ else:
+ args["content"] = content
+ if not (required_args <= set(args) <= all_args):
+ raise BadRequest("Missing required arguments.")
+ for name, conv in conversions:
+ if name not in args:
+ continue
+ try:
+ args[name] = conv(args[name])
+ except ValueError:
+ raise BadRequest()
+ return f(self, **args)
+
+ return wrapper
+
+ return wrap
+
+
+class URLToResource(object):
+ """Mappings from URLs to resources."""
+
+ def __init__(self):
+ self._map = routes.mapper.Mapper(controller_scan=None)
+
+ def register(self, resource_cls):
+ # register
+ self._map.connect(None, resource_cls.url_pattern,
+ resource_cls=resource_cls,
+ requirements={"dbname": DBNAME_CONSTRAINTS})
+ self._map.create_regs()
+ return resource_cls
+
+ def match(self, path):
+ params = self._map.match(path)
+ if params is None:
+ return None, None
+ resource_cls = params.pop('resource_cls')
+ return resource_cls, params
+
+url_to_resource = URLToResource()
+
+
+@url_to_resource.register
+class GlobalResource(object):
+ """Global (root) resource."""
+
+ url_pattern = "/"
+
+ def __init__(self, state, responder):
+ self.responder = responder
+
+ @http_method()
+ def get(self):
+ self.responder.send_response_json(version=_u1db_version)
+
+
+@url_to_resource.register
+class DatabaseResource(object):
+ """Database resource."""
+
+ url_pattern = "/{dbname}"
+
+ def __init__(self, dbname, state, responder):
+ self.dbname = dbname
+ self.state = state
+ self.responder = responder
+
+ @http_method()
+ def get(self):
+ self.state.check_database(self.dbname)
+ self.responder.send_response_json(200)
+
+ @http_method(content_as_args=True)
+ def put(self):
+ self.state.ensure_database(self.dbname)
+ self.responder.send_response_json(200, ok=True)
+
+ @http_method()
+ def delete(self):
+ self.state.delete_database(self.dbname)
+ self.responder.send_response_json(200, ok=True)
+
+
+@url_to_resource.register
+class DocsResource(object):
+ """Documents resource."""
+
+ url_pattern = "/{dbname}/docs"
+
+ def __init__(self, dbname, state, responder):
+ self.responder = responder
+ self.db = state.open_database(dbname)
+
+ @http_method(doc_ids=parse_list, check_for_conflicts=parse_bool,
+ include_deleted=parse_bool)
+ def get(self, doc_ids=None, check_for_conflicts=True,
+ include_deleted=False):
+ if doc_ids is None:
+ raise errors.MissingDocIds
+ docs = self.db.get_docs(doc_ids, include_deleted=include_deleted)
+ self.responder.content_type = 'application/json'
+ self.responder.start_response(200)
+ self.responder.start_stream(),
+ for doc in docs:
+ entry = dict(
+ doc_id=doc.doc_id, doc_rev=doc.rev, content=doc.get_json(),
+ has_conflicts=doc.has_conflicts)
+ self.responder.stream_entry(entry)
+ self.responder.end_stream()
+ self.responder.finish_response()
+
+
+@url_to_resource.register
+class DocResource(object):
+ """Document resource."""
+
+ url_pattern = "/{dbname}/doc/{id:.*}"
+
+ def __init__(self, dbname, id, state, responder):
+ self.id = id
+ self.responder = responder
+ self.db = state.open_database(dbname)
+
+ @http_method(old_rev=str)
+ def put(self, content, old_rev=None):
+ doc = Document(self.id, old_rev, content)
+ doc_rev = self.db.put_doc(doc)
+ if old_rev is None:
+ status = 201 # created
+ else:
+ status = 200
+ self.responder.send_response_json(status, rev=doc_rev)
+
+ @http_method(old_rev=str)
+ def delete(self, old_rev=None):
+ doc = Document(self.id, old_rev, None)
+ self.db.delete_doc(doc)
+ self.responder.send_response_json(200, rev=doc.rev)
+
+ @http_method(include_deleted=parse_bool)
+ def get(self, include_deleted=False):
+ doc = self.db.get_doc(self.id, include_deleted=include_deleted)
+ if doc is None:
+ wire_descr = errors.DocumentDoesNotExist.wire_description
+ self.responder.send_response_json(
+ http_errors.wire_description_to_status[wire_descr],
+ error=wire_descr,
+ headers={
+ 'x-u1db-rev': '',
+ 'x-u1db-has-conflicts': 'false'
+ })
+ return
+ headers = {
+ 'x-u1db-rev': doc.rev,
+ 'x-u1db-has-conflicts': json.dumps(doc.has_conflicts)
+ }
+ if doc.is_tombstone():
+ self.responder.send_response_json(
+ http_errors.wire_description_to_status[
+ errors.DOCUMENT_DELETED],
+ error=errors.DOCUMENT_DELETED,
+ headers=headers)
+ else:
+ self.responder.send_response_content(
+ doc.get_json(), headers=headers)
+
+
+@url_to_resource.register
+class SyncResource(object):
+ """Sync endpoint resource."""
+
+ # maximum allowed request body size
+ max_request_size = 15 * 1024 * 1024 # 15Mb
+ # maximum allowed entry/line size in request body
+ max_entry_size = 10 * 1024 * 1024 # 10Mb
+
+ url_pattern = "/{dbname}/sync-from/{source_replica_uid}"
+
+ # pluggable
+ sync_exchange_class = sync.SyncExchange
+
+ def __init__(self, dbname, source_replica_uid, state, responder):
+ self.source_replica_uid = source_replica_uid
+ self.responder = responder
+ self.state = state
+ self.dbname = dbname
+ self.replica_uid = None
+
+ def get_target(self):
+ return self.state.open_database(self.dbname).get_sync_target()
+
+ @http_method()
+ def get(self):
+ result = self.get_target().get_sync_info(self.source_replica_uid)
+ self.responder.send_response_json(
+ target_replica_uid=result[0], target_replica_generation=result[1],
+ target_replica_transaction_id=result[2],
+ source_replica_uid=self.source_replica_uid,
+ source_replica_generation=result[3],
+ source_transaction_id=result[4])
+
+ @http_method(generation=int,
+ content_as_args=True, no_query=True)
+ def put(self, generation, transaction_id):
+ self.get_target().record_sync_info(self.source_replica_uid,
+ generation,
+ transaction_id)
+ self.responder.send_response_json(ok=True)
+
+ # Implements the same logic as LocalSyncTarget.sync_exchange
+
+ @http_method(last_known_generation=int, last_known_trans_id=none_or_str,
+ content_as_args=True)
+ def post_args(self, last_known_generation, last_known_trans_id=None,
+ ensure=False):
+ if ensure:
+ db, self.replica_uid = self.state.ensure_database(self.dbname)
+ else:
+ db = self.state.open_database(self.dbname)
+ db.validate_gen_and_trans_id(
+ last_known_generation, last_known_trans_id)
+ self.sync_exch = self.sync_exchange_class(
+ db, self.source_replica_uid, last_known_generation)
+
+ @http_method(content_as_args=True)
+ def post_stream_entry(self, id, rev, content, gen, trans_id):
+ doc = Document(id, rev, content)
+ self.sync_exch.insert_doc_from_source(doc, gen, trans_id)
+
+ def post_end(self):
+
+ def send_doc(doc, gen, trans_id):
+ entry = dict(id=doc.doc_id, rev=doc.rev, content=doc.get_json(),
+ gen=gen, trans_id=trans_id)
+ self.responder.stream_entry(entry)
+
+ new_gen = self.sync_exch.find_changes_to_return()
+ self.responder.content_type = 'application/x-u1db-sync-stream'
+ self.responder.start_response(200)
+ self.responder.start_stream(),
+ header = {"new_generation": new_gen,
+ "new_transaction_id": self.sync_exch.new_trans_id}
+ if self.replica_uid is not None:
+ header['replica_uid'] = self.replica_uid
+ self.responder.stream_entry(header)
+ self.sync_exch.return_docs(send_doc)
+ self.responder.end_stream()
+ self.responder.finish_response()
+
+
+class HTTPResponder(object):
+ """Encode responses from the server back to the client."""
+
+ # a multi document response will put args and documents
+ # each on one line of the response body
+
+ def __init__(self, start_response):
+ self._started = False
+ self._stream_state = -1
+ self._no_initial_obj = True
+ self.sent_response = False
+ self._start_response = start_response
+ self._write = None
+ self.content_type = 'application/json'
+ self.content = []
+
+ def start_response(self, status, obj_dic=None, headers={}):
+ """start sending response with optional first json object."""
+ if self._started:
+ return
+ self._started = True
+ status_text = httplib.responses[status]
+ self._write = self._start_response('%d %s' % (status, status_text),
+ [('content-type', self.content_type),
+ ('cache-control', 'no-cache')] +
+ headers.items())
+ # xxx version in headers
+ if obj_dic is not None:
+ self._no_initial_obj = False
+ self._write(json.dumps(obj_dic) + "\r\n")
+
+ def finish_response(self):
+ """finish sending response."""
+ self.sent_response = True
+
+ def send_response_json(self, status=200, headers={}, **kwargs):
+ """send and finish response with json object body from keyword args."""
+ content = json.dumps(kwargs) + "\r\n"
+ self.send_response_content(content, headers=headers, status=status)
+
+ def send_response_content(self, content, status=200, headers={}):
+ """send and finish response with content"""
+ headers['content-length'] = str(len(content))
+ self.start_response(status, headers=headers)
+ if self._stream_state == 1:
+ self.content = [',\r\n', content]
+ else:
+ self.content = [content]
+ self.finish_response()
+
+ def start_stream(self):
+ "start stream (array) as part of the response."
+ assert self._started and self._no_initial_obj
+ self._stream_state = 0
+ self._write("[")
+
+ def stream_entry(self, entry):
+ "send stream entry as part of the response."
+ assert self._stream_state != -1
+ if self._stream_state == 0:
+ self._stream_state = 1
+ self._write('\r\n')
+ else:
+ self._write(',\r\n')
+ self._write(json.dumps(entry))
+
+ def end_stream(self):
+ "end stream (array)."
+ assert self._stream_state != -1
+ self._write("\r\n]\r\n")
+
+
+class HTTPInvocationByMethodWithBody(object):
+ """Invoke methods on a resource."""
+
+ def __init__(self, resource, environ, parameters):
+ self.resource = resource
+ self.environ = environ
+ self.max_request_size = getattr(
+ resource, 'max_request_size', parameters.max_request_size)
+ self.max_entry_size = getattr(
+ resource, 'max_entry_size', parameters.max_entry_size)
+
+ def _lookup(self, method):
+ try:
+ return getattr(self.resource, method)
+ except AttributeError:
+ raise BadRequest()
+
+ def __call__(self):
+ args = urlparse.parse_qsl(self.environ['QUERY_STRING'],
+ strict_parsing=False)
+ try:
+ args = dict(
+ (k.decode('utf-8'), v.decode('utf-8')) for k, v in args)
+ except ValueError:
+ raise BadRequest()
+ method = self.environ['REQUEST_METHOD'].lower()
+ if method in ('get', 'delete'):
+ meth = self._lookup(method)
+ return meth(args, None)
+ else:
+ # we expect content-length > 0, reconsider if we move
+ # to support chunked enconding
+ try:
+ content_length = int(self.environ['CONTENT_LENGTH'])
+ except (ValueError, KeyError):
+ raise BadRequest
+ if content_length <= 0:
+ raise BadRequest
+ if content_length > self.max_request_size:
+ raise BadRequest
+ reader = _FencedReader(self.environ['wsgi.input'], content_length,
+ self.max_entry_size)
+ content_type = self.environ.get('CONTENT_TYPE')
+ if content_type == 'application/json':
+ meth = self._lookup(method)
+ body = reader.read_chunk(sys.maxint)
+ return meth(args, body)
+ elif content_type == 'application/x-u1db-sync-stream':
+ meth_args = self._lookup('%s_args' % method)
+ meth_entry = self._lookup('%s_stream_entry' % method)
+ meth_end = self._lookup('%s_end' % method)
+ body_getline = reader.getline
+ if body_getline().strip() != '[':
+ raise BadRequest()
+ line = body_getline()
+ line, comma = utils.check_and_strip_comma(line.strip())
+ meth_args(args, line)
+ while True:
+ line = body_getline()
+ entry = line.strip()
+ if entry == ']':
+ break
+ if not entry or not comma: # empty or no prec comma
+ raise BadRequest
+ entry, comma = utils.check_and_strip_comma(entry)
+ meth_entry({}, entry)
+ if comma or body_getline(): # extra comma or data
+ raise BadRequest
+ return meth_end()
+ else:
+ raise BadRequest()
+
+
+class HTTPApp(object):
+
+ # maximum allowed request body size
+ max_request_size = 15 * 1024 * 1024 # 15Mb
+ # maximum allowed entry/line size in request body
+ max_entry_size = 10 * 1024 * 1024 # 10Mb
+
+ def __init__(self, state):
+ self.state = state
+
+ def _lookup_resource(self, environ, responder):
+ resource_cls, params = url_to_resource.match(environ['PATH_INFO'])
+ if resource_cls is None:
+ raise BadRequest # 404 instead?
+ resource = resource_cls(
+ state=self.state, responder=responder, **params)
+ return resource
+
+ def __call__(self, environ, start_response):
+ responder = HTTPResponder(start_response)
+ self.request_begin(environ)
+ try:
+ resource = self._lookup_resource(environ, responder)
+ HTTPInvocationByMethodWithBody(resource, environ, self)()
+ except errors.U1DBError, e:
+ self.request_u1db_error(environ, e)
+ status = http_errors.wire_description_to_status.get(
+ e.wire_description, 500)
+ responder.send_response_json(status, error=e.wire_description)
+ except BadRequest:
+ self.request_bad_request(environ)
+ responder.send_response_json(400, error="bad request")
+ except KeyboardInterrupt:
+ raise
+ except:
+ self.request_failed(environ)
+ raise
+ else:
+ self.request_done(environ)
+ return responder.content
+
+ # hooks for tracing requests
+
+ def request_begin(self, environ):
+ """Hook called at the beginning of processing a request."""
+ pass
+
+ def request_done(self, environ):
+ """Hook called when done processing a request."""
+ pass
+
+ def request_u1db_error(self, environ, exc):
+ """Hook called when processing a request resulted in a U1DBError.
+
+ U1DBError passed as exc.
+ """
+ pass
+
+ def request_bad_request(self, environ):
+ """Hook called when processing a bad request.
+
+ No actual processing was done.
+ """
+ pass
+
+ def request_failed(self, environ):
+ """Hook called when processing a request failed unexpectedly.
+
+ Invoked from an except block, so there's interpreter exception
+ information available.
+ """
+ pass
diff --git a/src/leap/soledad/u1db/remote/http_client.py b/src/leap/soledad/u1db/remote/http_client.py
new file mode 100644
index 00000000..decddda3
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/http_client.py
@@ -0,0 +1,218 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Base class to make requests to a remote HTTP server."""
+
+import httplib
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import socket
+import ssl
+import sys
+import urlparse
+import urllib
+
+from time import sleep
+from u1db import (
+ errors,
+ )
+from u1db.remote import (
+ http_errors,
+ )
+
+from u1db.remote.ssl_match_hostname import ( # noqa
+ CertificateError,
+ match_hostname,
+ )
+
+# Ubuntu/debian
+# XXX other...
+CA_CERTS = "/etc/ssl/certs/ca-certificates.crt"
+
+
+def _encode_query_parameter(value):
+ """Encode query parameter."""
+ if isinstance(value, bool):
+ if value:
+ value = 'true'
+ else:
+ value = 'false'
+ return unicode(value).encode('utf-8')
+
+
+class _VerifiedHTTPSConnection(httplib.HTTPSConnection):
+ """HTTPSConnection verifying server side certificates."""
+ # derived from httplib.py
+
+ def connect(self):
+ "Connect to a host on a given (SSL) port."
+
+ sock = socket.create_connection((self.host, self.port),
+ self.timeout, self.source_address)
+ if self._tunnel_host:
+ self.sock = sock
+ self._tunnel()
+ if sys.platform.startswith('linux'):
+ cert_opts = {
+ 'cert_reqs': ssl.CERT_REQUIRED,
+ 'ca_certs': CA_CERTS
+ }
+ else:
+ # XXX no cert verification implemented elsewhere for now
+ cert_opts = {}
+ self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
+ ssl_version=ssl.PROTOCOL_SSLv3,
+ **cert_opts
+ )
+ if cert_opts:
+ match_hostname(self.sock.getpeercert(), self.host)
+
+
+class HTTPClientBase(object):
+ """Base class to make requests to a remote HTTP server."""
+
+ # by default use HMAC-SHA1 OAuth signature method to not disclose
+ # tokens
+ # NB: given that the content bodies are not covered by the
+ # signatures though, to achieve security (against man-in-the-middle
+ # attacks for example) one would need HTTPS
+ oauth_signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1()
+
+ # Will use these delays to retry on 503 befor finally giving up. The final
+ # 0 is there to not wait after the final try fails.
+ _delays = (1, 1, 2, 4, 0)
+
+ def __init__(self, url, creds=None):
+ self._url = urlparse.urlsplit(url)
+ self._conn = None
+ self._creds = {}
+ if creds is not None:
+ if len(creds) != 1:
+ raise errors.UnknownAuthMethod()
+ auth_meth, credentials = creds.items()[0]
+ try:
+ set_creds = getattr(self, 'set_%s_credentials' % auth_meth)
+ except AttributeError:
+ raise errors.UnknownAuthMethod(auth_meth)
+ set_creds(**credentials)
+
+ def set_oauth_credentials(self, consumer_key, consumer_secret,
+ token_key, token_secret):
+ self._creds = {'oauth': (
+ oauth.OAuthConsumer(consumer_key, consumer_secret),
+ oauth.OAuthToken(token_key, token_secret))}
+
+ def _ensure_connection(self):
+ if self._conn is not None:
+ return
+ if self._url.scheme == 'https':
+ connClass = _VerifiedHTTPSConnection
+ else:
+ connClass = httplib.HTTPConnection
+ self._conn = connClass(self._url.hostname, self._url.port)
+
+ def close(self):
+ if self._conn:
+ self._conn.close()
+ self._conn = None
+
+ # xxx retry mechanism?
+
+ def _error(self, respdic):
+ descr = respdic.get("error")
+ exc_cls = errors.wire_description_to_exc.get(descr)
+ if exc_cls is not None:
+ message = respdic.get("message")
+ raise exc_cls(message)
+
+ def _response(self):
+ resp = self._conn.getresponse()
+ body = resp.read()
+ headers = dict(resp.getheaders())
+ if resp.status in (200, 201):
+ return body, headers
+ elif resp.status in http_errors.ERROR_STATUSES:
+ try:
+ respdic = json.loads(body)
+ except ValueError:
+ pass
+ else:
+ self._error(respdic)
+ # special case
+ if resp.status == 503:
+ raise errors.Unavailable(body, headers)
+ raise errors.HTTPError(resp.status, body, headers)
+
+ def _sign_request(self, method, url_query, params):
+ if 'oauth' in self._creds:
+ consumer, token = self._creds['oauth']
+ full_url = "%s://%s%s" % (self._url.scheme, self._url.netloc,
+ url_query)
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ consumer, token,
+ http_method=method,
+ parameters=params,
+ http_url=full_url
+ )
+ oauth_req.sign_request(
+ self.oauth_signature_method, consumer, token)
+ # Authorization: OAuth ...
+ return oauth_req.to_header().items()
+ else:
+ return []
+
+ def _request(self, method, url_parts, params=None, body=None,
+ content_type=None):
+ self._ensure_connection()
+ unquoted_url = url_query = self._url.path
+ if url_parts:
+ if not url_query.endswith('/'):
+ url_query += '/'
+ unquoted_url = url_query
+ url_query += '/'.join(urllib.quote(part, safe='')
+ for part in url_parts)
+ # oauth performs its own quoting
+ unquoted_url += '/'.join(url_parts)
+ encoded_params = {}
+ if params:
+ for key, value in params.items():
+ key = unicode(key).encode('utf-8')
+ encoded_params[key] = _encode_query_parameter(value)
+ url_query += ('?' + urllib.urlencode(encoded_params))
+ if body is not None and not isinstance(body, basestring):
+ body = json.dumps(body)
+ content_type = 'application/json'
+ headers = {}
+ if content_type:
+ headers['content-type'] = content_type
+ headers.update(
+ self._sign_request(method, unquoted_url, encoded_params))
+ for delay in self._delays:
+ try:
+ self._conn.request(method, url_query, body, headers)
+ return self._response()
+ except errors.Unavailable, e:
+ sleep(delay)
+ raise e
+
+ def _request_json(self, method, url_parts, params=None, body=None,
+ content_type=None):
+ res, headers = self._request(method, url_parts, params, body,
+ content_type)
+ return json.loads(res), headers
diff --git a/src/leap/soledad/u1db/remote/http_database.py b/src/leap/soledad/u1db/remote/http_database.py
new file mode 100644
index 00000000..6901baad
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/http_database.py
@@ -0,0 +1,143 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""HTTPDatabase to access a remote db over the HTTP API."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import uuid
+
+from u1db import (
+ Database,
+ Document,
+ errors,
+ )
+from u1db.remote import (
+ http_client,
+ http_errors,
+ http_target,
+ )
+
+
+DOCUMENT_DELETED_STATUS = http_errors.wire_description_to_status[
+ errors.DOCUMENT_DELETED]
+
+
+class HTTPDatabase(http_client.HTTPClientBase, Database):
+ """Implement the Database API to a remote HTTP server."""
+
+ def __init__(self, url, document_factory=None, creds=None):
+ super(HTTPDatabase, self).__init__(url, creds=creds)
+ self._factory = document_factory or Document
+
+ def set_document_factory(self, factory):
+ self._factory = factory
+
+ @staticmethod
+ def open_database(url, create):
+ db = HTTPDatabase(url)
+ db.open(create)
+ return db
+
+ @staticmethod
+ def delete_database(url):
+ db = HTTPDatabase(url)
+ db._delete()
+ db.close()
+
+ def open(self, create):
+ if create:
+ self._ensure()
+ else:
+ self._check()
+
+ def _check(self):
+ return self._request_json('GET', [])[0]
+
+ def _ensure(self):
+ self._request_json('PUT', [], {}, {})
+
+ def _delete(self):
+ self._request_json('DELETE', [], {}, {})
+
+ def put_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ params = {}
+ if doc.rev is not None:
+ params['old_rev'] = doc.rev
+ res, headers = self._request_json('PUT', ['doc', doc.doc_id], params,
+ doc.get_json(), 'application/json')
+ doc.rev = res['rev']
+ return res['rev']
+
+ def get_doc(self, doc_id, include_deleted=False):
+ try:
+ res, headers = self._request(
+ 'GET', ['doc', doc_id], {"include_deleted": include_deleted})
+ except errors.DocumentDoesNotExist:
+ return None
+ except errors.HTTPError, e:
+ if (e.status == DOCUMENT_DELETED_STATUS and
+ 'x-u1db-rev' in e.headers):
+ res = None
+ headers = e.headers
+ else:
+ raise
+ doc_rev = headers['x-u1db-rev']
+ has_conflicts = json.loads(headers['x-u1db-has-conflicts'])
+ doc = self._factory(doc_id, doc_rev, res)
+ doc.has_conflicts = has_conflicts
+ return doc
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ if not doc_ids:
+ return
+ doc_ids = ','.join(doc_ids)
+ res, headers = self._request(
+ 'GET', ['docs'], {
+ "doc_ids": doc_ids, "include_deleted": include_deleted,
+ "check_for_conflicts": check_for_conflicts})
+ for doc_dict in json.loads(res):
+ doc = self._factory(
+ doc_dict['doc_id'], doc_dict['doc_rev'], doc_dict['content'])
+ doc.has_conflicts = doc_dict['has_conflicts']
+ yield doc
+
+ def create_doc_from_json(self, content, doc_id=None):
+ if doc_id is None:
+ doc_id = 'D-%s' % (uuid.uuid4().hex,)
+ res, headers = self._request_json('PUT', ['doc', doc_id], {},
+ content, 'application/json')
+ new_doc = self._factory(doc_id, res['rev'], content)
+ return new_doc
+
+ def delete_doc(self, doc):
+ if doc.doc_id is None:
+ raise errors.InvalidDocId()
+ params = {'old_rev': doc.rev}
+ res, headers = self._request_json('DELETE',
+ ['doc', doc.doc_id], params)
+ doc.make_tombstone()
+ doc.rev = res['rev']
+
+ def get_sync_target(self):
+ st = http_target.HTTPSyncTarget(self._url.geturl())
+ st._creds = self._creds
+ return st
diff --git a/src/leap/soledad/u1db/remote/http_errors.py b/src/leap/soledad/u1db/remote/http_errors.py
new file mode 100644
index 00000000..2039c5b2
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/http_errors.py
@@ -0,0 +1,46 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Information about the encoding of errors over HTTP."""
+
+from u1db import (
+ errors,
+ )
+
+
+# error wire descriptions mapping to HTTP status codes
+wire_description_to_status = dict([
+ (errors.InvalidDocId.wire_description, 400),
+ (errors.MissingDocIds.wire_description, 400),
+ (errors.Unauthorized.wire_description, 401),
+ (errors.DocumentTooBig.wire_description, 403),
+ (errors.UserQuotaExceeded.wire_description, 403),
+ (errors.SubscriptionNeeded.wire_description, 403),
+ (errors.DatabaseDoesNotExist.wire_description, 404),
+ (errors.DocumentDoesNotExist.wire_description, 404),
+ (errors.DocumentAlreadyDeleted.wire_description, 404),
+ (errors.RevisionConflict.wire_description, 409),
+ (errors.InvalidGeneration.wire_description, 409),
+ (errors.InvalidTransactionId.wire_description, 409),
+ (errors.Unavailable.wire_description, 503),
+# without matching exception
+ (errors.DOCUMENT_DELETED, 404)
+])
+
+
+ERROR_STATUSES = set(wire_description_to_status.values())
+# 400 included explicitly for tests
+ERROR_STATUSES.add(400)
diff --git a/src/leap/soledad/u1db/remote/http_target.py b/src/leap/soledad/u1db/remote/http_target.py
new file mode 100644
index 00000000..1028963e
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/http_target.py
@@ -0,0 +1,135 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""SyncTarget API implementation to a remote HTTP server."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ Document,
+ SyncTarget,
+ )
+from u1db.errors import (
+ BrokenSyncStream,
+ )
+from u1db.remote import (
+ http_client,
+ utils,
+ )
+
+
+class HTTPSyncTarget(http_client.HTTPClientBase, SyncTarget):
+ """Implement the SyncTarget api to a remote HTTP server."""
+
+ @staticmethod
+ def connect(url):
+ return HTTPSyncTarget(url)
+
+ def get_sync_info(self, source_replica_uid):
+ self._ensure_connection()
+ res, _ = self._request_json('GET', ['sync-from', source_replica_uid])
+ return (res['target_replica_uid'], res['target_replica_generation'],
+ res['target_replica_transaction_id'],
+ res['source_replica_generation'], res['source_transaction_id'])
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_transaction_id):
+ self._ensure_connection()
+ if self._trace_hook: # for tests
+ self._trace_hook('record_sync_info')
+ self._request_json('PUT', ['sync-from', source_replica_uid], {},
+ {'generation': source_replica_generation,
+ 'transaction_id': source_transaction_id})
+
+ def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None):
+ parts = data.splitlines() # one at a time
+ if not parts or parts[0] != '[':
+ raise BrokenSyncStream
+ data = parts[1:-1]
+ comma = False
+ if data:
+ line, comma = utils.check_and_strip_comma(data[0])
+ res = json.loads(line)
+ if ensure_callback and 'replica_uid' in res:
+ ensure_callback(res['replica_uid'])
+ for entry in data[1:]:
+ if not comma: # missing in between comma
+ raise BrokenSyncStream
+ line, comma = utils.check_and_strip_comma(entry)
+ entry = json.loads(line)
+ doc = Document(entry['id'], entry['rev'], entry['content'])
+ return_doc_cb(doc, entry['gen'], entry['trans_id'])
+ if parts[-1] != ']':
+ try:
+ partdic = json.loads(parts[-1])
+ except ValueError:
+ pass
+ else:
+ if isinstance(partdic, dict):
+ self._error(partdic)
+ raise BrokenSyncStream
+ if not data or comma: # no entries or bad extra comma
+ raise BrokenSyncStream
+ return res
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ self._ensure_connection()
+ if self._trace_hook: # for tests
+ self._trace_hook('sync_exchange')
+ url = '%s/sync-from/%s' % (self._url.path, source_replica_uid)
+ self._conn.putrequest('POST', url)
+ self._conn.putheader('content-type', 'application/x-u1db-sync-stream')
+ for header_name, header_value in self._sign_request('POST', url, {}):
+ self._conn.putheader(header_name, header_value)
+ entries = ['[']
+ size = 1
+
+ def prepare(**dic):
+ entry = comma + '\r\n' + json.dumps(dic)
+ entries.append(entry)
+ return len(entry)
+
+ comma = ''
+ size += prepare(
+ last_known_generation=last_known_generation,
+ last_known_trans_id=last_known_trans_id,
+ ensure=ensure_callback is not None)
+ comma = ','
+ for doc, gen, trans_id in docs_by_generations:
+ size += prepare(id=doc.doc_id, rev=doc.rev, content=doc.get_json(),
+ gen=gen, trans_id=trans_id)
+ entries.append('\r\n]')
+ size += len(entries[-1])
+ self._conn.putheader('content-length', str(size))
+ self._conn.endheaders()
+ for entry in entries:
+ self._conn.send(entry)
+ entries = None
+ data, _ = self._response()
+ res = self._parse_sync_stream(data, return_doc_cb, ensure_callback)
+ data = None
+ return res['new_generation'], res['new_transaction_id']
+
+ # for tests
+ _trace_hook = None
+
+ def _set_trace_hook_shallow(self, cb):
+ self._trace_hook = cb
diff --git a/src/leap/soledad/u1db/remote/oauth_middleware.py b/src/leap/soledad/u1db/remote/oauth_middleware.py
new file mode 100644
index 00000000..5772580a
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/oauth_middleware.py
@@ -0,0 +1,89 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+"""U1DB OAuth authorisation WSGI middleware."""
+import httplib
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from urllib import quote
+from wsgiref.util import shift_path_info
+
+
+sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
+sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
+
+
+class OAuthMiddleware(object):
+ """U1DB OAuth Authorisation WSGI middleware."""
+
+ # max seconds the request timestamp is allowed to be shifted
+ # from arrival time
+ timestamp_threshold = 300
+
+ def __init__(self, app, base_url, prefix='/~/'):
+ self.app = app
+ self.base_url = base_url
+ self.prefix = prefix
+
+ def get_oauth_data_store(self):
+ """Provide a oauth.OAuthDataStore."""
+ raise NotImplementedError(self.get_oauth_data_store)
+
+ def _error(self, start_response, status, description, message=None):
+ start_response("%d %s" % (status, httplib.responses[status]),
+ [('content-type', 'application/json')])
+ err = {"error": description}
+ if message:
+ err['message'] = message
+ return [json.dumps(err)]
+
+ def __call__(self, environ, start_response):
+ if self.prefix and not environ['PATH_INFO'].startswith(self.prefix):
+ return self._error(start_response, 400, "bad request")
+ headers = {}
+ if 'HTTP_AUTHORIZATION' in environ:
+ headers['Authorization'] = environ['HTTP_AUTHORIZATION']
+ oauth_req = oauth.OAuthRequest.from_request(
+ http_method=environ['REQUEST_METHOD'],
+ http_url=self.base_url + environ['PATH_INFO'],
+ headers=headers,
+ query_string=environ['QUERY_STRING']
+ )
+ if oauth_req is None:
+ return self._error(start_response, 401, "unauthorized",
+ "Missing OAuth.")
+ try:
+ self.verify(environ, oauth_req)
+ except oauth.OAuthError, e:
+ return self._error(start_response, 401, "unauthorized",
+ e.message)
+ shift_path_info(environ)
+ return self.app(environ, start_response)
+
+ def verify(self, environ, oauth_req):
+ """Verify OAuth request, put user_id in the environ."""
+ oauth_server = oauth.OAuthServer(self.get_oauth_data_store())
+ oauth_server.timestamp_threshold = self.timestamp_threshold
+ oauth_server.add_signature_method(sign_meth_HMAC_SHA1)
+ oauth_server.add_signature_method(sign_meth_PLAINTEXT)
+ consumer, token, parameters = oauth_server.verify_request(oauth_req)
+ # filter out oauth bits
+ environ['QUERY_STRING'] = '&'.join("%s=%s" % (quote(k, safe=''),
+ quote(v, safe=''))
+ for k, v in parameters.iteritems())
+ return consumer, token
diff --git a/src/leap/soledad/u1db/remote/server_state.py b/src/leap/soledad/u1db/remote/server_state.py
new file mode 100644
index 00000000..96581359
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/server_state.py
@@ -0,0 +1,67 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""State for servers exposing a set of U1DB databases."""
+import os
+import errno
+
+class ServerState(object):
+ """Passed to a Request when it is instantiated.
+
+ This is used to track server-side state, such as working-directory, open
+ databases, etc.
+ """
+
+ def __init__(self):
+ self._workingdir = None
+
+ def set_workingdir(self, path):
+ self._workingdir = path
+
+ def _relpath(self, relpath):
+ # Note: We don't want to allow absolute paths here, because we
+ # don't want to expose the filesystem. We should also check that
+ # relpath doesn't have '..' in it, etc.
+ return self._workingdir + '/' + relpath
+
+ def open_database(self, path):
+ """Open a database at the given location."""
+ from u1db.backends import sqlite_backend
+ full_path = self._relpath(path)
+ return sqlite_backend.SQLiteDatabase.open_database(full_path,
+ create=False)
+
+ def check_database(self, path):
+ """Check if the database at the given location exists.
+
+ Simply returns if it does or raises DatabaseDoesNotExist.
+ """
+ db = self.open_database(path)
+ db.close()
+
+ def ensure_database(self, path):
+ """Ensure database at the given location."""
+ from u1db.backends import sqlite_backend
+ full_path = self._relpath(path)
+ db = sqlite_backend.SQLiteDatabase.open_database(full_path,
+ create=True)
+ return db, db._replica_uid
+
+ def delete_database(self, path):
+ """Delete database at the given location."""
+ from u1db.backends import sqlite_backend
+ full_path = self._relpath(path)
+ sqlite_backend.SQLiteDatabase.delete_database(full_path)
diff --git a/src/leap/soledad/u1db/remote/ssl_match_hostname.py b/src/leap/soledad/u1db/remote/ssl_match_hostname.py
new file mode 100644
index 00000000..fbabc177
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/ssl_match_hostname.py
@@ -0,0 +1,64 @@
+"""The match_hostname() function from Python 3.2, essential when using SSL."""
+# XXX put it here until it's packaged
+
+import re
+
+__version__ = '3.2a3'
+
+
+class CertificateError(ValueError):
+ pass
+
+
+def _dnsname_to_pat(dn):
+ pats = []
+ for frag in dn.split(r'.'):
+ if frag == '*':
+ # When '*' is a fragment by itself, it matches a non-empty dotless
+ # fragment.
+ pats.append('[^.]+')
+ else:
+ # Otherwise, '*' matches any dotless fragment.
+ frag = re.escape(frag)
+ pats.append(frag.replace(r'\*', '[^.]*'))
+ return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
+
+
+def match_hostname(cert, hostname):
+ """Verify that *cert* (in decoded format as returned by
+ SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules
+ are mostly followed, but IP addresses are not accepted for *hostname*.
+
+ CertificateError is raised on failure. On success, the function
+ returns nothing.
+ """
+ if not cert:
+ raise ValueError("empty or no certificate")
+ dnsnames = []
+ san = cert.get('subjectAltName', ())
+ for key, value in san:
+ if key == 'DNS':
+ if _dnsname_to_pat(value).match(hostname):
+ return
+ dnsnames.append(value)
+ if not san:
+ # The subject is only checked when subjectAltName is empty
+ for sub in cert.get('subject', ()):
+ for key, value in sub:
+ # XXX according to RFC 2818, the most specific Common Name
+ # must be used.
+ if key == 'commonName':
+ if _dnsname_to_pat(value).match(hostname):
+ return
+ dnsnames.append(value)
+ if len(dnsnames) > 1:
+ raise CertificateError("hostname %r "
+ "doesn't match either of %s"
+ % (hostname, ', '.join(map(repr, dnsnames))))
+ elif len(dnsnames) == 1:
+ raise CertificateError("hostname %r "
+ "doesn't match %r"
+ % (hostname, dnsnames[0]))
+ else:
+ raise CertificateError("no appropriate commonName or "
+ "subjectAltName fields were found")
diff --git a/src/leap/soledad/u1db/remote/utils.py b/src/leap/soledad/u1db/remote/utils.py
new file mode 100644
index 00000000..14cedea9
--- /dev/null
+++ b/src/leap/soledad/u1db/remote/utils.py
@@ -0,0 +1,23 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Utilities for details of the procotol."""
+
+
+def check_and_strip_comma(line):
+ if line and line[-1] == ',':
+ return line[:-1], True
+ return line, False
diff --git a/src/leap/soledad/u1db/sync.py b/src/leap/soledad/u1db/sync.py
new file mode 100644
index 00000000..3375d097
--- /dev/null
+++ b/src/leap/soledad/u1db/sync.py
@@ -0,0 +1,304 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The synchronization utilities for U1DB."""
+from itertools import izip
+
+import u1db
+from u1db import errors
+
+
+class Synchronizer(object):
+ """Collect the state around synchronizing 2 U1DB replicas.
+
+ Synchronization is bi-directional, in that new items in the source are sent
+ to the target, and new items in the target are returned to the source.
+ However, it still recognizes that one side is initiating the request. Also,
+ at the moment, conflicts are only created in the source.
+ """
+
+ def __init__(self, source, sync_target):
+ """Create a new Synchronization object.
+
+ :param source: A Database
+ :param sync_target: A SyncTarget
+ """
+ self.source = source
+ self.sync_target = sync_target
+ self.target_replica_uid = None
+ self.num_inserted = 0
+
+ def _insert_doc_from_target(self, doc, replica_gen, trans_id):
+ """Try to insert synced document from target.
+
+ Implements TAKE OTHER semantics: any document from the target
+ that is in conflict will be taken as the new official value,
+ while the current conflicting value will be stored alongside
+ as a conflict. In the process indexes will be updated etc.
+
+ :return: None
+ """
+ # Increases self.num_inserted depending whether the document
+ # was effectively inserted.
+ state, _ = self.source._put_doc_if_newer(doc, save_conflict=True,
+ replica_uid=self.target_replica_uid, replica_gen=replica_gen,
+ replica_trans_id=trans_id)
+ if state == 'inserted':
+ self.num_inserted += 1
+ elif state == 'converged':
+ # magical convergence
+ pass
+ elif state == 'superseded':
+ # we have something newer, will be taken care of at the next sync
+ pass
+ else:
+ assert state == 'conflicted'
+ # The doc was saved as a conflict, so the database was updated
+ self.num_inserted += 1
+
+ def _record_sync_info_with_the_target(self, start_generation):
+ """Record our new after sync generation with the target if gapless.
+
+ Any documents received from the target will cause the local
+ database to increment its generation. We do not want to send
+ them back to the target in a future sync. However, there could
+ also be concurrent updates from another process doing eg
+ 'put_doc' while the sync was running. And we do want to
+ synchronize those documents. We can tell if there was a
+ concurrent update by comparing our new generation number
+ versus the generation we started, and how many documents we
+ inserted from the target. If it matches exactly, then we can
+ record with the target that they are fully up to date with our
+ new generation.
+ """
+ cur_gen, trans_id = self.source._get_generation_info()
+ if (cur_gen == start_generation + self.num_inserted
+ and self.num_inserted > 0):
+ self.sync_target.record_sync_info(
+ self.source._replica_uid, cur_gen, trans_id)
+
+ def sync(self, callback=None, autocreate=False):
+ """Synchronize documents between source and target."""
+ sync_target = self.sync_target
+ # get target identifier, its current generation,
+ # and its last-seen database generation for this source
+ try:
+ (self.target_replica_uid, target_gen, target_trans_id,
+ target_my_gen, target_my_trans_id) = sync_target.get_sync_info(
+ self.source._replica_uid)
+ except errors.DatabaseDoesNotExist:
+ if not autocreate:
+ raise
+ # will try to ask sync_exchange() to create the db
+ self.target_replica_uid = None
+ target_gen, target_trans_id = 0, ''
+ target_my_gen, target_my_trans_id = 0, ''
+ def ensure_callback(replica_uid):
+ self.target_replica_uid = replica_uid
+ else:
+ ensure_callback = None
+ # validate the generation and transaction id the target knows about us
+ self.source.validate_gen_and_trans_id(
+ target_my_gen, target_my_trans_id)
+ # what's changed since that generation and this current gen
+ my_gen, _, changes = self.source.whats_changed(target_my_gen)
+
+ # this source last-seen database generation for the target
+ if self.target_replica_uid is None:
+ target_last_known_gen, target_last_known_trans_id = 0, ''
+ else:
+ target_last_known_gen, target_last_known_trans_id = \
+ self.source._get_replica_gen_and_trans_id(self.target_replica_uid)
+ if not changes and target_last_known_gen == target_gen:
+ if target_trans_id != target_last_known_trans_id:
+ raise errors.InvalidTransactionId
+ return my_gen
+ changed_doc_ids = [doc_id for doc_id, _, _ in changes]
+ # prepare to send all the changed docs
+ docs_to_send = self.source.get_docs(changed_doc_ids,
+ check_for_conflicts=False, include_deleted=True)
+ # TODO: there must be a way to not iterate twice
+ docs_by_generation = zip(
+ docs_to_send, (gen for _, gen, _ in changes),
+ (trans for _, _, trans in changes))
+
+ # exchange documents and try to insert the returned ones with
+ # the target, return target synced-up-to gen
+ new_gen, new_trans_id = sync_target.sync_exchange(
+ docs_by_generation, self.source._replica_uid,
+ target_last_known_gen, target_last_known_trans_id,
+ self._insert_doc_from_target, ensure_callback=ensure_callback)
+ # record target synced-up-to generation including applying what we sent
+ self.source._set_replica_gen_and_trans_id(
+ self.target_replica_uid, new_gen, new_trans_id)
+
+ # if gapless record current reached generation with target
+ self._record_sync_info_with_the_target(my_gen)
+
+ return my_gen
+
+
+class SyncExchange(object):
+ """Steps and state for carrying through a sync exchange on a target."""
+
+ def __init__(self, db, source_replica_uid, last_known_generation):
+ self._db = db
+ self.source_replica_uid = source_replica_uid
+ self.source_last_known_generation = last_known_generation
+ self.seen_ids = {} # incoming ids not superseded
+ self.changes_to_return = None
+ self.new_gen = None
+ self.new_trans_id = None
+ # for tests
+ self._incoming_trace = []
+ self._trace_hook = None
+ self._db._last_exchange_log = {
+ 'receive': {'docs': self._incoming_trace},
+ 'return': None
+ }
+
+ def _set_trace_hook(self, cb):
+ self._trace_hook = cb
+
+ def _trace(self, state):
+ if not self._trace_hook:
+ return
+ self._trace_hook(state)
+
+ def insert_doc_from_source(self, doc, source_gen, trans_id):
+ """Try to insert synced document from source.
+
+ Conflicting documents are not inserted but will be sent over
+ to the sync source.
+
+ It keeps track of progress by storing the document source
+ generation as well.
+
+ The 1st step of a sync exchange is to call this repeatedly to
+ try insert all incoming documents from the source.
+
+ :param doc: A Document object.
+ :param source_gen: The source generation of doc.
+ :return: None
+ """
+ state, at_gen = self._db._put_doc_if_newer(doc, save_conflict=False,
+ replica_uid=self.source_replica_uid, replica_gen=source_gen,
+ replica_trans_id=trans_id)
+ if state == 'inserted':
+ self.seen_ids[doc.doc_id] = at_gen
+ elif state == 'converged':
+ # magical convergence
+ self.seen_ids[doc.doc_id] = at_gen
+ elif state == 'superseded':
+ # we have something newer that we will return
+ pass
+ else:
+ # conflict that we will returne
+ assert state == 'conflicted'
+ # for tests
+ self._incoming_trace.append((doc.doc_id, doc.rev))
+ self._db._last_exchange_log['receive'].update({
+ 'source_uid': self.source_replica_uid,
+ 'source_gen': source_gen
+ })
+
+ def find_changes_to_return(self):
+ """Find changes to return.
+
+ Find changes since last_known_generation in db generation
+ order using whats_changed. It excludes documents ids that have
+ already been considered (superseded by the sender, etc).
+
+ :return: new_generation - the generation of this database
+ which the caller can consider themselves to be synchronized after
+ processing the returned documents.
+ """
+ self._db._last_exchange_log['receive'].update({ # for tests
+ 'last_known_gen': self.source_last_known_generation
+ })
+ self._trace('before whats_changed')
+ gen, trans_id, changes = self._db.whats_changed(
+ self.source_last_known_generation)
+ self._trace('after whats_changed')
+ self.new_gen = gen
+ self.new_trans_id = trans_id
+ seen_ids = self.seen_ids
+ # changed docs that weren't superseded by or converged with
+ self.changes_to_return = [
+ (doc_id, gen, trans_id) for (doc_id, gen, trans_id) in changes
+ # there was a subsequent update
+ if doc_id not in seen_ids or seen_ids.get(doc_id) < gen]
+ return self.new_gen
+
+ def return_docs(self, return_doc_cb):
+ """Return the changed documents and their last change generation
+ repeatedly invoking the callback return_doc_cb.
+
+ The final step of a sync exchange.
+
+ :param: return_doc_cb(doc, gen, trans_id): is a callback
+ used to return the documents with their last change generation
+ to the target replica.
+ :return: None
+ """
+ changes_to_return = self.changes_to_return
+ # return docs, including conflicts
+ changed_doc_ids = [doc_id for doc_id, _, _ in changes_to_return]
+ self._trace('before get_docs')
+ docs = self._db.get_docs(
+ changed_doc_ids, check_for_conflicts=False, include_deleted=True)
+
+ docs_by_gen = izip(
+ docs, (gen for _, gen, _ in changes_to_return),
+ (trans_id for _, _, trans_id in changes_to_return))
+ _outgoing_trace = [] # for tests
+ for doc, gen, trans_id in docs_by_gen:
+ return_doc_cb(doc, gen, trans_id)
+ _outgoing_trace.append((doc.doc_id, doc.rev))
+ # for tests
+ self._db._last_exchange_log['return'] = {
+ 'docs': _outgoing_trace,
+ 'last_gen': self.new_gen
+ }
+
+
+class LocalSyncTarget(u1db.SyncTarget):
+ """Common sync target implementation logic for all local sync targets."""
+
+ def __init__(self, db):
+ self._db = db
+ self._trace_hook = None
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ self._db.validate_gen_and_trans_id(
+ last_known_generation, last_known_trans_id)
+ sync_exch = SyncExchange(
+ self._db, source_replica_uid, last_known_generation)
+ if self._trace_hook:
+ sync_exch._set_trace_hook(self._trace_hook)
+ # 1st step: try to insert incoming docs and record progress
+ for doc, doc_gen, trans_id in docs_by_generations:
+ sync_exch.insert_doc_from_source(doc, doc_gen, trans_id)
+ # 2nd step: find changed documents (including conflicts) to return
+ new_gen = sync_exch.find_changes_to_return()
+ # final step: return docs and record source replica sync point
+ sync_exch.return_docs(return_doc_cb)
+ return new_gen, sync_exch.new_trans_id
+
+ def _set_trace_hook(self, cb):
+ self._trace_hook = cb
diff --git a/src/leap/soledad/u1db/tests/__init__.py b/src/leap/soledad/u1db/tests/__init__.py
new file mode 100644
index 00000000..b8e16b15
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/__init__.py
@@ -0,0 +1,463 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test infrastructure for U1DB"""
+
+import copy
+import shutil
+import socket
+import tempfile
+import threading
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from wsgiref import simple_server
+
+from oauth import oauth
+from sqlite3 import dbapi2
+from StringIO import StringIO
+
+import testscenarios
+import testtools
+
+from u1db import (
+ errors,
+ Document,
+ )
+from u1db.backends import (
+ inmemory,
+ sqlite_backend,
+ )
+from u1db.remote import (
+ server_state,
+ )
+
+try:
+ from u1db.tests import c_backend_wrapper
+ c_backend_error = None
+except ImportError, e:
+ c_backend_wrapper = None # noqa
+ c_backend_error = e
+
+# Setting this means that failing assertions will not include this module in
+# their traceback. However testtools doesn't seem to set it, and we don't want
+# this level to be omitted, but the lower levels to be shown.
+# __unittest = 1
+
+
+class TestCase(testtools.TestCase):
+
+ def createTempDir(self, prefix='u1db-tmp-'):
+ """Create a temporary directory to do some work in.
+
+ This directory will be scheduled for cleanup when the test ends.
+ """
+ tempdir = tempfile.mkdtemp(prefix=prefix)
+ self.addCleanup(shutil.rmtree, tempdir)
+ return tempdir
+
+ def make_document(self, doc_id, doc_rev, content, has_conflicts=False):
+ return self.make_document_for_test(
+ self, doc_id, doc_rev, content, has_conflicts)
+
+ def make_document_for_test(self, test, doc_id, doc_rev, content,
+ has_conflicts):
+ return make_document_for_test(
+ test, doc_id, doc_rev, content, has_conflicts)
+
+ def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id))
+
+ def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content,
+ has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True))
+
+ def assertGetDocConflicts(self, db, doc_id, conflicts):
+ """Assert what conflicts are stored for a given doc_id.
+
+ :param conflicts: A list of (doc_rev, content) pairs.
+ The first item must match the first item returned from the
+ database, however the rest can be returned in any order.
+ """
+ if conflicts:
+ conflicts = [(rev, (json.loads(cont) if isinstance(cont, basestring)
+ else cont)) for (rev, cont) in conflicts]
+ conflicts = conflicts[:1] + sorted(conflicts[1:])
+ actual = db.get_doc_conflicts(doc_id)
+ if actual:
+ actual = [(doc.rev, (json.loads(doc.get_json())
+ if doc.get_json() is not None else None)) for doc in actual]
+ actual = actual[:1] + sorted(actual[1:])
+ self.assertEqual(conflicts, actual)
+
+
+def multiply_scenarios(a_scenarios, b_scenarios):
+ """Create the cross-product of scenarios."""
+
+ all_scenarios = []
+ for a_name, a_attrs in a_scenarios:
+ for b_name, b_attrs in b_scenarios:
+ name = '%s,%s' % (a_name, b_name)
+ attrs = dict(a_attrs)
+ attrs.update(b_attrs)
+ all_scenarios.append((name, attrs))
+ return all_scenarios
+
+
+simple_doc = '{"key": "value"}'
+nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
+
+
+def make_memory_database_for_test(test, replica_uid):
+ return inmemory.InMemoryDatabase(replica_uid)
+
+
+def copy_memory_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = inmemory.InMemoryDatabase(db._replica_uid)
+ new_db._transaction_log = db._transaction_log[:]
+ new_db._docs = copy.deepcopy(db._docs)
+ new_db._conflicts = copy.deepcopy(db._conflicts)
+ new_db._indexes = copy.deepcopy(db._indexes)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_sqlite_partial_expanded_for_test(test, replica_uid):
+ db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ db._set_replica_uid(replica_uid)
+ return db
+
+
+def copy_sqlite_partial_expanded_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ tmpfile = StringIO()
+ for line in db._db_handle.iterdump():
+ if not 'sqlite_sequence' in line: # work around bug in iterdump
+ tmpfile.write('%s\n' % line)
+ tmpfile.seek(0)
+ new_db._db_handle = dbapi2.connect(':memory:')
+ new_db._db_handle.cursor().executescript(tmpfile.read())
+ new_db._db_handle.commit()
+ new_db._set_replica_uid(db._replica_uid)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ return Document(doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+def make_c_database_for_test(test, replica_uid):
+ if c_backend_wrapper is None:
+ test.skipTest('c_backend_wrapper is not available')
+ db = c_backend_wrapper.CDatabase(':memory:')
+ db._set_replica_uid(replica_uid)
+ return db
+
+
+def copy_c_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ if c_backend_wrapper is None:
+ test.skipTest('c_backend_wrapper is not available')
+ new_db = db._copy(db)
+ return new_db
+
+
+def make_c_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ if c_backend_wrapper is None:
+ test.skipTest('c_backend_wrapper is not available')
+ return c_backend_wrapper.make_document(
+ doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+LOCAL_DATABASES_SCENARIOS = [
+ ('mem', {'make_database_for_test': make_memory_database_for_test,
+ 'copy_database_for_test': copy_memory_database_for_test,
+ 'make_document_for_test': make_document_for_test}),
+ ('sql', {'make_database_for_test':
+ make_sqlite_partial_expanded_for_test,
+ 'copy_database_for_test':
+ copy_sqlite_partial_expanded_for_test,
+ 'make_document_for_test': make_document_for_test}),
+ ]
+
+
+C_DATABASE_SCENARIOS = [
+ ('c', {'make_database_for_test': make_c_database_for_test,
+ 'copy_database_for_test': copy_c_database_for_test,
+ 'make_document_for_test': make_c_document_for_test})]
+
+
+class DatabaseBaseTests(TestCase):
+
+ accept_fixed_trans_id = False # set to True assertTransactionLog
+ # is happy with all trans ids = ''
+
+ scenarios = LOCAL_DATABASES_SCENARIOS
+
+ def create_database(self, replica_uid):
+ return self.make_database_for_test(self, replica_uid)
+
+ def copy_database(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ return self.copy_database_for_test(self, db)
+
+ def setUp(self):
+ super(DatabaseBaseTests, self).setUp()
+ self.db = self.create_database('test')
+
+ def tearDown(self):
+ # TODO: Add close_database parameterization
+ # self.close_database(self.db)
+ super(DatabaseBaseTests, self).tearDown()
+
+ def assertTransactionLog(self, doc_ids, db):
+ """Assert that the given docs are in the transaction log."""
+ log = db._get_transaction_log()
+ just_ids = []
+ seen_transactions = set()
+ for doc_id, transaction_id in log:
+ just_ids.append(doc_id)
+ self.assertIsNot(None, transaction_id,
+ "Transaction id should not be None")
+ if transaction_id == '' and self.accept_fixed_trans_id:
+ continue
+ self.assertNotEqual('', transaction_id,
+ "Transaction id should be a unique string")
+ self.assertTrue(transaction_id.startswith('T-'))
+ self.assertNotIn(transaction_id, seen_transactions)
+ seen_transactions.add(transaction_id)
+ self.assertEqual(doc_ids, just_ids)
+
+ def getLastTransId(self, db):
+ """Return the transaction id for the last database update."""
+ return self.db._get_transaction_log()[-1][-1]
+
+
+class ServerStateForTests(server_state.ServerState):
+ """Used in the test suite, so we don't have to touch disk, etc."""
+
+ def __init__(self):
+ super(ServerStateForTests, self).__init__()
+ self._dbs = {}
+
+ def open_database(self, path):
+ try:
+ return self._dbs[path]
+ except KeyError:
+ raise errors.DatabaseDoesNotExist
+
+ def check_database(self, path):
+ # cares only about the possible exception
+ self.open_database(path)
+
+ def ensure_database(self, path):
+ try:
+ db = self.open_database(path)
+ except errors.DatabaseDoesNotExist:
+ db = self._create_database(path)
+ return db, db._replica_uid
+
+ def _copy_database(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ new_db = copy_memory_database_for_test(None, db)
+ path = db._replica_uid
+ while path in self._dbs:
+ path += 'copy'
+ self._dbs[path] = new_db
+ return new_db
+
+ def _create_database(self, path):
+ db = inmemory.InMemoryDatabase(path)
+ self._dbs[path] = db
+ return db
+
+ def delete_database(self, path):
+ del self._dbs[path]
+
+
+class ResponderForTests(object):
+ """Responder for tests."""
+ _started = False
+ sent_response = False
+ status = None
+
+ def start_response(self, status='success', **kwargs):
+ self._started = True
+ self.status = status
+ self.kwargs = kwargs
+
+ def send_response(self, status='success', **kwargs):
+ self.start_response(status, **kwargs)
+ self.finish_response()
+
+ def finish_response(self):
+ self.sent_response = True
+
+
+class TestCaseWithServer(TestCase):
+
+ @staticmethod
+ def server_def():
+ # hook point
+ # should return (ServerClass, "shutdown method name", "url_scheme")
+ class _RequestHandler(simple_server.WSGIRequestHandler):
+ def log_request(*args):
+ pass # suppress
+
+ def make_server(host_port, application):
+ assert application, "forgot to override make_app(_with_state)?"
+ srv = simple_server.WSGIServer(host_port, _RequestHandler)
+ # patch the value in if it's None
+ if getattr(application, 'base_url', 1) is None:
+ application.base_url = "http://%s:%s" % srv.server_address
+ srv.set_app(application)
+ return srv
+
+ return make_server, "shutdown", "http"
+
+ @staticmethod
+ def make_app_with_state(state):
+ # hook point
+ return None
+
+ def make_app(self):
+ # potential hook point
+ self.request_state = ServerStateForTests()
+ return self.make_app_with_state(self.request_state)
+
+ def setUp(self):
+ super(TestCaseWithServer, self).setUp()
+ self.server = self.server_thread = None
+
+ @property
+ def url_scheme(self):
+ return self.server_def()[-1]
+
+ def startServer(self):
+ server_def = self.server_def()
+ server_class, shutdown_meth, _ = server_def
+ application = self.make_app()
+ self.server = server_class(('127.0.0.1', 0), application)
+ self.server_thread = threading.Thread(target=self.server.serve_forever,
+ kwargs=dict(poll_interval=0.01))
+ self.server_thread.start()
+ self.addCleanup(self.server_thread.join)
+ self.addCleanup(getattr(self.server, shutdown_meth))
+
+ def getURL(self, path=None):
+ host, port = self.server.server_address
+ if path is None:
+ path = ''
+ return '%s://%s:%s/%s' % (self.url_scheme, host, port, path)
+
+
+def socket_pair():
+ """Return a pair of TCP sockets connected to each other.
+
+ Unlike socket.socketpair, this should work on Windows.
+ """
+ sock_pair = getattr(socket, 'socket_pair', None)
+ if sock_pair:
+ return sock_pair(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(1)
+ client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client_sock.connect(listen_sock.getsockname())
+ server_sock, addr = listen_sock.accept()
+ listen_sock.close()
+ return server_sock, client_sock
+
+
+# OAuth related testing
+
+consumer1 = oauth.OAuthConsumer('K1', 'S1')
+token1 = oauth.OAuthToken('kkkk1', 'XYZ')
+consumer2 = oauth.OAuthConsumer('K2', 'S2')
+token2 = oauth.OAuthToken('kkkk2', 'ZYX')
+token3 = oauth.OAuthToken('kkkk3', 'ZYX')
+
+
+class TestingOAuthDataStore(oauth.OAuthDataStore):
+ """In memory predefined OAuthDataStore for testing."""
+
+ consumers = {
+ consumer1.key: consumer1,
+ consumer2.key: consumer2,
+ }
+
+ tokens = {
+ token1.key: token1,
+ token2.key: token2
+ }
+
+ def lookup_consumer(self, key):
+ return self.consumers.get(key)
+
+ def lookup_token(self, token_type, token_token):
+ return self.tokens.get(token_token)
+
+ def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
+ return None
+
+testingOAuthStore = TestingOAuthDataStore()
+
+sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
+sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
+
+
+def load_with_scenarios(loader, standard_tests, pattern):
+ """Load the tests in a given module.
+
+ This just applies testscenarios.generate_scenarios to all the tests that
+ are present. We do it at load time rather than at run time, because it
+ plays nicer with various tools.
+ """
+ suite = loader.suiteClass()
+ suite.addTests(testscenarios.generate_scenarios(standard_tests))
+ return suite
diff --git a/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx
new file mode 100644
index 00000000..8a4b600d
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/c_backend_wrapper.pyx
@@ -0,0 +1,1541 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+#
+"""A Cython wrapper around the C implementation of U1DB Database backend."""
+
+cdef extern from "Python.h":
+ object PyString_FromStringAndSize(char *s, Py_ssize_t n)
+ int PyString_AsStringAndSize(object o, char **buf, Py_ssize_t *length
+ ) except -1
+ char *PyString_AsString(object) except NULL
+ char *PyString_AS_STRING(object)
+ char *strdup(char *)
+ void *calloc(size_t, size_t)
+ void free(void *)
+ ctypedef struct FILE:
+ pass
+ fprintf(FILE *, char *, ...)
+ FILE *stderr
+ size_t strlen(char *)
+
+cdef extern from "stdarg.h":
+ ctypedef struct va_list:
+ pass
+ void va_start(va_list, void*)
+ void va_start_int "va_start" (va_list, int)
+ void va_end(va_list)
+
+cdef extern from "u1db/u1db.h":
+ ctypedef struct u1database:
+ pass
+ ctypedef struct u1db_document:
+ char *doc_id
+ size_t doc_id_len
+ char *doc_rev
+ size_t doc_rev_len
+ char *json
+ size_t json_len
+ int has_conflicts
+ # Note: u1query is actually defined in u1db_internal.h, and in u1db.h it is
+ # just an opaque pointer. However, older versions of Cython don't let
+ # you have a forward declaration and a full declaration, so we just
+ # expose the whole thing here.
+ ctypedef struct u1query:
+ char *index_name
+ int num_fields
+ char **fields
+ cdef struct u1db_oauth_creds:
+ int auth_kind
+ char *consumer_key
+ char *consumer_secret
+ char *token_key
+ char *token_secret
+ ctypedef union u1db_creds
+ ctypedef u1db_creds* const_u1db_creds_ptr "const u1db_creds *"
+
+ ctypedef char* const_char_ptr "const char*"
+ ctypedef int (*u1db_doc_callback)(void *context, u1db_document *doc)
+ ctypedef int (*u1db_key_callback)(void *context, int num_fields,
+ const_char_ptr *key)
+ ctypedef int (*u1db_doc_gen_callback)(void *context,
+ u1db_document *doc, int gen, const_char_ptr trans_id)
+ ctypedef int (*u1db_trans_info_callback)(void *context,
+ const_char_ptr doc_id, int gen, const_char_ptr trans_id)
+
+ u1database * u1db_open(char *fname)
+ void u1db_free(u1database **)
+ int u1db_set_replica_uid(u1database *, char *replica_uid)
+ int u1db_set_document_size_limit(u1database *, int limit)
+ int u1db_get_replica_uid(u1database *, const_char_ptr *replica_uid)
+ int u1db_create_doc_from_json(u1database *db, char *json, char *doc_id,
+ u1db_document **doc)
+ int u1db_delete_doc(u1database *db, u1db_document *doc)
+ int u1db_get_doc(u1database *db, char *doc_id, int include_deleted,
+ u1db_document **doc)
+ int u1db_get_docs(u1database *db, int n_doc_ids, const_char_ptr *doc_ids,
+ int check_for_conflicts, int include_deleted,
+ void *context, u1db_doc_callback cb)
+ int u1db_get_all_docs(u1database *db, int include_deleted, int *generation,
+ void *context, u1db_doc_callback cb)
+ int u1db_put_doc(u1database *db, u1db_document *doc)
+ int u1db__validate_source(u1database *db, const_char_ptr replica_uid,
+ int replica_gen, const_char_ptr replica_trans_id)
+ int u1db__put_doc_if_newer(u1database *db, u1db_document *doc,
+ int save_conflict, char *replica_uid,
+ int replica_gen, char *replica_trans_id,
+ int *state, int *at_gen)
+ int u1db_resolve_doc(u1database *db, u1db_document *doc,
+ int n_revs, const_char_ptr *revs)
+ int u1db_delete_doc(u1database *db, u1db_document *doc)
+ int u1db_whats_changed(u1database *db, int *gen, char **trans_id,
+ void *context, u1db_trans_info_callback cb)
+ int u1db__get_transaction_log(u1database *db, void *context,
+ u1db_trans_info_callback cb)
+ int u1db_get_doc_conflicts(u1database *db, char *doc_id, void *context,
+ u1db_doc_callback cb)
+ int u1db_sync(u1database *db, const_char_ptr url,
+ const_u1db_creds_ptr creds, int *local_gen) nogil
+ int u1db_create_index_list(u1database *db, char *index_name,
+ int n_expressions, const_char_ptr *expressions)
+ int u1db_create_index(u1database *db, char *index_name, int n_expressions,
+ ...)
+ int u1db_get_from_index_list(u1database *db, u1query *query, void *context,
+ u1db_doc_callback cb, int n_values,
+ const_char_ptr *values)
+ int u1db_get_from_index(u1database *db, u1query *query, void *context,
+ u1db_doc_callback cb, int n_values, char *val0,
+ ...)
+ int u1db_get_range_from_index(u1database *db, u1query *query,
+ void *context, u1db_doc_callback cb,
+ int n_values, const_char_ptr *start_values,
+ const_char_ptr *end_values)
+ int u1db_delete_index(u1database *db, char *index_name)
+ int u1db_list_indexes(u1database *db, void *context,
+ int (*cb)(void *context, const_char_ptr index_name,
+ int n_expressions, const_char_ptr *expressions))
+ int u1db_get_index_keys(u1database *db, char *index_name, void *context,
+ u1db_key_callback cb)
+ int u1db_simple_lookup1(u1database *db, char *index_name, char *val1,
+ void *context, u1db_doc_callback cb)
+ int u1db_query_init(u1database *db, char *index_name, u1query **query)
+ void u1db_free_query(u1query **query)
+
+ int U1DB_OK
+ int U1DB_INVALID_PARAMETER
+ int U1DB_REVISION_CONFLICT
+ int U1DB_INVALID_DOC_ID
+ int U1DB_DOCUMENT_ALREADY_DELETED
+ int U1DB_DOCUMENT_DOES_NOT_EXIST
+ int U1DB_NOT_IMPLEMENTED
+ int U1DB_INVALID_JSON
+ int U1DB_DOCUMENT_TOO_BIG
+ int U1DB_USER_QUOTA_EXCEEDED
+ int U1DB_INVALID_VALUE_FOR_INDEX
+ int U1DB_INVALID_FIELD_SPECIFIER
+ int U1DB_INVALID_GLOBBING
+ int U1DB_BROKEN_SYNC_STREAM
+ int U1DB_DUPLICATE_INDEX_NAME
+ int U1DB_INDEX_DOES_NOT_EXIST
+ int U1DB_INVALID_GENERATION
+ int U1DB_INVALID_TRANSACTION_ID
+ int U1DB_INVALID_TRANSFORMATION_FUNCTION
+ int U1DB_UNKNOWN_OPERATION
+ int U1DB_INTERNAL_ERROR
+ int U1DB_TARGET_UNAVAILABLE
+
+ int U1DB_INSERTED
+ int U1DB_SUPERSEDED
+ int U1DB_CONVERGED
+ int U1DB_CONFLICTED
+
+ int U1DB_OAUTH_AUTH
+
+ void u1db_free_doc(u1db_document **doc)
+ int u1db_doc_set_json(u1db_document *doc, char *json)
+ int u1db_doc_get_size(u1db_document *doc)
+
+
+cdef extern from "u1db/u1db_internal.h":
+ ctypedef struct u1db_row:
+ u1db_row *next
+ int num_columns
+ int *column_sizes
+ unsigned char **columns
+
+ ctypedef struct u1db_table:
+ int status
+ u1db_row *first_row
+
+ ctypedef struct u1db_record:
+ u1db_record *next
+ char *doc_id
+ char *doc_rev
+ char *doc
+
+ ctypedef struct u1db_sync_exchange:
+ int target_gen
+ int num_doc_ids
+ char **doc_ids_to_return
+ int *gen_for_doc_ids
+ const_char_ptr *trans_ids_for_doc_ids
+
+ ctypedef int (*u1db__trace_callback)(void *context, const_char_ptr state)
+ ctypedef struct u1db_sync_target:
+ int (*get_sync_info)(u1db_sync_target *st, char *source_replica_uid,
+ const_char_ptr *st_replica_uid, int *st_gen,
+ char **st_trans_id, int *source_gen,
+ char **source_trans_id) nogil
+ int (*record_sync_info)(u1db_sync_target *st,
+ char *source_replica_uid, int source_gen, char *trans_id) nogil
+ int (*sync_exchange)(u1db_sync_target *st,
+ char *source_replica_uid, int n_docs,
+ u1db_document **docs, int *generations,
+ const_char_ptr *trans_ids,
+ int *target_gen, char **target_trans_id,
+ void *context, u1db_doc_gen_callback cb,
+ void *ensure_callback) nogil
+ int (*sync_exchange_doc_ids)(u1db_sync_target *st,
+ u1database *source_db, int n_doc_ids,
+ const_char_ptr *doc_ids, int *generations,
+ const_char_ptr *trans_ids,
+ int *target_gen, char **target_trans_id,
+ void *context,
+ u1db_doc_gen_callback cb,
+ void *ensure_callback) nogil
+ int (*get_sync_exchange)(u1db_sync_target *st,
+ char *source_replica_uid,
+ int last_known_source_gen,
+ u1db_sync_exchange **exchange) nogil
+ void (*finalize_sync_exchange)(u1db_sync_target *st,
+ u1db_sync_exchange **exchange) nogil
+ int (*_set_trace_hook)(u1db_sync_target *st,
+ void *context, u1db__trace_callback cb) nogil
+
+
+ void u1db__set_zero_delays()
+ int u1db__get_generation(u1database *, int *db_rev)
+ int u1db__get_document_size_limit(u1database *, int *limit)
+ int u1db__get_generation_info(u1database *, int *db_rev, char **trans_id)
+ int u1db__get_trans_id_for_gen(u1database *, int db_rev, char **trans_id)
+ int u1db_validate_gen_and_trans_id(u1database *, int db_rev,
+ const_char_ptr trans_id)
+ char *u1db__allocate_doc_id(u1database *)
+ int u1db__sql_close(u1database *)
+ u1database *u1db__copy(u1database *)
+ int u1db__sql_is_open(u1database *)
+ u1db_table *u1db__sql_run(u1database *, char *sql, size_t n)
+ void u1db__free_table(u1db_table **table)
+ u1db_record *u1db__create_record(char *doc_id, char *doc_rev, char *doc)
+ void u1db__free_records(u1db_record **)
+
+ int u1db__allocate_document(char *doc_id, char *revision, char *content,
+ int has_conflicts, u1db_document **result)
+ int u1db__generate_hex_uuid(char *)
+
+ int u1db__get_replica_gen_and_trans_id(u1database *db, char *replica_uid,
+ int *generation, char **trans_id)
+ int u1db__set_replica_gen_and_trans_id(u1database *db, char *replica_uid,
+ int generation, char *trans_id)
+ int u1db__sync_get_machine_info(u1database *db, char *other_replica_uid,
+ int *other_db_rev, char **my_replica_uid,
+ int *my_db_rev)
+ int u1db__sync_record_machine_info(u1database *db, char *replica_uid,
+ int db_rev)
+ int u1db__sync_exchange_seen_ids(u1db_sync_exchange *se, int *n_ids,
+ const_char_ptr **doc_ids)
+ int u1db__format_query(int n_fields, const_char_ptr *values, char **buf,
+ int *wildcard)
+ int u1db__get_sync_target(u1database *db, u1db_sync_target **sync_target)
+ int u1db__free_sync_target(u1db_sync_target **sync_target)
+ int u1db__sync_db_to_target(u1database *db, u1db_sync_target *target,
+ int *local_gen_before_sync) nogil
+
+ int u1db__sync_exchange_insert_doc_from_source(u1db_sync_exchange *se,
+ u1db_document *doc, int source_gen, const_char_ptr trans_id)
+ int u1db__sync_exchange_find_doc_ids_to_return(u1db_sync_exchange *se)
+ int u1db__sync_exchange_return_docs(u1db_sync_exchange *se, void *context,
+ int (*cb)(void *context,
+ u1db_document *doc, int gen,
+ const_char_ptr trans_id))
+ int u1db__create_http_sync_target(char *url, u1db_sync_target **target)
+ int u1db__create_oauth_http_sync_target(char *url,
+ char *consumer_key, char *consumer_secret,
+ char *token_key, char *token_secret,
+ u1db_sync_target **target)
+
+cdef extern from "u1db/u1db_http_internal.h":
+ int u1db__format_sync_url(u1db_sync_target *st,
+ const_char_ptr source_replica_uid, char **sync_url)
+ int u1db__get_oauth_authorization(u1db_sync_target *st,
+ char *http_method, char *url,
+ char **oauth_authorization)
+
+
+cdef extern from "u1db/u1db_vectorclock.h":
+ ctypedef struct u1db_vectorclock_item:
+ char *replica_uid
+ int generation
+
+ ctypedef struct u1db_vectorclock:
+ int num_items
+ u1db_vectorclock_item *items
+
+ u1db_vectorclock *u1db__vectorclock_from_str(char *s)
+ void u1db__free_vectorclock(u1db_vectorclock **clock)
+ int u1db__vectorclock_increment(u1db_vectorclock *clock, char *replica_uid)
+ int u1db__vectorclock_maximize(u1db_vectorclock *clock,
+ u1db_vectorclock *other)
+ int u1db__vectorclock_as_str(u1db_vectorclock *clock, char **result)
+ int u1db__vectorclock_is_newer(u1db_vectorclock *maybe_newer,
+ u1db_vectorclock *older)
+
+from u1db import errors
+from sqlite3 import dbapi2
+
+
+cdef int _append_trans_info_to_list(void *context, const_char_ptr doc_id,
+ int generation,
+ const_char_ptr trans_id) with gil:
+ a_list = <object>(context)
+ doc = doc_id
+ a_list.append((doc, generation, trans_id))
+ return 0
+
+
+cdef int _append_doc_to_list(void *context, u1db_document *doc) with gil:
+ a_list = <object>context
+ pydoc = CDocument()
+ pydoc._doc = doc
+ a_list.append(pydoc)
+ return 0
+
+cdef int _append_key_to_list(void *context, int num_fields,
+ const_char_ptr *key) with gil:
+ a_list = <object>(context)
+ field_list = []
+ for i from 0 <= i < num_fields:
+ field = key[i]
+ field_list.append(field.decode('utf-8'))
+ a_list.append(tuple(field_list))
+ return 0
+
+cdef _list_to_array(lst, const_char_ptr **res, int *count):
+ cdef const_char_ptr *tmp
+ count[0] = len(lst)
+ tmp = <const_char_ptr*>calloc(sizeof(char*), count[0])
+ for idx, x in enumerate(lst):
+ tmp[idx] = x
+ res[0] = tmp
+
+cdef _list_to_str_array(lst, const_char_ptr **res, int *count):
+ cdef const_char_ptr *tmp
+ count[0] = len(lst)
+ tmp = <const_char_ptr*>calloc(sizeof(char*), count[0])
+ new_objs = []
+ for idx, x in enumerate(lst):
+ if isinstance(x, unicode):
+ x = x.encode('utf-8')
+ new_objs.append(x)
+ tmp[idx] = x
+ res[0] = tmp
+ return new_objs
+
+
+cdef int _append_index_definition_to_list(void *context,
+ const_char_ptr index_name, int n_expressions,
+ const_char_ptr *expressions) with gil:
+ cdef int i
+
+ a_list = <object>(context)
+ exp_list = []
+ for i from 0 <= i < n_expressions:
+ s = expressions[i]
+ exp_list.append(s.decode('utf-8'))
+ a_list.append((index_name, exp_list))
+ return 0
+
+
+cdef int return_doc_cb_wrapper(void *context, u1db_document *doc,
+ int gen, const_char_ptr trans_id) with gil:
+ cdef CDocument pydoc
+ user_cb = <object>context
+ pydoc = CDocument()
+ pydoc._doc = doc
+ try:
+ user_cb(pydoc, gen, trans_id)
+ except Exception, e:
+ # We suppress the exception here, because intermediating through the C
+ # layer gets a bit crazy
+ return U1DB_INVALID_PARAMETER
+ return U1DB_OK
+
+
+cdef int _trace_hook(void *context, const_char_ptr state) with gil:
+ if context == NULL:
+ return U1DB_INVALID_PARAMETER
+ ctx = <object>context
+ try:
+ ctx(state)
+ except:
+ # Note: It would be nice if we could map the Python exception into
+ # something in C
+ return U1DB_INTERNAL_ERROR
+ return U1DB_OK
+
+
+cdef char *_ensure_str(object obj, object extra_objs) except NULL:
+ """Ensure that we have the UTF-8 representation of a parameter.
+
+ :param obj: A Unicode or String object.
+ :param extra_objs: This should be a Python list. If we have to convert obj
+ from being a Unicode object, this will hold the PyString object so that
+ we know the char* lifetime will be correct.
+ :return: A C pointer to the UTF-8 representation.
+ """
+ if isinstance(obj, unicode):
+ obj = obj.encode('utf-8')
+ extra_objs.append(obj)
+ return PyString_AsString(obj)
+
+
+def _format_query(fields):
+ """Wrapper around u1db__format_query for testing."""
+ cdef int status
+ cdef char *buf
+ cdef int wildcard[10]
+ cdef const_char_ptr *values
+ cdef int n_values
+
+ # keep a reference to new_objs so that the pointers in expressions
+ # remain valid.
+ new_objs = _list_to_str_array(fields, &values, &n_values)
+ try:
+ status = u1db__format_query(n_values, values, &buf, wildcard)
+ finally:
+ free(<void*>values)
+ handle_status("format_query", status)
+ if buf == NULL:
+ res = None
+ else:
+ res = buf
+ free(buf)
+ w = []
+ for i in range(len(fields)):
+ w.append(wildcard[i])
+ return res, w
+
+
+def make_document(doc_id, rev, content, has_conflicts=False):
+ cdef u1db_document *doc
+ cdef char *c_content = NULL, *c_rev = NULL, *c_doc_id = NULL
+ cdef int conflict
+
+ if has_conflicts:
+ conflict = 1
+ else:
+ conflict = 0
+ if doc_id is None:
+ c_doc_id = NULL
+ else:
+ c_doc_id = doc_id
+ if content is None:
+ c_content = NULL
+ else:
+ c_content = content
+ if rev is None:
+ c_rev = NULL
+ else:
+ c_rev = rev
+ handle_status(
+ "make_document",
+ u1db__allocate_document(c_doc_id, c_rev, c_content, conflict, &doc))
+ pydoc = CDocument()
+ pydoc._doc = doc
+ return pydoc
+
+
+def generate_hex_uuid():
+ uuid = PyString_FromStringAndSize(NULL, 32)
+ handle_status(
+ "Failed to generate uuid",
+ u1db__generate_hex_uuid(PyString_AS_STRING(uuid)))
+ return uuid
+
+
+cdef class CDocument(object):
+ """A thin wrapper around the C Document struct."""
+
+ cdef u1db_document *_doc
+
+ def __init__(self):
+ self._doc = NULL
+
+ def __dealloc__(self):
+ u1db_free_doc(&self._doc)
+
+ property doc_id:
+ def __get__(self):
+ if self._doc.doc_id == NULL:
+ return None
+ return PyString_FromStringAndSize(
+ self._doc.doc_id, self._doc.doc_id_len)
+
+ property rev:
+ def __get__(self):
+ if self._doc.doc_rev == NULL:
+ return None
+ return PyString_FromStringAndSize(
+ self._doc.doc_rev, self._doc.doc_rev_len)
+
+ def get_json(self):
+ if self._doc.json == NULL:
+ return None
+ return PyString_FromStringAndSize(
+ self._doc.json, self._doc.json_len)
+
+ def set_json(self, val):
+ u1db_doc_set_json(self._doc, val)
+
+ def get_size(self):
+ return u1db_doc_get_size(self._doc)
+
+ property has_conflicts:
+ def __get__(self):
+ if self._doc.has_conflicts:
+ return True
+ return False
+
+ def __repr__(self):
+ if self._doc.has_conflicts:
+ extra = ', conflicted'
+ else:
+ extra = ''
+ return '%s(%s, %s%s, %r)' % (self.__class__.__name__, self.doc_id,
+ self.rev, extra, self.get_json())
+
+ def __hash__(self):
+ raise NotImplementedError(self.__hash__)
+
+ def __richcmp__(self, other, int t):
+ try:
+ if t == 0: # Py_LT <
+ return ((self.doc_id, self.rev, self.get_json())
+ < (other.doc_id, other.rev, other.get_json()))
+ elif t == 2: # Py_EQ ==
+ return (self.doc_id == other.doc_id
+ and self.rev == other.rev
+ and self.get_json() == other.get_json()
+ and self.has_conflicts == other.has_conflicts)
+ except AttributeError:
+ # Fall through to NotImplemented
+ pass
+
+ return NotImplemented
+
+
+cdef object safe_str(const_char_ptr s):
+ if s == NULL:
+ return None
+ return s
+
+
+cdef class CQuery:
+
+ cdef u1query *_query
+
+ def __init__(self):
+ self._query = NULL
+
+ def __dealloc__(self):
+ u1db_free_query(&self._query)
+
+ def _check(self):
+ if self._query == NULL:
+ raise RuntimeError("No valid _query.")
+
+ property index_name:
+ def __get__(self):
+ self._check()
+ return safe_str(self._query.index_name)
+
+ property num_fields:
+ def __get__(self):
+ self._check()
+ return self._query.num_fields
+
+ property fields:
+ def __get__(self):
+ cdef int i
+ self._check()
+ fields = []
+ for i from 0 <= i < self._query.num_fields:
+ fields.append(safe_str(self._query.fields[i]))
+ return fields
+
+
+cdef handle_status(context, int status):
+ if status == U1DB_OK:
+ return
+ if status == U1DB_REVISION_CONFLICT:
+ raise errors.RevisionConflict()
+ if status == U1DB_INVALID_DOC_ID:
+ raise errors.InvalidDocId()
+ if status == U1DB_DOCUMENT_ALREADY_DELETED:
+ raise errors.DocumentAlreadyDeleted()
+ if status == U1DB_DOCUMENT_DOES_NOT_EXIST:
+ raise errors.DocumentDoesNotExist()
+ if status == U1DB_INVALID_PARAMETER:
+ raise RuntimeError('Bad parameters supplied')
+ if status == U1DB_NOT_IMPLEMENTED:
+ raise NotImplementedError("Functionality not implemented yet: %s"
+ % (context,))
+ if status == U1DB_INVALID_VALUE_FOR_INDEX:
+ raise errors.InvalidValueForIndex()
+ if status == U1DB_INVALID_GLOBBING:
+ raise errors.InvalidGlobbing()
+ if status == U1DB_INTERNAL_ERROR:
+ raise errors.U1DBError("internal error")
+ if status == U1DB_BROKEN_SYNC_STREAM:
+ raise errors.BrokenSyncStream()
+ if status == U1DB_CONFLICTED:
+ raise errors.ConflictedDoc()
+ if status == U1DB_DUPLICATE_INDEX_NAME:
+ raise errors.IndexNameTakenError()
+ if status == U1DB_INDEX_DOES_NOT_EXIST:
+ raise errors.IndexDoesNotExist
+ if status == U1DB_INVALID_GENERATION:
+ raise errors.InvalidGeneration
+ if status == U1DB_INVALID_TRANSACTION_ID:
+ raise errors.InvalidTransactionId
+ if status == U1DB_TARGET_UNAVAILABLE:
+ raise errors.Unavailable
+ if status == U1DB_INVALID_JSON:
+ raise errors.InvalidJSON
+ if status == U1DB_DOCUMENT_TOO_BIG:
+ raise errors.DocumentTooBig
+ if status == U1DB_USER_QUOTA_EXCEEDED:
+ raise errors.UserQuotaExceeded
+ if status == U1DB_INVALID_TRANSFORMATION_FUNCTION:
+ raise errors.IndexDefinitionParseError
+ if status == U1DB_UNKNOWN_OPERATION:
+ raise errors.IndexDefinitionParseError
+ if status == U1DB_INVALID_FIELD_SPECIFIER:
+ raise errors.IndexDefinitionParseError()
+ raise RuntimeError('%s (status: %s)' % (context, status))
+
+
+cdef class CDatabase
+cdef class CSyncTarget
+
+cdef class CSyncExchange(object):
+
+ cdef u1db_sync_exchange *_exchange
+ cdef CSyncTarget _target
+
+ def __init__(self, CSyncTarget target, source_replica_uid, source_gen):
+ self._target = target
+ assert self._target._st.get_sync_exchange != NULL, \
+ "get_sync_exchange is NULL?"
+ handle_status("get_sync_exchange",
+ self._target._st.get_sync_exchange(self._target._st,
+ source_replica_uid, source_gen, &self._exchange))
+
+ def __dealloc__(self):
+ if self._target is not None and self._target._st != NULL:
+ self._target._st.finalize_sync_exchange(self._target._st,
+ &self._exchange)
+
+ def _check(self):
+ if self._exchange == NULL:
+ raise RuntimeError("self._exchange is NULL")
+
+ property target_gen:
+ def __get__(self):
+ self._check()
+ return self._exchange.target_gen
+
+ def insert_doc_from_source(self, CDocument doc, source_gen,
+ source_trans_id):
+ self._check()
+ handle_status("insert_doc_from_source",
+ u1db__sync_exchange_insert_doc_from_source(self._exchange,
+ doc._doc, source_gen, source_trans_id))
+
+ def find_doc_ids_to_return(self):
+ self._check()
+ handle_status("find_doc_ids_to_return",
+ u1db__sync_exchange_find_doc_ids_to_return(self._exchange))
+
+ def return_docs(self, return_doc_cb):
+ self._check()
+ handle_status("return_docs",
+ u1db__sync_exchange_return_docs(self._exchange,
+ <void *>return_doc_cb, &return_doc_cb_wrapper))
+
+ def get_seen_ids(self):
+ cdef const_char_ptr *seen_ids
+ cdef int i, n_ids
+ self._check()
+ handle_status("sync_exchange_seen_ids",
+ u1db__sync_exchange_seen_ids(self._exchange, &n_ids, &seen_ids))
+ res = []
+ for i from 0 <= i < n_ids:
+ res.append(seen_ids[i])
+ if (seen_ids != NULL):
+ free(<void*>seen_ids)
+ return res
+
+ def get_doc_ids_to_return(self):
+ self._check()
+ res = []
+ if (self._exchange.num_doc_ids > 0
+ and self._exchange.doc_ids_to_return != NULL):
+ for i from 0 <= i < self._exchange.num_doc_ids:
+ res.append(
+ (self._exchange.doc_ids_to_return[i],
+ self._exchange.gen_for_doc_ids[i],
+ self._exchange.trans_ids_for_doc_ids[i]))
+ return res
+
+
+cdef class CSyncTarget(object):
+
+ cdef u1db_sync_target *_st
+ cdef CDatabase _db
+
+ def __init__(self):
+ self._db = None
+ self._st = NULL
+ u1db__set_zero_delays()
+
+ def __dealloc__(self):
+ u1db__free_sync_target(&self._st)
+
+ def _check(self):
+ if self._st == NULL:
+ raise RuntimeError("self._st is NULL")
+
+ def get_sync_info(self, source_replica_uid):
+ cdef const_char_ptr st_replica_uid = NULL
+ cdef int st_gen = 0, source_gen = 0, status
+ cdef char *trans_id = NULL
+ cdef char *st_trans_id = NULL
+ cdef char *c_source_replica_uid = NULL
+
+ self._check()
+ assert self._st.get_sync_info != NULL, "get_sync_info is NULL?"
+ c_source_replica_uid = source_replica_uid
+ with nogil:
+ status = self._st.get_sync_info(self._st, c_source_replica_uid,
+ &st_replica_uid, &st_gen, &st_trans_id, &source_gen, &trans_id)
+ handle_status("get_sync_info", status)
+ res_trans_id = None
+ res_st_trans_id = None
+ if trans_id != NULL:
+ res_trans_id = trans_id
+ free(trans_id)
+ if st_trans_id != NULL:
+ res_st_trans_id = st_trans_id
+ free(st_trans_id)
+ return (
+ safe_str(st_replica_uid), st_gen, res_st_trans_id, source_gen,
+ res_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_gen, source_trans_id):
+ cdef int status
+ cdef int c_source_gen
+ cdef char *c_source_replica_uid = NULL
+ cdef char *c_source_trans_id = NULL
+
+ self._check()
+ assert self._st.record_sync_info != NULL, "record_sync_info is NULL?"
+ c_source_replica_uid = source_replica_uid
+ c_source_gen = source_gen
+ c_source_trans_id = source_trans_id
+ with nogil:
+ status = self._st.record_sync_info(
+ self._st, c_source_replica_uid, c_source_gen,
+ c_source_trans_id)
+ handle_status("record_sync_info", status)
+
+ def _get_sync_exchange(self, source_replica_uid, source_gen):
+ self._check()
+ return CSyncExchange(self, source_replica_uid, source_gen)
+
+ def sync_exchange_doc_ids(self, source_db, doc_id_generations,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb):
+ cdef const_char_ptr *doc_ids
+ cdef int *generations
+ cdef int num_doc_ids
+ cdef int target_gen
+ cdef char *target_trans_id = NULL
+ cdef int status
+ cdef CDatabase sdb
+
+ self._check()
+ assert self._st.sync_exchange_doc_ids != NULL, "sync_exchange_doc_ids is NULL?"
+ sdb = source_db
+ num_doc_ids = len(doc_id_generations)
+ doc_ids = <const_char_ptr *>calloc(num_doc_ids, sizeof(char *))
+ if doc_ids == NULL:
+ raise MemoryError
+ generations = <int *>calloc(num_doc_ids, sizeof(int))
+ if generations == NULL:
+ free(<void *>doc_ids)
+ raise MemoryError
+ trans_ids = <const_char_ptr*>calloc(num_doc_ids, sizeof(char *))
+ if trans_ids == NULL:
+ raise MemoryError
+ res_trans_id = ''
+ try:
+ for i, (doc_id, gen, trans_id) in enumerate(doc_id_generations):
+ doc_ids[i] = PyString_AsString(doc_id)
+ generations[i] = gen
+ trans_ids[i] = trans_id
+ target_gen = last_known_generation
+ if last_known_trans_id is not None:
+ target_trans_id = last_known_trans_id
+ with nogil:
+ status = self._st.sync_exchange_doc_ids(self._st, sdb._db,
+ num_doc_ids, doc_ids, generations, trans_ids,
+ &target_gen, &target_trans_id,
+ <void*>return_doc_cb, return_doc_cb_wrapper, NULL)
+ handle_status("sync_exchange_doc_ids", status)
+ if target_trans_id != NULL:
+ res_trans_id = target_trans_id
+ finally:
+ if target_trans_id != NULL:
+ free(target_trans_id)
+ if doc_ids != NULL:
+ free(<void *>doc_ids)
+ if generations != NULL:
+ free(generations)
+ if trans_ids != NULL:
+ free(trans_ids)
+ return target_gen, res_trans_id
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ cdef CDocument cur_doc
+ cdef u1db_document **docs = NULL
+ cdef int *generations = NULL
+ cdef const_char_ptr *trans_ids = NULL
+ cdef char *target_trans_id = NULL
+ cdef char *c_source_replica_uid = NULL
+ cdef int i, count, status, target_gen
+ assert ensure_callback is None # interface difference
+
+ self._check()
+ assert self._st.sync_exchange != NULL, "sync_exchange is NULL?"
+ count = len(docs_by_generations)
+ res_trans_id = ''
+ try:
+ docs = <u1db_document **>calloc(count, sizeof(u1db_document*))
+ if docs == NULL:
+ raise MemoryError
+ generations = <int*>calloc(count, sizeof(int))
+ if generations == NULL:
+ raise MemoryError
+ trans_ids = <const_char_ptr*>calloc(count, sizeof(char*))
+ if trans_ids == NULL:
+ raise MemoryError
+ for i from 0 <= i < count:
+ cur_doc = docs_by_generations[i][0]
+ generations[i] = docs_by_generations[i][1]
+ trans_ids[i] = docs_by_generations[i][2]
+ docs[i] = cur_doc._doc
+ target_gen = last_known_generation
+ if last_known_trans_id is not None:
+ target_trans_id = last_known_trans_id
+ c_source_replica_uid = source_replica_uid
+ with nogil:
+ status = self._st.sync_exchange(
+ self._st, c_source_replica_uid, count, docs, generations,
+ trans_ids, &target_gen, &target_trans_id,
+ <void *>return_doc_cb, return_doc_cb_wrapper, NULL)
+ handle_status("sync_exchange", status)
+ finally:
+ if docs != NULL:
+ free(docs)
+ if generations != NULL:
+ free(generations)
+ if trans_ids != NULL:
+ free(trans_ids)
+ if target_trans_id != NULL:
+ res_trans_id = target_trans_id
+ free(target_trans_id)
+ return target_gen, res_trans_id
+
+ def _set_trace_hook(self, cb):
+ self._check()
+ assert self._st._set_trace_hook != NULL, "_set_trace_hook is NULL?"
+ handle_status("_set_trace_hook",
+ self._st._set_trace_hook(self._st, <void*>cb, _trace_hook))
+
+ _set_trace_hook_shallow = _set_trace_hook
+
+
+cdef class CDatabase(object):
+ """A thin wrapper/shim to interact with the C implementation.
+
+ Functionality should not be written here. It is only provided as a way to
+ expose the C API to the python test suite.
+ """
+
+ cdef public object _filename
+ cdef u1database *_db
+ cdef public object _supports_indexes
+
+ def __init__(self, filename):
+ self._supports_indexes = False
+ self._filename = filename
+ self._db = u1db_open(self._filename)
+
+ def __dealloc__(self):
+ u1db_free(&self._db)
+
+ def close(self):
+ return u1db__sql_close(self._db)
+
+ def _copy(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = CDatabase(':memory:')
+ u1db_free(&new_db._db)
+ new_db._db = u1db__copy(self._db)
+ return new_db
+
+ def _sql_is_open(self):
+ if self._db == NULL:
+ return True
+ return u1db__sql_is_open(self._db)
+
+ property _replica_uid:
+ def __get__(self):
+ cdef const_char_ptr val
+ cdef int status
+ status = u1db_get_replica_uid(self._db, &val)
+ if status != 0:
+ if val != NULL:
+ err = str(val)
+ else:
+ err = "<unknown>"
+ raise RuntimeError("Failed to get_replica_uid: %d %s"
+ % (status, err))
+ if val == NULL:
+ return None
+ return str(val)
+
+ def _set_replica_uid(self, replica_uid):
+ cdef int status
+ status = u1db_set_replica_uid(self._db, replica_uid)
+ if status != 0:
+ raise RuntimeError('replica_uid could not be set to %s, error: %d'
+ % (replica_uid, status))
+
+ property document_size_limit:
+ def __get__(self):
+ cdef int limit
+ handle_status("document_size_limit",
+ u1db__get_document_size_limit(self._db, &limit))
+ return limit
+
+ def set_document_size_limit(self, limit):
+ cdef int status
+ status = u1db_set_document_size_limit(self._db, limit)
+ if status != 0:
+ raise RuntimeError(
+ "document_size_limit could not be set to %d, error: %d",
+ (limit, status))
+
+ def _allocate_doc_id(self):
+ cdef char *val
+ val = u1db__allocate_doc_id(self._db)
+ if val == NULL:
+ raise RuntimeError("Failed to allocate document id")
+ s = str(val)
+ free(val)
+ return s
+
+ def _run_sql(self, sql):
+ cdef u1db_table *tbl
+ cdef u1db_row *cur_row
+ cdef size_t n
+ cdef int i
+
+ if self._db == NULL:
+ raise RuntimeError("called _run_sql with a NULL pointer.")
+ tbl = u1db__sql_run(self._db, sql, len(sql))
+ if tbl == NULL:
+ raise MemoryError("Failed to allocate table memory.")
+ try:
+ if tbl.status != 0:
+ raise RuntimeError("Status was not 0: %d" % (tbl.status,))
+ # Now convert the table into python
+ res = []
+ cur_row = tbl.first_row
+ while cur_row != NULL:
+ row = []
+ for i from 0 <= i < cur_row.num_columns:
+ row.append(PyString_FromStringAndSize(
+ <char*>(cur_row.columns[i]), cur_row.column_sizes[i]))
+ res.append(tuple(row))
+ cur_row = cur_row.next
+ return res
+ finally:
+ u1db__free_table(&tbl)
+
+ def create_doc_from_json(self, json, doc_id=None):
+ cdef u1db_document *doc = NULL
+ cdef char *c_doc_id
+
+ if doc_id is None:
+ c_doc_id = NULL
+ else:
+ c_doc_id = doc_id
+ handle_status('Failed to create_doc',
+ u1db_create_doc_from_json(self._db, json, c_doc_id, &doc))
+ pydoc = CDocument()
+ pydoc._doc = doc
+ return pydoc
+
+ def put_doc(self, CDocument doc):
+ handle_status("Failed to put_doc",
+ u1db_put_doc(self._db, doc._doc))
+ return doc.rev
+
+ def _validate_source(self, replica_uid, replica_gen, replica_trans_id):
+ cdef const_char_ptr c_uid, c_trans_id
+ cdef int c_gen = 0
+
+ c_uid = replica_uid
+ c_trans_id = replica_trans_id
+ c_gen = replica_gen
+ handle_status(
+ "invalid generation or transaction id",
+ u1db__validate_source(self._db, c_uid, c_gen, c_trans_id))
+
+ def _put_doc_if_newer(self, CDocument doc, save_conflict, replica_uid=None,
+ replica_gen=None, replica_trans_id=None):
+ cdef char *c_uid, *c_trans_id
+ cdef int gen, state = 0, at_gen = -1
+
+ if replica_uid is None:
+ c_uid = NULL
+ else:
+ c_uid = replica_uid
+ if replica_trans_id is None:
+ c_trans_id = NULL
+ else:
+ c_trans_id = replica_trans_id
+ if replica_gen is None:
+ gen = 0
+ else:
+ gen = replica_gen
+ handle_status("Failed to _put_doc_if_newer",
+ u1db__put_doc_if_newer(self._db, doc._doc, save_conflict,
+ c_uid, gen, c_trans_id, &state, &at_gen))
+ if state == U1DB_INSERTED:
+ return 'inserted', at_gen
+ elif state == U1DB_SUPERSEDED:
+ return 'superseded', at_gen
+ elif state == U1DB_CONVERGED:
+ return 'converged', at_gen
+ elif state == U1DB_CONFLICTED:
+ return 'conflicted', at_gen
+ else:
+ raise RuntimeError("Unknown _put_doc_if_newer state: %d" % (state,))
+
+ def get_doc(self, doc_id, include_deleted=False):
+ cdef u1db_document *doc = NULL
+ deleted = 1 if include_deleted else 0
+ handle_status("get_doc failed",
+ u1db_get_doc(self._db, doc_id, deleted, &doc))
+ if doc == NULL:
+ return None
+ pydoc = CDocument()
+ pydoc._doc = doc
+ return pydoc
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ cdef int n_doc_ids, conflicts
+ cdef const_char_ptr *c_doc_ids
+
+ _list_to_array(doc_ids, &c_doc_ids, &n_doc_ids)
+ deleted = 1 if include_deleted else 0
+ conflicts = 1 if check_for_conflicts else 0
+ a_list = []
+ handle_status("get_docs",
+ u1db_get_docs(self._db, n_doc_ids, c_doc_ids,
+ conflicts, deleted, <void*>a_list, _append_doc_to_list))
+ free(<void*>c_doc_ids)
+ return a_list
+
+ def get_all_docs(self, include_deleted=False):
+ cdef int c_generation
+
+ a_list = []
+ deleted = 1 if include_deleted else 0
+ generation = 0
+ c_generation = generation
+ handle_status(
+ "get_all_docs", u1db_get_all_docs(
+ self._db, deleted, &c_generation, <void*>a_list,
+ _append_doc_to_list))
+ return (c_generation, a_list)
+
+ def resolve_doc(self, CDocument doc, conflicted_doc_revs):
+ cdef const_char_ptr *revs
+ cdef int n_revs
+
+ _list_to_array(conflicted_doc_revs, &revs, &n_revs)
+ handle_status("resolve_doc",
+ u1db_resolve_doc(self._db, doc._doc, n_revs, revs))
+ free(<void*>revs)
+
+ def get_doc_conflicts(self, doc_id):
+ conflict_docs = []
+ handle_status("get_doc_conflicts",
+ u1db_get_doc_conflicts(self._db, doc_id, <void*>conflict_docs,
+ _append_doc_to_list))
+ return conflict_docs
+
+ def delete_doc(self, CDocument doc):
+ handle_status(
+ "Failed to delete %s" % (doc,),
+ u1db_delete_doc(self._db, doc._doc))
+
+ def whats_changed(self, generation=0):
+ cdef int c_generation
+ cdef int status
+ cdef char *trans_id = NULL
+
+ a_list = []
+ c_generation = generation
+ res_trans_id = ''
+ status = u1db_whats_changed(self._db, &c_generation, &trans_id,
+ <void*>a_list, _append_trans_info_to_list)
+ try:
+ handle_status("whats_changed", status)
+ finally:
+ if trans_id != NULL:
+ res_trans_id = trans_id
+ free(trans_id)
+ return c_generation, res_trans_id, a_list
+
+ def _get_transaction_log(self):
+ a_list = []
+ handle_status("_get_transaction_log",
+ u1db__get_transaction_log(self._db, <void*>a_list,
+ _append_trans_info_to_list))
+ return [(doc_id, trans_id) for doc_id, gen, trans_id in a_list]
+
+ def _get_generation(self):
+ cdef int generation
+ handle_status("get_generation",
+ u1db__get_generation(self._db, &generation))
+ return generation
+
+ def _get_generation_info(self):
+ cdef int generation
+ cdef char *trans_id
+ handle_status("get_generation_info",
+ u1db__get_generation_info(self._db, &generation, &trans_id))
+ raw_trans_id = None
+ if trans_id != NULL:
+ raw_trans_id = trans_id
+ free(trans_id)
+ return generation, raw_trans_id
+
+ def validate_gen_and_trans_id(self, generation, trans_id):
+ handle_status(
+ "validate_gen_and_trans_id",
+ u1db_validate_gen_and_trans_id(self._db, generation, trans_id))
+
+ def _get_trans_id_for_gen(self, generation):
+ cdef char *trans_id = NULL
+
+ handle_status(
+ "_get_trans_id_for_gen",
+ u1db__get_trans_id_for_gen(self._db, generation, &trans_id))
+ raw_trans_id = None
+ if trans_id != NULL:
+ raw_trans_id = trans_id
+ free(trans_id)
+ return raw_trans_id
+
+ def _get_replica_gen_and_trans_id(self, replica_uid):
+ cdef int generation, status
+ cdef char *trans_id = NULL
+
+ status = u1db__get_replica_gen_and_trans_id(
+ self._db, replica_uid, &generation, &trans_id)
+ handle_status("_get_replica_gen_and_trans_id", status)
+ raw_trans_id = None
+ if trans_id != NULL:
+ raw_trans_id = trans_id
+ free(trans_id)
+ return generation, raw_trans_id
+
+ def _set_replica_gen_and_trans_id(self, replica_uid, generation, trans_id):
+ handle_status("_set_replica_gen_and_trans_id",
+ u1db__set_replica_gen_and_trans_id(
+ self._db, replica_uid, generation, trans_id))
+
+ def create_index_list(self, index_name, index_expressions):
+ cdef const_char_ptr *expressions
+ cdef int n_expressions
+
+ # keep a reference to new_objs so that the pointers in expressions
+ # remain valid.
+ new_objs = _list_to_str_array(
+ index_expressions, &expressions, &n_expressions)
+ try:
+ status = u1db_create_index_list(
+ self._db, index_name, n_expressions, expressions)
+ finally:
+ free(<void*>expressions)
+ handle_status("create_index", status)
+
+ def create_index(self, index_name, *index_expressions):
+ extra = []
+ if len(index_expressions) == 0:
+ status = u1db_create_index(self._db, index_name, 0, NULL)
+ elif len(index_expressions) == 1:
+ status = u1db_create_index(
+ self._db, index_name, 1,
+ _ensure_str(index_expressions[0], extra))
+ elif len(index_expressions) == 2:
+ status = u1db_create_index(
+ self._db, index_name, 2,
+ _ensure_str(index_expressions[0], extra),
+ _ensure_str(index_expressions[1], extra))
+ elif len(index_expressions) == 3:
+ status = u1db_create_index(
+ self._db, index_name, 3,
+ _ensure_str(index_expressions[0], extra),
+ _ensure_str(index_expressions[1], extra),
+ _ensure_str(index_expressions[2], extra))
+ elif len(index_expressions) == 4:
+ status = u1db_create_index(
+ self._db, index_name, 4,
+ _ensure_str(index_expressions[0], extra),
+ _ensure_str(index_expressions[1], extra),
+ _ensure_str(index_expressions[2], extra),
+ _ensure_str(index_expressions[3], extra))
+ else:
+ status = U1DB_NOT_IMPLEMENTED
+ handle_status("create_index", status)
+
+ def sync(self, url, creds=None):
+ cdef const_char_ptr c_url
+ cdef int local_gen = 0
+ cdef u1db_oauth_creds _oauth_creds
+ cdef u1db_creds *_creds = NULL
+ c_url = url
+ if creds is not None:
+ _oauth_creds.auth_kind = U1DB_OAUTH_AUTH
+ _oauth_creds.consumer_key = creds['oauth']['consumer_key']
+ _oauth_creds.consumer_secret = creds['oauth']['consumer_secret']
+ _oauth_creds.token_key = creds['oauth']['token_key']
+ _oauth_creds.token_secret = creds['oauth']['token_secret']
+ _creds = <u1db_creds *>&_oauth_creds
+ with nogil:
+ status = u1db_sync(self._db, c_url, _creds, &local_gen)
+ handle_status("sync", status)
+ return local_gen
+
+ def list_indexes(self):
+ a_list = []
+ handle_status("list_indexes",
+ u1db_list_indexes(self._db, <void *>a_list,
+ _append_index_definition_to_list))
+ return a_list
+
+ def delete_index(self, index_name):
+ handle_status("delete_index",
+ u1db_delete_index(self._db, index_name))
+
+ def get_from_index_list(self, index_name, key_values):
+ cdef const_char_ptr *values
+ cdef int n_values
+ cdef CQuery query
+
+ query = self._query_init(index_name)
+ res = []
+ # keep a reference to new_objs so that the pointers in expressions
+ # remain valid.
+ new_objs = _list_to_str_array(key_values, &values, &n_values)
+ try:
+ handle_status(
+ "get_from_index", u1db_get_from_index_list(
+ self._db, query._query, <void*>res, _append_doc_to_list,
+ n_values, values))
+ finally:
+ free(<void*>values)
+ return res
+
+ def get_from_index(self, index_name, *key_values):
+ cdef CQuery query
+ cdef int status
+
+ extra = []
+ query = self._query_init(index_name)
+ res = []
+ status = U1DB_OK
+ if len(key_values) == 0:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 0, NULL)
+ elif len(key_values) == 1:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 1,
+ _ensure_str(key_values[0], extra))
+ elif len(key_values) == 2:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 2,
+ _ensure_str(key_values[0], extra),
+ _ensure_str(key_values[1], extra))
+ elif len(key_values) == 3:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 3,
+ _ensure_str(key_values[0], extra),
+ _ensure_str(key_values[1], extra),
+ _ensure_str(key_values[2], extra))
+ elif len(key_values) == 4:
+ status = u1db_get_from_index(self._db, query._query,
+ <void*>res, _append_doc_to_list, 4,
+ _ensure_str(key_values[0], extra),
+ _ensure_str(key_values[1], extra),
+ _ensure_str(key_values[2], extra),
+ _ensure_str(key_values[3], extra))
+ else:
+ status = U1DB_NOT_IMPLEMENTED
+ handle_status("get_from_index", status)
+ return res
+
+ def get_range_from_index(self, index_name, start_value=None,
+ end_value=None):
+ cdef CQuery query
+ cdef const_char_ptr *start_values
+ cdef int n_values
+ cdef const_char_ptr *end_values
+
+ if start_value is not None:
+ if isinstance(start_value, basestring):
+ start_value = (start_value,)
+ new_objs_1 = _list_to_str_array(
+ start_value, &start_values, &n_values)
+ else:
+ n_values = 0
+ start_values = NULL
+ if end_value is not None:
+ if isinstance(end_value, basestring):
+ end_value = (end_value,)
+ new_objs_2 = _list_to_str_array(
+ end_value, &end_values, &n_values)
+ else:
+ end_values = NULL
+ query = self._query_init(index_name)
+ res = []
+ try:
+ handle_status("get_range_from_index",
+ u1db_get_range_from_index(
+ self._db, query._query, <void*>res, _append_doc_to_list,
+ n_values, start_values, end_values))
+ finally:
+ if start_values != NULL:
+ free(<void*>start_values)
+ if end_values != NULL:
+ free(<void*>end_values)
+ return res
+
+ def get_index_keys(self, index_name):
+ cdef int status
+ keys = []
+ status = U1DB_OK
+ status = u1db_get_index_keys(
+ self._db, index_name, <void*>keys, _append_key_to_list)
+ handle_status("get_index_keys", status)
+ return keys
+
+ def _query_init(self, index_name):
+ cdef CQuery query
+ query = CQuery()
+ handle_status("query_init",
+ u1db_query_init(self._db, index_name, &query._query))
+ return query
+
+ def get_sync_target(self):
+ cdef CSyncTarget target
+ target = CSyncTarget()
+ target._db = self
+ handle_status("get_sync_target",
+ u1db__get_sync_target(target._db._db, &target._st))
+ return target
+
+
+cdef class VectorClockRev:
+
+ cdef u1db_vectorclock *_clock
+
+ def __init__(self, s):
+ if s is None:
+ self._clock = u1db__vectorclock_from_str(NULL)
+ else:
+ self._clock = u1db__vectorclock_from_str(s)
+
+ def __dealloc__(self):
+ u1db__free_vectorclock(&self._clock)
+
+ def __repr__(self):
+ cdef int status
+ cdef char *res
+ if self._clock == NULL:
+ return '%s(None)' % (self.__class__.__name__,)
+ status = u1db__vectorclock_as_str(self._clock, &res)
+ if status != U1DB_OK:
+ return '%s(<failure: %d>)' % (status,)
+ if res == NULL:
+ val = '%s(NULL)' % (self.__class__.__name__,)
+ else:
+ val = '%s(%s)' % (self.__class__.__name__, res)
+ free(res)
+ return val
+
+ def as_dict(self):
+ cdef u1db_vectorclock *cur
+ cdef int i
+ cdef int gen
+ if self._clock == NULL:
+ return None
+ res = {}
+ for i from 0 <= i < self._clock.num_items:
+ gen = self._clock.items[i].generation
+ res[self._clock.items[i].replica_uid] = gen
+ return res
+
+ def as_str(self):
+ cdef int status
+ cdef char *res
+
+ status = u1db__vectorclock_as_str(self._clock, &res)
+ if status != U1DB_OK:
+ raise RuntimeError("Failed to VectorClockRev.as_str(): %d" % (status,))
+ if res == NULL:
+ s = None
+ else:
+ s = res
+ free(res)
+ return s
+
+ def increment(self, replica_uid):
+ cdef int status
+
+ status = u1db__vectorclock_increment(self._clock, replica_uid)
+ if status != U1DB_OK:
+ raise RuntimeError("Failed to increment: %d" % (status,))
+
+ def maximize(self, vcr):
+ cdef int status
+ cdef VectorClockRev other
+
+ other = vcr
+ status = u1db__vectorclock_maximize(self._clock, other._clock)
+ if status != U1DB_OK:
+ raise RuntimeError("Failed to maximize: %d" % (status,))
+
+ def is_newer(self, vcr):
+ cdef int is_newer
+ cdef VectorClockRev other
+
+ other = vcr
+ is_newer = u1db__vectorclock_is_newer(self._clock, other._clock)
+ if is_newer == 0:
+ return False
+ elif is_newer == 1:
+ return True
+ else:
+ raise RuntimeError("Failed to is_newer: %d" % (is_newer,))
+
+
+def sync_db_to_target(db, target):
+ """Sync the data between a CDatabase and a CSyncTarget"""
+ cdef CDatabase cdb
+ cdef CSyncTarget ctarget
+ cdef int local_gen = 0, status
+
+ cdb = db
+ ctarget = target
+ with nogil:
+ status = u1db__sync_db_to_target(cdb._db, ctarget._st, &local_gen)
+ handle_status("sync_db_to_target", status)
+ return local_gen
+
+
+def create_http_sync_target(url):
+ cdef CSyncTarget target
+
+ target = CSyncTarget()
+ handle_status("create_http_sync_target",
+ u1db__create_http_sync_target(url, &target._st))
+ return target
+
+
+def create_oauth_http_sync_target(url, consumer_key, consumer_secret,
+ token_key, token_secret):
+ cdef CSyncTarget target
+
+ target = CSyncTarget()
+ handle_status("create_http_sync_target",
+ u1db__create_oauth_http_sync_target(url, consumer_key, consumer_secret,
+ token_key, token_secret,
+ &target._st))
+ return target
+
+
+def _format_sync_url(target, source_replica_uid):
+ cdef CSyncTarget st
+ cdef char *sync_url = NULL
+ cdef object res
+ st = target
+ handle_status("format_sync_url",
+ u1db__format_sync_url(st._st, source_replica_uid, &sync_url))
+ if sync_url == NULL:
+ res = None
+ else:
+ res = sync_url
+ free(sync_url)
+ return res
+
+
+def _get_oauth_authorization(target, method, url):
+ cdef CSyncTarget st
+ cdef char *auth = NULL
+
+ st = target
+ handle_status("get_oauth_authorization",
+ u1db__get_oauth_authorization(st._st, method, url, &auth))
+ res = None
+ if auth != NULL:
+ res = auth
+ free(auth)
+ return res
diff --git a/src/leap/soledad/u1db/tests/commandline/__init__.py b/src/leap/soledad/u1db/tests/commandline/__init__.py
new file mode 100644
index 00000000..007cecd3
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/commandline/__init__.py
@@ -0,0 +1,47 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import errno
+import time
+
+
+def safe_close(process, timeout=0.1):
+ """Shutdown the process in the nicest fashion you can manage.
+
+ :param process: A subprocess.Popen object.
+ :param timeout: We'll try to send 'SIGTERM' but if the process is alive
+ longer that 'timeout', we'll send SIGKILL.
+ """
+ if process.poll() is not None:
+ return
+ try:
+ process.terminate()
+ except OSError, e:
+ if e.errno in (errno.ESRCH,):
+ # Process has exited
+ return
+ tend = time.time() + timeout
+ while time.time() < tend:
+ if process.poll() is not None:
+ return
+ time.sleep(0.01)
+ try:
+ process.kill()
+ except OSError, e:
+ if e.errno in (errno.ESRCH,):
+ # Process has exited
+ return
+ process.wait()
diff --git a/src/leap/soledad/u1db/tests/commandline/test_client.py b/src/leap/soledad/u1db/tests/commandline/test_client.py
new file mode 100644
index 00000000..78ca21eb
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/commandline/test_client.py
@@ -0,0 +1,916 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import cStringIO
+import os
+import sys
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import subprocess
+
+from u1db import (
+ errors,
+ open as u1db_open,
+ tests,
+ vectorclock,
+ )
+from u1db.commandline import (
+ client,
+ serve,
+ )
+from u1db.tests.commandline import safe_close
+from u1db.tests import test_remote_sync_target
+
+
+class TestArgs(tests.TestCase):
+ """These tests are meant to test just the argument parsing.
+
+ Each Command should have at least one test, possibly more if it allows
+ optional arguments, etc.
+ """
+
+ def setUp(self):
+ super(TestArgs, self).setUp()
+ self.parser = client.client_commands.make_argparser()
+
+ def parse_args(self, args):
+ # ArgumentParser.parse_args doesn't play very nicely with a test suite,
+ # so we trap SystemExit in case something is wrong with the args we're
+ # parsing.
+ try:
+ return self.parser.parse_args(args)
+ except SystemExit:
+ raise AssertionError('got SystemExit')
+
+ def test_create(self):
+ args = self.parse_args(['create', 'test.db'])
+ self.assertEqual(client.CmdCreate, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual(None, args.doc_id)
+ self.assertEqual(None, args.infile)
+
+ def test_create_custom_doc_id(self):
+ args = self.parse_args(['create', '--id', 'xyz', 'test.db'])
+ self.assertEqual(client.CmdCreate, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('xyz', args.doc_id)
+ self.assertEqual(None, args.infile)
+
+ def test_delete(self):
+ args = self.parse_args(['delete', 'test.db', 'doc-id', 'doc-rev'])
+ self.assertEqual(client.CmdDelete, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual('doc-rev', args.doc_rev)
+
+ def test_get(self):
+ args = self.parse_args(['get', 'test.db', 'doc-id'])
+ self.assertEqual(client.CmdGet, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual(None, args.outfile)
+
+ def test_get_dash(self):
+ args = self.parse_args(['get', 'test.db', 'doc-id', '-'])
+ self.assertEqual(client.CmdGet, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual(sys.stdout, args.outfile)
+
+ def test_init_db(self):
+ args = self.parse_args(
+ ['init-db', 'test.db', '--replica-uid=replica-uid'])
+ self.assertEqual(client.CmdInitDB, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('replica-uid', args.replica_uid)
+
+ def test_init_db_no_replica(self):
+ args = self.parse_args(['init-db', 'test.db'])
+ self.assertEqual(client.CmdInitDB, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertIs(None, args.replica_uid)
+
+ def test_put(self):
+ args = self.parse_args(['put', 'test.db', 'doc-id', 'old-doc-rev'])
+ self.assertEqual(client.CmdPut, args.subcommand)
+ self.assertEqual('test.db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual('old-doc-rev', args.doc_rev)
+ self.assertEqual(None, args.infile)
+
+ def test_sync(self):
+ args = self.parse_args(['sync', 'source', 'target'])
+ self.assertEqual(client.CmdSync, args.subcommand)
+ self.assertEqual('source', args.source)
+ self.assertEqual('target', args.target)
+
+ def test_create_index(self):
+ args = self.parse_args(['create-index', 'db', 'index', 'expression'])
+ self.assertEqual(client.CmdCreateIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+ self.assertEqual(['expression'], args.expression)
+
+ def test_create_index_multi_expression(self):
+ args = self.parse_args(['create-index', 'db', 'index', 'e1', 'e2'])
+ self.assertEqual(client.CmdCreateIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+ self.assertEqual(['e1', 'e2'], args.expression)
+
+ def test_list_indexes(self):
+ args = self.parse_args(['list-indexes', 'db'])
+ self.assertEqual(client.CmdListIndexes, args.subcommand)
+ self.assertEqual('db', args.database)
+
+ def test_delete_index(self):
+ args = self.parse_args(['delete-index', 'db', 'index'])
+ self.assertEqual(client.CmdDeleteIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+
+ def test_get_index_keys(self):
+ args = self.parse_args(['get-index-keys', 'db', 'index'])
+ self.assertEqual(client.CmdGetIndexKeys, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+
+ def test_get_from_index(self):
+ args = self.parse_args(['get-from-index', 'db', 'index', 'foo'])
+ self.assertEqual(client.CmdGetFromIndex, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('index', args.index)
+ self.assertEqual(['foo'], args.values)
+
+ def test_get_doc_conflicts(self):
+ args = self.parse_args(['get-doc-conflicts', 'db', 'doc-id'])
+ self.assertEqual(client.CmdGetDocConflicts, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+
+ def test_resolve(self):
+ args = self.parse_args(
+ ['resolve-doc', 'db', 'doc-id', 'rev:1', 'other:1'])
+ self.assertEqual(client.CmdResolve, args.subcommand)
+ self.assertEqual('db', args.database)
+ self.assertEqual('doc-id', args.doc_id)
+ self.assertEqual(['rev:1', 'other:1'], args.doc_revs)
+ self.assertEqual(None, args.infile)
+
+
+class TestCaseWithDB(tests.TestCase):
+ """These next tests are meant to have one class per Command.
+
+ It is meant to test the inner workings of each command. The detailed
+ testing should happen in these classes. Stuff like how it handles errors,
+ etc. should be done here.
+ """
+
+ def setUp(self):
+ super(TestCaseWithDB, self).setUp()
+ self.working_dir = self.createTempDir()
+ self.db_path = self.working_dir + '/test.db'
+ self.db = u1db_open(self.db_path, create=True)
+ self.db._set_replica_uid('test')
+ self.addCleanup(self.db.close)
+
+ def make_command(self, cls, stdin_content=''):
+ inf = cStringIO.StringIO(stdin_content)
+ out = cStringIO.StringIO()
+ err = cStringIO.StringIO()
+ return cls(inf, out, err)
+
+
+class TestCmdCreate(TestCaseWithDB):
+
+ def test_create(self):
+ cmd = self.make_command(client.CmdCreate)
+ inf = cStringIO.StringIO(tests.simple_doc)
+ cmd.run(self.db_path, inf, 'test-id')
+ doc = self.db.get_doc('test-id')
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('id: test-id\nrev: %s\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+
+class TestCmdDelete(TestCaseWithDB):
+
+ def test_delete(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ cmd = self.make_command(client.CmdDelete)
+ cmd.run(self.db_path, doc.doc_id, doc.rev)
+ doc2 = self.db.get_doc(doc.doc_id, include_deleted=True)
+ self.assertEqual(doc.doc_id, doc2.doc_id)
+ self.assertNotEqual(doc.rev, doc2.rev)
+ self.assertIs(None, doc2.get_json())
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (doc2.rev,), cmd.stderr.getvalue())
+
+ def test_delete_fails_if_nonexistent(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ db2_path = self.db_path + '.typo'
+ cmd = self.make_command(client.CmdDelete)
+ # TODO: We should really not be showing a traceback here. But we need
+ # to teach the commandline infrastructure how to handle
+ # exceptions.
+ # However, we *do* want to test that the db doesn't get created
+ # by accident.
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ cmd.run, db2_path, doc.doc_id, doc.rev)
+ self.assertFalse(os.path.exists(db2_path))
+
+ def test_delete_no_such_doc(self):
+ cmd = self.make_command(client.CmdDelete)
+ # TODO: We should really not be showing a traceback here. But we need
+ # to teach the commandline infrastructure how to handle
+ # exceptions.
+ self.assertRaises(errors.DocumentDoesNotExist,
+ cmd.run, self.db_path, 'no-doc-id', 'no-rev')
+
+ def test_delete_bad_rev(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ cmd = self.make_command(client.CmdDelete)
+ self.assertRaises(errors.RevisionConflict,
+ cmd.run, self.db_path, doc.doc_id, 'not-the-actual-doc-rev:1')
+ # TODO: Test that we get a pretty output.
+
+
+class TestCmdGet(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdGet, self).setUp()
+ self.doc = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-test-doc')
+
+ def test_get_simple(self):
+ cmd = self.make_command(client.CmdGet)
+ cmd.run(self.db_path, 'my-test-doc', None)
+ self.assertEqual(tests.simple_doc + "\n", cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (self.doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_get_conflict(self):
+ doc = self.make_document('my-test-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ cmd = self.make_command(client.CmdGet)
+ cmd.run(self.db_path, 'my-test-doc', None)
+ self.assertEqual('{}\n', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\nDocument has conflicts.\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_get_fail(self):
+ cmd = self.make_command(client.CmdGet)
+ result = cmd.run(self.db_path, 'doc-not-there', None)
+ self.assertEqual(1, result)
+ self.assertEqual("", cmd.stdout.getvalue())
+ self.assertTrue("not found" in cmd.stderr.getvalue())
+
+ def test_get_no_database(self):
+ cmd = self.make_command(client.CmdGet)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", None)
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+
+class TestCmdGetDocConflicts(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdGetDocConflicts, self).setUp()
+ self.doc1 = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-doc')
+ self.doc2 = self.make_document('my-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ self.doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+
+ def test_get_doc_conflicts_none(self):
+ self.db.create_doc_from_json(tests.simple_doc, doc_id='a-doc')
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ cmd.run(self.db_path, 'a-doc')
+ self.assertEqual([], json.loads(cmd.stdout.getvalue()))
+ self.assertEqual('', cmd.stderr.getvalue())
+
+ def test_get_doc_conflicts_simple(self):
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ cmd.run(self.db_path, 'my-doc')
+ self.assertEqual(
+ [dict(rev=self.doc2.rev, content=self.doc2.content),
+ dict(rev=self.doc1.rev, content=self.doc1.content)],
+ json.loads(cmd.stdout.getvalue()))
+ self.assertEqual('', cmd.stderr.getvalue())
+
+ def test_get_doc_conflicts_no_db(self):
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_get_doc_conflicts_no_doc(self):
+ cmd = self.make_command(client.CmdGetDocConflicts)
+ retval = cmd.run(self.db_path, "some-doc")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n')
+
+
+class TestCmdInit(TestCaseWithDB):
+
+ def test_init_new(self):
+ path = self.working_dir + '/test2.db'
+ self.assertFalse(os.path.exists(path))
+ cmd = self.make_command(client.CmdInitDB)
+ cmd.run(path, 'test-uid')
+ self.assertTrue(os.path.exists(path))
+ db = u1db_open(path, create=False)
+ self.assertEqual('test-uid', db._replica_uid)
+
+ def test_init_no_uid(self):
+ path = self.working_dir + '/test2.db'
+ cmd = self.make_command(client.CmdInitDB)
+ cmd.run(path, None)
+ self.assertTrue(os.path.exists(path))
+ db = u1db_open(path, create=False)
+ self.assertIsNot(None, db._replica_uid)
+
+
+class TestCmdPut(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdPut, self).setUp()
+ self.doc = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-test-doc')
+
+ def test_put_simple(self):
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ cmd.run(self.db_path, 'my-test-doc', self.doc.rev, inf)
+ doc = self.db.get_doc('my-test-doc')
+ self.assertNotEqual(self.doc.rev, doc.rev)
+ self.assertGetDoc(self.db, 'my-test-doc', doc.rev,
+ tests.nested_doc, False)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_put_no_db(self):
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST",
+ 'my-test-doc', self.doc.rev, inf)
+ self.assertEqual(retval, 1)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Database does not exist.\n', cmd.stderr.getvalue())
+
+ def test_put_no_doc(self):
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path, 'no-such-doc', 'wut:1', inf)
+ self.assertEqual(1, retval)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Document does not exist.\n', cmd.stderr.getvalue())
+
+ def test_put_doc_old_rev(self):
+ rev = self.doc.rev
+ doc = self.make_document('my-test-doc', rev, '{}', False)
+ self.db.put_doc(doc)
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path, 'my-test-doc', rev, inf)
+ self.assertEqual(1, retval)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Given revision is not current.\n',
+ cmd.stderr.getvalue())
+
+ def test_put_doc_w_conflicts(self):
+ doc = self.make_document('my-test-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ cmd = self.make_command(client.CmdPut)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ retval = cmd.run(self.db_path, 'my-test-doc', 'other:1', inf)
+ self.assertEqual(1, retval)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('Document has conflicts.\n'
+ 'Inspect with get-doc-conflicts, then resolve.\n',
+ cmd.stderr.getvalue())
+
+
+class TestCmdResolve(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdResolve, self).setUp()
+ self.doc1 = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='my-doc')
+ self.doc2 = self.make_document('my-doc', 'other:1', '{}', False)
+ self.db._put_doc_if_newer(
+ self.doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+
+ def test_resolve_simple(self):
+ self.assertTrue(self.db.get_doc('my-doc').has_conflicts)
+ cmd = self.make_command(client.CmdResolve)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf)
+ doc = self.db.get_doc('my-doc')
+ vec = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(
+ vec.is_newer(vectorclock.VectorClockRev(self.doc1.rev)))
+ self.assertTrue(
+ vec.is_newer(vectorclock.VectorClockRev(self.doc2.rev)))
+ self.assertGetDoc(self.db, 'my-doc', doc.rev, tests.nested_doc, False)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual('rev: %s\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_resolve_double(self):
+ moar = '{"x": 42}'
+ doc3 = self.make_document('my-doc', 'third:1', moar, False)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ cmd = self.make_command(client.CmdResolve)
+ inf = cStringIO.StringIO(tests.nested_doc)
+ cmd.run(self.db_path, 'my-doc', [self.doc1.rev, self.doc2.rev], inf)
+ doc = self.db.get_doc('my-doc')
+ self.assertGetDoc(self.db, 'my-doc', doc.rev, moar, True)
+ self.assertEqual('', cmd.stdout.getvalue())
+ self.assertEqual(
+ 'rev: %s\nDocument still has conflicts.\n' % (doc.rev,),
+ cmd.stderr.getvalue())
+
+ def test_resolve_no_db(self):
+ cmd = self.make_command(client.CmdResolve)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "my-doc", [], None)
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_resolve_no_doc(self):
+ cmd = self.make_command(client.CmdResolve)
+ retval = cmd.run(self.db_path, "foo", [], None)
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Document does not exist.\n')
+
+
+class TestCmdSync(TestCaseWithDB):
+
+ def setUp(self):
+ super(TestCmdSync, self).setUp()
+ self.db2_path = self.working_dir + '/test2.db'
+ self.db2 = u1db_open(self.db2_path, create=True)
+ self.addCleanup(self.db2.close)
+ self.db2._set_replica_uid('test2')
+ self.doc = self.db.create_doc_from_json(
+ tests.simple_doc, doc_id='test-id')
+ self.doc2 = self.db2.create_doc_from_json(
+ tests.nested_doc, doc_id='my-test-id')
+
+ def test_sync(self):
+ cmd = self.make_command(client.CmdSync)
+ cmd.run(self.db_path, self.db2_path)
+ self.assertGetDoc(self.db2, 'test-id', self.doc.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, 'my-test-id', self.doc2.rev,
+ tests.nested_doc, False)
+
+
+class TestCmdSyncRemote(tests.TestCaseWithServer, TestCaseWithDB):
+
+ make_app_with_state = \
+ staticmethod(test_remote_sync_target.make_http_app)
+
+ def setUp(self):
+ super(TestCmdSyncRemote, self).setUp()
+ self.startServer()
+ self.db2 = self.request_state._create_database('test2.db')
+
+ def test_sync_remote(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db2.create_doc_from_json(tests.nested_doc)
+ db2_url = self.getURL('test2.db')
+ self.assertTrue(db2_url.startswith('http://'))
+ self.assertTrue(db2_url.endswith('/test2.db'))
+ cmd = self.make_command(client.CmdSync)
+ cmd.run(self.db_path, db2_url)
+ self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+
+
+class TestCmdCreateIndex(TestCaseWithDB):
+
+ def test_create_index(self):
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["bar", "baz"])
+ self.assertEqual(self.db.list_indexes(), [('foo', ['bar', "baz"])])
+ self.assertEqual(retval, None) # conveniently mapped to 0
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_create_index_no_db(self):
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", ["bar"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_create_dupe_index(self):
+ self.db.create_index("foo", "bar")
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["bar"])
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_create_dupe_index_different_expression(self):
+ self.db.create_index("foo", "bar")
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["baz"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(),
+ "There is already a different index named 'foo'.\n")
+
+ def test_create_index_bad_expression(self):
+ cmd = self.make_command(client.CmdCreateIndex)
+ retval = cmd.run(self.db_path, "foo", ["WAT()"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(),
+ 'Bad index expression.\n')
+
+
+class TestCmdListIndexes(TestCaseWithDB):
+
+ def test_list_no_indexes(self):
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path)
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_list_indexes(self):
+ self.db.create_index("foo", "bar", "baz")
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path)
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), 'foo: bar, baz\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_list_several_indexes(self):
+ self.db.create_index("foo", "bar", "baz")
+ self.db.create_index("bar", "baz", "foo")
+ self.db.create_index("baz", "foo", "bar")
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path)
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(),
+ 'bar: baz, foo\n'
+ 'baz: foo, bar\n'
+ 'foo: bar, baz\n'
+ )
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_list_indexes_no_db(self):
+ cmd = self.make_command(client.CmdListIndexes)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+
+class TestCmdDeleteIndex(TestCaseWithDB):
+
+ def test_delete_index(self):
+ self.db.create_index("foo", "bar", "baz")
+ cmd = self.make_command(client.CmdDeleteIndex)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+ self.assertEqual([], self.db.list_indexes())
+
+ def test_delete_index_no_db(self):
+ cmd = self.make_command(client.CmdDeleteIndex)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_delete_index_no_index(self):
+ cmd = self.make_command(client.CmdDeleteIndex)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+
+class TestCmdGetIndexKeys(TestCaseWithDB):
+
+ def test_get_index_keys(self):
+ self.db.create_index("foo", "bar")
+ self.db.create_doc_from_json('{"bar": 42}')
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '42\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_index_keys_nonascii(self):
+ self.db.create_index("foo", "bar")
+ self.db.create_doc_from_json('{"bar": "\u00a4"}')
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '\xc2\xa4\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_index_keys_empty(self):
+ self.db.create_index("foo", "bar")
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_index_keys_no_db(self):
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_get_index_keys_no_index(self):
+ cmd = self.make_command(client.CmdGetIndexKeys)
+ retval = cmd.run(self.db_path, "foo")
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n')
+
+
+class TestCmdGetFromIndex(TestCaseWithDB):
+
+ def test_get_from_index(self):
+ self.db.create_index("index", "key")
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db.create_doc_from_json(tests.nested_doc)
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path, "index", ["value"])
+ self.assertEqual(retval, None)
+ self.assertEqual(sorted(json.loads(cmd.stdout.getvalue())),
+ sorted([dict(id=doc1.doc_id,
+ rev=doc1.rev,
+ content=doc1.content),
+ dict(id=doc2.doc_id,
+ rev=doc2.rev,
+ content=doc2.content),
+ ]))
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_from_index_empty(self):
+ self.db.create_index("index", "key")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path, "index", ["value"])
+ self.assertEqual(retval, None)
+ self.assertEqual(cmd.stdout.getvalue(), '[]\n')
+ self.assertEqual(cmd.stderr.getvalue(), '')
+
+ def test_get_from_index_no_db(self):
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path + "__DOES_NOT_EXIST", "foo", [])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Database does not exist.\n')
+
+ def test_get_from_index_no_index(self):
+ cmd = self.make_command(client.CmdGetFromIndex)
+ retval = cmd.run(self.db_path, "foo", [])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual(cmd.stderr.getvalue(), 'Index does not exist.\n')
+
+ def test_get_from_index_two_expr_instead_of_one(self):
+ self.db.create_index("index", "key1")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1", "value2"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query: index 'index' requires"
+ " 1 query expression, not 2.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+ def test_get_from_index_three_expr_instead_of_two(self):
+ self.db.create_index("index", "key1", "key2")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1", "value2", "value3"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query: index 'index' requires"
+ " 2 query expressions, not 3.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1' 'value2'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+ def test_get_from_index_one_expr_instead_of_two(self):
+ self.db.create_index("index", "key1", "key2")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query: index 'index' requires"
+ " 2 query expressions, not 1.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1' '*'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+ def test_get_from_index_cant_bad_glob(self):
+ self.db.create_index("index", "key1", "key2")
+ cmd = self.make_command(client.CmdGetFromIndex)
+ cmd.argv = ["XX", "YY"]
+ retval = cmd.run(self.db_path, "index", ["value1*", "value2"])
+ self.assertEqual(retval, 1)
+ self.assertEqual(cmd.stdout.getvalue(), '')
+ self.assertEqual("Invalid query:"
+ " a star can only be followed by stars.\n"
+ "For example, the following would be valid:\n"
+ " XX YY %r 'index' 'value1*' '*'\n"
+ % self.db_path, cmd.stderr.getvalue())
+
+
+class RunMainHelper(object):
+
+ def run_main(self, args, stdin=None):
+ if stdin is not None:
+ self.patch(sys, 'stdin', cStringIO.StringIO(stdin))
+ stdout = cStringIO.StringIO()
+ stderr = cStringIO.StringIO()
+ self.patch(sys, 'stdout', stdout)
+ self.patch(sys, 'stderr', stderr)
+ try:
+ ret = client.main(args)
+ except SystemExit, e:
+ self.fail("Intercepted SystemExit: %s" % (e,))
+ if ret is None:
+ ret = 0
+ return ret, stdout.getvalue(), stderr.getvalue()
+
+
+class TestCommandLine(TestCaseWithDB, RunMainHelper):
+ """These are meant to test that the infrastructure is fully connected.
+
+ Each command is likely to only have one test here. Something that ensures
+ 'main()' knows about and can run the command correctly. Most logic-level
+ testing of the Command should go into its own test class above.
+ """
+
+ def _get_u1db_client_path(self):
+ from u1db import __path__ as u1db_path
+ u1db_parent_dir = os.path.dirname(u1db_path[0])
+ return os.path.join(u1db_parent_dir, 'u1db-client')
+
+ def runU1DBClient(self, args):
+ command = [sys.executable, self._get_u1db_client_path()]
+ command.extend(args)
+ p = subprocess.Popen(command, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ self.addCleanup(safe_close, p)
+ return p
+
+ def test_create_subprocess(self):
+ p = self.runU1DBClient(['create', '--id', 'test-id', self.db_path])
+ stdout, stderr = p.communicate(tests.simple_doc)
+ self.assertEqual(0, p.returncode)
+ self.assertEqual('', stdout)
+ doc = self.db.get_doc('test-id')
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+ expected = 'id: test-id\nrev: %s\n' % (doc.rev,)
+ stripped = stderr.replace('\r\n', '\n')
+ if expected != stripped:
+ # When run under python-dbg, it prints out the refs after the
+ # actual content, so match it if we need to.
+ expected_re = expected + '\[\d+ refs\]\n'
+ self.assertRegexpMatches(stripped, expected_re)
+
+ def test_get(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ ret, stdout, stderr = self.run_main(['get', self.db_path, 'test-id'])
+ self.assertEqual(0, ret)
+ self.assertEqual(tests.simple_doc + "\n", stdout)
+ self.assertEqual('rev: %s\n' % (doc.rev,), stderr)
+ ret, stdout, stderr = self.run_main(['get', self.db_path, 'not-there'])
+ self.assertEqual(1, ret)
+
+ def test_delete(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ ret, stdout, stderr = self.run_main(
+ ['delete', self.db_path, 'test-id', doc.rev])
+ doc = self.db.get_doc('test-id', include_deleted=True)
+ self.assertEqual(0, ret)
+ self.assertEqual('', stdout)
+ self.assertEqual('rev: %s\n' % (doc.rev,), stderr)
+
+ def test_init_db(self):
+ path = self.working_dir + '/test2.db'
+ ret, stdout, stderr = self.run_main(['init-db', path])
+ u1db_open(path, create=False)
+
+ def test_put(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ ret, stdout, stderr = self.run_main(
+ ['put', self.db_path, 'test-id', doc.rev],
+ stdin=tests.nested_doc)
+ doc = self.db.get_doc('test-id')
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual(tests.nested_doc, doc.get_json())
+ self.assertEqual(0, ret)
+ self.assertEqual('', stdout)
+ self.assertEqual('rev: %s\n' % (doc.rev,), stderr)
+
+ def test_sync(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc, doc_id='test-id')
+ self.db2_path = self.working_dir + '/test2.db'
+ self.db2 = u1db_open(self.db2_path, create=True)
+ self.addCleanup(self.db2.close)
+ ret, stdout, stderr = self.run_main(
+ ['sync', self.db_path, self.db2_path])
+ self.assertEqual(0, ret)
+ self.assertEqual('', stdout)
+ self.assertEqual('', stderr)
+ self.assertGetDoc(
+ self.db2, 'test-id', doc.rev, tests.simple_doc, False)
+
+
+class TestHTTPIntegration(tests.TestCaseWithServer, RunMainHelper):
+ """Meant to test the cases where commands operate over http."""
+
+ def server_def(self):
+ def make_server(host_port, _application):
+ return serve.make_server(host_port[0], host_port[1],
+ self.working_dir)
+ return make_server, "shutdown", "http"
+
+ def setUp(self):
+ super(TestHTTPIntegration, self).setUp()
+ self.working_dir = self.createTempDir(prefix='u1db-http-server-')
+ self.startServer()
+
+ def getPath(self, dbname):
+ return os.path.join(self.working_dir, dbname)
+
+ def test_init_db(self):
+ url = self.getURL('new.db')
+ ret, stdout, stderr = self.run_main(['init-db', url])
+ u1db_open(self.getPath('new.db'), create=False)
+
+ def test_create_get_put_delete(self):
+ db = u1db_open(self.getPath('test.db'), create=True)
+ url = self.getURL('test.db')
+ doc_id = '%abcd'
+ ret, stdout, stderr = self.run_main(['create', url, '--id', doc_id],
+ stdin=tests.simple_doc)
+ self.assertEqual(0, ret)
+ ret, stdout, stderr = self.run_main(['get', url, doc_id])
+ self.assertEqual(0, ret)
+ self.assertTrue(stderr.startswith('rev: '))
+ doc_rev = stderr[len('rev: '):].rstrip()
+ ret, stdout, stderr = self.run_main(['put', url, doc_id, doc_rev],
+ stdin=tests.nested_doc)
+ self.assertEqual(0, ret)
+ self.assertTrue(stderr.startswith('rev: '))
+ doc_rev1 = stderr[len('rev: '):].rstrip()
+ self.assertGetDoc(db, doc_id, doc_rev1, tests.nested_doc, False)
+ ret, stdout, stderr = self.run_main(['delete', url, doc_id, doc_rev1])
+ self.assertEqual(0, ret)
+ self.assertTrue(stderr.startswith('rev: '))
+ doc_rev2 = stderr[len('rev: '):].rstrip()
+ self.assertGetDocIncludeDeleted(db, doc_id, doc_rev2, None, False)
diff --git a/src/leap/soledad/u1db/tests/commandline/test_command.py b/src/leap/soledad/u1db/tests/commandline/test_command.py
new file mode 100644
index 00000000..43580f23
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/commandline/test_command.py
@@ -0,0 +1,105 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import cStringIO
+import argparse
+
+from u1db import (
+ tests,
+ )
+from u1db.commandline import (
+ command,
+ )
+
+
+class MyTestCommand(command.Command):
+ """Help String"""
+
+ name = 'mycmd'
+
+ @classmethod
+ def _populate_subparser(cls, parser):
+ parser.add_argument('foo')
+ parser.add_argument('--bar', dest='nbar', type=int)
+
+ def run(self, foo, nbar):
+ self.stdout.write('foo: %s nbar: %d' % (foo, nbar))
+ return 0
+
+
+def make_stdin_out_err():
+ return cStringIO.StringIO(), cStringIO.StringIO(), cStringIO.StringIO()
+
+
+class TestCommandGroup(tests.TestCase):
+
+ def trap_system_exit(self, func, *args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except SystemExit, e:
+ self.fail('Got SystemExit trying to run: %s' % (func,))
+
+ def parse_args(self, parser, args):
+ return self.trap_system_exit(parser.parse_args, args)
+
+ def test_register(self):
+ group = command.CommandGroup()
+ self.assertEqual({}, group.commands)
+ group.register(MyTestCommand)
+ self.assertEqual({'mycmd': MyTestCommand},
+ group.commands)
+
+ def test_make_argparser(self):
+ group = command.CommandGroup(description='test-foo')
+ parser = group.make_argparser()
+ self.assertIsInstance(parser, argparse.ArgumentParser)
+
+ def test_make_argparser_with_command(self):
+ group = command.CommandGroup(description='test-foo')
+ group.register(MyTestCommand)
+ parser = group.make_argparser()
+ args = self.parse_args(parser, ['mycmd', 'foozizle', '--bar=10'])
+ self.assertEqual('foozizle', args.foo)
+ self.assertEqual(10, args.nbar)
+ self.assertEqual(MyTestCommand, args.subcommand)
+
+ def test_run_argv(self):
+ group = command.CommandGroup()
+ group.register(MyTestCommand)
+ stdin, stdout, stderr = make_stdin_out_err()
+ ret = self.trap_system_exit(group.run_argv,
+ ['mycmd', 'foozizle', '--bar=10'],
+ stdin, stdout, stderr)
+ self.assertEqual(0, ret)
+
+
+class TestCommand(tests.TestCase):
+
+ def make_command(self):
+ stdin, stdout, stderr = make_stdin_out_err()
+ return command.Command(stdin, stdout, stderr)
+
+ def test__init__(self):
+ cmd = self.make_command()
+ self.assertIsNot(None, cmd.stdin)
+ self.assertIsNot(None, cmd.stdout)
+ self.assertIsNot(None, cmd.stderr)
+
+ def test_run_args(self):
+ stdin, stdout, stderr = make_stdin_out_err()
+ cmd = MyTestCommand(stdin, stdout, stderr)
+ res = cmd.run(foo='foozizle', nbar=10)
+ self.assertEqual('foo: foozizle nbar: 10', stdout.getvalue())
diff --git a/src/leap/soledad/u1db/tests/commandline/test_serve.py b/src/leap/soledad/u1db/tests/commandline/test_serve.py
new file mode 100644
index 00000000..6397eabe
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/commandline/test_serve.py
@@ -0,0 +1,101 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+import os
+import socket
+import subprocess
+import sys
+
+from u1db import (
+ __version__ as _u1db_version,
+ open as u1db_open,
+ tests,
+ )
+from u1db.remote import http_client
+from u1db.tests.commandline import safe_close
+
+
+class TestU1DBServe(tests.TestCase):
+
+ def _get_u1db_serve_path(self):
+ from u1db import __path__ as u1db_path
+ u1db_parent_dir = os.path.dirname(u1db_path[0])
+ return os.path.join(u1db_parent_dir, 'u1db-serve')
+
+ def startU1DBServe(self, args):
+ command = [sys.executable, self._get_u1db_serve_path()]
+ command.extend(args)
+ p = subprocess.Popen(command, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ self.addCleanup(safe_close, p)
+ return p
+
+ def test_help(self):
+ p = self.startU1DBServe(['--help'])
+ stdout, stderr = p.communicate()
+ if stderr != '':
+ # stderr should normally be empty, but if we are running under
+ # python-dbg, it contains the following string
+ self.assertRegexpMatches(stderr, r'\[\d+ refs\]')
+ self.assertEqual(0, p.returncode)
+ self.assertIn('Run the U1DB server', stdout)
+
+ def test_bind_to_port(self):
+ p = self.startU1DBServe([])
+ starts = 'listening on:'
+ x = p.stdout.readline()
+ self.assertTrue(x.startswith(starts))
+ port = int(x[len(starts):].split(":")[1])
+ url = "http://127.0.0.1:%s/" % port
+ c = http_client.HTTPClientBase(url)
+ self.addCleanup(c.close)
+ res, _ = c._request_json('GET', [])
+ self.assertEqual({'version': _u1db_version}, res)
+
+ def test_supply_port(self):
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ host, port = s.getsockname()
+ s.close()
+ p = self.startU1DBServe(['--port', str(port)])
+ x = p.stdout.readline().strip()
+ self.assertEqual('listening on: 127.0.0.1:%s' % (port,), x)
+ url = "http://127.0.0.1:%s/" % port
+ c = http_client.HTTPClientBase(url)
+ self.addCleanup(c.close)
+ res, _ = c._request_json('GET', [])
+ self.assertEqual({'version': _u1db_version}, res)
+
+ def test_bind_to_host(self):
+ p = self.startU1DBServe(["--host", "localhost"])
+ starts = 'listening on: 127.0.0.1:'
+ x = p.stdout.readline()
+ self.assertTrue(x.startswith(starts))
+
+ def test_supply_working_dir(self):
+ tmp_dir = self.createTempDir('u1db-serve-test')
+ db = u1db_open(os.path.join(tmp_dir, 'landmark.db'), create=True)
+ db.close()
+ p = self.startU1DBServe(['--working-dir', tmp_dir])
+ starts = 'listening on:'
+ x = p.stdout.readline()
+ self.assertTrue(x.startswith(starts))
+ port = int(x[len(starts):].split(":")[1])
+ url = "http://127.0.0.1:%s/landmark.db" % port
+ c = http_client.HTTPClientBase(url)
+ self.addCleanup(c.close)
+ res, _ = c._request_json('GET', [])
+ self.assertEqual({}, res)
diff --git a/src/leap/soledad/u1db/tests/test_auth_middleware.py b/src/leap/soledad/u1db/tests/test_auth_middleware.py
new file mode 100644
index 00000000..e765f8a7
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_auth_middleware.py
@@ -0,0 +1,309 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test OAuth wsgi middleware"""
+import paste.fixture
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import time
+
+from u1db import tests
+
+from u1db.remote.oauth_middleware import OAuthMiddleware
+from u1db.remote.basic_auth_middleware import BasicAuthMiddleware, Unauthorized
+
+
+BASE_URL = 'https://example.net'
+
+
+class TestBasicAuthMiddleware(tests.TestCase):
+
+ def setUp(self):
+ super(TestBasicAuthMiddleware, self).setUp()
+ self.got = []
+
+ def witness_app(environ, start_response):
+ start_response("200 OK", [("content-type", "text/plain")])
+ self.got.append((
+ environ['user_id'], environ['PATH_INFO'],
+ environ['QUERY_STRING']))
+ return ["ok"]
+
+ class MyAuthMiddleware(BasicAuthMiddleware):
+
+ def verify_user(self, environ, user, password):
+ if user != "correct_user":
+ raise Unauthorized
+ if password != "correct_password":
+ raise Unauthorized
+ environ['user_id'] = user
+
+ self.auth_midw = MyAuthMiddleware(witness_app, prefix="/pfx/")
+ self.app = paste.fixture.TestApp(self.auth_midw)
+
+ def test_expect_prefix(self):
+ url = BASE_URL + '/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"error": "bad request"}', resp.body)
+
+ def test_missing_auth(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "unauthorized",
+ "message": "Missing Basic Authentication."},
+ json.loads(resp.body))
+
+ def test_correct_auth(self):
+ user = "correct_user"
+ password = "correct_password"
+ params = {'old_rev': 'old-rev'}
+ url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ auth = '%s:%s' % (user, password)
+ headers = {
+ 'Authorization': 'Basic %s' % (auth.encode('base64'),)}
+ resp = self.app.delete(url, headers=headers)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ [('correct_user', '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_incorrect_auth(self):
+ user = "correct_user"
+ password = "incorrect_password"
+ params = {'old_rev': 'old-rev'}
+ url = BASE_URL + '/pfx/foo/doc/doc-id?%s' % (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ auth = '%s:%s' % (user, password)
+ headers = {
+ 'Authorization': 'Basic %s' % (auth.encode('base64'),)}
+ resp = self.app.delete(url, headers=headers, expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "unauthorized",
+ "message": "Incorrect password or login."},
+ json.loads(resp.body))
+
+
+class TestOAuthMiddlewareDefaultPrefix(tests.TestCase):
+ def setUp(self):
+
+ super(TestOAuthMiddlewareDefaultPrefix, self).setUp()
+ self.got = []
+
+ def witness_app(environ, start_response):
+ start_response("200 OK", [("content-type", "text/plain")])
+ self.got.append((environ['token_key'], environ['PATH_INFO'],
+ environ['QUERY_STRING']))
+ return ["ok"]
+
+ class MyOAuthMiddleware(OAuthMiddleware):
+ get_oauth_data_store = lambda self: tests.testingOAuthStore
+
+ def verify(self, environ, oauth_req):
+ consumer, token = super(MyOAuthMiddleware, self).verify(
+ environ, oauth_req)
+ environ['token_key'] = token.key
+
+ self.oauth_midw = MyOAuthMiddleware(witness_app, BASE_URL)
+ self.app = paste.fixture.TestApp(self.oauth_midw)
+
+ def test_expect_tilde(self):
+ url = BASE_URL + '/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"error": "bad request"}', resp.body)
+
+ def test_oauth_in_header(self):
+ url = BASE_URL + '/~/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer2,
+ tests.token2,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ url = oauth_req.get_normalized_http_url() + '?' + (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer2, tests.token2)
+ resp = self.app.delete(url, headers=oauth_req.to_header())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token2.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_in_query_string(self):
+ url = BASE_URL + '/~/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer1, tests.token1)
+ resp = self.app.delete(oauth_req.to_url())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token1.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+
+class TestOAuthMiddleware(tests.TestCase):
+
+ def setUp(self):
+ super(TestOAuthMiddleware, self).setUp()
+ self.got = []
+
+ def witness_app(environ, start_response):
+ start_response("200 OK", [("content-type", "text/plain")])
+ self.got.append((environ['token_key'], environ['PATH_INFO'],
+ environ['QUERY_STRING']))
+ return ["ok"]
+
+ class MyOAuthMiddleware(OAuthMiddleware):
+ get_oauth_data_store = lambda self: tests.testingOAuthStore
+
+ def verify(self, environ, oauth_req):
+ consumer, token = super(MyOAuthMiddleware, self).verify(
+ environ, oauth_req)
+ environ['token_key'] = token.key
+
+ self.oauth_midw = MyOAuthMiddleware(
+ witness_app, BASE_URL, prefix='/pfx/')
+ self.app = paste.fixture.TestApp(self.oauth_midw)
+
+ def test_expect_prefix(self):
+ url = BASE_URL + '/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"error": "bad request"}', resp.body)
+
+ def test_missing_oauth(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ resp = self.app.delete(url, expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "unauthorized", "message": "Missing OAuth."},
+ json.loads(resp.body))
+
+ def test_oauth_in_query_string(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer1, tests.token1)
+ resp = self.app.delete(oauth_req.to_url())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token1.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_invalid(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token3,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer1, tests.token3)
+ resp = self.app.delete(oauth_req.to_url(),
+ expect_errors=True)
+ self.assertEqual(401, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ err = json.loads(resp.body)
+ self.assertEqual({"error": "unauthorized",
+ "message": err['message']},
+ err)
+
+ def test_oauth_in_header(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer2,
+ tests.token2,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ url = oauth_req.get_normalized_http_url() + '?' + (
+ '&'.join("%s=%s" % (k, v) for k, v in params.items()))
+ oauth_req.sign_request(tests.sign_meth_HMAC_SHA1,
+ tests.consumer2, tests.token2)
+ resp = self.app.delete(url, headers=oauth_req.to_header())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token2.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_plain_text(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.sign_request(tests.sign_meth_PLAINTEXT,
+ tests.consumer1, tests.token1)
+ resp = self.app.delete(oauth_req.to_url())
+ self.assertEqual(200, resp.status)
+ self.assertEqual([(tests.token1.key,
+ '/foo/doc/doc-id', 'old_rev=old-rev')], self.got)
+
+ def test_oauth_timestamp_threshold(self):
+ url = BASE_URL + '/pfx/foo/doc/doc-id'
+ params = {'old_rev': 'old-rev'}
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ tests.consumer1,
+ tests.token1,
+ parameters=params,
+ http_url=url,
+ http_method='DELETE'
+ )
+ oauth_req.set_parameter('oauth_timestamp', int(time.time()) - 5)
+ oauth_req.sign_request(tests.sign_meth_PLAINTEXT,
+ tests.consumer1, tests.token1)
+ # tweak threshold
+ self.oauth_midw.timestamp_threshold = 1
+ resp = self.app.delete(oauth_req.to_url(), expect_errors=True)
+ self.assertEqual(401, resp.status)
+ err = json.loads(resp.body)
+ self.assertIn('Expired timestamp', err['message'])
+ self.assertIn('threshold 1', err['message'])
diff --git a/src/leap/soledad/u1db/tests/test_backends.py b/src/leap/soledad/u1db/tests/test_backends.py
new file mode 100644
index 00000000..7a3c9e5c
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_backends.py
@@ -0,0 +1,1895 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The backend class for U1DB. This deals with hiding storage details."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from u1db import (
+ DocumentBase,
+ errors,
+ tests,
+ vectorclock,
+ )
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+)
+
+from u1db.remote import (
+ http_database,
+ )
+
+try:
+ from u1db.tests import c_backend_wrapper
+except ImportError:
+ c_backend_wrapper = None # noqa
+
+
+def make_http_database_for_test(test, replica_uid, path='test'):
+ test.startServer()
+ test.request_state._create_database(replica_uid)
+ return http_database.HTTPDatabase(test.getURL(path))
+
+
+def copy_http_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ return test.request_state._copy_database(db)
+
+
+def make_oauth_http_database_for_test(test, replica_uid):
+ http_db = make_http_database_for_test(test, replica_uid, '~/test')
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+def copy_oauth_http_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ http_db = test.request_state._copy_database(db)
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+class TestAlternativeDocument(DocumentBase):
+ """A (not very) alternative implementation of Document."""
+
+
+class AllDatabaseTests(tests.DatabaseBaseTests, tests.TestCaseWithServer):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + [
+ ('http', {'make_database_for_test': make_http_database_for_test,
+ 'copy_database_for_test': copy_http_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_http_app}),
+ ('oauth_http', {'make_database_for_test':
+ make_oauth_http_database_for_test,
+ 'copy_database_for_test':
+ copy_oauth_http_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_oauth_http_app})
+ ] + tests.C_DATABASE_SCENARIOS
+
+ def test_close(self):
+ self.db.close()
+
+ def test_create_doc_allocating_doc_id(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertNotEqual(None, doc.doc_id)
+ self.assertNotEqual(None, doc.rev)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_create_doc_different_ids_same_db(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertNotEqual(doc1.doc_id, doc2.doc_id)
+
+ def test_create_doc_with_id(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id')
+ self.assertEqual('my-id', doc.doc_id)
+ self.assertNotEqual(None, doc.rev)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_create_doc_existing_id(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ new_content = '{"something": "else"}'
+ self.assertRaises(
+ errors.RevisionConflict, self.db.create_doc_from_json,
+ new_content, doc.doc_id)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_put_doc_creating_initial(self):
+ doc = self.make_document('my_doc_id', None, simple_doc)
+ new_rev = self.db.put_doc(doc)
+ self.assertIsNot(None, new_rev)
+ self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False)
+
+ def test_put_doc_space_in_id(self):
+ doc = self.make_document('my doc id', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_update(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ orig_rev = doc.rev
+ doc.set_json('{"updated": "stuff"}')
+ new_rev = self.db.put_doc(doc)
+ self.assertNotEqual(new_rev, orig_rev)
+ self.assertGetDoc(self.db, 'my_doc_id', new_rev,
+ '{"updated": "stuff"}', False)
+ self.assertEqual(doc.rev, new_rev)
+
+ def test_put_non_ascii_key(self):
+ content = json.dumps({u'key\xe5': u'val'})
+ doc = self.db.create_doc_from_json(content, doc_id='my_doc')
+ self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False)
+
+ def test_put_non_ascii_value(self):
+ content = json.dumps({'key': u'\xe5'})
+ doc = self.db.create_doc_from_json(content, doc_id='my_doc')
+ self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False)
+
+ def test_put_doc_refuses_no_id(self):
+ doc = self.make_document(None, None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+ doc = self.make_document("", None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_refuses_slashes(self):
+ doc = self.make_document('a/b', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+ doc = self.make_document(r'\b', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_url_quoting_is_fine(self):
+ doc_id = "%2F%2Ffoo%2Fbar"
+ doc = self.make_document(doc_id, None, simple_doc)
+ new_rev = self.db.put_doc(doc)
+ self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False)
+
+ def test_put_doc_refuses_non_existing_old_rev(self):
+ doc = self.make_document('doc-id', 'test:4', simple_doc)
+ self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc)
+
+ def test_put_doc_refuses_non_ascii_doc_id(self):
+ doc = self.make_document('d\xc3\xa5c-id', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_fails_with_bad_old_rev(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ old_rev = doc.rev
+ bad_doc = self.make_document(doc.doc_id, 'other:1',
+ '{"something": "else"}')
+ self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc)
+ self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False)
+
+ def test_create_succeeds_after_delete(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True)
+ deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev)
+ new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False)
+ new_vc = vectorclock.VectorClockRev(new_doc.rev)
+ self.assertTrue(
+ new_vc.is_newer(deleted_vc),
+ "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev))
+
+ def test_put_succeeds_after_delete(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True)
+ deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev)
+ doc2 = self.make_document('my_doc_id', None, simple_doc)
+ self.db.put_doc(doc2)
+ self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False)
+ new_vc = vectorclock.VectorClockRev(doc2.rev)
+ self.assertTrue(
+ new_vc.is_newer(deleted_vc),
+ "%s does not supersede %s" % (doc2.rev, deleted_doc.rev))
+
+ def test_get_doc_after_put(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False)
+
+ def test_get_doc_nonexisting(self):
+ self.assertIs(None, self.db.get_doc('non-existing'))
+
+ def test_get_doc_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ self.assertIs(None, self.db.get_doc('my_doc_id'))
+
+ def test_get_doc_include_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+
+ def test_get_docs(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual([doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+
+ def test_get_docs_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc1)
+ self.assertEqual([doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+
+ def test_get_docs_include_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc1)
+ self.assertEqual(
+ [doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id],
+ include_deleted=True)))
+
+ def test_get_docs_request_ordered(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual([doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+ self.assertEqual([doc2, doc1],
+ list(self.db.get_docs([doc2.doc_id, doc1.doc_id])))
+
+ def test_get_docs_empty_list(self):
+ self.assertEqual([], list(self.db.get_docs([])))
+
+ def test_handles_nested_content(self):
+ doc = self.db.create_doc_from_json(nested_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False)
+
+ def test_handles_doc_with_null(self):
+ doc = self.db.create_doc_from_json('{"key": null}')
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False)
+
+ def test_delete_doc(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ orig_rev = doc.rev
+ self.db.delete_doc(doc)
+ self.assertNotEqual(orig_rev, doc.rev)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+ self.assertIs(None, self.db.get_doc(doc.doc_id))
+
+ def test_delete_doc_non_existent(self):
+ doc = self.make_document('non-existing', 'other:1', simple_doc)
+ self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc)
+
+ def test_delete_doc_already_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertRaises(errors.DocumentAlreadyDeleted,
+ self.db.delete_doc, doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+
+ def test_delete_doc_bad_rev(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+ doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc)
+ self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+
+ def test_delete_doc_sets_content_to_None(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertIs(None, doc.get_json())
+
+ def test_delete_doc_rev_supersedes(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc.set_json(nested_doc)
+ self.db.put_doc(doc)
+ doc.set_json('{"fishy": "content"}')
+ self.db.put_doc(doc)
+ old_rev = doc.rev
+ self.db.delete_doc(doc)
+ cur_vc = vectorclock.VectorClockRev(old_rev)
+ deleted_vc = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(deleted_vc.is_newer(cur_vc),
+ "%s does not supersede %s" % (doc.rev, old_rev))
+
+ def test_delete_then_put(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+ doc.set_json(nested_doc)
+ self.db.put_doc(doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False)
+
+
+class DocumentSizeTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_put_doc_refuses_oversized_documents(self):
+ self.db.set_document_size_limit(1)
+ doc = self.make_document('doc-id', None, simple_doc)
+ self.assertRaises(errors.DocumentTooBig, self.db.put_doc, doc)
+
+ def test_create_doc_refuses_oversized_documents(self):
+ self.db.set_document_size_limit(1)
+ self.assertRaises(
+ errors.DocumentTooBig, self.db.create_doc_from_json, simple_doc,
+ doc_id='my_doc_id')
+
+ def test_set_document_size_limit_zero(self):
+ self.db.set_document_size_limit(0)
+ self.assertEqual(0, self.db.document_size_limit)
+
+ def test_set_document_size_limit(self):
+ self.db.set_document_size_limit(1000000)
+ self.assertEqual(1000000, self.db.document_size_limit)
+
+
+class LocalDatabaseTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_create_doc_different_ids_diff_db(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ db2 = self.create_database('other-uid')
+ doc2 = db2.create_doc_from_json(simple_doc)
+ self.assertNotEqual(doc1.doc_id, doc2.doc_id)
+
+ def test_put_doc_refuses_slashes_picky(self):
+ doc = self.make_document('/a', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_get_all_docs_empty(self):
+ self.assertEqual([], list(self.db.get_all_docs()[1]))
+
+ def test_get_all_docs(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(
+ sorted([doc1, doc2]), sorted(list(self.db.get_all_docs()[1])))
+
+ def test_get_all_docs_exclude_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc2)
+ self.assertEqual([doc1], list(self.db.get_all_docs()[1]))
+
+ def test_get_all_docs_include_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc2)
+ self.assertEqual(
+ sorted([doc1, doc2]),
+ sorted(list(self.db.get_all_docs(include_deleted=True)[1])))
+
+ def test_get_all_docs_generation(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(2, self.db.get_all_docs()[0])
+
+ def test_simple_put_doc_if_newer(self):
+ doc = self.make_document('my-doc-id', 'test:1', simple_doc)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('inserted', 1), state_at_gen)
+ self.assertGetDoc(self.db, 'my-doc-id', 'test:1', simple_doc, False)
+
+ def test_simple_put_doc_if_newer_deleted(self):
+ self.db.create_doc_from_json('{}', doc_id='my-doc-id')
+ doc = self.make_document('my-doc-id', 'test:2', None)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('inserted', 2), state_at_gen)
+ self.assertGetDocIncludeDeleted(
+ self.db, 'my-doc-id', 'test:2', None, False)
+
+ def test_put_doc_if_newer_already_superseded(self):
+ orig_doc = '{"new": "doc"}'
+ doc1 = self.db.create_doc_from_json(orig_doc)
+ doc1_rev1 = doc1.rev
+ doc1.set_json(simple_doc)
+ self.db.put_doc(doc1)
+ doc1_rev2 = doc1.rev
+ # Nothing is inserted, because the document is already superseded
+ doc = self.make_document(doc1.doc_id, doc1_rev1, orig_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('superseded', state)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1_rev2, simple_doc, False)
+
+ def test_put_doc_if_newer_autoresolve(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ rev = doc1.rev
+ doc = self.make_document(doc1.doc_id, "whatever:1", doc1.get_json())
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('superseded', state)
+ doc2 = self.db.get_doc(doc1.doc_id)
+ v2 = vectorclock.VectorClockRev(doc2.rev)
+ self.assertTrue(v2.is_newer(vectorclock.VectorClockRev("whatever:1")))
+ self.assertTrue(v2.is_newer(vectorclock.VectorClockRev(rev)))
+ # strictly newer locally
+ self.assertTrue(rev not in doc2.rev)
+
+ def test_put_doc_if_newer_already_converged(self):
+ orig_doc = '{"new": "doc"}'
+ doc1 = self.db.create_doc_from_json(orig_doc)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc1, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('converged', 1), state_at_gen)
+
+ def test_put_doc_if_newer_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ # Nothing is inserted, the document id is returned as would-conflict
+ alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ state, _ = self.db._put_doc_if_newer(
+ alt_doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('conflicted', state)
+ # The database wasn't altered
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+
+ def test_put_doc_if_newer_newer_generation(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.make_document('doc_id', 'other:2', simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='other', replica_gen=2,
+ replica_trans_id='T-irrelevant')
+ self.assertEqual('inserted', state)
+
+ def test_put_doc_if_newer_same_generation_same_txid(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.make_document(doc.doc_id, 'other:1', simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='other', replica_gen=1,
+ replica_trans_id='T-sid')
+ self.assertEqual('converged', state)
+
+ def test_put_doc_if_newer_wrong_transaction_id(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.make_document('doc_id', 'other:1', simple_doc)
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db._put_doc_if_newer, doc, save_conflict=False,
+ replica_uid='other', replica_gen=1, replica_trans_id='T-sad')
+
+ def test_put_doc_if_newer_old_generation_older_doc(self):
+ orig_doc = '{"new": "doc"}'
+ doc = self.db.create_doc_from_json(orig_doc)
+ doc_rev1 = doc.rev
+ doc.set_json(simple_doc)
+ self.db.put_doc(doc)
+ self.db._set_replica_gen_and_trans_id('other', 3, 'T-sid')
+ older_doc = self.make_document(doc.doc_id, doc_rev1, simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ older_doc, save_conflict=False, replica_uid='other', replica_gen=8,
+ replica_trans_id='T-irrelevant')
+ self.assertEqual('superseded', state)
+
+ def test_put_doc_if_newer_old_generation_newer_doc(self):
+ self.db._set_replica_gen_and_trans_id('other', 5, 'T-sid')
+ doc = self.make_document('doc_id', 'other:1', simple_doc)
+ self.assertRaises(
+ errors.InvalidGeneration,
+ self.db._put_doc_if_newer, doc, save_conflict=False,
+ replica_uid='other', replica_gen=1, replica_trans_id='T-sad')
+
+ def test_put_doc_if_newer_replica_uid(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1',
+ nested_doc)
+ self.assertEqual('inserted',
+ self.db._put_doc_if_newer(doc2, save_conflict=False,
+ replica_uid='other', replica_gen=2,
+ replica_trans_id='T-id2')[0])
+ self.assertEqual((2, 'T-id2'), self.db._get_replica_gen_and_trans_id(
+ 'other'))
+ # Compare to the old rev, should be superseded
+ doc2 = self.make_document(doc1.doc_id, doc1.rev, nested_doc)
+ self.assertEqual('superseded',
+ self.db._put_doc_if_newer(doc2, save_conflict=False,
+ replica_uid='other', replica_gen=3,
+ replica_trans_id='T-id3')[0])
+ self.assertEqual(
+ (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other'))
+ # A conflict that isn't saved still records the sync gen, because we
+ # don't need to see it again
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|fourth:1',
+ '{}')
+ self.assertEqual('conflicted',
+ self.db._put_doc_if_newer(doc2, save_conflict=False,
+ replica_uid='other', replica_gen=4,
+ replica_trans_id='T-id4')[0])
+ self.assertEqual(
+ (4, 'T-id4'), self.db._get_replica_gen_and_trans_id('other'))
+
+ def test__get_replica_gen_and_trans_id(self):
+ self.assertEqual(
+ (0, ''), self.db._get_replica_gen_and_trans_id('other-db'))
+ self.db._set_replica_gen_and_trans_id('other-db', 2, 'T-transaction')
+ self.assertEqual(
+ (2, 'T-transaction'),
+ self.db._get_replica_gen_and_trans_id('other-db'))
+
+ def test_put_updates_transaction_log(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ doc.set_json('{"something": "else"}')
+ self.db.put_doc(doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed())
+
+ def test_delete_updates_transaction_log(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ db_gen, _, _ = self.db.whats_changed()
+ self.db.delete_doc(doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed(db_gen))
+
+ def test_whats_changed_initial_database(self):
+ self.assertEqual((0, '', []), self.db.whats_changed())
+
+ def test_whats_changed_returns_one_id_for_multiple_changes(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc.set_json('{"new": "contents"}')
+ self.db.put_doc(doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed())
+ self.assertEqual((2, last_trans_id, []), self.db.whats_changed(2))
+
+ def test_whats_changed_returns_last_edits_ascending(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc.set_json('{"new": "contents"}')
+ self.db.delete_doc(doc1)
+ delete_trans_id = self.getLastTransId(self.db)
+ self.db.put_doc(doc)
+ put_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((4, put_trans_id,
+ [(doc1.doc_id, 3, delete_trans_id),
+ (doc.doc_id, 4, put_trans_id)]),
+ self.db.whats_changed())
+
+ def test_whats_changed_doesnt_include_old_gen(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((3, last_trans_id, [(doc2.doc_id, 3, last_trans_id)]),
+ self.db.whats_changed(2))
+
+
+class LocalDatabaseValidateGenNTransIdTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_validate_gen_and_trans_id(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, trans_id = self.db._get_generation_info()
+ self.db.validate_gen_and_trans_id(gen, trans_id)
+
+ def test_validate_gen_and_trans_id_invalid_txid(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, _ = self.db._get_generation_info()
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db.validate_gen_and_trans_id, gen, 'wrong')
+
+ def test_validate_gen_and_trans_id_invalid_gen(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, trans_id = self.db._get_generation_info()
+ self.assertRaises(
+ errors.InvalidGeneration,
+ self.db.validate_gen_and_trans_id, gen + 1, trans_id)
+
+
+class LocalDatabaseValidateSourceGenTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_validate_source_gen_and_trans_id_same(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.db._validate_source('other', 1, 'T-sid')
+
+ def test_validate_source_gen_newer(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.db._validate_source('other', 2, 'T-whatevs')
+
+ def test_validate_source_wrong_txid(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db._validate_source, 'other', 1, 'T-sad')
+
+
+class LocalDatabaseWithConflictsTests(tests.DatabaseBaseTests):
+ # test supporting/functionality around storing conflicts
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def test_get_docs_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual([doc2], list(self.db.get_docs([doc1.doc_id])))
+
+ def test_get_docs_conflicts_ignored(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ no_conflict_doc = self.make_document(doc1.doc_id, 'alternate:1',
+ nested_doc)
+ self.assertEqual([no_conflict_doc, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id],
+ check_for_conflicts=False)))
+
+ def test_get_doc_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual([alt_doc, doc],
+ self.db.get_doc_conflicts(doc.doc_id))
+
+ def test_get_all_docs_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ _, docs = self.db.get_all_docs()
+ self.assertTrue(list(docs)[0].has_conflicts)
+
+ def test_get_doc_conflicts_unconflicted(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([], self.db.get_doc_conflicts(doc.doc_id))
+
+ def test_get_doc_conflicts_no_such_id(self):
+ self.assertEqual([], self.db.get_doc_conflicts('doc-id'))
+
+ def test_resolve_doc(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc.doc_id,
+ [('alternate:1', nested_doc), (doc.rev, simple_doc)])
+ orig_rev = doc.rev
+ self.db.resolve_doc(doc, [alt_doc.rev, doc.rev])
+ self.assertNotEqual(orig_rev, doc.rev)
+ self.assertFalse(doc.has_conflicts)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertGetDocConflicts(self.db, doc.doc_id, [])
+
+ def test_resolve_doc_picks_biggest_vcr(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, simple_doc)])
+ orig_doc1_rev = doc1.rev
+ self.db.resolve_doc(doc1, [doc2.rev, doc1.rev])
+ self.assertFalse(doc1.has_conflicts)
+ self.assertNotEqual(orig_doc1_rev, doc1.rev)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ vcr_1 = vectorclock.VectorClockRev(orig_doc1_rev)
+ vcr_2 = vectorclock.VectorClockRev(doc2.rev)
+ vcr_new = vectorclock.VectorClockRev(doc1.rev)
+ self.assertTrue(vcr_new.is_newer(vcr_1))
+ self.assertTrue(vcr_new.is_newer(vcr_2))
+
+ def test_resolve_doc_partial_not_winning(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, simple_doc)])
+ content3 = '{"key": "valin3"}'
+ doc3 = self.make_document(doc1.doc_id, 'third:1', content3)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+ self.db.resolve_doc(doc1, [doc2.rev, doc1.rev])
+ self.assertTrue(doc1.has_conflicts)
+ self.assertGetDoc(self.db, doc1.doc_id, doc3.rev, content3, True)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc)])
+
+ def test_resolve_doc_partial_winning(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ content3 = '{"key": "valin3"}'
+ doc3 = self.make_document(doc1.doc_id, 'third:1', content3)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+ self.db.resolve_doc(doc1, [doc3.rev, doc1.rev])
+ self.assertTrue(doc1.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+
+ def test_resolve_doc_with_delete_conflict(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, None)])
+ self.db.resolve_doc(doc2, [doc1.rev, doc2.rev])
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, False)
+
+ def test_resolve_doc_with_delete_to_delete(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, None)])
+ self.db.resolve_doc(doc1, [doc1.rev, doc2.rev])
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ self.assertGetDocIncludeDeleted(
+ self.db, doc1.doc_id, doc1.rev, None, False)
+
+ def test_put_doc_if_newer_save_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ # Document is inserted as a conflict
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('conflicted', state)
+ # The database was updated
+ self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, nested_doc, True)
+
+ def test_force_doc_conflict_supersedes_properly(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', '{"b": 1}')
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ doc3 = self.make_document(doc1.doc_id, 'altalt:1', '{"c": 1}')
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ doc22 = self.make_document(doc1.doc_id, 'alternate:2', '{"b": 2}')
+ self.db._put_doc_if_newer(
+ doc22, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='zed')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:2', doc22.get_json()),
+ ('altalt:1', doc3.get_json()),
+ (doc1.rev, simple_doc)])
+
+ def test_put_doc_if_newer_save_conflict_was_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertTrue(doc2.has_conflicts)
+ self.assertGetDoc(
+ self.db, doc1.doc_id, 'alternate:1', nested_doc, True)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:1', nested_doc), (doc1.rev, None)])
+
+ def test_put_doc_if_newer_propagates_full_resolution(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ resolved_vcr = vectorclock.VectorClockRev(doc1.rev)
+ vcr_2 = vectorclock.VectorClockRev(doc2.rev)
+ resolved_vcr.maximize(vcr_2)
+ resolved_vcr.increment('alternate')
+ doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(),
+ '{"good": 1}')
+ state, _ = self.db._put_doc_if_newer(
+ doc_resolved, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual('inserted', state)
+ self.assertFalse(doc_resolved.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ doc3 = self.db.get_doc(doc1.doc_id)
+ self.assertFalse(doc3.has_conflicts)
+
+ def test_put_doc_if_newer_propagates_partial_resolution(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'altalt:1', '{}')
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ doc3 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:1', nested_doc), ('test:1', simple_doc),
+ ('altalt:1', '{}')])
+ resolved_vcr = vectorclock.VectorClockRev(doc1.rev)
+ vcr_3 = vectorclock.VectorClockRev(doc3.rev)
+ resolved_vcr.maximize(vcr_3)
+ resolved_vcr.increment('alternate')
+ doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(),
+ '{"good": 1}')
+ state, _ = self.db._put_doc_if_newer(
+ doc_resolved, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual('inserted', state)
+ self.assertTrue(doc_resolved.has_conflicts)
+ doc4 = self.db.get_doc(doc1.doc_id)
+ self.assertTrue(doc4.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:2|test:1', '{"good": 1}'), ('altalt:1', '{}')])
+
+ def test_put_doc_if_newer_replica_uid(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-id')
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1',
+ nested_doc)
+ self.db._put_doc_if_newer(doc2, save_conflict=True,
+ replica_uid='other', replica_gen=2,
+ replica_trans_id='T-id2')
+ # Conflict vs the current update
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|third:3',
+ '{}')
+ self.assertEqual('conflicted',
+ self.db._put_doc_if_newer(doc2, save_conflict=True,
+ replica_uid='other', replica_gen=3,
+ replica_trans_id='T-id3')[0])
+ self.assertEqual(
+ (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other'))
+
+ def test_put_doc_if_newer_autoresolve_2(self):
+ # this is an ordering variant of _3, but that already works
+ # adding the test explicitly to catch the regression easily
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', "{}")
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1',
+ '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:2|other:1', "{}")
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'inserted')
+ self.assertFalse(self.db.get_doc(doc_a1.doc_id).has_conflicts)
+
+ def test_put_doc_if_newer_autoresolve_3(self):
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', "{}")
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', "{}")
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'superseded')
+ doc = self.db.get_doc(doc_a1.doc_id, True)
+ self.assertFalse(doc.has_conflicts)
+ rev = vectorclock.VectorClockRev(doc.rev)
+ rev_a3 = vectorclock.VectorClockRev('test:3')
+ rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1')
+ self.assertTrue(rev.is_newer(rev_a3))
+ self.assertTrue('test:4' in doc.rev) # locally increased
+ self.assertTrue(rev.is_newer(rev_a1b1))
+
+ def test_put_doc_if_newer_autoresolve_4(self):
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', None)
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', None)
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'superseded')
+ doc = self.db.get_doc(doc_a1.doc_id, True)
+ self.assertFalse(doc.has_conflicts)
+ rev = vectorclock.VectorClockRev(doc.rev)
+ rev_a3 = vectorclock.VectorClockRev('test:3')
+ rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1')
+ self.assertTrue(rev.is_newer(rev_a3))
+ self.assertTrue('test:4' in doc.rev) # locally increased
+ self.assertTrue(rev.is_newer(rev_a1b1))
+
+ def test_put_refuses_to_update_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ content2 = '{"key": "altval"}'
+ doc2 = self.make_document(doc1.doc_id, 'altrev:1', content2)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, content2, True)
+ content3 = '{"key": "local"}'
+ doc2.set_json(content3)
+ self.assertRaises(errors.ConflictedDoc, self.db.put_doc, doc2)
+
+ def test_delete_refuses_for_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'altrev:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, True)
+ self.assertRaises(errors.ConflictedDoc, self.db.delete_doc, doc2)
+
+
+class DatabaseIndexTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + tests.C_DATABASE_SCENARIOS
+
+ def assertParseError(self, definition):
+ self.db.create_doc_from_json(nested_doc)
+ self.assertRaises(
+ errors.IndexDefinitionParseError, self.db.create_index, 'idx',
+ definition)
+
+ def assertIndexCreatable(self, definition):
+ name = "idx"
+ self.db.create_doc_from_json(nested_doc)
+ self.db.create_index(name, definition)
+ self.assertEqual(
+ [(name, [definition])], self.db.list_indexes())
+
+ def test_create_index(self):
+ self.db.create_index('test-idx', 'name')
+ self.assertEqual([('test-idx', ['name'])],
+ self.db.list_indexes())
+
+ def test_create_index_on_non_ascii_field_name(self):
+ doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'}))
+ self.db.create_index('test-idx', u'\xe5')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_list_indexes_with_non_ascii_field_names(self):
+ self.db.create_index('test-idx', u'\xe5')
+ self.assertEqual(
+ [('test-idx', [u'\xe5'])], self.db.list_indexes())
+
+ def test_create_index_evaluates_it(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_wildcard_matches_unicode_value(self):
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_retrieve_unicode_value_from_index(self):
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', u"valu\xe5"))
+
+ def test_create_index_fails_if_name_taken(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertRaises(errors.IndexNameTakenError,
+ self.db.create_index,
+ 'test-idx', 'stuff')
+
+ def test_create_index_does_not_fail_if_name_taken_with_same_index(self):
+ self.db.create_index('test-idx', 'key')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+
+ def test_create_index_does_not_duplicate_indexed_fields(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.db.delete_index('test-idx')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(1, len(self.db.get_from_index('test-idx', 'value')))
+
+ def test_delete_index_does_not_remove_fields_from_other_indexes(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.db.create_index('test-idx2', 'key')
+ self.db.delete_index('test-idx')
+ self.assertEqual(1, len(self.db.get_from_index('test-idx2', 'value')))
+
+ def test_create_index_after_deleting_document(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc2)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_delete_index(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+ self.db.delete_index('test-idx')
+ self.assertEqual([], self.db.list_indexes())
+
+ def test_create_adds_to_index(self):
+ self.db.create_index('test-idx', 'key')
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_unmatched(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_from_index('test-idx', 'novalue'))
+
+ def test_create_index_multiple_exact_matches(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'value')))
+
+ def test_get_from_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_multi(self):
+ content = '{"key": "value", "key2": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2'))
+
+ def test_get_from_index_multi_list(self):
+ doc = self.db.create_doc_from_json(
+ '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-1'))
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-2'))
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-3'))
+ self.assertEqual(
+ [('value', 'value2-1'), ('value', 'value2-2'),
+ ('value', 'value2-3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_from_index_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key', 'key2')
+ alt_doc = self.make_document(
+ doc.doc_id, 'alternate:1',
+ '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ docs = self.db.get_from_index('test-idx', 'value', 'value2-1')
+ self.assertTrue(docs[0].has_conflicts)
+
+ def test_get_index_keys_multi_list_list(self):
+ self.db.create_doc_from_json(
+ '{"key": "value1-1 value1-2 value1-3", '
+ '"key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db.create_index('test-idx', 'split_words(key)', 'key2')
+ self.assertEqual(
+ [(u'value1-1', u'value2-1'), (u'value1-1', u'value2-2'),
+ (u'value1-1', u'value2-3'), (u'value1-2', u'value2-1'),
+ (u'value1-2', u'value2-2'), (u'value1-2', u'value2-3'),
+ (u'value1-3', u'value2-1'), (u'value1-3', u'value2-2'),
+ (u'value1-3', u'value2-3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_from_index_multi_ordered(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2, doc1],
+ self.db.get_from_index('test-idx', 'v*', '*'))
+
+ def test_get_range_from_index_start_end(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ self.db.create_doc_from_json('{"key": "value4"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', 'value2', 'value3'))
+
+ def test_get_range_from_index_start(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value4"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc2, doc1, doc3],
+ self.db.get_range_from_index('test-idx', 'value2'))
+
+ def test_get_range_from_index_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ alt_doc = self.make_document(
+ doc.doc_id, 'alternate:1', '{"key": "valuedepalue"}')
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ docs = self.db.get_range_from_index('test-idx', 'a')
+ self.assertTrue(docs[0].has_conflicts)
+
+ def test_get_range_from_index_end(self):
+ self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ self.db.create_doc_from_json('{"key": "value4"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc4, doc2],
+ self.db.get_range_from_index('test-idx', None, 'value2'))
+
+ def test_get_wildcard_range_from_index_start(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value4"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value23"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value22"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc3, doc4, doc2, doc1],
+ self.db.get_range_from_index('test-idx', 'value2*'))
+
+ def test_get_wildcard_range_from_index_end(self):
+ self.db.create_doc_from_json('{"key": "value4"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value23"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value22"}')
+ doc5 = self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc5, doc3, doc4, doc2],
+ self.db.get_range_from_index('test-idx', None, 'value2*'))
+
+ def test_get_wildcard_range_from_index_start_end(self):
+ self.db.create_doc_from_json('{"key": "a"}')
+ self.db.create_doc_from_json('{"key": "boo3"}')
+ doc3 = self.db.create_doc_from_json('{"key": "catalyst"}')
+ doc4 = self.db.create_doc_from_json('{"key": "whaever"}')
+ self.db.create_doc_from_json('{"key": "zerg"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc3, doc4],
+ self.db.get_range_from_index('test-idx', 'cat*', 'zap*'))
+
+ def test_get_range_from_index_multi_column_start_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', ('value2', 'value2'), ('value2', 'value3')))
+
+ def test_get_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ self.db.create_doc_from_json('{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', 'value3')))
+
+ def test_get_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', None, ('value2', 'value3')))
+
+ def test_get_wildcard_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc3, doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', 'value2*')))
+
+ def test_get_wildcard_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', None, ('value2', 'value2*')))
+
+ def test_get_glob_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', '*')))
+
+ def test_get_glob_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index('test-idx', None, ('value2', '*')))
+
+ def test_get_range_from_index_illegal_wildcard_order(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', ('*', 'v2'))
+
+ def test_get_range_from_index_illegal_glob_after_wildcard(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', ('*', 'v*'))
+
+ def test_get_range_from_index_illegal_wildcard_order_end(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', None, ('*', 'v2'))
+
+ def test_get_range_from_index_illegal_glob_after_wildcard_end(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', None, ('*', 'v*'))
+
+ def test_get_from_index_fails_if_no_index(self):
+ self.assertRaises(
+ errors.IndexDoesNotExist, self.db.get_from_index, 'foo')
+
+ def test_get_index_keys_fails_if_no_index(self):
+ self.assertRaises(errors.IndexDoesNotExist,
+ self.db.get_index_keys,
+ 'foo')
+
+ def test_get_index_keys_works_if_no_docs(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_index_keys('test-idx'))
+
+ def test_put_updates_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ new_content = '{"key": "altval"}'
+ doc.set_json(new_content)
+ self.db.put_doc(doc)
+ self.assertEqual([], self.db.get_from_index('test-idx', 'value'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'altval'))
+
+ def test_delete_updates_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'value')))
+ self.db.delete_doc(doc)
+ self.assertEqual([doc2], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_illegal_number_of_entries(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidValueForIndex, self.db.get_from_index, 'test-idx')
+ self.assertRaises(
+ errors.InvalidValueForIndex,
+ self.db.get_from_index, 'test-idx', 'v1')
+ self.assertRaises(
+ errors.InvalidValueForIndex,
+ self.db.get_from_index, 'test-idx', 'v1', 'v2', 'v3')
+
+ def test_get_from_index_illegal_wildcard_order(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', '*', 'v2')
+
+ def test_get_from_index_illegal_glob_after_wildcard(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', '*', 'v*')
+
+ def test_get_all_from_index(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ # This one should not be in the index
+ self.db.create_doc_from_json('{"no": "key"}')
+ diff_value_doc = '{"key": "diff value"}'
+ doc4 = self.db.create_doc_from_json(diff_value_doc)
+ # This is essentially a 'prefix' match, but we match every entry.
+ self.assertEqual(
+ sorted([doc1, doc2, doc4]),
+ sorted(self.db.get_from_index('test-idx', '*')))
+
+ def test_get_all_from_index_ordered(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json('{"key": "value x"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value b"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value a"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value m"}')
+ # This is essentially a 'prefix' match, but we match every entry.
+ self.assertEqual(
+ [doc3, doc2, doc4, doc1], self.db.get_from_index('test-idx', '*'))
+
+ def test_put_updates_when_adding_key(self):
+ doc = self.db.create_doc_from_json("{}")
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+ doc.set_json(simple_doc)
+ self.db.put_doc(doc)
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_get_from_index_empty_string(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ content2 = '{"key": ""}'
+ doc2 = self.db.create_doc_from_json(content2)
+ self.assertEqual([doc2], self.db.get_from_index('test-idx', ''))
+ # Empty string matches the wildcard.
+ self.assertEqual(
+ sorted([doc1, doc2]),
+ sorted(self.db.get_from_index('test-idx', '*')))
+
+ def test_get_from_index_not_null(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json('{"key": null}')
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', '*'))
+
+ def test_get_partial_from_index(self):
+ content1 = '{"k1": "v1", "k2": "v2"}'
+ content2 = '{"k1": "v1", "k2": "x2"}'
+ content3 = '{"k1": "v1", "k2": "y2"}'
+ # doc4 has a different k1 value, so it doesn't match the prefix.
+ content4 = '{"k1": "NN", "k2": "v2"}'
+ doc1 = self.db.create_doc_from_json(content1)
+ doc2 = self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertEqual(
+ sorted([doc1, doc2, doc3]),
+ sorted(self.db.get_from_index('test-idx', "v1", "*")))
+
+ def test_get_glob_match(self):
+ # Note: the exact glob syntax is probably subject to change
+ content1 = '{"k1": "v1", "k2": "v1"}'
+ content2 = '{"k1": "v1", "k2": "v2"}'
+ content3 = '{"k1": "v1", "k2": "v3"}'
+ # doc4 has a different k2 prefix value, so it doesn't match
+ content4 = '{"k1": "v1", "k2": "ZZ"}'
+ self.db.create_index('test-idx', 'k1', 'k2')
+ doc1 = self.db.create_doc_from_json(content1)
+ doc2 = self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.assertEqual(
+ sorted([doc1, doc2, doc3]),
+ sorted(self.db.get_from_index('test-idx', "v1", "v*")))
+
+ def test_nested_index(self):
+ doc = self.db.create_doc_from_json(nested_doc)
+ self.db.create_index('test-idx', 'sub.doc')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'underneath'))
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'underneath')))
+
+ def test_nested_nonexistent(self):
+ self.db.create_doc_from_json(nested_doc)
+ # sub exists, but sub.foo does not:
+ self.db.create_index('test-idx', 'sub.foo')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+
+ def test_nested_nonexistent2(self):
+ self.db.create_doc_from_json(nested_doc)
+ self.db.create_index('test-idx', 'sub.foo.bar.baz.qux.fnord')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+
+ def test_nested_traverses_lists(self):
+ # subpath finds dicts in list
+ doc = self.db.create_doc_from_json(
+ '{"foo": [{"zap": "bar"}, {"zap": "baz"}]}')
+ # subpath only finds dicts in list
+ self.db.create_doc_from_json('{"foo": ["zap", "baz"]}')
+ self.db.create_index('test-idx', 'foo.zap')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'bar'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'baz'))
+
+ def test_nested_list_traversal(self):
+ # subpath finds dicts in list
+ doc = self.db.create_doc_from_json(
+ '{"foo": [{"zap": [{"qux": "fnord"}, {"qux": "zombo"}]},'
+ '{"zap": "baz"}]}')
+ # subpath only finds dicts in list
+ self.db.create_index('test-idx', 'foo.zap.qux')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'fnord'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'zombo'))
+
+ def test_index_list1(self):
+ self.db.create_index("index", "name")
+ content = '{"name": ["foo", "bar"]}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_list2(self):
+ self.db.create_index("index", "name")
+ content = '{"name": ["foo", "bar"]}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_case_sensitive(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([], self.db.get_from_index('test-idx', 'V*'))
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', 'v*'))
+
+ def test_get_from_index_illegal_glob_before_value(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', 'v*', 'v2')
+
+ def test_get_from_index_illegal_glob_after_glob(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', 'v*', 'v*')
+
+ def test_get_from_index_with_sql_wildcards(self):
+ self.db.create_index('test-idx', 'key')
+ content1 = '{"key": "va%lue"}'
+ content2 = '{"key": "value"}'
+ content3 = '{"key": "va_lue"}'
+ doc1 = self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ # The '%' in the search should be treated literally, not as a sql
+ # globbing character.
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', 'va%*'))
+ # Same for '_'
+ self.assertEqual([doc3], self.db.get_from_index('test-idx', 'va_*'))
+
+ def test_get_from_index_with_lower(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "Foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_lower_matches_same_case(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_lower_doesnt_match_different_case(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "Foo"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "Foo")
+ self.assertEqual([], rows)
+
+ def test_index_lower_doesnt_match_other_index(self):
+ self.db.create_index("index", "lower(name)")
+ self.db.create_index("other_index", "name")
+ content = '{"name": "Foo"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "Foo")
+ self.assertEqual(0, len(rows))
+
+ def test_index_split_words_match_first(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_match_second(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_match_both(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_double_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_leading_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": " foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_trailing_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar "}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_number(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 12}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "00012")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_number_bigger_than_padding(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 123456}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "123456")
+ self.assertEqual([doc], rows)
+
+ def test_number_mapping_ignores_non_numbers(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 56}'
+ doc1 = self.db.create_doc_from_json(content)
+ content = '{"foo": "this is not a maigret painting"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "*")
+ self.assertEqual([doc1], rows)
+
+ def test_get_from_index_with_bool(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": true}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "1")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_bool_false(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": false}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "0")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_non_bool(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": 42}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "*")
+ self.assertEqual([], rows)
+
+ def test_get_from_index_with_combine(self):
+ self.db.create_index("index", "combine(foo, bar)")
+ content = '{"foo": "value1", "bar": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "value1")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "value2")
+ self.assertEqual([doc], rows)
+
+ def test_get_complex_combine(self):
+ self.db.create_index(
+ "index", "combine(number(foo, 5), lower(bar), split_words(baz))")
+ content = '{"foo": 12, "bar": "ALLCAPS", "baz": "qux nox"}'
+ doc = self.db.create_doc_from_json(content)
+ content = '{"foo": "not a number", "bar": "something"}'
+ doc2 = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "00012")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "allcaps")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "nox")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "something")
+ self.assertEqual([doc2], rows)
+
+ def test_get_index_keys_from_index(self):
+ self.db.create_index('test-idx', 'key')
+ content1 = '{"key": "value1"}'
+ content2 = '{"key": "value2"}'
+ content3 = '{"key": "value2"}'
+ self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ self.db.create_doc_from_json(content3)
+ self.assertEqual(
+ [('value1',), ('value2',)],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_index_keys_from_multicolumn_index(self):
+ self.db.create_index('test-idx', 'key1', 'key2')
+ content1 = '{"key1": "value1", "key2": "val2-1"}'
+ content2 = '{"key1": "value2", "key2": "val2-2"}'
+ content3 = '{"key1": "value2", "key2": "val2-2"}'
+ content4 = '{"key1": "value2", "key2": "val3"}'
+ self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.assertEqual([
+ ('value1', 'val2-1'),
+ ('value2', 'val2-2'),
+ ('value2', 'val3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_empty_expr(self):
+ self.assertParseError('')
+
+ def test_nested_unknown_operation(self):
+ self.assertParseError('unknown_operation(field1)')
+
+ def test_parse_missing_close_paren(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_trailing_close_paren(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab)adsf")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_top_level_commas(self):
+ self.assertParseError("a, b")
+
+ def test_invalid_field_name(self):
+ self.assertParseError("a.")
+
+ def test_invalid_inner_field_name(self):
+ self.assertParseError("lower(a.)")
+
+ def test_gobbledigook(self):
+ self.assertParseError("(@#@cc @#!*DFJSXV(()jccd")
+
+ def test_leading_space(self):
+ self.assertIndexCreatable(" lower(a)")
+
+ def test_trailing_space(self):
+ self.assertIndexCreatable("lower(a) ")
+
+ def test_spaces_before_open_paren(self):
+ self.assertIndexCreatable("lower (a)")
+
+ def test_spaces_after_open_paren(self):
+ self.assertIndexCreatable("lower( a)")
+
+ def test_spaces_before_close_paren(self):
+ self.assertIndexCreatable("lower(a )")
+
+ def test_spaces_before_comma(self):
+ self.assertIndexCreatable("combine(a , b , c)")
+
+ def test_spaces_after_comma(self):
+ self.assertIndexCreatable("combine(a, b, c)")
+
+ def test_all_together_now(self):
+ self.assertParseError(' (a) ')
+
+ def test_all_together_now2(self):
+ self.assertParseError('combine(lower(x)x,foo)')
+
+
+class PythonBackendTests(tests.DatabaseBaseTests):
+
+ def setUp(self):
+ super(PythonBackendTests, self).setUp()
+ self.simple_doc = json.loads(simple_doc)
+
+ def test_create_doc_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id')
+ self.assertTrue(isinstance(doc, TestAlternativeDocument))
+
+ def test_get_doc_after_put_with_factory(self):
+ doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id')
+ self.db.set_document_factory(TestAlternativeDocument)
+ result = self.db.get_doc('my_doc_id')
+ self.assertTrue(isinstance(result, TestAlternativeDocument))
+ self.assertEqual(doc.doc_id, result.doc_id)
+ self.assertEqual(doc.rev, result.rev)
+ self.assertEqual(doc.get_json(), result.get_json())
+ self.assertEqual(False, result.has_conflicts)
+
+ def test_get_doc_nonexisting_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.assertIs(None, self.db.get_doc('non-existing'))
+
+ def test_get_all_docs_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.db.create_doc(self.simple_doc)
+ self.assertTrue(isinstance(
+ list(self.db.get_all_docs()[1])[0], TestAlternativeDocument))
+
+ def test_get_docs_conflicted_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ doc1 = self.db.create_doc(self.simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertTrue(
+ isinstance(
+ list(self.db.get_docs([doc1.doc_id]))[0],
+ TestAlternativeDocument))
+
+ def test_get_from_index_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.db.create_doc(self.simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertTrue(
+ isinstance(
+ self.db.get_from_index('test-idx', 'value')[0],
+ TestAlternativeDocument))
+
+ def test_sync_exchange_updates_indexes(self):
+ doc = self.db.create_doc(self.simple_doc)
+ self.db.create_index('test-idx', 'key')
+ new_content = '{"key": "altval"}'
+ other_rev = 'test:1|z:2'
+ st = self.db.get_sync_target()
+
+ def ignore(doc_id, doc_rev, doc):
+ pass
+
+ doc_other = self.make_document(doc.doc_id, other_rev, new_content)
+ docs_by_gen = [(doc_other, 10, 'T-sid')]
+ st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=ignore)
+ self.assertGetDoc(self.db, doc.doc_id, other_rev, new_content, False)
+ self.assertEqual(
+ [doc_other], self.db.get_from_index('test-idx', 'altval'))
+ self.assertEqual([], self.db.get_from_index('test-idx', 'value'))
+
+
+# Use a custom loader to apply the scenarios at load time.
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/u1db/tests/test_c_backend.py b/src/leap/soledad/u1db/tests/test_c_backend.py
new file mode 100644
index 00000000..bdd2aec7
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_c_backend.py
@@ -0,0 +1,634 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from u1db import (
+ Document,
+ errors,
+ tests,
+ )
+from u1db.tests import c_backend_wrapper, c_backend_error
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app
+ )
+
+
+class TestCDatabaseExists(tests.TestCase):
+
+ def test_c_backend_compiled(self):
+ if c_backend_wrapper is None:
+ self.fail("Could not import the c_backend_wrapper module."
+ " Was it compiled properly?\n%s" % (c_backend_error,))
+
+
+# Rather than lots of failing tests, we have the above check to test that the
+# module exists, and all these tests just get skipped
+class BackendTests(tests.TestCase):
+
+ def setUp(self):
+ super(BackendTests, self).setUp()
+ if c_backend_wrapper is None:
+ self.skipTest("The c_backend_wrapper could not be imported")
+
+
+class TestCDatabase(BackendTests):
+
+ def test_exists(self):
+ if c_backend_wrapper is None:
+ self.fail("Could not import the c_backend_wrapper module."
+ " Was it compiled properly?")
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertEqual(':memory:', db._filename)
+
+ def test__is_closed(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertTrue(db._sql_is_open())
+ db.close()
+ self.assertFalse(db._sql_is_open())
+
+ def test__run_sql(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertTrue(db._sql_is_open())
+ self.assertEqual([], db._run_sql('CREATE TABLE test (id INTEGER)'))
+ self.assertEqual([], db._run_sql('INSERT INTO test VALUES (1)'))
+ self.assertEqual([('1',)], db._run_sql('SELECT * FROM test'))
+
+ def test__get_generation(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertEqual(0, db._get_generation())
+ db.create_doc_from_json(tests.simple_doc)
+ self.assertEqual(1, db._get_generation())
+
+ def test__get_generation_info(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertEqual((0, ''), db._get_generation_info())
+ db.create_doc_from_json(tests.simple_doc)
+ info = db._get_generation_info()
+ self.assertEqual(1, info[0])
+ self.assertTrue(info[1].startswith('T-'))
+
+ def test__set_replica_uid(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertIsNot(None, db._replica_uid)
+ db._set_replica_uid('foo')
+ self.assertEqual([('foo',)], db._run_sql(
+ "SELECT value FROM u1db_config WHERE name='replica_uid'"))
+
+ def test_default_replica_uid(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.assertIsNot(None, self.db._replica_uid)
+ self.assertEqual(32, len(self.db._replica_uid))
+ # casting to an int from the uid *is* the check for correct behavior.
+ int(self.db._replica_uid, 16)
+
+ def test_get_conflicts_with_borked_data(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ # We add an entry to conflicts, but not to documents, which is an
+ # invalid situation
+ self.db._run_sql("INSERT INTO conflicts"
+ " VALUES ('doc-id', 'doc-rev', '{}')")
+ self.assertRaises(Exception, self.db.get_doc_conflicts, 'doc-id')
+
+ def test_create_index_list(self):
+ # We manually poke data into the DB, so that we test just the "get_doc"
+ # code, rather than also testing the index management code.
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index_list("key-idx", ["key"])
+ docs = self.db.get_from_index('key-idx', 'value')
+ self.assertEqual([doc], docs)
+
+ def test_create_index_list_on_non_ascii_field_name(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'}))
+ self.db.create_index_list('test-idx', [u'\xe5'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_list_indexes_with_non_ascii_field_names(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index_list('test-idx', [u'\xe5'])
+ self.assertEqual(
+ [('test-idx', [u'\xe5'])], self.db.list_indexes())
+
+ def test_create_index_evaluates_it(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_wildcard_matches_unicode_value(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_create_index_fails_if_name_taken(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertRaises(errors.IndexNameTakenError,
+ self.db.create_index_list,
+ 'test-idx', ['stuff'])
+
+ def test_create_index_does_not_fail_if_name_taken_with_same_index(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index_list('test-idx', ['key'])
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+
+ def test_create_index_after_deleting_document(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.delete_doc(doc2)
+ self.db.create_index_list('test-idx', ['key'])
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index(self):
+ # We manually poke data into the DB, so that we test just the "get_doc"
+ # code, rather than also testing the index management code.
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index("key-idx", "key")
+ docs = self.db.get_from_index('key-idx', 'value')
+ self.assertEqual([doc], docs)
+
+ def test_get_from_index_list(self):
+ # We manually poke data into the DB, so that we test just the "get_doc"
+ # code, rather than also testing the index management code.
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index("key-idx", "key")
+ docs = self.db.get_from_index_list('key-idx', ['value'])
+ self.assertEqual([doc], docs)
+
+ def test_get_from_index_list_multi(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ content = '{"key": "value", "key2": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc],
+ self.db.get_from_index_list('test-idx', ['value', 'value2']))
+
+ def test_get_from_index_list_multi_ordered(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2, doc1],
+ self.db.get_from_index_list('test-idx', ['v*', '*']))
+
+ def test_get_from_index_2(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ doc = self.db.create_doc_from_json(tests.nested_doc)
+ self.db.create_index("multi-idx", "key", "sub.doc")
+ docs = self.db.get_from_index('multi-idx', 'value', 'underneath')
+ self.assertEqual([doc], docs)
+
+ def test_get_index_keys(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_doc_from_json(tests.simple_doc)
+ self.db.create_index("key-idx", "key")
+ keys = self.db.get_index_keys('key-idx')
+ self.assertEqual([("value",)], keys)
+
+ def test__query_init_one_field(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index("key-idx", "key")
+ query = self.db._query_init("key-idx")
+ self.assertEqual("key-idx", query.index_name)
+ self.assertEqual(1, query.num_fields)
+ self.assertEqual(["key"], query.fields)
+
+ def test__query_init_two_fields(self):
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.db.create_index("two-idx", "key", "key2")
+ query = self.db._query_init("two-idx")
+ self.assertEqual("two-idx", query.index_name)
+ self.assertEqual(2, query.num_fields)
+ self.assertEqual(["key", "key2"], query.fields)
+
+ def assertFormatQueryEquals(self, expected, wildcards, fields):
+ val, w = c_backend_wrapper._format_query(fields)
+ self.assertEqual(expected, val)
+ self.assertEqual(wildcards, w)
+
+ def test__format_query(self):
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id FROM document_fields d0"
+ " WHERE d0.field_name = ? AND d0.value = ? ORDER BY d0.value",
+ [0], ["1"])
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id"
+ " FROM document_fields d0, document_fields d1"
+ " WHERE d0.field_name = ? AND d0.value = ?"
+ " AND d0.doc_id = d1.doc_id"
+ " AND d1.field_name = ? AND d1.value = ?"
+ " ORDER BY d0.value, d1.value",
+ [0, 0], ["1", "2"])
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id"
+ " FROM document_fields d0, document_fields d1, document_fields d2"
+ " WHERE d0.field_name = ? AND d0.value = ?"
+ " AND d0.doc_id = d1.doc_id"
+ " AND d1.field_name = ? AND d1.value = ?"
+ " AND d0.doc_id = d2.doc_id"
+ " AND d2.field_name = ? AND d2.value = ?"
+ " ORDER BY d0.value, d1.value, d2.value",
+ [0, 0, 0], ["1", "2", "3"])
+
+ def test__format_query_wildcard(self):
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id FROM document_fields d0"
+ " WHERE d0.field_name = ? AND d0.value NOT NULL ORDER BY d0.value",
+ [1], ["*"])
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id"
+ " FROM document_fields d0, document_fields d1"
+ " WHERE d0.field_name = ? AND d0.value = ?"
+ " AND d0.doc_id = d1.doc_id"
+ " AND d1.field_name = ? AND d1.value NOT NULL"
+ " ORDER BY d0.value, d1.value",
+ [0, 1], ["1", "*"])
+
+ def test__format_query_glob(self):
+ self.assertFormatQueryEquals(
+ "SELECT d0.doc_id FROM document_fields d0"
+ " WHERE d0.field_name = ? AND d0.value GLOB ? ORDER BY d0.value",
+ [2], ["1*"])
+
+
+class TestCSyncTarget(BackendTests):
+
+ def setUp(self):
+ super(TestCSyncTarget, self).setUp()
+ self.db = c_backend_wrapper.CDatabase(':memory:')
+ self.st = self.db.get_sync_target()
+
+ def test_attached_to_db(self):
+ self.assertEqual(
+ self.db._replica_uid, self.st.get_sync_info("misc")[0])
+
+ def test_get_sync_exchange(self):
+ exc = self.st._get_sync_exchange("source-uid", 10)
+ self.assertIsNot(None, exc)
+
+ def test_sync_exchange_insert_doc_from_source(self):
+ exc = self.st._get_sync_exchange("source-uid", 5)
+ doc = c_backend_wrapper.make_document('doc-id', 'replica:1',
+ tests.simple_doc)
+ self.assertEqual([], exc.get_seen_ids())
+ exc.insert_doc_from_source(doc, 10, 'T-sid')
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', tests.simple_doc,
+ False)
+ self.assertEqual(
+ (10, 'T-sid'), self.db._get_replica_gen_and_trans_id('source-uid'))
+ self.assertEqual(['doc-id'], exc.get_seen_ids())
+
+ def test_sync_exchange_conflicted_doc(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ exc = self.st._get_sync_exchange("source-uid", 5)
+ doc2 = c_backend_wrapper.make_document(doc.doc_id, 'replica:1',
+ tests.nested_doc)
+ self.assertEqual([], exc.get_seen_ids())
+ # The insert should be rejected and the doc_id not considered 'seen'
+ exc.insert_doc_from_source(doc2, 10, 'T-sid')
+ self.assertGetDoc(
+ self.db, doc.doc_id, doc.rev, tests.simple_doc, False)
+ self.assertEqual([], exc.get_seen_ids())
+
+ def test_sync_exchange_find_doc_ids(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ exc = self.st._get_sync_exchange("source-uid", 0)
+ self.assertEqual(0, exc.target_gen)
+ exc.find_doc_ids_to_return()
+ doc_id = exc.get_doc_ids_to_return()[0]
+ self.assertEqual(
+ (doc.doc_id, 1), doc_id[:-1])
+ self.assertTrue(doc_id[-1].startswith('T-'))
+ self.assertEqual(1, exc.target_gen)
+
+ def test_sync_exchange_find_doc_ids_not_including_recently_inserted(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db.create_doc_from_json(tests.nested_doc)
+ exc = self.st._get_sync_exchange("source-uid", 0)
+ doc3 = c_backend_wrapper.make_document(doc1.doc_id,
+ doc1.rev + "|zreplica:2", tests.simple_doc)
+ exc.insert_doc_from_source(doc3, 10, 'T-sid')
+ exc.find_doc_ids_to_return()
+ self.assertEqual(
+ (doc2.doc_id, 2), exc.get_doc_ids_to_return()[0][:-1])
+ self.assertEqual(3, exc.target_gen)
+
+ def test_sync_exchange_return_docs(self):
+ returned = []
+
+ def return_doc_cb(doc, gen, trans_id):
+ returned.append((doc, gen, trans_id))
+
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ exc = self.st._get_sync_exchange("source-uid", 0)
+ exc.find_doc_ids_to_return()
+ exc.return_docs(return_doc_cb)
+ self.assertEqual((doc1, 1), returned[0][:-1])
+
+ def test_sync_exchange_doc_ids(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc, doc_id='doc-1')
+ db2 = c_backend_wrapper.CDatabase(':memory:')
+ doc2 = db2.create_doc_from_json(tests.nested_doc, doc_id='doc-2')
+ returned = []
+
+ def return_doc_cb(doc, gen, trans_id):
+ returned.append((doc, gen, trans_id))
+
+ val = self.st.sync_exchange_doc_ids(
+ db2, [(doc2.doc_id, 1, 'T-sid')], 0, None, return_doc_cb)
+ last_trans_id = self.db._get_transaction_log()[-1][1]
+ self.assertEqual(2, self.db._get_generation())
+ self.assertEqual((2, last_trans_id), val)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+ self.assertEqual((doc1, 1), returned[0][:-1])
+
+
+class TestCHTTPSyncTarget(BackendTests):
+
+ def test_format_sync_url(self):
+ target = c_backend_wrapper.create_http_sync_target("http://base_url")
+ self.assertEqual("http://base_url/sync-from/replica-uid",
+ c_backend_wrapper._format_sync_url(target, "replica-uid"))
+
+ def test_format_sync_url_escapes(self):
+ # The base_url should not get munged (we assume it is already a
+ # properly formed URL), but the replica-uid should get properly escaped
+ target = c_backend_wrapper.create_http_sync_target(
+ "http://host/base%2Ctest/")
+ self.assertEqual("http://host/base%2Ctest/sync-from/replica%2Cuid",
+ c_backend_wrapper._format_sync_url(target, "replica,uid"))
+
+ def test_format_refuses_non_http(self):
+ db = c_backend_wrapper.CDatabase(':memory:')
+ target = db.get_sync_target()
+ self.assertRaises(RuntimeError,
+ c_backend_wrapper._format_sync_url, target, 'replica,uid')
+
+ def test_oauth_credentials(self):
+ target = c_backend_wrapper.create_oauth_http_sync_target(
+ "http://host/base%2Ctest/",
+ 'consumer-key', 'consumer-secret', 'token-key', 'token-secret')
+ auth = c_backend_wrapper._get_oauth_authorization(target,
+ "GET", "http://host/base%2Ctest/sync-from/abcd-efg")
+ self.assertIsNot(None, auth)
+ self.assertTrue(auth.startswith('Authorization: OAuth realm="", '))
+ self.assertNotIn('http://host/base', auth)
+ self.assertIn('oauth_nonce="', auth)
+ self.assertIn('oauth_timestamp="', auth)
+ self.assertIn('oauth_consumer_key="consumer-key"', auth)
+ self.assertIn('oauth_signature_method="HMAC-SHA1"', auth)
+ self.assertIn('oauth_version="1.0"', auth)
+ self.assertIn('oauth_token="token-key"', auth)
+ self.assertIn('oauth_signature="', auth)
+
+
+class TestSyncCtoHTTPViaC(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestSyncCtoHTTPViaC, self).setUp()
+ if c_backend_wrapper is None:
+ self.skipTest("The c_backend_wrapper could not be imported")
+ self.startServer()
+
+ def test_trivial_sync(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ url = self.getURL('test.db')
+ target = c_backend_wrapper.create_http_sync_target(url)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ c_backend_wrapper.sync_db_to_target(db, target)
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+ def test_unavailable(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_db.create_doc_from_json(tests.nested_doc)
+ tries = []
+
+ def wrapper(instance, *args, **kwargs):
+ tries.append(None)
+ raise errors.Unavailable
+
+ mem_db.whats_changed = wrapper
+ url = self.getURL('test.db')
+ target = c_backend_wrapper.create_http_sync_target(url)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ db.create_doc_from_json(tests.simple_doc)
+ self.assertRaises(
+ errors.Unavailable, c_backend_wrapper.sync_db_to_target, db,
+ target)
+ self.assertEqual(5, len(tries))
+
+ def test_unavailable_then_available(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ orig_whatschanged = mem_db.whats_changed
+ tries = []
+
+ def wrapper(instance, *args, **kwargs):
+ if len(tries) < 1:
+ tries.append(None)
+ raise errors.Unavailable
+ return orig_whatschanged(instance, *args, **kwargs)
+
+ mem_db.whats_changed = wrapper
+ url = self.getURL('test.db')
+ target = c_backend_wrapper.create_http_sync_target(url)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ c_backend_wrapper.sync_db_to_target(db, target)
+ self.assertEqual(1, len(tries))
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+ def test_db_sync(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ url = self.getURL('test.db')
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ local_gen_before_sync = db.sync(url)
+ gen, _, changes = db.whats_changed(local_gen_before_sync)
+ self.assertEqual(1, len(changes))
+ self.assertEqual(mem_doc.doc_id, changes[0][0])
+ self.assertEqual(1, gen - local_gen_before_sync)
+ self.assertEqual(1, local_gen_before_sync)
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+
+class TestSyncCtoOAuthHTTPViaC(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_oauth_http_app)
+
+ def setUp(self):
+ super(TestSyncCtoOAuthHTTPViaC, self).setUp()
+ if c_backend_wrapper is None:
+ self.skipTest("The c_backend_wrapper could not be imported")
+ self.startServer()
+
+ def test_trivial_sync(self):
+ mem_db = self.request_state._create_database('test.db')
+ mem_doc = mem_db.create_doc_from_json(tests.nested_doc)
+ url = self.getURL('~/test.db')
+ target = c_backend_wrapper.create_oauth_http_sync_target(url,
+ tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ db = c_backend_wrapper.CDatabase(':memory:')
+ doc = db.create_doc_from_json(tests.simple_doc)
+ c_backend_wrapper.sync_db_to_target(db, target)
+ self.assertGetDoc(mem_db, doc.doc_id, doc.rev, doc.get_json(), False)
+ self.assertGetDoc(db, mem_doc.doc_id, mem_doc.rev, mem_doc.get_json(),
+ False)
+
+
+class TestVectorClock(BackendTests):
+
+ def create_vcr(self, rev):
+ return c_backend_wrapper.VectorClockRev(rev)
+
+ def test_parse_empty(self):
+ self.assertEqual('VectorClockRev()',
+ repr(self.create_vcr('')))
+
+ def test_parse_invalid(self):
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('x')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('x:a')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:a')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('x:a|y:1')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2a')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1||')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2|')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2|:')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:2|m:')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|x:|m:3')))
+ self.assertEqual('VectorClockRev(None)',
+ repr(self.create_vcr('y:1|:|m:3')))
+
+ def test_parse_single(self):
+ self.assertEqual('VectorClockRev(test:1)',
+ repr(self.create_vcr('test:1')))
+
+ def test_parse_multi(self):
+ self.assertEqual('VectorClockRev(test:1|z:2)',
+ repr(self.create_vcr('test:1|z:2')))
+ self.assertEqual('VectorClockRev(ab:1|bc:2|cd:3|de:4|ef:5)',
+ repr(self.create_vcr('ab:1|bc:2|cd:3|de:4|ef:5')))
+ self.assertEqual('VectorClockRev(a:2|b:1)',
+ repr(self.create_vcr('b:1|a:2')))
+
+
+class TestCDocument(BackendTests):
+
+ def make_document(self, *args, **kwargs):
+ return c_backend_wrapper.make_document(*args, **kwargs)
+
+ def test_create(self):
+ self.make_document('doc-id', 'uid:1', tests.simple_doc)
+
+ def assertPyDocEqualCDoc(self, *args, **kwargs):
+ cdoc = self.make_document(*args, **kwargs)
+ pydoc = Document(*args, **kwargs)
+ self.assertEqual(pydoc, cdoc)
+ self.assertEqual(cdoc, pydoc)
+
+ def test_cmp_to_pydoc_equal(self):
+ self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc)
+ self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=False)
+ self.assertPyDocEqualCDoc('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+
+ def test_cmp_to_pydoc_not_equal_conflicts(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+ def test_cmp_to_pydoc_not_equal_doc_id(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc2-id', 'uid:1', tests.simple_doc)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+ def test_cmp_to_pydoc_not_equal_doc_rev(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc-id', 'uid:2', tests.simple_doc)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+ def test_cmp_to_pydoc_not_equal_content(self):
+ cdoc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ pydoc = Document('doc-id', 'uid:1', tests.nested_doc)
+ self.assertNotEqual(cdoc, pydoc)
+ self.assertNotEqual(pydoc, cdoc)
+
+
+class TestUUID(BackendTests):
+
+ def test_uuid4_conformance(self):
+ uuids = set()
+ for i in range(20):
+ uuid = c_backend_wrapper.generate_hex_uuid()
+ self.assertIsInstance(uuid, str)
+ self.assertEqual(32, len(uuid))
+ # This will raise ValueError if it isn't a valid hex string
+ long(uuid, 16)
+ # Version 4 uuids have 2 other requirements, the high 4 bits of the
+ # seventh byte are always '0x4', and the middle bits of byte 9 are
+ # always set
+ self.assertEqual('4', uuid[12])
+ self.assertTrue(uuid[16] in '89ab')
+ self.assertTrue(uuid not in uuids)
+ uuids.add(uuid)
diff --git a/src/leap/soledad/u1db/tests/test_common_backend.py b/src/leap/soledad/u1db/tests/test_common_backend.py
new file mode 100644
index 00000000..8c7c7ed9
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_common_backend.py
@@ -0,0 +1,33 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test common backend bits."""
+
+from u1db import (
+ backends,
+ tests,
+ )
+
+
+class TestCommonBackendImpl(tests.TestCase):
+
+ def test__allocate_doc_id(self):
+ db = backends.CommonBackend()
+ doc_id1 = db._allocate_doc_id()
+ self.assertTrue(doc_id1.startswith('D-'))
+ self.assertEqual(34, len(doc_id1))
+ int(doc_id1[len('D-'):], 16)
+ self.assertNotEqual(doc_id1, db._allocate_doc_id())
diff --git a/src/leap/soledad/u1db/tests/test_document.py b/src/leap/soledad/u1db/tests/test_document.py
new file mode 100644
index 00000000..20f254b9
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_document.py
@@ -0,0 +1,148 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+
+from u1db import errors, tests
+
+
+class TestDocument(tests.TestCase):
+
+ scenarios = ([(
+ 'py', {'make_document_for_test': tests.make_document_for_test})] +
+ tests.C_DATABASE_SCENARIOS)
+
+ def test_create_doc(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ self.assertEqual('doc-id', doc.doc_id)
+ self.assertEqual('uid:1', doc.rev)
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+
+ def test__repr__(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ self.assertEqual(
+ '%s(doc-id, uid:1, \'{"key": "value"}\')'
+ % (doc.__class__.__name__,),
+ repr(doc))
+
+ def test__repr__conflicted(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+ self.assertEqual(
+ '%s(doc-id, uid:1, conflicted, \'{"key": "value"}\')'
+ % (doc.__class__.__name__,),
+ repr(doc))
+
+ def test__lt__(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('b', 'b', '{}')
+ self.assertTrue(doc_a < doc_b)
+ self.assertTrue(doc_b > doc_a)
+ doc_aa = self.make_document('a', 'a', '{}')
+ self.assertTrue(doc_aa < doc_a)
+
+ def test__eq__(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('a', 'b', '{}')
+ self.assertTrue(doc_a == doc_b)
+ doc_b = self.make_document('a', 'b', '{}', has_conflicts=True)
+ self.assertFalse(doc_a == doc_b)
+
+ def test_non_json_dict(self):
+ self.assertRaises(
+ errors.InvalidJSON, self.make_document, 'id', 'uid:1',
+ '"not a json dictionary"')
+
+ def test_non_json(self):
+ self.assertRaises(
+ errors.InvalidJSON, self.make_document, 'id', 'uid:1',
+ 'not a json dictionary')
+
+ def test_get_size(self):
+ doc_a = self.make_document('a', 'b', '{"some": "content"}')
+ self.assertEqual(
+ len('a' + 'b' + '{"some": "content"}'), doc_a.get_size())
+
+ def test_get_size_empty_document(self):
+ doc_a = self.make_document('a', 'b', None)
+ self.assertEqual(len('a' + 'b'), doc_a.get_size())
+
+
+class TestPyDocument(tests.TestCase):
+
+ scenarios = ([(
+ 'py', {'make_document_for_test': tests.make_document_for_test})])
+
+ def test_get_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertEqual({"content": ""}, doc.content)
+ doc.set_json('{"content": "new"}')
+ self.assertEqual({"content": "new"}, doc.content)
+
+ def test_set_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ doc.content = {"content": "new"}
+ self.assertEqual('{"content": "new"}', doc.get_json())
+
+ def test_set_bad_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(
+ errors.InvalidContent, setattr, doc, 'content',
+ '{"content": "new"}')
+
+ def test_is_tombstone(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ self.assertFalse(doc_a.is_tombstone())
+ doc_a.set_json(None)
+ self.assertTrue(doc_a.is_tombstone())
+
+ def test_make_tombstone(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ self.assertFalse(doc_a.is_tombstone())
+ doc_a.make_tombstone()
+ self.assertTrue(doc_a.is_tombstone())
+
+ def test_same_content_as(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('d', 'e', '{}')
+ self.assertTrue(doc_a.same_content_as(doc_b))
+ doc_b = self.make_document('p', 'q', '{}', has_conflicts=True)
+ self.assertTrue(doc_a.same_content_as(doc_b))
+ doc_b.content['key'] = 'value'
+ self.assertFalse(doc_a.same_content_as(doc_b))
+
+ def test_same_content_as_json_order(self):
+ doc_a = self.make_document(
+ 'a', 'b', '{"key1": "val1", "key2": "val2"}')
+ doc_b = self.make_document(
+ 'c', 'd', '{"key2": "val2", "key1": "val1"}')
+ self.assertTrue(doc_a.same_content_as(doc_b))
+
+ def test_set_json(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ doc.set_json('{"content": "new"}')
+ self.assertEqual('{"content": "new"}', doc.get_json())
+
+ def test_set_json_non_dict(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(errors.InvalidJSON, doc.set_json, '"is not a dict"')
+
+ def test_set_json_error(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(errors.InvalidJSON, doc.set_json, 'is not json')
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/u1db/tests/test_errors.py b/src/leap/soledad/u1db/tests/test_errors.py
new file mode 100644
index 00000000..0e089ede
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_errors.py
@@ -0,0 +1,61 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests error infrastructure."""
+
+from u1db import (
+ errors,
+ tests,
+ )
+
+
+class TestError(tests.TestCase):
+
+ def test_error_base(self):
+ err = errors.U1DBError()
+ self.assertEqual("error", err.wire_description)
+ self.assertIs(None, err.message)
+
+ err = errors.U1DBError("Message.")
+ self.assertEqual("error", err.wire_description)
+ self.assertEqual("Message.", err.message)
+
+ def test_HTTPError(self):
+ err = errors.HTTPError(500)
+ self.assertEqual(500, err.status)
+ self.assertIs(None, err.wire_description)
+ self.assertIs(None, err.message)
+
+ err = errors.HTTPError(500, "Crash.")
+ self.assertEqual(500, err.status)
+ self.assertIs(None, err.wire_description)
+ self.assertEqual("Crash.", err.message)
+
+ def test_HTTPError_str(self):
+ err = errors.HTTPError(500)
+ self.assertEqual("HTTPError(500)", str(err))
+
+ err = errors.HTTPError(500, "ERROR")
+ self.assertEqual("HTTPError(500, 'ERROR')", str(err))
+
+ def test_Unvailable(self):
+ err = errors.Unavailable()
+ self.assertEqual(503, err.status)
+ self.assertEqual("Unavailable()", str(err))
+
+ err = errors.Unavailable("DOWN")
+ self.assertEqual("DOWN", err.message)
+ self.assertEqual("Unavailable('DOWN')", str(err))
diff --git a/src/leap/soledad/u1db/tests/test_http_app.py b/src/leap/soledad/u1db/tests/test_http_app.py
new file mode 100644
index 00000000..13522693
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_http_app.py
@@ -0,0 +1,1133 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test the WSGI app."""
+
+import paste.fixture
+import sys
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import StringIO
+
+from u1db import (
+ __version__ as _u1db_version,
+ errors,
+ sync,
+ tests,
+ )
+
+from u1db.remote import (
+ http_app,
+ http_errors,
+ )
+
+
+class TestFencedReader(tests.TestCase):
+
+ def test_init(self):
+ reader = http_app._FencedReader(StringIO.StringIO(""), 25, 100)
+ self.assertEqual(25, reader.remaining)
+
+ def test_read_chunk(self):
+ inp = StringIO.StringIO("abcdef")
+ reader = http_app._FencedReader(inp, 5, 10)
+ data = reader.read_chunk(2)
+ self.assertEqual("ab", data)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(3, reader.remaining)
+
+ def test_read_chunk_remaining(self):
+ inp = StringIO.StringIO("abcdef")
+ reader = http_app._FencedReader(inp, 4, 10)
+ data = reader.read_chunk(9999)
+ self.assertEqual("abcd", data)
+ self.assertEqual(4, inp.tell())
+ self.assertEqual(0, reader.remaining)
+
+ def test_read_chunk_nothing_left(self):
+ inp = StringIO.StringIO("abc")
+ reader = http_app._FencedReader(inp, 2, 10)
+ reader.read_chunk(2)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(0, reader.remaining)
+ data = reader.read_chunk(2)
+ self.assertEqual("", data)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(0, reader.remaining)
+
+ def test_read_chunk_kept(self):
+ inp = StringIO.StringIO("abcde")
+ reader = http_app._FencedReader(inp, 4, 10)
+ reader._kept = "xyz"
+ data = reader.read_chunk(2) # atmost ignored
+ self.assertEqual("xyz", data)
+ self.assertEqual(0, inp.tell())
+ self.assertEqual(4, reader.remaining)
+ self.assertIsNone(reader._kept)
+
+ def test_getline(self):
+ inp = StringIO.StringIO("abc\r\nde")
+ reader = http_app._FencedReader(inp, 6, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abc\r\n", line)
+ self.assertEqual("d", reader._kept)
+
+ def test_getline_exact(self):
+ inp = StringIO.StringIO("abcd\r\nef")
+ reader = http_app._FencedReader(inp, 6, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abcd\r\n", line)
+ self.assertIs(None, reader._kept)
+
+ def test_getline_no_newline(self):
+ inp = StringIO.StringIO("abcd")
+ reader = http_app._FencedReader(inp, 4, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abcd", line)
+
+ def test_getline_many_chunks(self):
+ inp = StringIO.StringIO("abcde\r\nf")
+ reader = http_app._FencedReader(inp, 8, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("abcde\r\n", line)
+ self.assertEqual("f", reader._kept)
+ line = reader.getline()
+ self.assertEqual("f", line)
+
+ def test_getline_empty(self):
+ inp = StringIO.StringIO("")
+ reader = http_app._FencedReader(inp, 0, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("", line)
+ line = reader.getline()
+ self.assertEqual("", line)
+
+ def test_getline_just_newline(self):
+ inp = StringIO.StringIO("\r\n")
+ reader = http_app._FencedReader(inp, 2, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("\r\n", line)
+ line = reader.getline()
+ self.assertEqual("", line)
+
+ def test_getline_too_large(self):
+ inp = StringIO.StringIO("x" * 50)
+ reader = http_app._FencedReader(inp, 50, 25)
+ reader.MAXCHUNK = 4
+ self.assertRaises(http_app.BadRequest, reader.getline)
+
+ def test_getline_too_large_complete(self):
+ inp = StringIO.StringIO("x" * 25 + "\r\n")
+ reader = http_app._FencedReader(inp, 50, 25)
+ reader.MAXCHUNK = 4
+ self.assertRaises(http_app.BadRequest, reader.getline)
+
+
+class TestHTTPMethodDecorator(tests.TestCase):
+
+ def test_args(self):
+ @http_app.http_method()
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x", "b": "y"}, None)
+ self.assertEqual(("self", "x", "y"), res)
+
+ def test_args_missing(self):
+ @http_app.http_method()
+ def f(self, a, b):
+ return a, b
+ self.assertRaises(http_app.BadRequest, f, "self", {"a": "x"}, None)
+
+ def test_args_unexpected(self):
+ @http_app.http_method()
+ def f(self, a):
+ return a
+ self.assertRaises(http_app.BadRequest, f, "self",
+ {"a": "x", "c": "z"}, None)
+
+ def test_args_default(self):
+ @http_app.http_method()
+ def f(self, a, b="z"):
+ return a, b
+ res = f("self", {"a": "x"}, None)
+ self.assertEqual(("x", "z"), res)
+
+ def test_args_conversion(self):
+ @http_app.http_method(b=int)
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x", "b": "2"}, None)
+ self.assertEqual(("self", "x", 2), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self",
+ {"a": "x", "b": "foo"}, None)
+
+ def test_args_conversion_with_default(self):
+ @http_app.http_method(b=str)
+ def f(self, a, b=None):
+ return self, a, b
+ res = f("self", {"a": "x"}, None)
+ self.assertEqual(("self", "x", None), res)
+
+ def test_args_content(self):
+ @http_app.http_method()
+ def f(self, a, content):
+ return a, content
+ res = f(self, {"a": "x"}, "CONTENT")
+ self.assertEqual(("x", "CONTENT"), res)
+
+ def test_args_content_as_args(self):
+ @http_app.http_method(b=int, content_as_args=True)
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x"}, '{"b": "2"}')
+ self.assertEqual(("self", "x", 2), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self", {}, 'not-json')
+
+ def test_args_content_no_query(self):
+ @http_app.http_method(no_query=True,
+ content_as_args=True)
+ def f(self, a='a', b='b'):
+ return a, b
+ res = f("self", {}, '{"b": "y"}')
+ self.assertEqual(('a', 'y'), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self", {'a': 'x'},
+ '{"b": "y"}')
+
+
+class TestResource(object):
+
+ @http_app.http_method()
+ def get(self, a, b):
+ self.args = dict(a=a, b=b)
+ return 'Get'
+
+ @http_app.http_method()
+ def put(self, a, content):
+ self.args = dict(a=a)
+ self.content = content
+ return 'Put'
+
+ @http_app.http_method(content_as_args=True)
+ def put_args(self, a, b):
+ self.args = dict(a=a, b=b)
+ self.order = ['a']
+ self.entries = []
+
+ @http_app.http_method()
+ def put_stream_entry(self, content):
+ self.entries.append(content)
+ self.order.append('s')
+
+ def put_end(self):
+ self.order.append('e')
+ return "Put/end"
+
+
+class parameters:
+ max_request_size = 200000
+ max_entry_size = 100000
+
+
+class TestHTTPInvocationByMethodWithBody(tests.TestCase):
+
+ def test_get(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'GET'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Get', res)
+ self.assertEqual({'a': '1', 'b': '2'}, resource.args)
+
+ def test_put_json(self):
+ resource = TestResource()
+ body = '{"body": true}'
+ environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Put', res)
+ self.assertEqual({'a': '1'}, resource.args)
+ self.assertEqual('{"body": true}', resource.content)
+
+ def test_put_sync_stream(self):
+ resource = TestResource()
+ body = (
+ '[\r\n'
+ '{"b": 2},\r\n' # args
+ '{"entry": "x"},\r\n' # stream entry
+ '{"entry": "y"}\r\n' # stream entry
+ ']'
+ )
+ environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-sync-stream'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Put/end', res)
+ self.assertEqual({'a': '1', 'b': 2}, resource.args)
+ self.assertEqual(
+ ['{"entry": "x"}', '{"entry": "y"}'], resource.entries)
+ self.assertEqual(['a', 's', 's', 'e'], resource.order)
+
+ def _put_sync_stream(self, body):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-sync-stream'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ invoke()
+
+ def test_put_sync_stream_wrong_start(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "{}\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "\r\n{}\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "")
+
+ def test_put_sync_stream_wrong_end(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}\r\n]\r\n...")
+
+ def test_put_sync_stream_missing_comma(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}\r\n{}\r\n]")
+
+ def test_put_sync_stream_extra_comma(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{},\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{},\r\n{},\r\n]")
+
+ def test_bad_request_decode_failure(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=\xff', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_content_type(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'text/plain'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_content_length_too_large(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '10000',
+ 'CONTENT_TYPE': 'text/plain'}
+
+ resource.max_request_size = 5000
+ resource.max_entry_size = sys.maxint # we don't get to use this
+
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_no_content_length(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('a'),
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_invalid_content_length(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('abc'),
+ 'CONTENT_LENGTH': '1unk',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_empty_body(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(''),
+ 'CONTENT_LENGTH': '0',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_get_like(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'DELETE'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_put_like(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_put_like_multi_json(self):
+ resource = TestResource()
+ body = '{}\r\n{}\r\n'
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'POST',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-multi-json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+
+class TestHTTPResponder(tests.TestCase):
+
+ def start_response(self, status, headers):
+ self.status = status
+ self.headers = dict(headers)
+ self.response_body = []
+
+ def write(data):
+ self.response_body.append(data)
+
+ return write
+
+ def test_send_response_content_w_headers(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_content('foo', headers={'x-a': '1'})
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/json',
+ 'cache-control': 'no-cache',
+ 'x-a': '1', 'content-length': '3'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual(['foo'], responder.content)
+
+ def test_send_response_json(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_json(value='success')
+ self.assertEqual('200 OK', self.status)
+ expected_body = '{"value": "success"}\r\n'
+ self.assertEqual({'content-type': 'application/json',
+ 'content-length': str(len(expected_body)),
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual([expected_body], responder.content)
+
+ def test_send_response_json_status_fail(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_json(400)
+ self.assertEqual('400 Bad Request', self.status)
+ expected_body = '{}\r\n'
+ self.assertEqual({'content-type': 'application/json',
+ 'content-length': str(len(expected_body)),
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual([expected_body], responder.content)
+
+ def test_start_finish_response_status_fail(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.start_response(404, {'error': 'not found'})
+ responder.finish_response()
+ self.assertEqual('404 Not Found', self.status)
+ self.assertEqual({'content-type': 'application/json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['{"error": "not found"}\r\n'], self.response_body)
+ self.assertEqual([], responder.content)
+
+ def test_send_stream_entry(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.content_type = "application/x-u1db-multi-json"
+ responder.start_response(200)
+ responder.start_stream()
+ responder.stream_entry({'entry': 1})
+ responder.stream_entry({'entry': 2})
+ responder.end_stream()
+ responder.finish_response()
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/x-u1db-multi-json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['[',
+ '\r\n', '{"entry": 1}',
+ ',\r\n', '{"entry": 2}',
+ '\r\n]\r\n'], self.response_body)
+ self.assertEqual([], responder.content)
+
+ def test_send_stream_w_error(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.content_type = "application/x-u1db-multi-json"
+ responder.start_response(200)
+ responder.start_stream()
+ responder.stream_entry({'entry': 1})
+ responder.send_response_json(503, error="unavailable")
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/x-u1db-multi-json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['[',
+ '\r\n', '{"entry": 1}'], self.response_body)
+ self.assertEqual([',\r\n', '{"error": "unavailable"}\r\n'],
+ responder.content)
+
+
+class TestHTTPApp(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPApp, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.http_app = http_app.HTTPApp(self.state)
+ self.app = paste.fixture.TestApp(self.http_app)
+ self.db0 = self.state._create_database('db0')
+
+ def test_bad_request_broken(self):
+ resp = self.app.put('/db0/doc/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/foo'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_bad_request_dispatch(self):
+ resp = self.app.put('/db0/foo/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_version(self):
+ resp = self.app.get('/')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"version": _u1db_version}, json.loads(resp.body))
+
+ def test_create_database(self):
+ resp = self.app.put('/db1', params='{}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+
+ resp = self.app.put('/db1', params='{}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+
+ def test_delete_database(self):
+ resp = self.app.delete('/db0')
+ self.assertEqual(200, resp.status)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.check_database, 'db0')
+
+ def test_get_database(self):
+ resp = self.app.get('/db0')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({}, json.loads(resp.body))
+
+ def test_valid_database_names(self):
+ resp = self.app.get('/a-database', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/db1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/0', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/0-0', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/org.future', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ def test_invalid_database_names(self):
+ resp = self.app.get('/.a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ resp = self.app.get('/-a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ resp = self.app.get('/_a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_put_doc_create(self):
+ resp = self.app.put('/db0/doc/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'})
+ doc = self.db0.get_doc('doc1')
+ self.assertEqual(201, resp.status) # created
+ self.assertEqual('{"x": 1}', doc.get_json())
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_put_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev,
+ params='{"x": 2}',
+ headers={'content-type': 'application/json'})
+ doc = self.db0.get_doc('doc1')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('{"x": 2}', doc.get_json())
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_put_doc_too_large(self):
+ self.http_app.max_request_size = 15000
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev,
+ params='{"%s": 2}' % ('z' * 16000),
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_delete_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.delete('/db0/doc/doc1?old_rev=%s' % doc.rev)
+ doc = self.db0.get_doc('doc1', include_deleted=True)
+ self.assertEqual(None, doc.content)
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_get_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.get('/db0/doc/%s' % doc.doc_id)
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"x": 1}', resp.body)
+ self.assertEqual(doc.rev, resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_non_existing(self):
+ resp = self.app.get('/db0/doc/not-there', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "document does not exist"}, json.loads(resp.body))
+ self.assertEqual('', resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_deleted(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get('/db0/doc/doc1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DocumentDoesNotExist.wire_description},
+ json.loads(resp.body))
+
+ def test_get_doc_deleted_explicit_exclude(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get(
+ '/db0/doc/doc1?include_deleted=false', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DocumentDoesNotExist.wire_description},
+ json.loads(resp.body))
+
+ def test_get_deleted_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get(
+ '/db0/doc/doc1?include_deleted=true', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DOCUMENT_DELETED}, json.loads(resp.body))
+ self.assertEqual(doc.rev, resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_non_existing_dabase(self):
+ resp = self.app.get('/not-there/doc/doc1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "database does not exist"}, json.loads(resp.body))
+
+ def test_get_docs(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False},
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_missing_doc_ids(self):
+ resp = self.app.get('/db0/docs', expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "missing document ids"}, json.loads(resp.body))
+
+ def test_get_docs_empty_doc_ids(self):
+ resp = self.app.get('/db0/docs?doc_ids=', expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "missing document ids"}, json.loads(resp.body))
+
+ def test_get_docs_percent(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc%1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc%1",
+ "has_conflicts": False},
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_deleted(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ self.db0.delete_doc(doc2)
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_include_deleted(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ self.db0.delete_doc(doc2)
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s&include_deleted=true' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False},
+ {"content": None, "doc_rev": "db0:2", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_sync_info(self):
+ self.db0._set_replica_gen_and_trans_id('other-id', 1, 'T-transid')
+ resp = self.app.get('/db0/sync-from/other-id')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(dict(target_replica_uid='db0',
+ target_replica_generation=0,
+ target_replica_transaction_id='',
+ source_replica_uid='other-id',
+ source_replica_generation=1,
+ source_transaction_id='T-transid'),
+ json.loads(resp.body))
+
+ def test_record_sync_info(self):
+ resp = self.app.put('/db0/sync-from/other-id',
+ params='{"generation": 2, "transaction_id": "T-transid"}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+ self.assertEqual(
+ (2, 'T-transid'),
+ self.db0._get_replica_gen_and_trans_id('other-id'))
+
+ def test_sync_exchange_send(self):
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'},
+ 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content':
+ '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'}
+ }
+
+ gens = []
+ _do_set_replica_gen_and_trans_id = \
+ self.db0._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness(other_uid, other_gen, other_trans_id):
+ gens.append((other_uid, other_gen))
+ _do_set_replica_gen_and_trans_id(
+ other_uid, other_gen, other_trans_id)
+ self.assertGetDoc(self.db0, entries[other_gen]['id'],
+ entries[other_gen]['rev'],
+ entries[other_gen]['content'], False)
+
+ self.patch(
+ self.db0, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness)
+
+ args = dict(last_known_generation=0)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s,\r\n" % json.dumps(entries[10]) +
+ "%s\r\n" % json.dumps(entries[11]) +
+ "]\r\n")
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ bits = resp.body.split('\r\n')
+ self.assertEqual('[', bits[0])
+ last_trans_id = self.db0._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id},
+ json.loads(bits[1]))
+ self.assertEqual(']', bits[2])
+ self.assertEqual('', bits[3])
+ self.assertEqual([('replica', 10), ('replica', 11)], gens)
+
+ def test_sync_exchange_send_ensure(self):
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'},
+ 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content':
+ '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'}
+ }
+
+ args = dict(last_known_generation=0, ensure=True)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s,\r\n" % json.dumps(entries[10]) +
+ "%s\r\n" % json.dumps(entries[11]) +
+ "]\r\n")
+ resp = self.app.post('/dbnew/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ bits = resp.body.split('\r\n')
+ self.assertEqual('[', bits[0])
+ dbnew = self.state.open_database("dbnew")
+ last_trans_id = dbnew._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id,
+ 'replica_uid': dbnew._replica_uid},
+ json.loads(bits[1]))
+ self.assertEqual(']', bits[2])
+ self.assertEqual('', bits[3])
+
+ def test_sync_exchange_send_entry_too_large(self):
+ self.patch(http_app.SyncResource, 'max_request_size', 20000)
+ self.patch(http_app.SyncResource, 'max_entry_size', 10000)
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "%s"}' % ('H' * 11000), 'gen': 10},
+ }
+ args = dict(last_known_generation=0)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s\r\n" % json.dumps(entries[10]) +
+ "]\r\n")
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_sync_exchange_receive(self):
+ doc = self.db0.create_doc_from_json('{"value": "there"}')
+ doc2 = self.db0.create_doc_from_json('{"value": "there2"}')
+ args = dict(last_known_generation=0)
+ body = "[\r\n%s\r\n]" % json.dumps(args)
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ parts = resp.body.splitlines()
+ self.assertEqual(5, len(parts))
+ self.assertEqual('[', parts[0])
+ last_trans_id = self.db0._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id},
+ json.loads(parts[1].rstrip(",")))
+ part2 = json.loads(parts[2].rstrip(","))
+ self.assertTrue(part2['trans_id'].startswith('T-'))
+ self.assertEqual('{"value": "there"}', part2['content'])
+ self.assertEqual(doc.rev, part2['rev'])
+ self.assertEqual(doc.doc_id, part2['id'])
+ self.assertEqual(1, part2['gen'])
+ part3 = json.loads(parts[3].rstrip(","))
+ self.assertTrue(part3['trans_id'].startswith('T-'))
+ self.assertEqual('{"value": "there2"}', part3['content'])
+ self.assertEqual(doc2.rev, part3['rev'])
+ self.assertEqual(doc2.doc_id, part3['id'])
+ self.assertEqual(2, part3['gen'])
+ self.assertEqual(']', parts[4])
+
+ def test_sync_exchange_error_in_stream(self):
+ args = dict(last_known_generation=0)
+ body = "[\r\n%s\r\n]" % json.dumps(args)
+
+ def boom(self, return_doc_cb):
+ raise errors.Unavailable
+
+ self.patch(sync.SyncExchange, 'return_docs',
+ boom)
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ parts = resp.body.splitlines()
+ self.assertEqual(3, len(parts))
+ self.assertEqual('[', parts[0])
+ self.assertEqual({'new_generation': 0, 'new_transaction_id': ''},
+ json.loads(parts[1].rstrip(",")))
+ self.assertEqual({'error': 'unavailable'}, json.loads(parts[2]))
+
+
+class TestRequestHooks(tests.TestCase):
+
+ def setUp(self):
+ super(TestRequestHooks, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.http_app = http_app.HTTPApp(self.state)
+ self.app = paste.fixture.TestApp(self.http_app)
+ self.db0 = self.state._create_database('db0')
+
+ def test_begin_and_done(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def done(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('done')
+
+ self.http_app.request_begin = begin
+ self.http_app.request_done = done
+
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.app.get('/db0/doc/%s' % doc.doc_id)
+
+ self.assertEqual(['begin', 'done'], calls)
+
+ def test_bad_request(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def bad_request(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('bad-request')
+
+ self.http_app.request_begin = begin
+ self.http_app.request_bad_request = bad_request
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ resp = self.app.put('/db0/foo/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual(['begin', 'bad-request'], calls)
+
+
+class TestHTTPErrors(tests.TestCase):
+
+ def test_wire_description_to_status(self):
+ self.assertNotIn("error", http_errors.wire_description_to_status)
+
+
+class TestHTTPAppErrorHandling(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPAppErrorHandling, self).setUp()
+ self.exc = None
+ self.state = tests.ServerStateForTests()
+
+ class ErroringResource(object):
+
+ def post(_, args, content):
+ raise self.exc
+
+ def lookup_resource(environ, responder):
+ return ErroringResource()
+
+ self.http_app = http_app.HTTPApp(self.state)
+ self.http_app._lookup_resource = lookup_resource
+ self.app = paste.fixture.TestApp(self.http_app)
+
+ def test_RevisionConflict_etc(self):
+ self.exc = errors.RevisionConflict()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(409, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "revision conflict"},
+ json.loads(resp.body))
+
+ def test_Unavailable(self):
+ self.exc = errors.Unavailable
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(503, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "unavailable"},
+ json.loads(resp.body))
+
+ def test_generic_u1db_errors(self):
+ self.exc = errors.U1DBError()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(500, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "error"},
+ json.loads(resp.body))
+
+ def test_generic_u1db_errors_hooks(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def u1db_error(environ, exc):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append(('error', exc))
+
+ self.http_app.request_begin = begin
+ self.http_app.request_u1db_error = u1db_error
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ self.exc = errors.U1DBError()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(500, resp.status)
+ self.assertEqual(['begin', ('error', self.exc)], calls)
+
+ def test_failure(self):
+ class Failure(Exception):
+ pass
+ self.exc = Failure()
+ self.assertRaises(Failure, self.app.post, '/req', params='{}',
+ headers={'content-type': 'application/json'})
+
+ def test_failure_hooks(self):
+ class Failure(Exception):
+ pass
+ calls = []
+
+ def begin(environ):
+ calls.append('begin')
+
+ def failed(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append(('failed', sys.exc_info()))
+
+ self.http_app.request_begin = begin
+ self.http_app.request_failed = failed
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ self.exc = Failure()
+ self.assertRaises(Failure, self.app.post, '/req', params='{}',
+ headers={'content-type': 'application/json'})
+
+ self.assertEqual(2, len(calls))
+ self.assertEqual('begin', calls[0])
+ marker, (exc_type, exc, tb) = calls[1]
+ self.assertEqual('failed', marker)
+ self.assertEqual(self.exc, exc)
+
+
+class TestPluggableSyncExchange(tests.TestCase):
+
+ def setUp(self):
+ super(TestPluggableSyncExchange, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.state.ensure_database('foo')
+
+ def test_plugging(self):
+
+ class MySyncExchange(object):
+ def __init__(self, db, source_replica_uid, last_known_generation):
+ pass
+
+ class MySyncResource(http_app.SyncResource):
+ sync_exchange_class = MySyncExchange
+
+ sync_res = MySyncResource('foo', 'src', self.state, None)
+ sync_res.post_args(
+ {'last_known_generation': 0, 'last_known_trans_id': None}, '{}')
+ self.assertIsInstance(sync_res.sync_exch, MySyncExchange)
diff --git a/src/leap/soledad/u1db/tests/test_http_client.py b/src/leap/soledad/u1db/tests/test_http_client.py
new file mode 100644
index 00000000..115c8aaa
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_http_client.py
@@ -0,0 +1,361 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for HTTPDatabase"""
+
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.remote import (
+ http_client,
+ )
+
+
+class TestEncoder(tests.TestCase):
+
+ def test_encode_string(self):
+ self.assertEqual("foo", http_client._encode_query_parameter("foo"))
+
+ def test_encode_true(self):
+ self.assertEqual("true", http_client._encode_query_parameter(True))
+
+ def test_encode_false(self):
+ self.assertEqual("false", http_client._encode_query_parameter(False))
+
+
+class TestHTTPClientBase(tests.TestCaseWithServer):
+
+ def setUp(self):
+ super(TestHTTPClientBase, self).setUp()
+ self.errors = 0
+
+ def app(self, environ, start_response):
+ if environ['PATH_INFO'].endswith('echo'):
+ start_response("200 OK", [('Content-Type', 'application/json')])
+ ret = {}
+ for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
+ ret[name] = environ[name]
+ if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
+ ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
+ content_length = int(environ['CONTENT_LENGTH'])
+ ret['body'] = environ['wsgi.input'].read(content_length)
+ return [json.dumps(ret)]
+ elif environ['PATH_INFO'].endswith('error_then_accept'):
+ if self.errors >= 3:
+ start_response(
+ "200 OK", [('Content-Type', 'application/json')])
+ ret = {}
+ for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
+ ret[name] = environ[name]
+ if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
+ ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
+ content_length = int(environ['CONTENT_LENGTH'])
+ ret['body'] = '{"oki": "doki"}'
+ return [json.dumps(ret)]
+ self.errors += 1
+ content_length = int(environ['CONTENT_LENGTH'])
+ error = json.loads(
+ environ['wsgi.input'].read(content_length))
+ response = error['response']
+ # In debug mode, wsgiref has an assertion that the status parameter
+ # is a 'str' object. However error['status'] returns a unicode
+ # object.
+ status = str(error['status'])
+ if isinstance(response, unicode):
+ response = str(response)
+ if isinstance(response, str):
+ start_response(status, [('Content-Type', 'text/plain')])
+ return [str(response)]
+ else:
+ start_response(status, [('Content-Type', 'application/json')])
+ return [json.dumps(response)]
+ elif environ['PATH_INFO'].endswith('error'):
+ self.errors += 1
+ content_length = int(environ['CONTENT_LENGTH'])
+ error = json.loads(
+ environ['wsgi.input'].read(content_length))
+ response = error['response']
+ # In debug mode, wsgiref has an assertion that the status parameter
+ # is a 'str' object. However error['status'] returns a unicode
+ # object.
+ status = str(error['status'])
+ if isinstance(response, unicode):
+ response = str(response)
+ if isinstance(response, str):
+ start_response(status, [('Content-Type', 'text/plain')])
+ return [str(response)]
+ else:
+ start_response(status, [('Content-Type', 'application/json')])
+ return [json.dumps(response)]
+ elif '/oauth' in environ['PATH_INFO']:
+ base_url = self.getURL('').rstrip('/')
+ oauth_req = oauth.OAuthRequest.from_request(
+ http_method=environ['REQUEST_METHOD'],
+ http_url=base_url + environ['PATH_INFO'],
+ headers={'Authorization': environ['HTTP_AUTHORIZATION']},
+ query_string=environ['QUERY_STRING']
+ )
+ oauth_server = oauth.OAuthServer(tests.testingOAuthStore)
+ oauth_server.add_signature_method(tests.sign_meth_HMAC_SHA1)
+ try:
+ consumer, token, params = oauth_server.verify_request(
+ oauth_req)
+ except oauth.OAuthError, e:
+ start_response("401 Unauthorized",
+ [('Content-Type', 'application/json')])
+ return [json.dumps({"error": "unauthorized",
+ "message": e.message})]
+ start_response("200 OK", [('Content-Type', 'application/json')])
+ return [json.dumps([environ['PATH_INFO'], token.key, params])]
+
+ def make_app(self):
+ return self.app
+
+ def getClient(self, **kwds):
+ self.startServer()
+ return http_client.HTTPClientBase(self.getURL('dbase'), **kwds)
+
+ def test_construct(self):
+ self.startServer()
+ url = self.getURL()
+ cli = http_client.HTTPClientBase(url)
+ self.assertEqual(url, cli._url.geturl())
+ self.assertIs(None, cli._conn)
+
+ def test_parse_url(self):
+ cli = http_client.HTTPClientBase(
+ '%s://127.0.0.1:12345/' % self.url_scheme)
+ self.assertEqual(self.url_scheme, cli._url.scheme)
+ self.assertEqual('127.0.0.1', cli._url.hostname)
+ self.assertEqual(12345, cli._url.port)
+ self.assertEqual('/', cli._url.path)
+
+ def test__ensure_connection(self):
+ cli = self.getClient()
+ self.assertIs(None, cli._conn)
+ cli._ensure_connection()
+ self.assertIsNot(None, cli._conn)
+ conn = cli._conn
+ cli._ensure_connection()
+ self.assertIs(conn, cli._conn)
+
+ def test_close(self):
+ cli = self.getClient()
+ cli._ensure_connection()
+ cli.close()
+ self.assertIs(None, cli._conn)
+
+ def test__request(self):
+ cli = self.getClient()
+ res, headers = cli._request('PUT', ['echo'], {}, {})
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': '',
+ 'body': '{}',
+ 'REQUEST_METHOD': 'PUT'}, json.loads(res))
+
+ res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1})
+ self.assertEqual({'PATH_INFO': '/dbase/doc/echo',
+ 'QUERY_STRING': 'a=1',
+ 'REQUEST_METHOD': 'GET'}, json.loads(res))
+
+ res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1})
+ self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo',
+ 'QUERY_STRING': 'a=1',
+ 'REQUEST_METHOD': 'GET'}, json.loads(res))
+
+ res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body',
+ 'application/x-test')
+ self.assertEqual({'CONTENT_TYPE': 'application/x-test',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': 'b=2',
+ 'body': 'Body',
+ 'REQUEST_METHOD': 'POST'}, json.loads(res))
+
+ def test__request_json(self):
+ cli = self.getClient()
+ res, headers = cli._request_json(
+ 'POST', ['echo'], {'b': 2}, {'a': 'x'})
+ self.assertEqual('application/json', headers['content-type'])
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': 'b=2',
+ 'body': '{"a": "x"}',
+ 'REQUEST_METHOD': 'POST'}, res)
+
+ def test_unspecified_http_error(self):
+ cli = self.getClient()
+ self.assertRaises(errors.HTTPError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "500 Internal Error",
+ 'response': "Crash."})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "500 Internal Error",
+ 'response': "Fail."})
+ except errors.HTTPError, e:
+ pass
+
+ self.assertEqual(500, e.status)
+ self.assertEqual("Fail.", e.message)
+ self.assertTrue("content-type" in e.headers)
+
+ def test_revision_conflict(self):
+ cli = self.getClient()
+ self.assertRaises(errors.RevisionConflict,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "409 Conflict",
+ 'response': {"error": "revision conflict"}})
+
+ def test_unavailable_proper(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ self.assertRaises(errors.Unavailable,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "503 Service Unavailable",
+ 'response': {"error": "unavailable"}})
+ self.assertEqual(5, self.errors)
+
+ def test_unavailable_then_available(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ res, headers = cli._request_json(
+ 'POST', ['error_then_accept'], {'b': 2},
+ {'status': "503 Service Unavailable",
+ 'response': {"error": "unavailable"}})
+ self.assertEqual('application/json', headers['content-type'])
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/error_then_accept',
+ 'QUERY_STRING': 'b=2',
+ 'body': '{"oki": "doki"}',
+ 'REQUEST_METHOD': 'POST'}, res)
+ self.assertEqual(3, self.errors)
+
+ def test_unavailable_random_source(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "503 Service Unavailable",
+ 'response': "random unavailable."})
+ except errors.Unavailable, e:
+ pass
+
+ self.assertEqual(503, e.status)
+ self.assertEqual("random unavailable.", e.message)
+ self.assertTrue("content-type" in e.headers)
+ self.assertEqual(5, self.errors)
+
+ def test_document_too_big(self):
+ cli = self.getClient()
+ self.assertRaises(errors.DocumentTooBig,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "document too big"}})
+
+ def test_user_quota_exceeded(self):
+ cli = self.getClient()
+ self.assertRaises(errors.UserQuotaExceeded,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "user quota exceeded"}})
+
+ def test_user_needs_subscription(self):
+ cli = self.getClient()
+ self.assertRaises(errors.SubscriptionNeeded,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "user needs subscription"}})
+
+ def test_generic_u1db_error(self):
+ cli = self.getClient()
+ self.assertRaises(errors.U1DBError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': {"error": "error"}})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': {"error": "error"}})
+ except errors.U1DBError, e:
+ pass
+ self.assertIs(e.__class__, errors.U1DBError)
+
+ def test_unspecified_bad_request(self):
+ cli = self.getClient()
+ self.assertRaises(errors.HTTPError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': "<Bad Request>"})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': "<Bad Request>"})
+ except errors.HTTPError, e:
+ pass
+
+ self.assertEqual(400, e.status)
+ self.assertEqual("<Bad Request>", e.message)
+ self.assertTrue("content-type" in e.headers)
+
+ def test_oauth(self):
+ cli = self.getClient()
+ cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res))
+
+ # oauth does its own internal quoting
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth', 'foo bar'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth/foo bar', tests.token1.key, params],
+ json.loads(res))
+
+ def test_oauth_ctr_creds(self):
+ cli = self.getClient(creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret,
+ }})
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res))
+
+ def test_unknown_creds(self):
+ self.assertRaises(errors.UnknownAuthMethod,
+ self.getClient, creds={'foo': {}})
+ self.assertRaises(errors.UnknownAuthMethod,
+ self.getClient, creds={})
+
+ def test_oauth_Unauthorized(self):
+ cli = self.getClient()
+ cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, "WRONG")
+ params = {'y': 'foo'}
+ self.assertRaises(errors.Unauthorized, cli._request, 'GET',
+ ['doc', 'oauth'], params)
diff --git a/src/leap/soledad/u1db/tests/test_http_database.py b/src/leap/soledad/u1db/tests/test_http_database.py
new file mode 100644
index 00000000..c8e7eb76
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_http_database.py
@@ -0,0 +1,256 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for HTTPDatabase"""
+
+import inspect
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ errors,
+ Document,
+ tests,
+ )
+from u1db.remote import (
+ http_database,
+ http_target,
+ )
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+)
+
+
+class TestHTTPDatabaseSimpleOperations(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPDatabaseSimpleOperations, self).setUp()
+ self.db = http_database.HTTPDatabase('dbase')
+ self.db._conn = object() # crash if used
+ self.got = None
+ self.response_val = None
+
+ def _request(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ def _request_json(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ self.db._request = _request
+ self.db._request_json = _request_json
+
+ def test__sanity_same_signature(self):
+ my_request_sig = inspect.getargspec(self.db._request)
+ my_request_sig = (['self'] + my_request_sig[0],) + my_request_sig[1:]
+ self.assertEqual(my_request_sig,
+ inspect.getargspec(http_database.HTTPDatabase._request))
+ my_request_json_sig = inspect.getargspec(self.db._request_json)
+ my_request_json_sig = ((['self'] + my_request_json_sig[0],) +
+ my_request_json_sig[1:])
+ self.assertEqual(my_request_json_sig,
+ inspect.getargspec(http_database.HTTPDatabase._request_json))
+
+ def test__ensure(self):
+ self.response_val = {'ok': True}, {}
+ self.db._ensure()
+ self.assertEqual(('PUT', [], {}, {}, None), self.got)
+
+ def test__delete(self):
+ self.response_val = {'ok': True}, {}
+ self.db._delete()
+ self.assertEqual(('DELETE', [], {}, {}, None), self.got)
+
+ def test__check(self):
+ self.response_val = {}, {}
+ res = self.db._check()
+ self.assertEqual({}, res)
+ self.assertEqual(('GET', [], None, None, None), self.got)
+
+ def test_put_doc(self):
+ self.response_val = {'rev': 'doc-rev'}, {}
+ doc = Document('doc-id', None, '{"v": 1}')
+ res = self.db.put_doc(doc)
+ self.assertEqual('doc-rev', res)
+ self.assertEqual('doc-rev', doc.rev)
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {},
+ '{"v": 1}', 'application/json'), self.got)
+
+ self.response_val = {'rev': 'doc-rev-2'}, {}
+ doc.content = {"v": 2}
+ res = self.db.put_doc(doc)
+ self.assertEqual('doc-rev-2', res)
+ self.assertEqual('doc-rev-2', doc.rev)
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {'old_rev': 'doc-rev'},
+ '{"v": 2}', 'application/json'), self.got)
+
+ def test_get_doc(self):
+ self.response_val = '{"v": 2}', {'x-u1db-rev': 'doc-rev',
+ 'x-u1db-has-conflicts': 'false'}
+ self.assertGetDoc(self.db, 'doc-id', 'doc-rev', '{"v": 2}', False)
+ self.assertEqual(
+ ('GET', ['doc', 'doc-id'], {'include_deleted': False}, None, None),
+ self.got)
+
+ def test_get_doc_non_existing(self):
+ self.response_val = errors.DocumentDoesNotExist()
+ self.assertIs(None, self.db.get_doc('not-there'))
+ self.assertEqual(
+ ('GET', ['doc', 'not-there'], {'include_deleted': False}, None,
+ None), self.got)
+
+ def test_get_doc_deleted(self):
+ self.response_val = errors.DocumentDoesNotExist()
+ self.assertIs(None, self.db.get_doc('deleted'))
+ self.assertEqual(
+ ('GET', ['doc', 'deleted'], {'include_deleted': False}, None,
+ None), self.got)
+
+ def test_get_doc_deleted_include_deleted(self):
+ self.response_val = errors.HTTPError(404,
+ json.dumps(
+ {"error": errors.DOCUMENT_DELETED}
+ ),
+ {'x-u1db-rev': 'doc-rev-gone',
+ 'x-u1db-has-conflicts': 'false'})
+ doc = self.db.get_doc('deleted', include_deleted=True)
+ self.assertEqual('deleted', doc.doc_id)
+ self.assertEqual('doc-rev-gone', doc.rev)
+ self.assertIs(None, doc.content)
+ self.assertEqual(
+ ('GET', ['doc', 'deleted'], {'include_deleted': True}, None, None),
+ self.got)
+
+ def test_get_doc_pass_through_errors(self):
+ self.response_val = errors.HTTPError(500, 'Crash.')
+ self.assertRaises(errors.HTTPError,
+ self.db.get_doc, 'something-something')
+
+ def test_create_doc_with_id(self):
+ self.response_val = {'rev': 'doc-rev'}, {}
+ new_doc = self.db.create_doc_from_json('{"v": 1}', doc_id='doc-id')
+ self.assertEqual('doc-rev', new_doc.rev)
+ self.assertEqual('doc-id', new_doc.doc_id)
+ self.assertEqual('{"v": 1}', new_doc.get_json())
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {},
+ '{"v": 1}', 'application/json'), self.got)
+
+ def test_create_doc_without_id(self):
+ self.response_val = {'rev': 'doc-rev-2'}, {}
+ new_doc = self.db.create_doc_from_json('{"v": 3}')
+ self.assertEqual('D-', new_doc.doc_id[:2])
+ self.assertEqual('doc-rev-2', new_doc.rev)
+ self.assertEqual('{"v": 3}', new_doc.get_json())
+ self.assertEqual(('PUT', ['doc', new_doc.doc_id], {},
+ '{"v": 3}', 'application/json'), self.got)
+
+ def test_delete_doc(self):
+ self.response_val = {'rev': 'doc-rev-gone'}, {}
+ doc = Document('doc-id', 'doc-rev', None)
+ self.db.delete_doc(doc)
+ self.assertEqual('doc-rev-gone', doc.rev)
+ self.assertEqual(('DELETE', ['doc', 'doc-id'], {'old_rev': 'doc-rev'},
+ None, None), self.got)
+
+ def test_get_sync_target(self):
+ st = self.db.get_sync_target()
+ self.assertIsInstance(st, http_target.HTTPSyncTarget)
+ self.assertEqual(st._url, self.db._url)
+
+ def test_get_sync_target_inherits_oauth_credentials(self):
+ self.db.set_oauth_credentials(tests.consumer1.key,
+ tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ st = self.db.get_sync_target()
+ self.assertEqual(self.db._creds, st._creds)
+
+
+class TestHTTPDatabaseCtrWithCreds(tests.TestCase):
+
+ def test_ctr_with_creds(self):
+ db1 = http_database.HTTPDatabase('http://dbs/db', creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret
+ }})
+ self.assertIn('oauth', db1._creds)
+
+
+class TestHTTPDatabaseIntegration(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestHTTPDatabaseIntegration, self).setUp()
+ self.startServer()
+
+ def test_non_existing_db(self):
+ db = http_database.HTTPDatabase(self.getURL('not-there'))
+ self.assertRaises(errors.DatabaseDoesNotExist, db.get_doc, 'doc1')
+
+ def test__ensure(self):
+ db = http_database.HTTPDatabase(self.getURL('new'))
+ db._ensure()
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test__delete(self):
+ self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase(self.getURL('db0'))
+ db._delete()
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_open_database_existing(self):
+ self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase.open_database(self.getURL('db0'),
+ create=False)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_open_database_non_existing(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ http_database.HTTPDatabase.open_database,
+ self.getURL('not-there'),
+ create=False)
+
+ def test_open_database_create(self):
+ db = http_database.HTTPDatabase.open_database(self.getURL('new'),
+ create=True)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_delete_database_existing(self):
+ self.request_state._create_database('db0')
+ http_database.HTTPDatabase.delete_database(self.getURL('db0'))
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_doc_ids_needing_quoting(self):
+ db0 = self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase.open_database(self.getURL('db0'),
+ create=False)
+ doc = Document('%fff', None, '{}')
+ db.put_doc(doc)
+ self.assertGetDoc(db0, '%fff', doc.rev, '{}', False)
+ self.assertGetDoc(db, '%fff', doc.rev, '{}', False)
diff --git a/src/leap/soledad/u1db/tests/test_https.py b/src/leap/soledad/u1db/tests/test_https.py
new file mode 100644
index 00000000..67681c8a
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_https.py
@@ -0,0 +1,117 @@
+"""Test support for client-side https support."""
+
+import os
+import ssl
+import sys
+
+from paste import httpserver
+
+from u1db import (
+ tests,
+ )
+from u1db.remote import (
+ http_client,
+ http_target,
+ )
+
+from u1db.tests.test_remote_sync_target import (
+ make_oauth_http_app,
+ )
+
+
+def https_server_def():
+ def make_server(host_port, application):
+ from OpenSSL import SSL
+ cert_file = os.path.join(os.path.dirname(__file__), 'testing-certs',
+ 'testing.cert')
+ key_file = os.path.join(os.path.dirname(__file__), 'testing-certs',
+ 'testing.key')
+ ssl_context = SSL.Context(SSL.SSLv23_METHOD)
+ ssl_context.use_privatekey_file(key_file)
+ ssl_context.use_certificate_chain_file(cert_file)
+ srv = httpserver.WSGIServerBase(application, host_port,
+ httpserver.WSGIHandler,
+ ssl_context=ssl_context
+ )
+
+ def shutdown_request(req):
+ req.shutdown()
+ srv.close_request(req)
+
+ srv.shutdown_request = shutdown_request
+ application.base_url = "https://localhost:%s" % srv.server_address[1]
+ return srv
+ return make_server, "shutdown", "https"
+
+
+def oauth_https_sync_target(test, host, path):
+ _, port = test.server.server_address
+ st = http_target.HTTPSyncTarget('https://%s:%d/~/%s' % (host, port, path))
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestHttpSyncTargetHttpsSupport(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('oauth_https', {'server_def': https_server_def,
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': oauth_https_sync_target
+ }),
+ ]
+
+ def setUp(self):
+ try:
+ import OpenSSL # noqa
+ except ImportError:
+ self.skipTest("Requires pyOpenSSL")
+ self.cacert_pem = os.path.join(os.path.dirname(__file__),
+ 'testing-certs', 'cacert.pem')
+ super(TestHttpSyncTargetHttpsSupport, self).setUp()
+
+ def getSyncTarget(self, host, path=None):
+ if self.server is None:
+ self.startServer()
+ return self.sync_target(self, host, path)
+
+ def test_working(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ self.patch(http_client, 'CA_CERTS', self.cacert_pem)
+ remote_target = self.getSyncTarget('localhost', 'test')
+ remote_target.record_sync_info('other-id', 2, 'T-id')
+ self.assertEqual(
+ (2, 'T-id'), db._get_replica_gen_and_trans_id('other-id'))
+
+ def test_cannot_verify_cert(self):
+ if not sys.platform.startswith('linux'):
+ self.skipTest(
+ "XXX certificate verification happens on linux only for now")
+ self.startServer()
+ # don't print expected traceback server-side
+ self.server.handle_error = lambda req, cli_addr: None
+ self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('localhost', 'test')
+ try:
+ remote_target.record_sync_info('other-id', 2, 'T-id')
+ except ssl.SSLError, e:
+ self.assertIn("certificate verify failed", str(e))
+ else:
+ self.fail("certificate verification should have failed.")
+
+ def test_host_mismatch(self):
+ if not sys.platform.startswith('linux'):
+ self.skipTest(
+ "XXX certificate verification happens on linux only for now")
+ self.startServer()
+ self.request_state._create_database('test')
+ self.patch(http_client, 'CA_CERTS', self.cacert_pem)
+ remote_target = self.getSyncTarget('127.0.0.1', 'test')
+ self.assertRaises(
+ http_client.CertificateError, remote_target.record_sync_info,
+ 'other-id', 2, 'T-id')
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/u1db/tests/test_inmemory.py b/src/leap/soledad/u1db/tests/test_inmemory.py
new file mode 100644
index 00000000..255a1e08
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_inmemory.py
@@ -0,0 +1,128 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test in-memory backend internals."""
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.backends import inmemory
+
+
+simple_doc = '{"key": "value"}'
+
+
+class TestInMemoryDatabaseInternals(tests.TestCase):
+
+ def setUp(self):
+ super(TestInMemoryDatabaseInternals, self).setUp()
+ self.db = inmemory.InMemoryDatabase('test')
+
+ def test__allocate_doc_rev_from_None(self):
+ self.assertEqual('test:1', self.db._allocate_doc_rev(None))
+
+ def test__allocate_doc_rev_incremental(self):
+ self.assertEqual('test:2', self.db._allocate_doc_rev('test:1'))
+
+ def test__allocate_doc_rev_other(self):
+ self.assertEqual('replica:1|test:1',
+ self.db._allocate_doc_rev('replica:1'))
+
+ def test__get_replica_uid(self):
+ self.assertEqual('test', self.db._replica_uid)
+
+
+class TestInMemoryIndex(tests.TestCase):
+
+ def test_has_name_and_definition(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ self.assertEqual('idx-name', idx._name)
+ self.assertEqual(['key'], idx._definition)
+
+ def test_evaluate_json(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ self.assertEqual(['value'], idx.evaluate_json(simple_doc))
+
+ def test_evaluate_json_field_None(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['missing'])
+ self.assertEqual([], idx.evaluate_json(simple_doc))
+
+ def test_evaluate_json_subfield_None(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key', 'missing'])
+ self.assertEqual([], idx.evaluate_json(simple_doc))
+
+ def test_evaluate_multi_index(self):
+ doc = '{"key": "value", "key2": "value2"}'
+ idx = inmemory.InMemoryIndex('idx-name', ['key', 'key2'])
+ self.assertEqual(['value\x01value2'],
+ idx.evaluate_json(doc))
+
+ def test_update_ignores_None(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['nokey'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual({}, idx._values)
+
+ def test_update_adds_entry(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual({'value': ['doc-id']}, idx._values)
+
+ def test_remove_json(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual({'value': ['doc-id']}, idx._values)
+ idx.remove_json('doc-id', simple_doc)
+ self.assertEqual({}, idx._values)
+
+ def test_remove_json_multiple(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ idx.add_json('doc2-id', simple_doc)
+ self.assertEqual({'value': ['doc-id', 'doc2-id']}, idx._values)
+ idx.remove_json('doc-id', simple_doc)
+ self.assertEqual({'value': ['doc2-id']}, idx._values)
+
+ def test_keys(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual(['value'], idx.keys())
+
+ def test_lookup(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ self.assertEqual(['doc-id'], idx.lookup(['value']))
+
+ def test_lookup_multi(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['key'])
+ idx.add_json('doc-id', simple_doc)
+ idx.add_json('doc2-id', simple_doc)
+ self.assertEqual(['doc-id', 'doc2-id'], idx.lookup(['value']))
+
+ def test__find_non_wildcards(self):
+ idx = inmemory.InMemoryIndex('idx-name', ['k1', 'k2', 'k3'])
+ self.assertEqual(-1, idx._find_non_wildcards(('a', 'b', 'c')))
+ self.assertEqual(2, idx._find_non_wildcards(('a', 'b', '*')))
+ self.assertEqual(3, idx._find_non_wildcards(('a', 'b', 'c*')))
+ self.assertEqual(2, idx._find_non_wildcards(('a', 'b*', '*')))
+ self.assertEqual(0, idx._find_non_wildcards(('*', '*', '*')))
+ self.assertEqual(1, idx._find_non_wildcards(('a*', '*', '*')))
+ self.assertRaises(errors.InvalidValueForIndex,
+ idx._find_non_wildcards, ('a', 'b'))
+ self.assertRaises(errors.InvalidValueForIndex,
+ idx._find_non_wildcards, ('a', 'b', 'c', 'd'))
+ self.assertRaises(errors.InvalidGlobbing,
+ idx._find_non_wildcards, ('*', 'b', 'c'))
diff --git a/src/leap/soledad/u1db/tests/test_open.py b/src/leap/soledad/u1db/tests/test_open.py
new file mode 100644
index 00000000..fbeb0cfd
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_open.py
@@ -0,0 +1,69 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test u1db.open"""
+
+import os
+
+from u1db import (
+ errors,
+ open as u1db_open,
+ tests,
+ )
+from u1db.backends import sqlite_backend
+from u1db.tests.test_backends import TestAlternativeDocument
+
+
+class TestU1DBOpen(tests.TestCase):
+
+ def setUp(self):
+ super(TestU1DBOpen, self).setUp()
+ tmpdir = self.createTempDir()
+ self.db_path = tmpdir + '/test.db'
+
+ def test_open_no_create(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ u1db_open, self.db_path, create=False)
+ self.assertFalse(os.path.exists(self.db_path))
+
+ def test_open_create(self):
+ db = u1db_open(self.db_path, create=True)
+ self.addCleanup(db.close)
+ self.assertTrue(os.path.exists(self.db_path))
+ self.assertIsInstance(db, sqlite_backend.SQLiteDatabase)
+
+ def test_open_with_factory(self):
+ db = u1db_open(self.db_path, create=True,
+ document_factory=TestAlternativeDocument)
+ self.addCleanup(db.close)
+ self.assertEqual(TestAlternativeDocument, db._factory)
+
+ def test_open_existing(self):
+ db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path)
+ self.addCleanup(db.close)
+ doc = db.create_doc_from_json(tests.simple_doc)
+ # Even though create=True, we shouldn't wipe the db
+ db2 = u1db_open(self.db_path, create=True)
+ self.addCleanup(db2.close)
+ doc2 = db2.get_doc(doc.doc_id)
+ self.assertEqual(doc, doc2)
+
+ def test_open_existing_no_create(self):
+ db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path)
+ self.addCleanup(db.close)
+ db2 = u1db_open(self.db_path, create=False)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
diff --git a/src/leap/soledad/u1db/tests/test_query_parser.py b/src/leap/soledad/u1db/tests/test_query_parser.py
new file mode 100644
index 00000000..ee374267
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_query_parser.py
@@ -0,0 +1,443 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+from u1db import (
+ errors,
+ query_parser,
+ tests,
+ )
+
+
+trivial_raw_doc = {}
+
+
+class TestFieldName(tests.TestCase):
+
+ def test_check_fieldname_valid(self):
+ self.assertIsNone(query_parser.check_fieldname("foo"))
+
+ def test_check_fieldname_invalid(self):
+ self.assertRaises(
+ errors.IndexDefinitionParseError, query_parser.check_fieldname,
+ "foo.")
+
+
+class TestMakeTree(tests.TestCase):
+
+ def setUp(self):
+ super(TestMakeTree, self).setUp()
+ self.parser = query_parser.Parser()
+
+ def assertParseError(self, definition):
+ self.assertRaises(
+ errors.IndexDefinitionParseError, self.parser.parse,
+ definition)
+
+ def test_single_field(self):
+ self.assertIsInstance(
+ self.parser.parse('f'), query_parser.ExtractField)
+
+ def test_single_mapping(self):
+ self.assertIsInstance(
+ self.parser.parse('bool(field1)'), query_parser.Bool)
+
+ def test_nested_mapping(self):
+ self.assertIsInstance(
+ self.parser.parse('lower(split_words(field1))'),
+ query_parser.Lower)
+
+ def test_nested_branching_mapping(self):
+ self.assertIsInstance(
+ self.parser.parse(
+ 'combine(lower(field1), split_words(field2), '
+ 'number(field3, 5))'), query_parser.Combine)
+
+ def test_single_mapping_multiple_fields(self):
+ self.assertIsInstance(
+ self.parser.parse('number(field1, 5)'), query_parser.Number)
+
+ def test_unknown_mapping(self):
+ self.assertParseError('mapping(whatever)')
+
+ def test_parse_missing_close_paren(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_top_level_commas(self):
+ self.assertParseError("a, b")
+
+ def test_invalid_field_name(self):
+ self.assertParseError("a.")
+
+ def test_invalid_inner_field_name(self):
+ self.assertParseError("lower(a.)")
+
+ def test_gobbledigook(self):
+ self.assertParseError("(@#@cc @#!*DFJSXV(()jccd")
+
+ def test_leading_space(self):
+ self.assertIsInstance(
+ self.parser.parse(" lower(a)"), query_parser.Lower)
+
+ def test_trailing_space(self):
+ self.assertIsInstance(
+ self.parser.parse("lower(a) "), query_parser.Lower)
+
+ def test_spaces_before_open_paren(self):
+ self.assertIsInstance(
+ self.parser.parse("lower (a)"), query_parser.Lower)
+
+ def test_spaces_after_open_paren(self):
+ self.assertIsInstance(
+ self.parser.parse("lower( a)"), query_parser.Lower)
+
+ def test_spaces_before_close_paren(self):
+ self.assertIsInstance(
+ self.parser.parse("lower(a )"), query_parser.Lower)
+
+ def test_spaces_before_comma(self):
+ self.assertIsInstance(
+ self.parser.parse("number(a , 5)"), query_parser.Number)
+
+ def test_spaces_after_comma(self):
+ self.assertIsInstance(
+ self.parser.parse("number(a, 5)"), query_parser.Number)
+
+
+class TestStaticGetter(tests.TestCase):
+
+ def test_returns_string(self):
+ getter = query_parser.StaticGetter('foo')
+ self.assertEqual(['foo'], getter.get(trivial_raw_doc))
+
+ def test_returns_int(self):
+ getter = query_parser.StaticGetter(9)
+ self.assertEqual([9], getter.get(trivial_raw_doc))
+
+ def test_returns_float(self):
+ getter = query_parser.StaticGetter(9.2)
+ self.assertEqual([9.2], getter.get(trivial_raw_doc))
+
+ def test_returns_None(self):
+ getter = query_parser.StaticGetter(None)
+ self.assertEqual([], getter.get(trivial_raw_doc))
+
+ def test_returns_list(self):
+ getter = query_parser.StaticGetter(['a', 'b'])
+ self.assertEqual(['a', 'b'], getter.get(trivial_raw_doc))
+
+
+class TestExtractField(tests.TestCase):
+
+ def assertExtractField(self, expected, field_name, raw_doc):
+ getter = query_parser.ExtractField(field_name)
+ self.assertEqual(expected, getter.get(raw_doc))
+
+ def test_get_value(self):
+ self.assertExtractField(['bar'], 'foo', {'foo': 'bar'})
+
+ def test_get_value_None(self):
+ self.assertExtractField([], 'foo', {'foo': None})
+
+ def test_get_value_missing_key(self):
+ self.assertExtractField([], 'foo', {})
+
+ def test_get_value_subfield(self):
+ self.assertExtractField(['bar'], 'foo.baz', {'foo': {'baz': 'bar'}})
+
+ def test_get_value_subfield_missing(self):
+ self.assertExtractField([], 'foo.baz', {'foo': 'bar'})
+
+ def test_get_value_dict(self):
+ self.assertExtractField([], 'foo', {'foo': {'baz': 'bar'}})
+
+ def test_get_value_list(self):
+ self.assertExtractField(['bar', 'zap'], 'foo', {'foo': ['bar', 'zap']})
+
+ def test_get_value_mixed_list(self):
+ self.assertExtractField(['bar', 'zap'], 'foo',
+ {'foo': ['bar', ['baa'], 'zap', {'bing': 9}]})
+
+ def test_get_value_list_of_dicts(self):
+ self.assertExtractField([], 'foo', {'foo': [{'zap': 'bar'}]})
+
+ def test_get_value_list_of_dicts2(self):
+ self.assertExtractField(
+ ['bar', 'baz'], 'foo.zap',
+ {'foo': [{'zap': 'bar'}, {'zap': 'baz'}]})
+
+ def test_get_value_int(self):
+ self.assertExtractField([9], 'foo', {'foo': 9})
+
+ def test_get_value_float(self):
+ self.assertExtractField([9.2], 'foo', {'foo': 9.2})
+
+ def test_get_value_bool(self):
+ self.assertExtractField([True], 'foo', {'foo': True})
+ self.assertExtractField([False], 'foo', {'foo': False})
+
+
+class TestLower(tests.TestCase):
+
+ def assertLowerGets(self, expected, input_val):
+ getter = query_parser.Lower(query_parser.StaticGetter(input_val))
+ out_val = getter.get(trivial_raw_doc)
+ self.assertEqual(sorted(expected), sorted(out_val))
+
+ def test_inner_returns_None(self):
+ self.assertLowerGets([], None)
+
+ def test_inner_returns_string(self):
+ self.assertLowerGets(['foo'], 'fOo')
+
+ def test_inner_returns_list(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', 'bAr'])
+
+ def test_inner_returns_int(self):
+ self.assertLowerGets([], 9)
+
+ def test_inner_returns_float(self):
+ self.assertLowerGets([], 9.0)
+
+ def test_inner_returns_bool(self):
+ self.assertLowerGets([], True)
+
+ def test_inner_returns_list_containing_int(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', 9, 'bAr'])
+
+ def test_inner_returns_list_containing_float(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', 9.2, 'bAr'])
+
+ def test_inner_returns_list_containing_bool(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', True, 'bAr'])
+
+ def test_inner_returns_list_containing_list(self):
+ # TODO: Should this be unfolding the inner list?
+ self.assertLowerGets(['foo', 'bar'], ['fOo', ['bAa'], 'bAr'])
+
+ def test_inner_returns_list_containing_dict(self):
+ self.assertLowerGets(['foo', 'bar'], ['fOo', {'baa': 'xam'}, 'bAr'])
+
+
+class TestSplitWords(tests.TestCase):
+
+ def assertSplitWords(self, expected, value):
+ getter = query_parser.SplitWords(query_parser.StaticGetter(value))
+ self.assertEqual(sorted(expected), sorted(getter.get(trivial_raw_doc)))
+
+ def test_inner_returns_None(self):
+ self.assertSplitWords([], None)
+
+ def test_inner_returns_string(self):
+ self.assertSplitWords(['foo', 'bar'], 'foo bar')
+
+ def test_inner_returns_list(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', 'bar sux'])
+
+ def test_deduplicates(self):
+ self.assertSplitWords(['bar'], ['bar', 'bar', 'bar'])
+
+ def test_inner_returns_int(self):
+ self.assertSplitWords([], 9)
+
+ def test_inner_returns_float(self):
+ self.assertSplitWords([], 9.2)
+
+ def test_inner_returns_bool(self):
+ self.assertSplitWords([], True)
+
+ def test_inner_returns_list_containing_int(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', 9, 'bar sux'])
+
+ def test_inner_returns_list_containing_float(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', 9.2, 'bar sux'])
+
+ def test_inner_returns_list_containing_bool(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', True, 'bar sux'])
+
+ def test_inner_returns_list_containing_list(self):
+ # TODO: Expand sub-lists?
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', ['baa'], 'bar sux'])
+
+ def test_inner_returns_list_containing_dict(self):
+ self.assertSplitWords(['foo', 'baz', 'bar', 'sux'],
+ ['foo baz', {'baa': 'xam'}, 'bar sux'])
+
+
+class TestNumber(tests.TestCase):
+
+ def assertNumber(self, expected, value, padding=5):
+ """Assert number transformation produced expected values."""
+ getter = query_parser.Number(query_parser.StaticGetter(value), padding)
+ self.assertEqual(expected, getter.get(trivial_raw_doc))
+
+ def test_inner_returns_None(self):
+ """None is thrown away."""
+ self.assertNumber([], None)
+
+ def test_inner_returns_int(self):
+ """A single integer is converted to zero padded strings."""
+ self.assertNumber(['00009'], 9)
+
+ def test_inner_returns_list(self):
+ """Integers are converted to zero padded strings."""
+ self.assertNumber(['00009', '00235'], [9, 235])
+
+ def test_inner_returns_string(self):
+ """A string is thrown away."""
+ self.assertNumber([], 'foo bar')
+
+ def test_inner_returns_float(self):
+ """A float is thrown away."""
+ self.assertNumber([], 9.2)
+
+ def test_inner_returns_bool(self):
+ """A boolean is thrown away."""
+ self.assertNumber([], True)
+
+ def test_inner_returns_list_containing_strings(self):
+ """Strings in a list are thrown away."""
+ self.assertNumber(['00009'], ['foo baz', 9, 'bar sux'])
+
+ def test_inner_returns_list_containing_float(self):
+ """Floats in a list are thrown away."""
+ self.assertNumber(
+ ['00083', '00073'], [83, 9.2, 73])
+
+ def test_inner_returns_list_containing_bool(self):
+ """Booleans in a list are thrown away."""
+ self.assertNumber(
+ ['00083', '00073'], [83, True, 73])
+
+ def test_inner_returns_list_containing_list(self):
+ """Lists in a list are thrown away."""
+ # TODO: Expand sub-lists?
+ self.assertNumber(
+ ['00012', '03333'], [12, [29], 3333])
+
+ def test_inner_returns_list_containing_dict(self):
+ """Dicts in a list are thrown away."""
+ self.assertNumber(
+ ['00012', '00001'], [12, {54: 89}, 1])
+
+
+class TestIsNull(tests.TestCase):
+
+ def assertIsNull(self, value):
+ getter = query_parser.IsNull(query_parser.StaticGetter(value))
+ self.assertEqual([True], getter.get(trivial_raw_doc))
+
+ def assertIsNotNull(self, value):
+ getter = query_parser.IsNull(query_parser.StaticGetter(value))
+ self.assertEqual([False], getter.get(trivial_raw_doc))
+
+ def test_inner_returns_None(self):
+ self.assertIsNull(None)
+
+ def test_inner_returns_string(self):
+ self.assertIsNotNull('foo')
+
+ def test_inner_returns_list(self):
+ self.assertIsNotNull(['foo', 'bar'])
+
+ def test_inner_returns_empty_list(self):
+ # TODO: is this the behavior we want?
+ self.assertIsNull([])
+
+ def test_inner_returns_int(self):
+ self.assertIsNotNull(9)
+
+ def test_inner_returns_float(self):
+ self.assertIsNotNull(9.2)
+
+ def test_inner_returns_bool(self):
+ self.assertIsNotNull(True)
+
+ # TODO: What about a dict? Inner is likely to return None, even though the
+ # attribute does exist...
+
+
+class TestParser(tests.TestCase):
+
+ def parse(self, spec):
+ parser = query_parser.Parser()
+ return parser.parse(spec)
+
+ def parse_all(self, specs):
+ parser = query_parser.Parser()
+ return parser.parse_all(specs)
+
+ def assertParseError(self, definition):
+ self.assertRaises(errors.IndexDefinitionParseError, self.parse,
+ definition)
+
+ def test_parse_empty_string(self):
+ self.assertRaises(errors.IndexDefinitionParseError, self.parse, "")
+
+ def test_parse_field(self):
+ getter = self.parse("a")
+ self.assertIsInstance(getter, query_parser.ExtractField)
+ self.assertEqual(["a"], getter.field)
+
+ def test_parse_dotted_field(self):
+ getter = self.parse("a.b")
+ self.assertIsInstance(getter, query_parser.ExtractField)
+ self.assertEqual(["a", "b"], getter.field)
+
+ def test_parse_dotted_field_nothing_after_dot(self):
+ self.assertParseError("a.")
+
+ def test_parse_missing_close_on_transformation(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_missing_field_in_transformation(self):
+ self.assertParseError("lower()")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_unknown_op(self):
+ self.assertParseError("no_such_operation(field)")
+
+ def test_parse_wrong_arg_type(self):
+ self.assertParseError("number(field, fnord)")
+
+ def test_parse_transformation(self):
+ getter = self.parse("lower(a)")
+ self.assertIsInstance(getter, query_parser.Lower)
+ self.assertIsInstance(getter.inner, query_parser.ExtractField)
+ self.assertEqual(["a"], getter.inner.field)
+
+ def test_parse_all(self):
+ getters = self.parse_all(["a", "b"])
+ self.assertEqual(2, len(getters))
+ self.assertIsInstance(getters[0], query_parser.ExtractField)
+ self.assertEqual(["a"], getters[0].field)
+ self.assertIsInstance(getters[1], query_parser.ExtractField)
+ self.assertEqual(["b"], getters[1].field)
diff --git a/src/leap/soledad/u1db/tests/test_remote_sync_target.py b/src/leap/soledad/u1db/tests/test_remote_sync_target.py
new file mode 100644
index 00000000..3e0d8995
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_remote_sync_target.py
@@ -0,0 +1,314 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for the remote sync targets"""
+
+import cStringIO
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.remote import (
+ http_app,
+ http_target,
+ oauth_middleware,
+ )
+
+
+class TestHTTPSyncTargetBasics(tests.TestCase):
+
+ def test_parse_url(self):
+ remote_target = http_target.HTTPSyncTarget('http://127.0.0.1:12345/')
+ self.assertEqual('http', remote_target._url.scheme)
+ self.assertEqual('127.0.0.1', remote_target._url.hostname)
+ self.assertEqual(12345, remote_target._url.port)
+ self.assertEqual('/', remote_target._url.path)
+
+
+class TestParsingSyncStream(tests.TestCase):
+
+ def test_wrong_start(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "{}\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "\r\n{}\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "", None)
+
+ def test_wrong_end(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{}", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n", None)
+
+ def test_missing_comma(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{}\r\n{"id": "i", "rev": "r", '
+ '"content": "c", "gen": 3}\r\n]', None)
+
+ def test_no_entries(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n]", None)
+
+ def test_extra_comma(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{},\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{},\r\n{"id": "i", "rev": "r", '
+ '"content": "{}", "gen": 3, "trans_id": "T-sid"}'
+ ',\r\n]',
+ lambda doc, gen, trans_id: None)
+
+ def test_error_in_stream(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"new_generation": 0},'
+ '\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "?"}\r\n', None)
+
+
+def make_http_app(state):
+ return http_app.HTTPApp(state)
+
+
+def http_sync_target(test, path):
+ return http_target.HTTPSyncTarget(test.getURL(path))
+
+
+def make_oauth_http_app(state):
+ app = http_app.HTTPApp(state)
+ application = oauth_middleware.OAuthMiddleware(app, None, prefix='/~/')
+ application.get_oauth_data_store = lambda: tests.testingOAuthStore
+ return application
+
+
+def oauth_http_sync_target(test, path):
+ st = http_sync_target(test, '~/' + path)
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestRemoteSyncTargets(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('http', {'make_app_with_state': make_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': http_sync_target}),
+ ('oauth_http', {'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': oauth_http_sync_target}),
+ ]
+
+ def getSyncTarget(self, path=None):
+ if self.server is None:
+ self.startServer()
+ return self.sync_target(self, path)
+
+ def test_get_sync_info(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ db._set_replica_gen_and_trans_id('other-id', 1, 'T-transid')
+ remote_target = self.getSyncTarget('test')
+ self.assertEqual(('test', 0, '', 1, 'T-transid'),
+ remote_target.get_sync_info('other-id'))
+
+ def test_record_sync_info(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('test')
+ remote_target.record_sync_info('other-id', 2, 'T-transid')
+ self.assertEqual(
+ (2, 'T-transid'), db._get_replica_gen_and_trans_id('other-id'))
+
+ def test_sync_exchange_send(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('test')
+ other_docs = []
+
+ def receive_doc(doc):
+ other_docs.append((doc.doc_id, doc.rev, doc.get_json()))
+
+ doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc, 10, 'T-sid')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc)
+ self.assertEqual(1, new_gen)
+ self.assertGetDoc(
+ db, 'doc-here', 'replica:1', '{"value": "here"}', False)
+
+ def test_sync_exchange_send_failure_and_retry_scenario(self):
+ self.startServer()
+
+ def blackhole_getstderr(inst):
+ return cStringIO.StringIO()
+
+ self.patch(self.server.RequestHandlerClass, 'get_stderr',
+ blackhole_getstderr)
+ db = self.request_state._create_database('test')
+ _put_doc_if_newer = db._put_doc_if_newer
+ trigger_ids = ['doc-here2']
+
+ def bomb_put_doc_if_newer(doc, save_conflict,
+ replica_uid=None, replica_gen=None,
+ replica_trans_id=None):
+ if doc.doc_id in trigger_ids:
+ raise Exception
+ return _put_doc_if_newer(doc, save_conflict=save_conflict,
+ replica_uid=replica_uid, replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+ self.patch(db, '_put_doc_if_newer', bomb_put_doc_if_newer)
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ doc1 = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ doc2 = self.make_document('doc-here2', 'replica:1',
+ '{"value": "here2"}')
+ self.assertRaises(
+ errors.HTTPError,
+ remote_target.sync_exchange,
+ [(doc1, 10, 'T-sid'), (doc2, 11, 'T-sud')],
+ 'replica', last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertGetDoc(db, 'doc-here', 'replica:1', '{"value": "here"}',
+ False)
+ self.assertEqual(
+ (10, 'T-sid'), db._get_replica_gen_and_trans_id('replica'))
+ self.assertEqual([], other_changes)
+ # retry
+ trigger_ids = []
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc2, 11, 'T-sud')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc)
+ self.assertGetDoc(db, 'doc-here2', 'replica:1', '{"value": "here2"}',
+ False)
+ self.assertEqual(
+ (11, 'T-sud'), db._get_replica_gen_and_trans_id('replica'))
+ self.assertEqual(2, new_gen)
+ # bounced back to us
+ self.assertEqual(
+ ('doc-here', 'replica:1', '{"value": "here"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_in_stream_error(self):
+ self.startServer()
+
+ def blackhole_getstderr(inst):
+ return cStringIO.StringIO()
+
+ self.patch(self.server.RequestHandlerClass, 'get_stderr',
+ blackhole_getstderr)
+ db = self.request_state._create_database('test')
+ doc = db.create_doc_from_json('{"value": "there"}')
+
+ def bomb_get_docs(doc_ids, check_for_conflicts=None,
+ include_deleted=False):
+ yield doc
+ # delayed failure case
+ raise errors.Unavailable
+
+ self.patch(db, 'get_docs', bomb_get_docs)
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ self.assertRaises(
+ errors.Unavailable, remote_target.sync_exchange, [], 'replica',
+ last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, '{"value": "there"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_receive(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ doc = db.create_doc_from_json('{"value": "there"}')
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ new_gen, trans_id = remote_target.sync_exchange(
+ [], 'replica', last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertEqual(1, new_gen)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, '{"value": "there"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_send_ensure_callback(self):
+ self.startServer()
+ remote_target = self.getSyncTarget('test')
+ other_docs = []
+ replica_uid_box = []
+
+ def receive_doc(doc):
+ other_docs.append((doc.doc_id, doc.rev, doc.get_json()))
+
+ def ensure_cb(replica_uid):
+ replica_uid_box.append(replica_uid)
+
+ doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc, 10, 'T-sid')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc,
+ ensure_callback=ensure_cb)
+ self.assertEqual(1, new_gen)
+ db = self.request_state.open_database('test')
+ self.assertEqual(1, len(replica_uid_box))
+ self.assertEqual(db._replica_uid, replica_uid_box[0])
+ self.assertGetDoc(
+ db, 'doc-here', 'replica:1', '{"value": "here"}', False)
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/u1db/tests/test_remote_utils.py b/src/leap/soledad/u1db/tests/test_remote_utils.py
new file mode 100644
index 00000000..959cd882
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_remote_utils.py
@@ -0,0 +1,36 @@
+# Copyright 2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for protocol details utils."""
+
+from u1db.tests import TestCase
+from u1db.remote import utils
+
+
+class TestUtils(TestCase):
+
+ def test_check_and_strip_comma(self):
+ line, comma = utils.check_and_strip_comma("abc,")
+ self.assertTrue(comma)
+ self.assertEqual("abc", line)
+
+ line, comma = utils.check_and_strip_comma("abc")
+ self.assertFalse(comma)
+ self.assertEqual("abc", line)
+
+ line, comma = utils.check_and_strip_comma("")
+ self.assertFalse(comma)
+ self.assertEqual("", line)
diff --git a/src/leap/soledad/u1db/tests/test_server_state.py b/src/leap/soledad/u1db/tests/test_server_state.py
new file mode 100644
index 00000000..fc3f1282
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_server_state.py
@@ -0,0 +1,93 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for server state object."""
+
+import os
+
+from u1db import (
+ errors,
+ tests,
+ )
+from u1db.remote import (
+ server_state,
+ )
+from u1db.backends import sqlite_backend
+
+
+class TestServerState(tests.TestCase):
+
+ def setUp(self):
+ super(TestServerState, self).setUp()
+ self.state = server_state.ServerState()
+
+ def test_set_workingdir(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ self.assertTrue(self.state._relpath('path').startswith(tempdir))
+
+ def test_open_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ self.assertFalse(os.path.exists(path))
+ # Create the db, but don't do anything with it
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db = self.state.open_database('test.db')
+ self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_check_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ self.assertFalse(os.path.exists(path))
+
+ # doesn't exist => raises
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.check_database, 'test.db')
+
+ # Create the db, but don't do anything with it
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ # exists => returns
+ res = self.state.check_database('test.db')
+ self.assertIsNone(res)
+
+ def test_ensure_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ self.assertFalse(os.path.exists(path))
+ db, replica_uid = self.state.ensure_database('test.db')
+ self.assertIsInstance(db, sqlite_backend.SQLitePartialExpandDatabase)
+ self.assertEqual(db._replica_uid, replica_uid)
+ self.assertTrue(os.path.exists(path))
+ db2 = self.state.open_database('test.db')
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_delete_database(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ path = tempdir + '/test.db'
+ db, _ = self.state.ensure_database('test.db')
+ db.close()
+ self.state.delete_database('test.db')
+ self.assertFalse(os.path.exists(path))
+
+ def test_delete_database_DoesNotExist(self):
+ tempdir = self.createTempDir()
+ self.state.set_workingdir(tempdir)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.delete_database, 'test.db')
diff --git a/src/leap/soledad/u1db/tests/test_sqlite_backend.py b/src/leap/soledad/u1db/tests/test_sqlite_backend.py
new file mode 100644
index 00000000..73330789
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_sqlite_backend.py
@@ -0,0 +1,493 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test sqlite backend internals."""
+
+import os
+import time
+import threading
+
+from sqlite3 import dbapi2
+
+from u1db import (
+ errors,
+ tests,
+ query_parser,
+ )
+from u1db.backends import sqlite_backend
+from u1db.tests.test_backends import TestAlternativeDocument
+
+
+simple_doc = '{"key": "value"}'
+nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
+
+
+class TestSQLiteDatabase(tests.TestCase):
+
+ def test_atomic_initialize(self):
+ tmpdir = self.createTempDir()
+ dbname = os.path.join(tmpdir, 'atomic.db')
+
+ t2 = None # will be a thread
+
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ _index_storage_value = "testing"
+
+ def __init__(self, dbname, ntry):
+ self._try = ntry
+ self._is_initialized_invocations = 0
+ super(SQLiteDatabaseTesting, self).__init__(dbname)
+
+ def _is_initialized(self, c):
+ res = super(SQLiteDatabaseTesting, self)._is_initialized(c)
+ if self._try == 1:
+ self._is_initialized_invocations += 1
+ if self._is_initialized_invocations == 2:
+ t2.start()
+ # hard to do better and have a generic test
+ time.sleep(0.05)
+ return res
+
+ outcome2 = []
+
+ def second_try():
+ try:
+ db2 = SQLiteDatabaseTesting(dbname, 2)
+ except Exception, e:
+ outcome2.append(e)
+ else:
+ outcome2.append(db2)
+
+ t2 = threading.Thread(target=second_try)
+ db1 = SQLiteDatabaseTesting(dbname, 1)
+ t2.join()
+
+ self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting)
+ db2 = outcome2[0]
+ self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor()))
+
+
+class TestSQLitePartialExpandDatabase(tests.TestCase):
+
+ def setUp(self):
+ super(TestSQLitePartialExpandDatabase, self).setUp()
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db._set_replica_uid('test')
+
+ def test_create_database(self):
+ raw_db = self.db._get_sqlite_handle()
+ self.assertNotEqual(None, raw_db)
+
+ def test_default_replica_uid(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.assertIsNot(None, self.db._replica_uid)
+ self.assertEqual(32, len(self.db._replica_uid))
+ int(self.db._replica_uid, 16)
+
+ def test__close_sqlite_handle(self):
+ raw_db = self.db._get_sqlite_handle()
+ self.db._close_sqlite_handle()
+ self.assertRaises(dbapi2.ProgrammingError,
+ raw_db.cursor)
+
+ def test_create_database_initializes_schema(self):
+ raw_db = self.db._get_sqlite_handle()
+ c = raw_db.cursor()
+ c.execute("SELECT * FROM u1db_config")
+ config = dict([(r[0], r[1]) for r in c.fetchall()])
+ self.assertEqual({'sql_schema': '0', 'replica_uid': 'test',
+ 'index_storage': 'expand referenced'}, config)
+
+ # These tables must exist, though we don't care what is in them yet
+ c.execute("SELECT * FROM transaction_log")
+ c.execute("SELECT * FROM document")
+ c.execute("SELECT * FROM document_fields")
+ c.execute("SELECT * FROM sync_log")
+ c.execute("SELECT * FROM conflicts")
+ c.execute("SELECT * FROM index_definitions")
+
+ def test__parse_index(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ g = self.db._parse_index_definition('fieldname')
+ self.assertIsInstance(g, query_parser.ExtractField)
+ self.assertEqual(['fieldname'], g.field)
+
+ def test__update_indexes(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ g = self.db._parse_index_definition('fieldname')
+ c = self.db._get_sqlite_handle().cursor()
+ self.db._update_indexes('doc-id', {'fieldname': 'val'},
+ [('fieldname', g)], c)
+ c.execute('SELECT doc_id, field_name, value FROM document_fields')
+ self.assertEqual([('doc-id', 'fieldname', 'val')],
+ c.fetchall())
+
+ def test__set_replica_uid(self):
+ # Start from scratch, so that replica_uid isn't set.
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.assertIsNot(None, self.db._real_replica_uid)
+ self.assertIsNot(None, self.db._replica_uid)
+ self.db._set_replica_uid('foo')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'")
+ self.assertEqual(('foo',), c.fetchone())
+ self.assertEqual('foo', self.db._real_replica_uid)
+ self.assertEqual('foo', self.db._replica_uid)
+ self.db._close_sqlite_handle()
+ self.assertEqual('foo', self.db._replica_uid)
+
+ def test__get_generation(self):
+ self.assertEqual(0, self.db._get_generation())
+
+ def test__get_generation_info(self):
+ self.assertEqual((0, ''), self.db._get_generation_info())
+
+ def test_create_index(self):
+ self.db.create_index('test-idx', "key")
+ self.assertEqual([('test-idx', ["key"])], self.db.list_indexes())
+
+ def test_create_index_multiple_fields(self):
+ self.db.create_index('test-idx', "key", "key2")
+ self.assertEqual([('test-idx', ["key", "key2"])],
+ self.db.list_indexes())
+
+ def test__get_index_definition(self):
+ self.db.create_index('test-idx', "key", "key2")
+ # TODO: How would you test that an index is getting used for an SQL
+ # request?
+ self.assertEqual(["key", "key2"],
+ self.db._get_index_definition('test-idx'))
+
+ def test_list_index_mixed(self):
+ # Make sure that we properly order the output
+ c = self.db._get_sqlite_handle().cursor()
+ # We intentionally insert the data in weird ordering, to make sure the
+ # query still gets it back correctly.
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ [('idx-1', 0, 'key10'),
+ ('idx-2', 2, 'key22'),
+ ('idx-1', 1, 'key11'),
+ ('idx-2', 0, 'key20'),
+ ('idx-2', 1, 'key21')])
+ self.assertEqual([('idx-1', ['key10', 'key11']),
+ ('idx-2', ['key20', 'key21', 'key22'])],
+ self.db.list_indexes())
+
+ def test_no_indexes_no_document_fields(self):
+ self.db.create_doc_from_json(
+ '{"key1": "val1", "key2": "val2"}')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([], c.fetchall())
+
+ def test_create_extracts_fields(self):
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([], c.fetchall())
+ self.db.create_index('test', 'key1', 'key2')
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual(sorted(
+ [(doc1.doc_id, "key1", "val1"),
+ (doc1.doc_id, "key2", "val2"),
+ (doc2.doc_id, "key1", "valx"),
+ (doc2.doc_id, "key2", "valy"),
+ ]), sorted(c.fetchall()))
+
+ def test_put_updates_fields(self):
+ self.db.create_index('test', 'key1', 'key2')
+ doc1 = self.db.create_doc_from_json(
+ '{"key1": "val1", "key2": "val2"}')
+ doc1.content = {"key1": "val1", "key2": "valy"}
+ self.db.put_doc(doc1)
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, "key1", "val1"),
+ (doc1.doc_id, "key2", "valy"),
+ ], c.fetchall())
+
+ def test_put_updates_nested_fields(self):
+ self.db.create_index('test', 'key', 'sub.doc')
+ doc1 = self.db.create_doc_from_json(nested_doc)
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, "key", "value"),
+ (doc1.doc_id, "sub.doc", "underneath"),
+ ], c.fetchall())
+
+ def test__ensure_schema_rollback(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/rollback.db'
+
+ class SQLitePartialExpandDbTesting(
+ sqlite_backend.SQLitePartialExpandDatabase):
+
+ def _set_replica_uid_in_transaction(self, uid):
+ super(SQLitePartialExpandDbTesting,
+ self)._set_replica_uid_in_transaction(uid)
+ if fail:
+ raise Exception()
+
+ db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting)
+ db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
+ fail = True
+ self.assertRaises(Exception, db._ensure_schema)
+ fail = False
+ db._initialize(db._db_handle.cursor())
+
+ def test__open_database(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase._open_database(path)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test__open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase._open_database(
+ path, document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test__open_database_non_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase._open_database, path)
+
+ def test__open_database_during_init(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/initialised.db'
+ db = sqlite_backend.SQLitePartialExpandDatabase.__new__(
+ sqlite_backend.SQLitePartialExpandDatabase)
+ db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
+ self.addCleanup(db.close)
+ observed = []
+
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c)
+ db._ensure_schema() # init db
+ observed.append(res[0])
+ return res
+
+ db2 = SQLiteDatabaseTesting._open_database(path)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ self.assertEqual([None,
+ sqlite_backend.SQLitePartialExpandDatabase._index_storage_value],
+ observed)
+
+ def test__open_database_invalid(self):
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path1 = temp_dir + '/invalid1.db'
+ with open(path1, 'wb') as f:
+ f.write("")
+ self.assertRaises(dbapi2.OperationalError,
+ SQLiteDatabaseTesting._open_database, path1)
+ with open(path1, 'wb') as f:
+ f.write("invalid")
+ self.assertRaises(dbapi2.DatabaseError,
+ SQLiteDatabaseTesting._open_database, path1)
+
+ def test_open_database_existing(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(
+ path, create=False, document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test_open_database_create(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/new.sqlite'
+ sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_open_database_non_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.open_database, path,
+ create=False)
+
+ def test_delete_database_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/new.sqlite'
+ db = sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db.close()
+ sqlite_backend.SQLiteDatabase.delete_database(path)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.open_database, path,
+ create=False)
+
+ def test_delete_database_nonexistent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.delete_database, path)
+
+ def test__get_indexed_fields(self):
+ self.db.create_index('idx1', 'a', 'b')
+ self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields())
+ self.db.create_index('idx2', 'b', 'c')
+ self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields())
+
+ def test_indexed_fields_expanded(self):
+ self.db.create_index('idx1', 'key1')
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ self.assertEqual(set(['key1']), self.db._get_indexed_fields())
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall())
+
+ def test_create_index_updates_fields(self):
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ self.db.create_index('idx1', 'key1')
+ self.assertEqual(set(['key1']), self.db._get_indexed_fields())
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall())
+
+ def assertFormatQueryEquals(self, exp_statement, exp_args, definition,
+ values):
+ statement, args = self.db._format_query(definition, values)
+ self.assertEqual(exp_statement, statement)
+ self.assertEqual(exp_args, args)
+
+ def test__format_query(self):
+ self.assertFormatQueryEquals(
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON "
+ "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name "
+ "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content "
+ "ORDER BY d0.value;", ["key1", "a"],
+ ["key1"], ["a"])
+
+ def test__format_query2(self):
+ self.assertFormatQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ["key1", "a", "key2", "b", "key3", "c"],
+ ["key1", "key2", "key3"], ["a", "b", "c"])
+
+ def test__format_query_wildcard(self):
+ self.assertFormatQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? '
+ 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content '
+ 'ORDER BY d0.value, d1.value, d2.value;',
+ ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"],
+ ["a", "b*", "*"])
+
+ def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition,
+ start_value, end_value):
+ statement, args = self.db._format_range_query(
+ definition, start_value, end_value)
+ self.assertEqual(exp_statement, statement)
+ self.assertEqual(exp_args, args)
+
+ def test__format_range_query(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q',
+ 'key3', 'r'],
+ ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"])
+
+ def test__format_range_query_no_start(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c'],
+ ["key1", "key2", "key3"], None, ["a", "b", "c"])
+
+ def test__format_range_query_no_end(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c'],
+ ["key1", "key2", "key3"], ["a", "b", "c"], None)
+
+ def test__format_range_query_wildcard(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? '
+ 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? '
+ 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id '
+ 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, '
+ 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*',
+ 'key3'],
+ ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"])
diff --git a/src/leap/soledad/u1db/tests/test_sync.py b/src/leap/soledad/u1db/tests/test_sync.py
new file mode 100644
index 00000000..f2a925f0
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_sync.py
@@ -0,0 +1,1285 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The Synchronization class for U1DB."""
+
+import os
+from wsgiref import simple_server
+
+from u1db import (
+ errors,
+ sync,
+ tests,
+ vectorclock,
+ SyncTarget,
+ )
+from u1db.backends import (
+ inmemory,
+ )
+from u1db.remote import (
+ http_target,
+ )
+
+from u1db.tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+ )
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+
+def _make_local_db_and_target(test):
+ db = test.create_database('test')
+ st = db.get_sync_target()
+ return db, st
+
+
+def _make_local_db_and_http_target(test, path='test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ st = http_target.HTTPSyncTarget.connect(test.getURL(path))
+ return db, st
+
+
+def _make_c_db_and_c_http_target(test, path='test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ url = test.getURL(path)
+ st = tests.c_backend_wrapper.create_http_sync_target(url)
+ return db, st
+
+
+def _make_local_db_and_oauth_http_target(test):
+ db, st = _make_local_db_and_http_target(test, '~/test')
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return db, st
+
+
+def _make_c_db_and_oauth_http_target(test, path='~/test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ url = test.getURL(path)
+ st = tests.c_backend_wrapper.create_oauth_http_sync_target(url,
+ tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return db, st
+
+
+target_scenarios = [
+ ('local', {'create_db_and_target': _make_local_db_and_target}),
+ ('http', {'create_db_and_target': _make_local_db_and_http_target,
+ 'make_app_with_state': make_http_app}),
+ ('oauth_http', {'create_db_and_target':
+ _make_local_db_and_oauth_http_target,
+ 'make_app_with_state': make_oauth_http_app}),
+ ]
+
+c_db_scenarios = [
+ ('local,c', {'create_db_and_target': _make_local_db_and_target,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'copy_database_for_test': tests.copy_c_database_for_test,
+ 'make_document_for_test': tests.make_c_document_for_test,
+ 'whitebox': False}),
+ ('http,c', {'create_db_and_target': _make_c_db_and_c_http_target,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'copy_database_for_test': tests.copy_c_database_for_test,
+ 'make_document_for_test': tests.make_c_document_for_test,
+ 'make_app_with_state': make_http_app,
+ 'whitebox': False}),
+ ('oauth_http,c', {'create_db_and_target': _make_c_db_and_oauth_http_target,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'copy_database_for_test': tests.copy_c_database_for_test,
+ 'make_document_for_test': tests.make_c_document_for_test,
+ 'make_app_with_state': make_oauth_http_app,
+ 'whitebox': False}),
+ ]
+
+
+class DatabaseSyncTargetTests(tests.DatabaseBaseTests,
+ tests.TestCaseWithServer):
+
+ scenarios = (tests.multiply_scenarios(tests.DatabaseBaseTests.scenarios,
+ target_scenarios)
+ + c_db_scenarios)
+ # whitebox true means self.db is the actual local db object
+ # against which the sync is performed
+ whitebox = True
+
+ def setUp(self):
+ super(DatabaseSyncTargetTests, self).setUp()
+ self.db, self.st = self.create_db_and_target(self)
+ self.other_changes = []
+
+ def tearDown(self):
+ # We delete them explicitly, so that connections are cleanly closed
+ del self.st
+ self.db.close()
+ del self.db
+ super(DatabaseSyncTargetTests, self).tearDown()
+
+ def receive_doc(self, doc, gen, trans_id):
+ self.other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ def set_trace_hook(self, callback, shallow=False):
+ setter = (self.st._set_trace_hook if not shallow else
+ self.st._set_trace_hook_shallow)
+ try:
+ setter(callback)
+ except NotImplementedError:
+ self.skipTest("%s does not implement _set_trace_hook"
+ % (self.st.__class__.__name__,))
+
+ def test_get_sync_target(self):
+ self.assertIsNot(None, self.st)
+
+ def test_get_sync_info(self):
+ self.assertEqual(
+ ('test', 0, '', 0, ''), self.st.get_sync_info('other'))
+
+ def test_create_doc_updates_sync_info(self):
+ self.assertEqual(
+ ('test', 0, '', 0, ''), self.st.get_sync_info('other'))
+ self.db.create_doc_from_json(simple_doc)
+ self.assertEqual(1, self.st.get_sync_info('other')[1])
+
+ def test_record_sync_info(self):
+ self.st.record_sync_info('replica', 10, 'T-transid')
+ self.assertEqual(
+ ('test', 0, '', 10, 'T-transid'), self.st.get_sync_info('replica'))
+
+ def test_sync_exchange(self):
+ docs_by_gen = [
+ (self.make_document('doc-id', 'replica:1', simple_doc), 10,
+ 'T-sid')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False)
+ self.assertTransactionLog(['doc-id'], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 1, last_trans_id),
+ (self.other_changes, new_gen, last_trans_id))
+ self.assertEqual(10, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_deleted(self):
+ doc = self.db.create_doc_from_json('{}')
+ edit_rev = 'replica:1|' + doc.rev
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, edit_rev, None), 10, 'T-sid')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, edit_rev, None, False)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 2, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(10, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_push_many(self):
+ docs_by_gen = [
+ (self.make_document('doc-id', 'replica:1', simple_doc), 10, 'T-1'),
+ (self.make_document('doc-id2', 'replica:1', nested_doc), 11,
+ 'T-2')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False)
+ self.assertGetDoc(self.db, 'doc-id2', 'replica:1', nested_doc, False)
+ self.assertTransactionLog(['doc-id', 'doc-id2'], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 2, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(11, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_refuses_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'replica:1', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1])
+ self.assertEqual(1, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_ignores_convergence(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ gen, txid = self.db._get_generation_info()
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, doc.rev, simple_doc), 10, 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=gen,
+ last_known_trans_id=txid, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(([], 1), (self.other_changes, new_gen))
+
+ def test_sync_exchange_returns_new_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1])
+ self.assertEqual(1, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_returns_deleted_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, None, 2), self.other_changes[0][:-1])
+ self.assertEqual(2, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_returns_many_new_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ self.assertEqual(2, new_gen)
+ self.assertEqual(
+ [(doc.doc_id, doc.rev, simple_doc, 1),
+ (doc2.doc_id, doc2.rev, nested_doc, 2)],
+ [c[:-1] for c in self.other_changes])
+ if self.whitebox:
+ self.assertEqual(
+ self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs':
+ [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]})
+
+ def test_sync_exchange_getting_newer_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_with_concurrent_updates_of_synced_doc(self):
+ expected = []
+
+ def before_whatschanged_cb(state):
+ if state != 'before whats_changed':
+ return
+ cont = '{"key": "cuncurrent"}'
+ conc_rev = self.db.put_doc(
+ self.make_document(doc.doc_id, 'test:1|z:2', cont))
+ expected.append((doc.doc_id, conc_rev, cont, 3))
+
+ self.set_trace_hook(before_whatschanged_cb)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(expected, [c[:-1] for c in self.other_changes])
+ self.assertEqual(3, new_gen)
+
+ def test_sync_exchange_with_concurrent_updates(self):
+
+ def after_whatschanged_cb(state):
+ if state != 'after whats_changed':
+ return
+ self.db.create_doc_from_json('{"new": "doc"}')
+
+ self.set_trace_hook(after_whatschanged_cb)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_converged_handling(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ docs_by_gen = [
+ (self.make_document('new', 'other:1', '{}'), 4, 'T-foo'),
+ (self.make_document(doc.doc_id, doc.rev, doc.get_json()), 5,
+ 'T-bar')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_detect_incomplete_exchange(self):
+ def before_get_docs_explode(state):
+ if state != 'before get_docs':
+ return
+ raise errors.U1DBError("fail")
+ self.set_trace_hook(before_get_docs_explode)
+ # suppress traceback printing in the wsgiref server
+ self.patch(simple_server.ServerHandler,
+ 'log_exception', lambda h, exc_info: None)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertRaises(
+ (errors.U1DBError, errors.BrokenSyncStream),
+ self.st.sync_exchange, [], 'other-replica',
+ last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=self.receive_doc)
+
+ def test_sync_exchange_doc_ids(self):
+ sync_exchange_doc_ids = getattr(self.st, 'sync_exchange_doc_ids', None)
+ if sync_exchange_doc_ids is None:
+ self.skipTest("sync_exchange_doc_ids not implemented")
+ db2 = self.create_database('test2')
+ doc = db2.create_doc_from_json(simple_doc)
+ new_gen, trans_id = sync_exchange_doc_ids(
+ db2, [(doc.doc_id, 10, 'T-sid')], 0, None,
+ return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 1, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(10, self.st.get_sync_info(db2._replica_uid)[3])
+
+ def test__set_trace_hook(self):
+ called = []
+
+ def cb(state):
+ called.append(state)
+
+ self.set_trace_hook(cb)
+ self.st.sync_exchange([], 'replica', 0, None, self.receive_doc)
+ self.st.record_sync_info('replica', 0, 'T-sid')
+ self.assertEqual(['before whats_changed',
+ 'after whats_changed',
+ 'before get_docs',
+ 'record_sync_info',
+ ],
+ called)
+
+ def test__set_trace_hook_shallow(self):
+ if (self.st._set_trace_hook_shallow == self.st._set_trace_hook
+ or self.st._set_trace_hook_shallow.im_func ==
+ SyncTarget._set_trace_hook_shallow.im_func):
+ # shallow same as full
+ expected = ['before whats_changed',
+ 'after whats_changed',
+ 'before get_docs',
+ 'record_sync_info',
+ ]
+ else:
+ expected = ['sync_exchange', 'record_sync_info']
+
+ called = []
+
+ def cb(state):
+ called.append(state)
+
+ self.set_trace_hook(cb, shallow=True)
+ self.st.sync_exchange([], 'replica', 0, None, self.receive_doc)
+ self.st.record_sync_info('replica', 0, 'T-sid')
+ self.assertEqual(expected, called)
+
+
+def sync_via_synchronizer(test, db_source, db_target, trace_hook=None,
+ trace_hook_shallow=None):
+ target = db_target.get_sync_target()
+ trace_hook = trace_hook or trace_hook_shallow
+ if trace_hook:
+ target._set_trace_hook(trace_hook)
+ return sync.Synchronizer(db_source, target).sync()
+
+
+sync_scenarios = []
+for name, scenario in tests.LOCAL_DATABASES_SCENARIOS:
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_synchronizer
+ sync_scenarios.append((name, scenario))
+ scenario = dict(scenario)
+
+
+def make_database_for_http_test(test, replica_uid):
+ if test.server is None:
+ test.startServer()
+ db = test.request_state._create_database(replica_uid)
+ try:
+ http_at = test._http_at
+ except AttributeError:
+ http_at = test._http_at = {}
+ http_at[db] = replica_uid
+ return db
+
+
+def copy_database_for_http_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR HOUSE.
+ if test.server is None:
+ test.startServer()
+ new_db = test.request_state._copy_database(db)
+ try:
+ http_at = test._http_at
+ except AttributeError:
+ http_at = test._http_at = {}
+ path = db._replica_uid
+ while path in http_at.values():
+ path += 'copy'
+ http_at[new_db] = path
+ return new_db
+
+
+def sync_via_synchronizer_and_http(test, db_source, db_target,
+ trace_hook=None, trace_hook_shallow=None):
+ if trace_hook:
+ test.skipTest("full trace hook unsupported over http")
+ path = test._http_at[db_target]
+ target = http_target.HTTPSyncTarget.connect(test.getURL(path))
+ if trace_hook_shallow:
+ target._set_trace_hook_shallow(trace_hook_shallow)
+ return sync.Synchronizer(db_source, target).sync()
+
+
+sync_scenarios.append(('pyhttp', {
+ 'make_database_for_test': make_database_for_http_test,
+ 'copy_database_for_test': copy_database_for_http_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_http_app,
+ 'do_sync': sync_via_synchronizer_and_http
+ }))
+
+
+if tests.c_backend_wrapper is not None:
+ # TODO: We should hook up sync tests with an HTTP target
+ def sync_via_c_sync(test, db_source, db_target, trace_hook=None,
+ trace_hook_shallow=None):
+ target = db_target.get_sync_target()
+ trace_hook = trace_hook or trace_hook_shallow
+ if trace_hook:
+ target._set_trace_hook(trace_hook)
+ return tests.c_backend_wrapper.sync_db_to_target(db_source, target)
+
+ for name, scenario in tests.C_DATABASE_SCENARIOS:
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_synchronizer
+ sync_scenarios.append((name + ',pysync', scenario))
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_c_sync
+ sync_scenarios.append((name + ',csync', scenario))
+
+
+class DatabaseSyncTests(tests.DatabaseBaseTests,
+ tests.TestCaseWithServer):
+
+ scenarios = sync_scenarios
+ do_sync = None # set by scenarios
+
+ def create_database(self, replica_uid, sync_role=None):
+ if replica_uid == 'test' and sync_role is None:
+ # created up the chain by base class but unused
+ return None
+ db = self.create_database_for_role(replica_uid, sync_role)
+ if sync_role:
+ self._use_tracking[db] = (replica_uid, sync_role)
+ return db
+
+ def create_database_for_role(self, replica_uid, sync_role):
+ # hook point for reuse
+ return super(DatabaseSyncTests, self).create_database(replica_uid)
+
+ def copy_database(self, db, sync_role=None):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ db_copy = super(DatabaseSyncTests, self).copy_database(db)
+ name, orig_sync_role = self._use_tracking[db]
+ self._use_tracking[db_copy] = (name + '(copy)', sync_role
+ or orig_sync_role)
+ return db_copy
+
+ def sync(self, db_from, db_to, trace_hook=None,
+ trace_hook_shallow=None):
+ from_name, from_sync_role = self._use_tracking[db_from]
+ to_name, to_sync_role = self._use_tracking[db_to]
+ if from_sync_role not in ('source', 'both'):
+ raise Exception("%s marked for %s use but used as source" %
+ (from_name, from_sync_role))
+ if to_sync_role not in ('target', 'both'):
+ raise Exception("%s marked for %s use but used as target" %
+ (to_name, to_sync_role))
+ return self.do_sync(self, db_from, db_to, trace_hook,
+ trace_hook_shallow)
+
+ def setUp(self):
+ self._use_tracking = {}
+ super(DatabaseSyncTests, self).setUp()
+
+ def assertLastExchangeLog(self, db, expected):
+ log = getattr(db, '_last_exchange_log', None)
+ if log is None:
+ return
+ self.assertEqual(expected, log)
+
+ def test_sync_tracks_db_generation_of_other(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.assertEqual(0, self.sync(self.db1, self.db2))
+ self.assertEqual(
+ (0, ''), self.db1._get_replica_gen_and_trans_id('test2'))
+ self.assertEqual(
+ (0, ''), self.db2._get_replica_gen_and_trans_id('test1'))
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [], 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 0}})
+
+ def test_sync_autoresolves(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ rev1 = doc1.rev
+ doc2 = self.db2.create_doc_from_json(simple_doc, doc_id='doc')
+ rev2 = doc2.rev
+ self.sync(self.db1, self.db2)
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual(doc.rev, self.db2.get_doc('doc').rev)
+ v = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev1)))
+ self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev2)))
+
+ def test_sync_autoresolves_moar(self):
+ # here we test that when a database that has a conflicted document is
+ # the source of a sync, and the target database has a revision of the
+ # conflicted document that is newer than the source database's, and
+ # that target's database's document's content is the same as the
+ # source's document's conflict's, the source's document's conflict gets
+ # autoresolved, and the source's document's revision bumped.
+ #
+ # idea is as follows:
+ # A B
+ # a1 -
+ # `------->
+ # a1 a1
+ # v v
+ # a2 a1b1
+ # `------->
+ # a1b1+a2 a1b1
+ # v
+ # a1b1+a2 a1b2 (a1b2 has same content as a2)
+ # `------->
+ # a3b2 a1b2 (autoresolved)
+ # `------->
+ # a3b2 a3b2
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db2)
+ # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertTrue(doc.has_conflicts)
+ # set db2 to have a doc of {} (same as db1 before the conflict)
+ doc = self.db2.get_doc('doc')
+ doc.set_json('{}')
+ self.db2.put_doc(doc)
+ rev2 = doc.rev
+ # sync it across
+ self.sync(self.db1, self.db2)
+ # tadaa!
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ vec1 = vectorclock.VectorClockRev(rev1)
+ vec2 = vectorclock.VectorClockRev(rev2)
+ vec3 = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(vec3.is_newer(vec1))
+ self.assertTrue(vec3.is_newer(vec2))
+ # because the conflict is on the source, sync it another time
+ self.sync(self.db1, self.db2)
+ # make sure db2 now has the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_autoresolves_moar_backwards(self):
+ # here we test that when a database that has a conflicted document is
+ # the target of a sync, and the source database has a revision of the
+ # conflicted document that is newer than the target database's, and
+ # that source's database's document's content is the same as the
+ # target's document's conflict's, the target's document's conflict gets
+ # autoresolved, and the document's revision bumped.
+ #
+ # idea is as follows:
+ # A B
+ # a1 -
+ # `------->
+ # a1 a1
+ # v v
+ # a2 a1b1
+ # `------->
+ # a1b1+a2 a1b1
+ # v
+ # a1b1+a2 a1b2 (a1b2 has same content as a2)
+ # <-------'
+ # a3b2 a3b2 (autoresolved and propagated)
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db2)
+ # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertTrue(doc.has_conflicts)
+ revc = self.db1.get_doc_conflicts('doc')[-1].rev
+ # set db2 to have a doc of {} (same as db1 before the conflict)
+ doc = self.db2.get_doc('doc')
+ doc.set_json('{}')
+ self.db2.put_doc(doc)
+ rev2 = doc.rev
+ # sync it across
+ self.sync(self.db2, self.db1)
+ # tadaa!
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ vec1 = vectorclock.VectorClockRev(rev1)
+ vec2 = vectorclock.VectorClockRev(rev2)
+ vec3 = vectorclock.VectorClockRev(doc.rev)
+ vecc = vectorclock.VectorClockRev(revc)
+ self.assertTrue(vec3.is_newer(vec1))
+ self.assertTrue(vec3.is_newer(vec2))
+ self.assertTrue(vec3.is_newer(vecc))
+ # make sure db2 now has the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_autoresolves_moar_backwards_three(self):
+ # same as autoresolves_moar_backwards, but with three databases (note
+ # all the syncs go in the same direction -- this is a more natural
+ # scenario):
+ #
+ # A B C
+ # a1 - -
+ # `------->
+ # a1 a1 -
+ # `------->
+ # a1 a1 a1
+ # v v
+ # a2 a1b1 a1
+ # `------------------->
+ # a2 a1b1 a2
+ # `------->
+ # a2+a1b1 a2
+ # v
+ # a2 a2+a1b1 a2c1 (same as a1b1)
+ # `------------------->
+ # a2c1 a2+a1b1 a2c1
+ # `------->
+ # a2b2c1 a2b2c1 a2c1
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ self.db3 = self.create_database('test3', 'target')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ self.sync(self.db2, self.db3)
+ for db, content in [(self.db2, '{"hi": 42}'),
+ (self.db1, '{}'),
+ ]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db3)
+ self.sync(self.db2, self.db3)
+ # db2 and db3 now both have a doc of {}, but db2 has a
+ # conflict
+ doc = self.db2.get_doc('doc')
+ self.assertTrue(doc.has_conflicts)
+ revc = self.db2.get_doc_conflicts('doc')[-1].rev
+ self.assertEqual('{}', doc.get_json())
+ self.assertEqual(self.db3.get_doc('doc').get_json(), doc.get_json())
+ self.assertEqual(self.db3.get_doc('doc').rev, doc.rev)
+ # set db3 to have a doc of {hi:42} (same as db2 before the conflict)
+ doc = self.db3.get_doc('doc')
+ doc.set_json('{"hi": 42}')
+ self.db3.put_doc(doc)
+ rev3 = doc.rev
+ # sync it across to db1
+ self.sync(self.db1, self.db3)
+ # db1 now has hi:42, with a rev that is newer than db2's doc
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual('{"hi": 42}', doc.get_json())
+ VCR = vectorclock.VectorClockRev
+ self.assertTrue(VCR(rev1).is_newer(VCR(self.db2.get_doc('doc').rev)))
+ # so sync it to db2
+ self.sync(self.db1, self.db2)
+ # tadaa!
+ doc = self.db2.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ # db2's revision of the document is strictly newer than db1's before
+ # the sync, and db3's before that sync way back when
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(rev1)))
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(rev3)))
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(revc)))
+ # make sure both dbs now have the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_puts_changes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ self.assertEqual(1, self.sync(self.db1, self.db2))
+ self.assertGetDoc(self.db2, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0])
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc.doc_id, doc.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 1}})
+
+ def test_sync_pulls_changes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db2.create_doc_from_json(simple_doc)
+ self.db1.create_index('test-idx', 'key')
+ self.assertEqual(0, self.sync(self.db1, self.db2))
+ self.assertGetDoc(self.db1, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0])
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [], 'last_known_gen': 0},
+ 'return': {'docs': [(doc.doc_id, doc.rev)],
+ 'last_gen': 1}})
+ self.assertEqual([doc], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_pulling_doesnt_update_other_if_changed(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db2.create_doc_from_json(simple_doc)
+ # After the local side has sent its list of docs, before we start
+ # receiving the "targets" response, we update the local database with a
+ # new record.
+ # When we finish synchronizing, we can notice that something locally
+ # was updated, and we cannot tell c2 our new updated generation
+
+ def before_get_docs(state):
+ if state != 'before get_docs':
+ return
+ self.db1.create_doc_from_json(simple_doc)
+
+ self.assertEqual(0, self.sync(self.db1, self.db2,
+ trace_hook=before_get_docs))
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [], 'last_known_gen': 0},
+ 'return': {'docs': [(doc.doc_id, doc.rev)],
+ 'last_gen': 1}})
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ # c2 should not have gotten a '_record_sync_info' call, because the
+ # local database had been updated more than just by the messages
+ # returned from c2.
+ self.assertEqual(
+ (0, ''), self.db2._get_replica_gen_and_trans_id('test1'))
+
+ def test_sync_doesnt_update_other_if_nothing_pulled(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(simple_doc)
+
+ def no_record_sync_info(state):
+ if state != 'record_sync_info':
+ return
+ self.fail('SyncTarget.record_sync_info was called')
+ self.assertEqual(1, self.sync(self.db1, self.db2,
+ trace_hook_shallow=no_record_sync_info))
+ self.assertEqual(
+ 1,
+ self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)[0])
+
+ def test_sync_ignores_convergence(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ self.db3 = self.create_database('test3', 'target')
+ self.assertEqual(1, self.sync(self.db1, self.db3))
+ self.assertEqual(0, self.sync(self.db2, self.db3))
+ self.assertEqual(1, self.sync(self.db1, self.db2))
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc.doc_id, doc.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 1}})
+
+ def test_sync_ignores_superseded(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ doc_rev1 = doc.rev
+ self.db3 = self.create_database('test3', 'target')
+ self.sync(self.db1, self.db3)
+ self.sync(self.db2, self.db3)
+ new_content = '{"key": "altval"}'
+ doc.set_json(new_content)
+ self.db1.put_doc(doc)
+ doc_rev2 = doc.rev
+ self.sync(self.db2, self.db1)
+ self.assertLastExchangeLog(self.db1,
+ {'receive': {'docs': [(doc.doc_id, doc_rev1)],
+ 'source_uid': 'test2',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [(doc.doc_id, doc_rev2)],
+ 'last_gen': 2}})
+ self.assertGetDoc(self.db1, doc.doc_id, doc_rev2, new_content, False)
+
+ def test_sync_sees_remote_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ doc1_rev = doc1.rev
+ self.db1.create_index('test-idx', 'key')
+ new_doc = '{"key": "altval"}'
+ doc2 = self.db2.create_doc_from_json(new_doc, doc_id=doc_id)
+ doc2_rev = doc2.rev
+ self.assertTransactionLog([doc1.doc_id], self.db1)
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc_id, doc1_rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [(doc_id, doc2_rev)],
+ 'last_gen': 1}})
+ self.assertTransactionLog([doc_id, doc_id], self.db1)
+ self.assertGetDoc(self.db1, doc_id, doc2_rev, new_doc, True)
+ self.assertGetDoc(self.db2, doc_id, doc2_rev, new_doc, False)
+ from_idx = self.db1.get_from_index('test-idx', 'altval')[0]
+ self.assertEqual(doc2.doc_id, from_idx.doc_id)
+ self.assertEqual(doc2.rev, from_idx.rev)
+ self.assertTrue(from_idx.has_conflicts)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_sees_remote_delete_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ doc2 = self.make_document(doc1.doc_id, doc1.rev, doc1.get_json())
+ new_doc = '{"key": "altval"}'
+ doc1.set_json(new_doc)
+ self.db1.put_doc(doc1)
+ self.db2.delete_doc(doc2)
+ self.assertTransactionLog([doc_id, doc_id], self.db1)
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc_id, doc1.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 2, 'last_known_gen': 1},
+ 'return': {'docs': [(doc_id, doc2.rev)],
+ 'last_gen': 2}})
+ self.assertTransactionLog([doc_id, doc_id, doc_id], self.db1)
+ self.assertGetDocIncludeDeleted(self.db1, doc_id, doc2.rev, None, True)
+ self.assertGetDocIncludeDeleted(
+ self.db2, doc_id, doc2.rev, None, False)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_local_race_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc.doc_id
+ doc1_rev = doc.rev
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ content1 = '{"key": "localval"}'
+ content2 = '{"key": "altval"}'
+ doc.set_json(content2)
+ self.db2.put_doc(doc)
+ doc2_rev2 = doc.rev
+ triggered = []
+
+ def after_whatschanged(state):
+ if state != 'after whats_changed':
+ return
+ triggered.append(True)
+ doc = self.make_document(doc_id, doc1_rev, content1)
+ self.db1.put_doc(doc)
+
+ self.sync(self.db1, self.db2, trace_hook=after_whatschanged)
+ self.assertEqual([True], triggered)
+ self.assertGetDoc(self.db1, doc_id, doc2_rev2, content2, True)
+ from_idx = self.db1.get_from_index('test-idx', 'altval')[0]
+ self.assertEqual(doc.doc_id, from_idx.doc_id)
+ self.assertEqual(doc.rev, from_idx.rev)
+ self.assertTrue(from_idx.has_conflicts)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'localval'))
+
+ def test_sync_propagates_deletes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ self.db2.create_index('test-idx', 'key')
+ self.db3 = self.create_database('test3', 'target')
+ self.sync(self.db1, self.db3)
+ self.db1.delete_doc(doc1)
+ deleted_rev = doc1.rev
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive': {'docs': [(doc_id, deleted_rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 2, 'last_known_gen': 1},
+ 'return': {'docs': [], 'last_gen': 2}})
+ self.assertGetDocIncludeDeleted(
+ self.db1, doc_id, deleted_rev, None, False)
+ self.assertGetDocIncludeDeleted(
+ self.db2, doc_id, deleted_rev, None, False)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+ self.assertEqual([], self.db2.get_from_index('test-idx', 'value'))
+ self.sync(self.db2, self.db3)
+ self.assertLastExchangeLog(self.db3,
+ {'receive': {'docs': [(doc_id, deleted_rev)],
+ 'source_uid': 'test2',
+ 'source_gen': 2, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 2}})
+ self.assertGetDocIncludeDeleted(
+ self.db3, doc_id, deleted_rev, None, False)
+
+ def test_sync_propagates_resolution(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc')
+ db3 = self.create_database('test3', 'both')
+ self.sync(self.db2, self.db1)
+ self.assertEqual(
+ self.db1._get_generation_info(),
+ self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid))
+ self.assertEqual(
+ self.db2._get_generation_info(),
+ self.db1._get_replica_gen_and_trans_id(self.db2._replica_uid))
+ self.sync(db3, self.db1)
+ # update on 2
+ doc2 = self.make_document('the-doc', doc1.rev, '{"a": 2}')
+ self.db2.put_doc(doc2)
+ self.sync(self.db2, db3)
+ self.assertEqual(db3.get_doc('the-doc').rev, doc2.rev)
+ # update on 1
+ doc1.set_json('{"a": 3}')
+ self.db1.put_doc(doc1)
+ # conflicts
+ self.sync(self.db2, self.db1)
+ self.sync(db3, self.db1)
+ self.assertTrue(self.db2.get_doc('the-doc').has_conflicts)
+ self.assertTrue(db3.get_doc('the-doc').has_conflicts)
+ # resolve
+ conflicts = self.db2.get_doc_conflicts('the-doc')
+ doc4 = self.make_document('the-doc', None, '{"a": 4}')
+ revs = [doc.rev for doc in conflicts]
+ self.db2.resolve_doc(doc4, revs)
+ doc2 = self.db2.get_doc('the-doc')
+ self.assertEqual(doc4.get_json(), doc2.get_json())
+ self.assertFalse(doc2.has_conflicts)
+ self.sync(self.db2, db3)
+ doc3 = db3.get_doc('the-doc')
+ self.assertEqual(doc4.get_json(), doc3.get_json())
+ self.assertFalse(doc3.has_conflicts)
+
+ def test_sync_supersedes_conflicts(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.create_database('test3', 'both')
+ doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc')
+ self.db2.create_doc_from_json('{"b": 1}', doc_id='the-doc')
+ db3.create_doc_from_json('{"c": 1}', doc_id='the-doc')
+ self.sync(db3, self.db1)
+ self.assertEqual(
+ self.db1._get_generation_info(),
+ db3._get_replica_gen_and_trans_id(self.db1._replica_uid))
+ self.assertEqual(
+ db3._get_generation_info(),
+ self.db1._get_replica_gen_and_trans_id(db3._replica_uid))
+ self.sync(db3, self.db2)
+ self.assertEqual(
+ self.db2._get_generation_info(),
+ db3._get_replica_gen_and_trans_id(self.db2._replica_uid))
+ self.assertEqual(
+ db3._get_generation_info(),
+ self.db2._get_replica_gen_and_trans_id(db3._replica_uid))
+ self.assertEqual(3, len(db3.get_doc_conflicts('the-doc')))
+ doc1.set_json('{"a": 2}')
+ self.db1.put_doc(doc1)
+ self.sync(db3, self.db1)
+ # original doc1 should have been removed from conflicts
+ self.assertEqual(3, len(db3.get_doc_conflicts('the-doc')))
+
+ def test_sync_stops_after_get_sync_info(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc)
+ self.sync(self.db1, self.db2)
+
+ def put_hook(state):
+ self.fail("Tracehook triggered for %s" % (state,))
+
+ self.sync(self.db1, self.db2, trace_hook_shallow=put_hook)
+
+ def test_sync_detects_rollback_in_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1')
+ self.sync(self.db1, self.db2)
+ db1_copy = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidGeneration, self.sync, db1_copy, self.db2)
+
+ def test_sync_detects_rollback_in_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ db2_copy = self.copy_database(self.db2)
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidGeneration, self.sync, self.db1, db2_copy)
+
+ def test_sync_detects_diverged_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ db3.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, db3, self.db2)
+
+ def test_sync_detects_diverged_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.copy_database(self.db2)
+ db3.create_doc_from_json(tests.nested_doc, doc_id="divergent")
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, self.db1, db3)
+
+ def test_sync_detects_rollback_and_divergence_in_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1')
+ self.sync(self.db1, self.db2)
+ db1_copy = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.sync(self.db1, self.db2)
+ db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, db1_copy, self.db2)
+
+ def test_sync_detects_rollback_and_divergence_in_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ db2_copy = self.copy_database(self.db2)
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.sync(self.db1, self.db2)
+ db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, self.db1, db2_copy)
+
+
+class TestDbSync(tests.TestCaseWithServer):
+ """Test db.sync remote sync shortcut"""
+
+ scenarios = [
+ ('py-http', {
+ 'make_app_with_state': make_http_app,
+ 'make_database_for_test': tests.make_memory_database_for_test,
+ }),
+ ('c-http', {
+ 'make_app_with_state': make_http_app,
+ 'make_database_for_test': tests.make_c_database_for_test
+ }),
+ ('py-oauth-http', {
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_database_for_test': tests.make_memory_database_for_test,
+ 'oauth': True
+ }),
+ ('c-oauth-http', {
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_database_for_test': tests.make_c_database_for_test,
+ 'oauth': True
+ }),
+ ]
+
+ oauth = False
+
+ def do_sync(self, target_name):
+ if self.oauth:
+ path = '~/' + target_name
+ extra = dict(creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret
+ }})
+ else:
+ path = target_name
+ extra = {}
+ target_url = self.getURL(path)
+ return self.db.sync(target_url, **extra)
+
+ def setUp(self):
+ super(TestDbSync, self).setUp()
+ self.startServer()
+ self.db = self.make_database_for_test(self, 'test1')
+ self.db2 = self.request_state._create_database('test2.db')
+
+ def test_db_sync(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db2.create_doc_from_json(tests.nested_doc)
+ local_gen_before_sync = self.do_sync('test2.db')
+ gen, _, changes = self.db.whats_changed(local_gen_before_sync)
+ self.assertEqual(1, len(changes))
+ self.assertEqual(doc2.doc_id, changes[0][0])
+ self.assertEqual(1, gen - local_gen_before_sync)
+ self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+
+ def test_db_sync_autocreate(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ local_gen_before_sync = self.do_sync('test3.db')
+ gen, _, changes = self.db.whats_changed(local_gen_before_sync)
+ self.assertEqual(0, gen - local_gen_before_sync)
+ db3 = self.request_state.open_database('test3.db')
+ gen, _, changes = db3.whats_changed()
+ self.assertEqual(1, len(changes))
+ self.assertEqual(doc1.doc_id, changes[0][0])
+ self.assertGetDoc(db3, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ t_gen, _ = self.db._get_replica_gen_and_trans_id('test3.db')
+ s_gen, _ = db3._get_replica_gen_and_trans_id('test1')
+ self.assertEqual(1, t_gen)
+ self.assertEqual(1, s_gen)
+
+
+class TestRemoteSyncIntegration(tests.TestCaseWithServer):
+ """Integration tests for the most common sync scenario local -> remote"""
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestRemoteSyncIntegration, self).setUp()
+ self.startServer()
+ self.db1 = inmemory.InMemoryDatabase('test1')
+ self.db2 = self.request_state._create_database('test2')
+
+ def test_sync_tracks_generations_incrementally(self):
+ doc11 = self.db1.create_doc_from_json('{"a": 1}')
+ doc12 = self.db1.create_doc_from_json('{"a": 2}')
+ doc21 = self.db2.create_doc_from_json('{"b": 1}')
+ doc22 = self.db2.create_doc_from_json('{"b": 2}')
+ #sanity
+ self.assertEqual(2, len(self.db1._get_transaction_log()))
+ self.assertEqual(2, len(self.db2._get_transaction_log()))
+ progress1 = []
+ progress2 = []
+ _do_set_replica_gen_and_trans_id = \
+ self.db1._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness1(other_uid, other_gen, trans_id):
+ progress1.append((other_uid, other_gen,
+ [d for d, t in self.db1._get_transaction_log()[2:]]))
+ _do_set_replica_gen_and_trans_id(other_uid, other_gen, trans_id)
+ self.patch(self.db1, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness1)
+ _do_set_replica_gen_and_trans_id2 = \
+ self.db2._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness2(other_uid, other_gen, trans_id):
+ progress2.append((other_uid, other_gen,
+ [d for d, t in self.db2._get_transaction_log()[2:]]))
+ _do_set_replica_gen_and_trans_id2(other_uid, other_gen, trans_id)
+ self.patch(self.db2, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness2)
+
+ db2_url = self.getURL('test2')
+ self.db1.sync(db2_url)
+
+ self.assertEqual([('test2', 1, [doc21.doc_id]),
+ ('test2', 2, [doc21.doc_id, doc22.doc_id]),
+ ('test2', 4, [doc21.doc_id, doc22.doc_id])],
+ progress1)
+ self.assertEqual([('test1', 1, [doc11.doc_id]),
+ ('test1', 2, [doc11.doc_id, doc12.doc_id]),
+ ('test1', 4, [doc11.doc_id, doc12.doc_id])],
+ progress2)
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/u1db/tests/test_test_infrastructure.py b/src/leap/soledad/u1db/tests/test_test_infrastructure.py
new file mode 100644
index 00000000..b79e0516
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_test_infrastructure.py
@@ -0,0 +1,41 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for test infrastructure bits"""
+
+from wsgiref import simple_server
+
+from u1db import (
+ tests,
+ )
+
+
+class TestTestCaseWithServer(tests.TestCaseWithServer):
+
+ def make_app(self):
+ return "app"
+
+ @staticmethod
+ def server_def():
+ def make_server(host_port, application):
+ assert application == "app"
+ return simple_server.WSGIServer(host_port, None)
+ return (make_server, "shutdown", "http")
+
+ def test_getURL(self):
+ self.startServer()
+ url = self.getURL()
+ self.assertTrue(url.startswith('http://127.0.0.1:'))
diff --git a/src/leap/soledad/u1db/tests/test_vectorclock.py b/src/leap/soledad/u1db/tests/test_vectorclock.py
new file mode 100644
index 00000000..72baf246
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/test_vectorclock.py
@@ -0,0 +1,121 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""VectorClockRev helper class tests."""
+
+from u1db import tests, vectorclock
+
+try:
+ from u1db.tests import c_backend_wrapper
+except ImportError:
+ c_backend_wrapper = None
+
+
+c_vectorclock_scenarios = []
+if c_backend_wrapper is not None:
+ c_vectorclock_scenarios.append(
+ ('c', {'create_vcr': c_backend_wrapper.VectorClockRev}))
+
+
+class TestVectorClockRev(tests.TestCase):
+
+ scenarios = [('py', {'create_vcr': vectorclock.VectorClockRev})
+ ] + c_vectorclock_scenarios
+
+ def assertIsNewer(self, newer_rev, older_rev):
+ new_vcr = self.create_vcr(newer_rev)
+ old_vcr = self.create_vcr(older_rev)
+ self.assertTrue(new_vcr.is_newer(old_vcr))
+ self.assertFalse(old_vcr.is_newer(new_vcr))
+
+ def assertIsConflicted(self, rev_a, rev_b):
+ vcr_a = self.create_vcr(rev_a)
+ vcr_b = self.create_vcr(rev_b)
+ self.assertFalse(vcr_a.is_newer(vcr_b))
+ self.assertFalse(vcr_b.is_newer(vcr_a))
+
+ def assertRoundTrips(self, rev):
+ self.assertEqual(rev, self.create_vcr(rev).as_str())
+
+ def test__is_newer_doc_rev(self):
+ self.assertIsNewer('test:1', None)
+ self.assertIsNewer('test:2', 'test:1')
+ self.assertIsNewer('other:2|test:1', 'other:1|test:1')
+ self.assertIsNewer('other:1|test:1', 'other:1')
+ self.assertIsNewer('a:2|b:1', 'b:1')
+ self.assertIsNewer('a:1|b:2', 'a:1')
+ self.assertIsConflicted('other:2|test:1', 'other:1|test:2')
+ self.assertIsConflicted('other:1|test:1', 'other:2')
+ self.assertIsConflicted('test:1', 'test:1')
+
+ def test_None(self):
+ vcr = self.create_vcr(None)
+ self.assertEqual('', vcr.as_str())
+
+ def test_round_trips(self):
+ self.assertRoundTrips('test:1')
+ self.assertRoundTrips('a:1|b:2')
+ self.assertRoundTrips('alternate:2|test:1')
+
+ def test_handles_sort_order(self):
+ self.assertEqual('a:1|b:2', self.create_vcr('b:2|a:1').as_str())
+ # Last one out of place
+ self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6',
+ self.create_vcr('f:6|a:1|b:2|c:3|d:4|e:5').as_str())
+ # Fully reversed
+ self.assertEqual('a:1|b:2|c:3|d:4|e:5|f:6',
+ self.create_vcr('f:6|e:5|d:4|c:3|b:2|a:1').as_str())
+
+ def assertIncrement(self, original, replica_uid, after_increment):
+ vcr = self.create_vcr(original)
+ vcr.increment(replica_uid)
+ self.assertEqual(after_increment, vcr.as_str())
+
+ def test_increment(self):
+ self.assertIncrement(None, 'test', 'test:1')
+ self.assertIncrement('test:1', 'test', 'test:2')
+
+ def test_increment_adds_uid(self):
+ self.assertIncrement('other:1', 'test', 'other:1|test:1')
+ self.assertIncrement('a:1|ab:2', 'aa', 'a:1|aa:1|ab:2')
+
+ def test_increment_update_partial(self):
+ self.assertIncrement('a:1|ab:2', 'a', 'a:2|ab:2')
+ self.assertIncrement('a:2|ab:2', 'ab', 'a:2|ab:3')
+
+ def test_increment_appends_uid(self):
+ self.assertIncrement('b:2', 'c', 'b:2|c:1')
+
+ def assertMaximize(self, rev1, rev2, maximized):
+ vcr1 = self.create_vcr(rev1)
+ vcr2 = self.create_vcr(rev2)
+ vcr1.maximize(vcr2)
+ self.assertEqual(maximized, vcr1.as_str())
+ # reset vcr1 to maximize the other way
+ vcr1 = self.create_vcr(rev1)
+ vcr2.maximize(vcr1)
+ self.assertEqual(maximized, vcr2.as_str())
+
+ def test_maximize(self):
+ self.assertMaximize(None, None, '')
+ self.assertMaximize(None, 'x:1', 'x:1')
+ self.assertMaximize('x:1', 'y:1', 'x:1|y:1')
+ self.assertMaximize('x:2', 'x:1', 'x:2')
+ self.assertMaximize('x:2', 'x:1|y:2', 'x:2|y:2')
+ self.assertMaximize('a:1|c:2|e:3', 'b:3|d:4|f:5',
+ 'a:1|b:3|c:2|d:4|e:3|f:5')
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/u1db/tests/testing-certs/Makefile b/src/leap/soledad/u1db/tests/testing-certs/Makefile
new file mode 100644
index 00000000..2385e75b
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/testing-certs/Makefile
@@ -0,0 +1,35 @@
+CATOP=./demoCA
+ORIG_CONF=/usr/lib/ssl/openssl.cnf
+ELEVEN_YEARS=-days 4015
+
+init:
+ cp $(ORIG_CONF) ca.conf
+ install -d $(CATOP)
+ install -d $(CATOP)/certs
+ install -d $(CATOP)/crl
+ install -d $(CATOP)/newcerts
+ install -d $(CATOP)/private
+ touch $(CATOP)/index.txt
+ echo 01>$(CATOP)/crlnumber
+ @echo '**** Making CA certificate ...'
+ openssl req -nodes -new \
+ -newkey rsa -keyout $(CATOP)/private/cakey.pem \
+ -out $(CATOP)/careq.pem \
+ -multivalue-rdn \
+ -subj "/C=UK/ST=-/O=u1db LOCAL TESTING ONLY, DO NO TRUST/CN=u1db testing CA"
+ openssl ca -config ./ca.conf -create_serial \
+ -out $(CATOP)/cacert.pem $(ELEVEN_YEARS) -batch \
+ -keyfile $(CATOP)/private/cakey.pem -selfsign \
+ -extensions v3_ca -infiles $(CATOP)/careq.pem
+
+pems:
+ cp ./demoCA/cacert.pem .
+ openssl req -new -config ca.conf \
+ -multivalue-rdn \
+ -subj "/O=u1db LOCAL TESTING ONLY, DO NOT TRUST/CN=localhost" \
+ -nodes -keyout testing.key -out newreq.pem $(ELEVEN_YEARS)
+ openssl ca -batch -config ./ca.conf $(ELEVEN_YEARS) \
+ -policy policy_anything \
+ -out testing.cert -infiles newreq.pem
+
+.PHONY: init pems
diff --git a/src/leap/soledad/u1db/tests/testing-certs/cacert.pem b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem
new file mode 100644
index 00000000..c019a730
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/testing-certs/cacert.pem
@@ -0,0 +1,58 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number:
+ e4:de:01:76:c4:78:78:7e
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA
+ Validity
+ Not Before: May 3 11:11:11 2012 GMT
+ Not After : May 1 11:11:11 2023 GMT
+ Subject: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:bc:91:a5:7f:7d:37:f7:06:c7:db:5b:83:6a:6b:
+ 63:c3:8b:5c:f7:84:4d:97:6d:d4:be:bf:e7:79:a8:
+ c1:03:57:ec:90:d4:20:e7:02:95:d9:a6:49:e3:f9:
+ 9a:ea:37:b9:b2:02:62:ab:40:d3:42:bb:4a:4e:a2:
+ 47:71:0f:1d:a2:c5:94:a1:cf:35:d3:23:32:42:c0:
+ 1e:8d:cb:08:58:fb:8a:5c:3e:ea:eb:d5:2c:ed:d6:
+ aa:09:b4:b5:7d:e3:45:c9:ae:c2:82:b2:ae:c0:81:
+ bc:24:06:65:a9:e7:e0:61:ac:25:ee:53:d3:d7:be:
+ 22:f7:00:a2:ad:c6:0e:3a:39
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Subject Key Identifier:
+ DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D
+ X509v3 Authority Key Identifier:
+ keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D
+
+ X509v3 Basic Constraints:
+ CA:TRUE
+ Signature Algorithm: sha1WithRSAEncryption
+ 72:9b:c1:f7:07:65:83:36:25:4e:01:2f:b7:4a:f2:a4:00:28:
+ 80:c7:56:2c:32:39:90:13:61:4b:bb:12:c5:44:9d:42:57:85:
+ 28:19:70:69:e1:43:c8:bd:11:f6:94:df:91:2d:c3:ea:82:8d:
+ b4:8f:5d:47:a3:00:99:53:29:93:27:6c:c5:da:c1:20:6f:ab:
+ ec:4a:be:34:f3:8f:02:e5:0c:c0:03:ac:2b:33:41:71:4f:0a:
+ 72:5a:b4:26:1a:7f:81:bc:c0:95:8a:06:87:a8:11:9f:5c:73:
+ 38:df:5a:69:40:21:29:ad:46:23:56:75:e1:e9:8b:10:18:4c:
+ 7b:54
+-----BEGIN CERTIFICATE-----
+MIICkjCCAfugAwIBAgIJAOTeAXbEeHh+MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV
+BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg
+T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x
+MjA1MDMxMTExMTFaFw0yMzA1MDExMTExMTFaMGIxCzAJBgNVBAYTAlVLMQowCAYD
+VQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcgT05MWSwgRE8gTk8g
+VFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTCBnzANBgkqhkiG9w0BAQEF
+AAOBjQAwgYkCgYEAvJGlf3039wbH21uDamtjw4tc94RNl23Uvr/neajBA1fskNQg
+5wKV2aZJ4/ma6je5sgJiq0DTQrtKTqJHcQ8dosWUoc810yMyQsAejcsIWPuKXD7q
+69Us7daqCbS1feNFya7CgrKuwIG8JAZlqefgYawl7lPT174i9wCircYOOjkCAwEA
+AaNQME4wHQYDVR0OBBYEFNs9k1FsMhVUjxBQ/ElPNhUou5VtMB8GA1UdIwQYMBaA
+FNs9k1FsMhVUjxBQ/ElPNhUou5VtMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF
+BQADgYEAcpvB9wdlgzYlTgEvt0rypAAogMdWLDI5kBNhS7sSxUSdQleFKBlwaeFD
+yL0R9pTfkS3D6oKNtI9dR6MAmVMpkydsxdrBIG+r7Eq+NPOPAuUMwAOsKzNBcU8K
+clq0Jhp/gbzAlYoGh6gRn1xzON9aaUAhKa1GI1Z14emLEBhMe1Q=
+-----END CERTIFICATE-----
diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.cert b/src/leap/soledad/u1db/tests/testing-certs/testing.cert
new file mode 100644
index 00000000..985684fb
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/testing-certs/testing.cert
@@ -0,0 +1,61 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number:
+ e4:de:01:76:c4:78:78:7f
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=UK, ST=-, O=u1db LOCAL TESTING ONLY, DO NO TRUST, CN=u1db testing CA
+ Validity
+ Not Before: May 3 11:11:14 2012 GMT
+ Not After : May 1 11:11:14 2023 GMT
+ Subject: O=u1db LOCAL TESTING ONLY, DO NOT TRUST, CN=localhost
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:c6:1d:72:d3:c5:e4:fc:d1:4c:d9:e4:08:3e:90:
+ 10:ce:3f:1f:87:4a:1d:4f:7f:2a:5a:52:c9:65:4f:
+ d9:2c:bf:69:75:18:1a:b5:c9:09:32:00:47:f5:60:
+ aa:c6:dd:3a:87:37:5f:16:be:de:29:b5:ea:fc:41:
+ 7e:eb:77:bb:df:63:c3:06:1e:ed:e9:a0:67:1a:f1:
+ ec:e1:9d:f7:9c:8f:1c:fa:c3:66:7b:39:dc:70:ae:
+ 09:1b:9c:c0:9a:c4:90:77:45:8e:39:95:a9:2f:92:
+ 43:bd:27:07:5a:99:51:6e:76:a0:af:dd:b1:2c:8f:
+ ca:8b:8c:47:0d:f6:6e:fc:69
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Basic Constraints:
+ CA:FALSE
+ Netscape Comment:
+ OpenSSL Generated Certificate
+ X509v3 Subject Key Identifier:
+ 1C:63:85:E1:1D:F3:89:2E:6C:4E:3F:FB:D0:10:64:5A:C1:22:6A:2A
+ X509v3 Authority Key Identifier:
+ keyid:DB:3D:93:51:6C:32:15:54:8F:10:50:FC:49:4F:36:15:28:BB:95:6D
+
+ Signature Algorithm: sha1WithRSAEncryption
+ 1d:6d:3e:bd:93:fd:bd:3e:17:b8:9f:f0:99:7f:db:50:5c:b2:
+ 01:42:03:b5:d5:94:05:d3:f6:8e:80:82:55:47:1f:58:f2:18:
+ 6c:ab:ef:43:2c:2f:10:e1:7c:c4:5c:cc:ac:50:50:22:42:aa:
+ 35:33:f5:b9:f3:a6:66:55:d9:36:f4:f2:e4:d4:d9:b5:2c:52:
+ 66:d4:21:17:97:22:b8:9b:d7:0e:7c:3d:ce:85:19:ca:c4:d2:
+ 58:62:31:c6:18:3e:44:fc:f4:30:b6:95:87:ee:21:4a:08:f0:
+ af:3c:8f:c4:ba:5e:a1:5c:37:1a:7d:7b:fe:66:ae:62:50:17:
+ 31:ca
+-----BEGIN CERTIFICATE-----
+MIICnzCCAgigAwIBAgIJAOTeAXbEeHh/MA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV
+BAYTAlVLMQowCAYDVQQIDAEtMS0wKwYDVQQKDCR1MWRiIExPQ0FMIFRFU1RJTkcg
+T05MWSwgRE8gTk8gVFJVU1QxGDAWBgNVBAMMD3UxZGIgdGVzdGluZyBDQTAeFw0x
+MjA1MDMxMTExMTRaFw0yMzA1MDExMTExMTRaMEQxLjAsBgNVBAoMJXUxZGIgTE9D
+QUwgVEVTVElORyBPTkxZLCBETyBOT1QgVFJVU1QxEjAQBgNVBAMMCWxvY2FsaG9z
+dDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAxh1y08Xk/NFM2eQIPpAQzj8f
+h0odT38qWlLJZU/ZLL9pdRgatckJMgBH9WCqxt06hzdfFr7eKbXq/EF+63e732PD
+Bh7t6aBnGvHs4Z33nI8c+sNmeznccK4JG5zAmsSQd0WOOZWpL5JDvScHWplRbnag
+r92xLI/Ki4xHDfZu/GkCAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0E
+HxYdT3BlblNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFBxjheEd
+84kubE4/+9AQZFrBImoqMB8GA1UdIwQYMBaAFNs9k1FsMhVUjxBQ/ElPNhUou5Vt
+MA0GCSqGSIb3DQEBBQUAA4GBAB1tPr2T/b0+F7if8Jl/21BcsgFCA7XVlAXT9o6A
+glVHH1jyGGyr70MsLxDhfMRczKxQUCJCqjUz9bnzpmZV2Tb08uTU2bUsUmbUIReX
+Irib1w58Pc6FGcrE0lhiMcYYPkT89DC2lYfuIUoI8K88j8S6XqFcNxp9e/5mrmJQ
+FzHK
+-----END CERTIFICATE-----
diff --git a/src/leap/soledad/u1db/tests/testing-certs/testing.key b/src/leap/soledad/u1db/tests/testing-certs/testing.key
new file mode 100644
index 00000000..d83d4920
--- /dev/null
+++ b/src/leap/soledad/u1db/tests/testing-certs/testing.key
@@ -0,0 +1,16 @@
+-----BEGIN PRIVATE KEY-----
+MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMYdctPF5PzRTNnk
+CD6QEM4/H4dKHU9/KlpSyWVP2Sy/aXUYGrXJCTIAR/VgqsbdOoc3Xxa+3im16vxB
+fut3u99jwwYe7emgZxrx7OGd95yPHPrDZns53HCuCRucwJrEkHdFjjmVqS+SQ70n
+B1qZUW52oK/dsSyPyouMRw32bvxpAgMBAAECgYBs3lXxhjg1rhabTjIxnx19GTcM
+M3Az9V+izweZQu3HJ1CeZiaXauhAr+LbNsniCkRVddotN6oCJdQB10QVxXBZc9Jz
+HPJ4zxtZfRZlNMTMmG7eLWrfxpgWnb/BUjDb40yy1nhr9yhDUnI/8RoHDRHnAEHZ
+/CnHGUrqcVcrY5zJAQJBAPLhBJg9W88JVmcOKdWxRgs7dLHnZb999Kv1V5mczmAi
+jvGvbUmucqOqke6pTUHNYyNHqU6pySzGUi2cH+BAkFECQQDQ0VoAOysg6FVoT15v
+tGh57t5sTiCZZ7PS8jwvtThsgA+vcf6c16XWzXgjGXSap4r2QDOY2rI5lsWLaQ8T
++fyZAkAfyFJRmbXp4c7srW3MCOahkaYzoZQu+syJtBFCiMJ40gzik5I5khpuUGPI
+V19EvRu8AiSlppIsycb3MPb64XgBAkEAy7DrUf5le5wmc7G4NM6OeyJ+5LbxJbL6
+vnJ8My1a9LuWkVVpQCU7J+UVo2dZTuLPspW9vwTVhUeFOxAoHRxlQQJAFem93f7m
+el2BkB2EFqU3onPejkZ5UrDmfmeOQR1axMQNSXqSxcJxqa16Ru1BWV2gcWRbwajQ
+oc+kuJThu/r/Ug==
+-----END PRIVATE KEY-----
diff --git a/src/leap/soledad/u1db/vectorclock.py b/src/leap/soledad/u1db/vectorclock.py
new file mode 100644
index 00000000..42bceaa8
--- /dev/null
+++ b/src/leap/soledad/u1db/vectorclock.py
@@ -0,0 +1,89 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""VectorClockRev helper class."""
+
+
+class VectorClockRev(object):
+ """Track vector clocks for multiple replica ids.
+
+ This allows simple comparison to determine if one VectorClockRev is
+ newer/older/in-conflict-with another VectorClockRev without having to
+ examine history. Every replica has a strictly increasing revision. When
+ creating a new revision, they include all revisions for all other replicas
+ which the new revision dominates, and increment their own revision to
+ something greater than the current value.
+ """
+
+ def __init__(self, value):
+ self._values = self._expand(value)
+
+ def __repr__(self):
+ s = self.as_str()
+ return '%s(%s)' % (self.__class__.__name__, s)
+
+ def as_str(self):
+ s = '|'.join(['%s:%d' % (m, r) for m, r
+ in sorted(self._values.items())])
+ return s
+
+ def _expand(self, value):
+ result = {}
+ if value is None:
+ return result
+ for replica_info in value.split('|'):
+ replica_uid, counter = replica_info.split(':')
+ counter = int(counter)
+ result[replica_uid] = counter
+ return result
+
+ def is_newer(self, other):
+ """Is this VectorClockRev strictly newer than other.
+ """
+ if not self._values:
+ return False
+ if not other._values:
+ return True
+ this_is_newer = False
+ other_expand = dict(other._values)
+ for key, value in self._values.iteritems():
+ if key in other_expand:
+ other_value = other_expand.pop(key)
+ if other_value > value:
+ return False
+ elif other_value < value:
+ this_is_newer = True
+ else:
+ this_is_newer = True
+ if other_expand:
+ return False
+ return this_is_newer
+
+ def increment(self, replica_uid):
+ """Increase the 'replica_uid' section of this vector clock.
+
+ :return: A string representing the new vector clock value
+ """
+ self._values[replica_uid] = self._values.get(replica_uid, 0) + 1
+
+ def maximize(self, other_vcr):
+ for replica_uid, counter in other_vcr._values.iteritems():
+ if replica_uid not in self._values:
+ self._values[replica_uid] = counter
+ else:
+ this_counter = self._values[replica_uid]
+ if this_counter < counter:
+ self._values[replica_uid] = counter
diff --git a/src/leap/testing/__init__.py b/src/leap/testing/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/testing/__init__.py
diff --git a/src/leap/testing/basetest.py b/src/leap/testing/basetest.py
new file mode 100644
index 00000000..3186e1eb
--- /dev/null
+++ b/src/leap/testing/basetest.py
@@ -0,0 +1,85 @@
+import os
+import platform
+import shutil
+import tempfile
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from leap.base.config import get_username, get_groupname
+from leap.util.fileutil import mkdir_p, check_and_fix_urw_only
+
+_system = platform.system()
+
+
+class BaseLeapTest(unittest.TestCase):
+
+ __name__ = "leap_test"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.old_path = os.environ['PATH']
+ cls.old_home = os.environ['HOME']
+ cls.tempdir = tempfile.mkdtemp(prefix="leap_tests-")
+ cls.home = cls.tempdir
+ bin_tdir = os.path.join(
+ cls.tempdir,
+ 'bin')
+ os.environ["PATH"] = bin_tdir
+ os.environ["HOME"] = cls.tempdir
+
+ @classmethod
+ def tearDownClass(cls):
+ os.environ["PATH"] = cls.old_path
+ os.environ["HOME"] = cls.old_home
+ # safety check
+ assert cls.tempdir.startswith('/tmp/leap_tests-')
+ shutil.rmtree(cls.tempdir)
+
+ # you have to override these methods
+ # this way we ensure we did not put anything
+ # here that you can forget to call.
+
+ def setUp(self):
+ raise NotImplementedError("abstract base class")
+
+ def tearDown(self):
+ raise NotImplementedError("abstract base class")
+
+ #
+ # helper methods
+ #
+
+ def get_tempfile(self, filename):
+ return os.path.join(self.tempdir, filename)
+
+ def get_username(self):
+ return get_username()
+
+ def get_groupname(self):
+ return get_groupname()
+
+ def _missing_test_for_plat(self, do_raise=False):
+ if do_raise:
+ raise NotImplementedError(
+ "This test is not implemented "
+ "for the running platform: %s" %
+ _system)
+
+ def touch(self, filepath):
+ folder, filename = os.path.split(filepath)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ # XXX should move to test_basetest
+ self.assertTrue(os.path.isdir(folder))
+
+ with open(filepath, 'w') as fp:
+ fp.write(' ')
+
+ # XXX should move to test_basetest
+ self.assertTrue(os.path.isfile(filepath))
+
+ def chmod600(self, filepath):
+ check_and_fix_urw_only(filepath)
diff --git a/src/leap/testing/cacert.pem b/src/leap/testing/cacert.pem
new file mode 100644
index 00000000..6989c480
--- /dev/null
+++ b/src/leap/testing/cacert.pem
@@ -0,0 +1,23 @@
+-----BEGIN CERTIFICATE-----
+MIID1TCCAr2gAwIBAgIJAOv0BS09D8byMA0GCSqGSIb3DQEBBQUAMIGAMQswCQYD
+VQQGEwJVUzETMBEGA1UECAwKY3liZXJzcGFjZTEnMCUGA1UECgweTEVBUCBFbmNy
+eXB0aW9uIEFjY2VzcyBQcm9qZWN0MRYwFAYDVQQDDA10ZXN0cy1sZWFwLnNlMRsw
+GQYJKoZIhvcNAQkBFgxpbmZvQGxlYXAuc2UwHhcNMTIwODMxMTYyNjMwWhcNMTUw
+ODMxMTYyNjMwWjCBgDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCmN5YmVyc3BhY2Ux
+JzAlBgNVBAoMHkxFQVAgRW5jcnlwdGlvbiBBY2Nlc3MgUHJvamVjdDEWMBQGA1UE
+AwwNdGVzdHMtbGVhcC5zZTEbMBkGCSqGSIb3DQEJARYMaW5mb0BsZWFwLnNlMIIB
+IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1pU7OU+abrUXFZwp6X0LlF0f
+xQvC1Nmr5sFH7N9RTu3bdwY2t57ECP2TPkH6+x7oOvCTgAMxIE1scWEEkfgKViqW
+FH/Om1UW1PMaiDYGtFuqEuxM95FvaYxp2K6rzA37WNsedA28sCYzhRD+/5HqbCNT
+3rRS2cPaVO8kXI/5bgd8bUk3009pWTg4SvTtOW/9MWJbBH5f5JWmMn7Ayt6hIdT/
+E6npofEK/UCqAlEscARYFXSB/F8nK1whjo9mGFjMUd7d/25UbFHqOk4K7ishD4DH
+F7LaS84rS+Sjwn3YtDdDQblGghJfz8X1AfPSGivGnvLVdkmMF9Y2hJlSQ7+C5wID
+AQABo1AwTjAdBgNVHQ4EFgQUnpJEv4FnlqKbfm7mprudKdrnOAowHwYDVR0jBBgw
+FoAUnpJEv4FnlqKbfm7mprudKdrnOAowDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0B
+AQUFAAOCAQEAGW66qwdK/ATRVZkTpI2sgi+2dWD5tY4VyZuJIrRwfXsGPeVvmdsa
+zDmwW5dMkth1Of5yO6o7ijvUvfnw/UCLNLNICKZhH5G0DHstfBeFc0jnP2MqOZCp
+puRGPBlO2nxUCvoGcPRUKGQK9XSYmxcmaSFyzKVDMLnmH+Lakj5vaY9a8ZAcZTz7
+T5qePxKAxg+RIlH8Ftc485QP3fhqPYPrRsL3g6peiqCvIRshoP1MSoh19boI+1uX
+wHQ/NyDkL5ErKC5JCSpaeF8VG1ek570kKWQLuQAbnlXZw+Sqfu35CIdizHaYGEcx
+xA8oXH4L2JaT2x9GKDSpCmB2xXy/NVamUg==
+-----END CERTIFICATE-----
diff --git a/src/leap/testing/https_server.py b/src/leap/testing/https_server.py
new file mode 100644
index 00000000..21191c32
--- /dev/null
+++ b/src/leap/testing/https_server.py
@@ -0,0 +1,68 @@
+from BaseHTTPServer import HTTPServer
+import os
+import ssl
+import SocketServer
+import threading
+import unittest
+
+_where = os.path.split(__file__)[0]
+
+
+def where(filename):
+ return os.path.join(_where, filename)
+
+
+class HTTPSServer(HTTPServer):
+ def server_bind(self):
+ SocketServer.TCPServer.server_bind(self)
+ self.socket = ssl.wrap_socket(
+ self.socket, server_side=True,
+ certfile=where("leaptestscert.pem"),
+ keyfile=where("leaptestskey.pem"),
+ ca_certs=where("cacert.pem"),
+ ssl_version=ssl.PROTOCOL_SSLv23)
+
+
+class TestServerThread(threading.Thread):
+ def __init__(self, test_object, request_handler):
+ threading.Thread.__init__(self)
+ self.request_handler = request_handler
+ self.test_object = test_object
+
+ def run(self):
+ self.server = HTTPSServer(('localhost', 0), self.request_handler)
+ host, port = self.server.socket.getsockname()
+ self.test_object.HOST, self.test_object.PORT = host, port
+ self.test_object.server_started.set()
+ self.test_object = None
+ try:
+ self.server.serve_forever(0.05)
+ finally:
+ self.server.server_close()
+
+ def stop(self):
+ self.server.shutdown()
+
+
+class BaseHTTPSServerTestCase(unittest.TestCase):
+ """
+ derived classes need to implement a request_handler
+ """
+ def setUp(self):
+ self.server_started = threading.Event()
+ self.thread = TestServerThread(self, self.request_handler)
+ self.thread.start()
+ self.server_started.wait()
+
+ def tearDown(self):
+ self.thread.stop()
+
+ def get_server(self):
+ host, port = self.HOST, self.PORT
+ if host == "127.0.0.1":
+ host = "localhost"
+ return "%s:%s" % (host, port)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/testing/leaptestscert.pem b/src/leap/testing/leaptestscert.pem
new file mode 100644
index 00000000..65596b1a
--- /dev/null
+++ b/src/leap/testing/leaptestscert.pem
@@ -0,0 +1,84 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number:
+ eb:f4:05:2d:3d:0f:c6:f3
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=US, ST=cyberspace, O=LEAP Encryption Access Project, CN=tests-leap.se/emailAddress=info@leap.se
+ Validity
+ Not Before: Aug 31 16:30:17 2012 GMT
+ Not After : Aug 31 16:30:17 2013 GMT
+ Subject: C=US, ST=cyberspace, L=net, O=LEAP Encryption Access Project, CN=localhost/emailAddress=info@leap.se
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (2048 bit)
+ Modulus:
+ 00:bc:f1:c4:05:ce:4b:d5:9b:9a:fa:c1:a5:0c:89:
+ 15:7e:05:69:b6:a4:62:38:3a:d6:14:4a:36:aa:3c:
+ 31:70:54:2e:bf:7d:05:19:ad:7b:0c:a9:a6:7d:46:
+ be:83:62:cb:ea:b9:48:6c:7d:78:a0:10:0b:ad:8a:
+ 74:7a:b8:ff:32:85:64:36:90:dc:38:dd:90:6e:07:
+ 82:70:ae:5f:4e:1f:f4:46:98:f3:98:b4:fa:08:65:
+ bf:d6:ec:a9:ba:7e:a8:f0:40:a2:d0:1a:cb:e6:fc:
+ 95:c5:54:63:92:5b:b8:0a:36:cc:26:d3:2b:ad:16:
+ ff:49:53:f4:65:7c:64:27:9a:f5:12:75:11:a5:0c:
+ 5a:ea:1e:e4:31:f3:a6:2b:db:0e:4a:5d:aa:47:3a:
+ f0:5e:2a:d5:6f:74:b6:f8:bc:9a:73:d0:fa:8a:be:
+ a8:69:47:9b:07:45:d9:b5:cd:1c:9b:c5:41:9a:65:
+ cc:99:a0:bd:bf:b5:e8:9f:66:5f:69:c9:6d:c8:68:
+ 50:68:74:ae:8e:12:7e:9c:24:4f:dc:05:61:b7:8a:
+ 6d:2a:95:43:d9:3f:fe:d8:c9:a7:ae:63:cd:30:d5:
+ 95:84:18:2d:12:b5:2d:a6:fe:37:dd:74:b8:f8:a5:
+ 59:18:8f:ca:f7:ae:63:0d:9d:66:51:7d:9c:40:48:
+ 9b:a1
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Basic Constraints:
+ CA:FALSE
+ Netscape Comment:
+ OpenSSL Generated Certificate
+ X509v3 Subject Key Identifier:
+ B2:50:B4:C6:38:8F:BA:C4:3B:69:4C:6B:45:7C:CF:08:48:36:02:E0
+ X509v3 Authority Key Identifier:
+ keyid:9E:92:44:BF:81:67:96:A2:9B:7E:6E:E6:A6:BB:9D:29:DA:E7:38:0A
+
+ Signature Algorithm: sha1WithRSAEncryption
+ aa:ab:d4:27:e3:cb:42:05:55:fd:24:b3:e5:55:7d:fb:ce:6c:
+ ff:c7:96:f0:7d:30:a1:53:4a:04:eb:a4:24:5e:96:ee:65:ef:
+ e5:aa:08:47:9d:aa:95:2a:bb:6a:28:9f:51:62:63:d9:7d:1a:
+ 81:a0:72:f7:9f:33:6b:3b:f4:dc:85:cd:2a:ee:83:a9:93:3d:
+ 75:53:91:fa:0b:1b:10:83:11:2c:03:4e:ac:bf:c3:e6:25:74:
+ 9f:14:13:4a:43:66:c2:d7:1c:6c:94:3e:a6:f3:a5:bd:01:2c:
+ 9f:20:29:2e:62:82:12:d8:8b:70:1b:88:2b:18:68:5a:45:80:
+ 46:2a:6a:d5:df:1f:d3:e8:57:39:0a:be:1a:d8:b0:3e:e5:b6:
+ c3:69:b7:5e:c0:7b:b3:a8:a6:78:ee:0a:3d:a0:74:40:fb:42:
+ 9f:f4:98:7f:47:cc:15:28:eb:b1:95:77:82:a8:65:9b:46:c3:
+ 4f:f9:f4:72:be:bd:24:28:5c:0d:b3:89:e4:13:71:c8:a7:54:
+ 1b:26:15:f3:c1:b2:a9:13:77:54:c2:b9:b0:c7:24:39:00:4c:
+ 1a:a7:9b:e7:ad:4a:3a:32:c2:81:0d:13:2d:27:ea:98:00:a9:
+ 0e:9e:38:3b:8f:80:34:17:17:3d:49:7e:f4:a5:19:05:28:08:
+ 7d:de:d3:1f
+-----BEGIN CERTIFICATE-----
+MIIECjCCAvKgAwIBAgIJAOv0BS09D8bzMA0GCSqGSIb3DQEBBQUAMIGAMQswCQYD
+VQQGEwJVUzETMBEGA1UECAwKY3liZXJzcGFjZTEnMCUGA1UECgweTEVBUCBFbmNy
+eXB0aW9uIEFjY2VzcyBQcm9qZWN0MRYwFAYDVQQDDA10ZXN0cy1sZWFwLnNlMRsw
+GQYJKoZIhvcNAQkBFgxpbmZvQGxlYXAuc2UwHhcNMTIwODMxMTYzMDE3WhcNMTMw
+ODMxMTYzMDE3WjCBijELMAkGA1UEBhMCVVMxEzARBgNVBAgMCmN5YmVyc3BhY2Ux
+DDAKBgNVBAcMA25ldDEnMCUGA1UECgweTEVBUCBFbmNyeXB0aW9uIEFjY2VzcyBQ
+cm9qZWN0MRIwEAYDVQQDDAlsb2NhbGhvc3QxGzAZBgkqhkiG9w0BCQEWDGluZm9A
+bGVhcC5zZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALzxxAXOS9Wb
+mvrBpQyJFX4FabakYjg61hRKNqo8MXBULr99BRmtewyppn1GvoNiy+q5SGx9eKAQ
+C62KdHq4/zKFZDaQ3DjdkG4HgnCuX04f9EaY85i0+ghlv9bsqbp+qPBAotAay+b8
+lcVUY5JbuAo2zCbTK60W/0lT9GV8ZCea9RJ1EaUMWuoe5DHzpivbDkpdqkc68F4q
+1W90tvi8mnPQ+oq+qGlHmwdF2bXNHJvFQZplzJmgvb+16J9mX2nJbchoUGh0ro4S
+fpwkT9wFYbeKbSqVQ9k//tjJp65jzTDVlYQYLRK1Lab+N910uPilWRiPyveuYw2d
+ZlF9nEBIm6ECAwEAAaN7MHkwCQYDVR0TBAIwADAsBglghkgBhvhCAQ0EHxYdT3Bl
+blNTTCBHZW5lcmF0ZWQgQ2VydGlmaWNhdGUwHQYDVR0OBBYEFLJQtMY4j7rEO2lM
+a0V8zwhINgLgMB8GA1UdIwQYMBaAFJ6SRL+BZ5aim35u5qa7nSna5zgKMA0GCSqG
+SIb3DQEBBQUAA4IBAQCqq9Qn48tCBVX9JLPlVX37zmz/x5bwfTChU0oE66QkXpbu
+Ze/lqghHnaqVKrtqKJ9RYmPZfRqBoHL3nzNrO/Tchc0q7oOpkz11U5H6CxsQgxEs
+A06sv8PmJXSfFBNKQ2bC1xxslD6m86W9ASyfICkuYoIS2ItwG4grGGhaRYBGKmrV
+3x/T6Fc5Cr4a2LA+5bbDabdewHuzqKZ47go9oHRA+0Kf9Jh/R8wVKOuxlXeCqGWb
+RsNP+fRyvr0kKFwNs4nkE3HIp1QbJhXzwbKpE3dUwrmwxyQ5AEwap5vnrUo6MsKB
+DRMtJ+qYAKkOnjg7j4A0Fxc9SX70pRkFKAh93tMf
+-----END CERTIFICATE-----
diff --git a/src/leap/testing/leaptestskey.pem b/src/leap/testing/leaptestskey.pem
new file mode 100644
index 00000000..fe6291a1
--- /dev/null
+++ b/src/leap/testing/leaptestskey.pem
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpQIBAAKCAQEAvPHEBc5L1Zua+sGlDIkVfgVptqRiODrWFEo2qjwxcFQuv30F
+Ga17DKmmfUa+g2LL6rlIbH14oBALrYp0erj/MoVkNpDcON2QbgeCcK5fTh/0Rpjz
+mLT6CGW/1uypun6o8ECi0BrL5vyVxVRjklu4CjbMJtMrrRb/SVP0ZXxkJ5r1EnUR
+pQxa6h7kMfOmK9sOSl2qRzrwXirVb3S2+Lyac9D6ir6oaUebB0XZtc0cm8VBmmXM
+maC9v7Xon2ZfacltyGhQaHSujhJ+nCRP3AVht4ptKpVD2T/+2MmnrmPNMNWVhBgt
+ErUtpv433XS4+KVZGI/K965jDZ1mUX2cQEiboQIDAQABAoIBAQCh/+yhSbrtoCgm
+PegEsnix/3QfPBxWt+Obq/HozglZlWQrnMbFuF+bgM4V9ZUdU5UhYNF+66mEG53X
+orGyE3IDYCmHO3cGbroKDPhDIs7mTjGEYlniIbGLh6oPXgU8uKKis9ik84TGPOUx
+NuTUtT07zLYHx+FX3DLwLUKLzTaWWSRgA7nxNwCY8aPqDxCkXEyZHvSlm9KYZnhe
+nVevycoHR+chxL6X/ebbBt2FKR7tl4328mlDXvMXr0vahPH94CuXEvfTj+f6ZxZF
+OctdikyRfd8O3ebrUw0XjafPYyTsDMH0/rQovEBVlecEHqh6Z9dBFlogRq5DSun9
+jem4bBXRAoGBAPGPi4g21pTQPqTFxpqea8TsPqIfo3csfMDPdzT246MxzALHqCfG
+yZi4g2JYJrReSWHulZDORO5skSKNEb5VTA/3xFhKLt8CULZOakKBDLkzRXlnDFXg
+Jsu9vtjDWjQcJsdsRx1tc5V6s+hmel70aaUu/maUlEYZnyIXaTe+1SB1AoGBAMg9
+EMEO5YN52pOI5qPH8j7uyVKtZWKRiR6jb5KA5TxWqZalSdPV6YwDqV/e+HjWrZNw
+kSEFONY0seKpIHwXchx91aym7rDHUgOoBQfCWufRMYvRXLhfOTBu4X+U52++i8wt
+FvKgh6eSmc7VayAaDfHp7yfrIfS03IiN0T35mGj9AoGAPCoXg7a83VW8tId5/trE
+VsjMlM6yhSU0cUV7GFsBuYzWlj6qODX/0iTqvFzeTwBI4LZu1CE78/Jgd62RJMnT
+5wo8Ag1//RVziuSe/K9tvtbxT9qFrQHmR8qbtRt65Q257uOeFstDBZEJLDIR+oJ/
+qZ+5x0zsXUVWaERSdYr3RF0CgYEApKDgN3oB5Ti4Jnh1984aMver+heptYKmU9RX
+lQH4dsVhpQO8UTgcTgtso+/0JZWLHB9+ksFyW1rzrcETfjLglOA4XzzYHeuiWHM5
+v4lhqBpsO+Ij80oHAPUI3RYVud/VnEauCUlGftWfM1hwPPJu6KhHAnDleAWDE5pV
+oDinwBkCgYEAnn/OceaqA2fNYp1IRegbFzpewjUlHLq3bXiCIVhO7W/HqsdfUxjE
+VVdjEno/pAG7ZCO5j8u+rLkG2ZIVY3qsUENUiXz52Q08qEltgM8nfirK7vIQkfd9
+YISRE3QHYJd+ArY4v+7rNeF1O5eIEyzPAbvG5raeZFcZ6POxy66uWKo=
+-----END RSA PRIVATE KEY-----
diff --git a/src/leap/testing/test_basetest.py b/src/leap/testing/test_basetest.py
new file mode 100644
index 00000000..14d8f8a3
--- /dev/null
+++ b/src/leap/testing/test_basetest.py
@@ -0,0 +1,91 @@
+"""becase it's oh so meta"""
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+import os
+import StringIO
+
+from leap.testing.basetest import BaseLeapTest
+
+# global for tempdir checking
+_tempdir = None
+
+
+class _TestCaseRunner(object):
+ def run_testcase(self, testcase=None):
+ if not testcase:
+ return None
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromTestCase(testcase)
+
+ # Create runner, and run testcase
+ io = StringIO.StringIO()
+ runner = unittest.TextTestRunner(stream=io)
+ results = runner.run(suite)
+ return results
+
+
+class TestAbstractBaseLeapTest(unittest.TestCase, _TestCaseRunner):
+
+ def test_abstract_base_class(self):
+ class _BaseTest(BaseLeapTest):
+ def test_dummy_method(self):
+ pass
+
+ def test_tautology(self):
+ assert True
+
+ results = self.run_testcase(_BaseTest)
+
+ # should be 2 errors: NotImplemented
+ # raised for setUp/tearDown
+ self.assertEquals(results.testsRun, 2)
+ self.assertEquals(len(results.failures), 0)
+ self.assertEquals(len(results.errors), 2)
+
+
+class TestInitBaseLeapTest(BaseLeapTest):
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_path_is_changed(self):
+ os_path = os.environ['PATH']
+ self.assertTrue(os_path.startswith(self.tempdir))
+
+ def test_old_path_is_saved(self):
+ self.assertTrue(len(self.old_path) > 1)
+
+
+class TestCleanedBaseLeapTest(unittest.TestCase, _TestCaseRunner):
+
+ def test_tempdir_is_cleaned_after_tests(self):
+ class _BaseTest(BaseLeapTest):
+ def setUp(self):
+ global _tempdir
+ _tempdir = self.tempdir
+
+ def tearDown(self):
+ pass
+
+ def test_tempdir_created(self):
+ self.assertTrue(os.path.isdir(self.tempdir))
+
+ def test_tempdir_created_on_setupclass(self):
+ self.assertEqual(_tempdir, self.tempdir)
+
+ results = self.run_testcase(_BaseTest)
+ self.assertEquals(results.testsRun, 2)
+ self.assertEquals(len(results.failures), 0)
+ self.assertEquals(len(results.errors), 0)
+
+ # did we cleaned the tempdir?
+ self.assertFalse(os.path.isdir(_tempdir))
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/tests/fakeclient.py b/src/leap/tests/fakeclient.py
deleted file mode 100644
index 45de2cd6..00000000
--- a/src/leap/tests/fakeclient.py
+++ /dev/null
@@ -1,63 +0,0 @@
-fakeoutput = """
-mullvad Sun Jun 17 14:34:57 2012 OpenVPN 2.2.1 i486-linux-gnu [SSL] [LZO2] [EPOLL] [PKCS11] [eurephia] [MH] [PF_INET6] [IPv6 payload 20110424-2 (2.2RC2)] built
- on Mar 23 2012
-Sun Jun 17 14:34:57 2012 MANAGEMENT: TCP Socket listening on [AF_INET]127.0.0.1:7505
-Sun Jun 17 14:34:57 2012 NOTE: the current --script-security setting may allow this configuration to call user-defined scripts
-Sun Jun 17 14:34:57 2012 WARNING: file 'ssl/1021380964266.key' is group or others accessible
-Sun Jun 17 14:34:57 2012 LZO compression initialized
-Sun Jun 17 14:34:57 2012 Control Channel MTU parms [ L:1542 D:138 EF:38 EB:0 ET:0 EL:0 ]
-Sun Jun 17 14:34:57 2012 Socket Buffers: R=[163840->131072] S=[163840->131072]
-Sun Jun 17 14:34:57 2012 Data Channel MTU parms [ L:1542 D:1450 EF:42 EB:135 ET:0 EL:0 AF:3/1 ]
-Sun Jun 17 14:34:57 2012 Local Options hash (VER=V4): '41690919'
-Sun Jun 17 14:34:57 2012 Expected Remote Options hash (VER=V4): '530fdded'
-Sun Jun 17 14:34:57 2012 UDPv4 link local: [undef]
-Sun Jun 17 14:34:57 2012 UDPv4 link remote: [AF_INET]46.21.99.25:1197
-Sun Jun 17 14:34:57 2012 TLS: Initial packet from [AF_INET]46.21.99.25:1197, sid=63c29ace 1d3060d0
-Sun Jun 17 14:34:58 2012 VERIFY OK: depth=2, /C=NA/ST=None/L=None/O=Mullvad/CN=Mullvad_CA/emailAddress=info@mullvad.net
-Sun Jun 17 14:34:58 2012 VERIFY OK: depth=1, /C=NA/ST=None/L=None/O=Mullvad/CN=master.mullvad.net/emailAddress=info@mullvad.net
-Sun Jun 17 14:34:58 2012 Validating certificate key usage
-Sun Jun 17 14:34:58 2012 ++ Certificate has key usage 00a0, expects 00a0
-Sun Jun 17 14:34:58 2012 VERIFY KU OK
-Sun Jun 17 14:34:58 2012 Validating certificate extended key usage
-Sun Jun 17 14:34:58 2012 ++ Certificate has EKU (str) TLS Web Server Authentication, expects TLS Web Server Authentication
-Sun Jun 17 14:34:58 2012 VERIFY EKU OK
-Sun Jun 17 14:34:58 2012 VERIFY OK: depth=0, /C=NA/ST=None/L=None/O=Mullvad/CN=se2.mullvad.net/emailAddress=info@mullvad.net
-Sun Jun 17 14:34:59 2012 Data Channel Encrypt: Cipher 'BF-CBC' initialized with 128 bit key
-Sun Jun 17 14:34:59 2012 Data Channel Encrypt: Using 160 bit message hash 'SHA1' for HMAC authentication
-Sun Jun 17 14:34:59 2012 Data Channel Decrypt: Cipher 'BF-CBC' initialized with 128 bit key
-Sun Jun 17 14:34:59 2012 Data Channel Decrypt: Using 160 bit message hash 'SHA1' for HMAC authentication
-Sun Jun 17 14:34:59 2012 Control Channel: TLSv1, cipher TLSv1/SSLv3 DHE-RSA-AES256-SHA, 2048 bit RSA
-Sun Jun 17 14:34:59 2012 [se2.mullvad.net] Peer Connection Initiated with [AF_INET]46.21.99.25:1197
-Sun Jun 17 14:35:01 2012 SENT CONTROL [se2.mullvad.net]: 'PUSH_REQUEST' (status=1)
-Sun Jun 17 14:35:02 2012 PUSH: Received control message: 'PUSH_REPLY,redirect-gateway def1 bypass-dhcp,dhcp-option DNS 10.11.0.1,route 10.11.0.1,topology net30,ifconfig 10.11.0.202 10.11.0.201'
-Sun Jun 17 14:35:02 2012 OPTIONS IMPORT: --ifconfig/up options modified
-Sun Jun 17 14:35:02 2012 OPTIONS IMPORT: route options modified
-Sun Jun 17 14:35:02 2012 OPTIONS IMPORT: --ip-win32 and/or --dhcp-option options modified
-Sun Jun 17 14:35:02 2012 ROUTE default_gateway=192.168.0.1
-Sun Jun 17 14:35:02 2012 TUN/TAP device tun0 opened
-Sun Jun 17 14:35:02 2012 TUN/TAP TX queue length set to 100
-Sun Jun 17 14:35:02 2012 do_ifconfig, tt->ipv6=0, tt->did_ifconfig_ipv6_setup=0
-Sun Jun 17 14:35:02 2012 /sbin/ifconfig tun0 10.11.0.202 pointopoint 10.11.0.201 mtu 1500
-Sun Jun 17 14:35:02 2012 /etc/openvpn/update-resolv-conf tun0 1500 1542 10.11.0.202 10.11.0.201 init
-dhcp-option DNS 10.11.0.1
-Sun Jun 17 14:35:05 2012 /sbin/route add -net 46.21.99.25 netmask 255.255.255.255 gw 192.168.0.1
-Sun Jun 17 14:35:05 2012 /sbin/route add -net 0.0.0.0 netmask 128.0.0.0 gw 10.11.0.201
-Sun Jun 17 14:35:05 2012 /sbin/route add -net 128.0.0.0 netmask 128.0.0.0 gw 10.11.0.201
-Sun Jun 17 14:35:05 2012 /sbin/route add -net 10.11.0.1 netmask 255.255.255.255 gw 10.11.0.201
-Sun Jun 17 14:35:05 2012 Initialization Sequence Completed
-Sun Jun 17 14:34:57 2012 MANAGEMENT: TCP Socket listening on [AF_INET]127.0.0.1:7505
-"""
-
-import time
-import sys
-
-
-def write_output():
- for line in fakeoutput.split('\n'):
- sys.stdout.write(line + '\n')
- sys.stdout.flush()
- #print(line)
- time.sleep(0.1)
-
-if __name__ == "__main__":
- write_output()
diff --git a/src/leap/tests/mocks/__init__.py b/src/leap/tests/mocks/__init__.py
deleted file mode 100644
index 06f96870..00000000
--- a/src/leap/tests/mocks/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-import manager
diff --git a/src/leap/tests/mocks/manager.py b/src/leap/tests/mocks/manager.py
deleted file mode 100644
index 564631cd..00000000
--- a/src/leap/tests/mocks/manager.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from mock import Mock
-
-from eip_client.vpnmanager import OpenVPNManager
-
-vpn_commands = {
- 'status': [
- 'OpenVPN STATISTICS', 'Updated,Mon Jun 25 11:51:21 2012',
- 'TUN/TAP read bytes,306170', 'TUN/TAP write bytes,872102',
- 'TCP/UDP read bytes,986177', 'TCP/UDP write bytes,439329',
- 'Auth read bytes,872102'],
- 'state': ['1340616463,CONNECTED,SUCCESS,172.28.0.2,198.252.153.38'],
- # XXX add more tests
- }
-
-
-def get_openvpn_manager_mocks():
- manager = OpenVPNManager()
- manager.status = Mock(return_value='\n'.join(vpn_commands['status']))
- manager.state = Mock(return_value=vpn_commands['state'][0])
- return manager
diff --git a/src/leap/util/__init__.py b/src/leap/util/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/util/__init__.py
diff --git a/src/leap/utils/coroutines.py b/src/leap/util/coroutines.py
index 5e25eb63..0657fc04 100644
--- a/src/leap/utils/coroutines.py
+++ b/src/leap/util/coroutines.py
@@ -4,10 +4,13 @@
from __future__ import division, print_function
+import logging
from subprocess import PIPE, Popen
import sys
from threading import Thread
+logger = logging.getLogger(__name__)
+
ON_POSIX = 'posix' in sys.builtin_module_names
@@ -38,8 +41,7 @@ for each event
if callable(callback):
callback(m)
else:
- #XXX log instead
- print('not a callable passed')
+ logger.debug('not a callable passed')
except GeneratorExit:
return
@@ -72,10 +74,10 @@ def watch_output(out, observers):
:type out: fd
:param observers: tuple of coroutines to send data\
for each event
- :type ovservers: tuple
+ :type observers: tuple
"""
- observer_dict = {observer: process_events(observer)
- for observer in observers}
+ observer_dict = dict(((observer, process_events(observer))
+ for observer in observers))
for line in iter(out.readline, b''):
for obs in observer_dict:
observer_dict[obs].send(line)
diff --git a/src/leap/util/dicts.py b/src/leap/util/dicts.py
new file mode 100644
index 00000000..001ca96b
--- /dev/null
+++ b/src/leap/util/dicts.py
@@ -0,0 +1,268 @@
+# Backport of OrderedDict() class that runs
+# on Python 2.4, 2.5, 2.6, 2.7 and pypy.
+# Passes Python2.7's test suite and incorporates all the latest updates.
+
+try:
+ from thread import get_ident as _get_ident
+except ImportError:
+ from dummy_thread import get_ident as _get_ident
+
+try:
+ from _abcoll import KeysView, ValuesView, ItemsView
+except ImportError:
+ pass
+
+
+class OrderedDict(dict):
+ 'Dictionary that remembers insertion order'
+ # An inherited dict maps keys to values.
+ # The inherited dict provides __getitem__, __len__, __contains__, and get.
+ # The remaining methods are order-aware.
+ # Big-O running times for all methods are the same as for regular
+ # dictionaries.
+
+ # The internal self.__map dictionary maps keys to links in a doubly
+ # linked list.
+ # The circular doubly linked list starts and ends with a sentinel element.
+ # The sentinel element never gets deleted (this simplifies the algorithm).
+ # Each link is stored as a list of length three: [PREV, NEXT, KEY].
+
+ def __init__(self, *args, **kwds):
+ '''Initialize an ordered dictionary. Signature is the same as for
+ regular dictionaries, but keyword arguments are not recommended
+ because their insertion order is arbitrary.
+
+ '''
+ if len(args) > 1:
+ raise TypeError('expected at most 1 arguments, got %d' % len(args))
+ try:
+ self.__root
+ except AttributeError:
+ self.__root = root = [] # sentinel node
+ root[:] = [root, root, None]
+ self.__map = {}
+ self.__update(*args, **kwds)
+
+ def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
+ 'od.__setitem__(i, y) <==> od[i]=y'
+ # Setting a new item creates a new link which goes at the end
+ # of the linked list, and the inherited dictionary is updated
+ # with the new key/value pair.
+ if key not in self:
+ root = self.__root
+ last = root[0]
+ last[1] = root[0] = self.__map[key] = [last, root, key]
+ dict_setitem(self, key, value)
+
+ def __delitem__(self, key, dict_delitem=dict.__delitem__):
+ 'od.__delitem__(y) <==> del od[y]'
+ # Deleting an existing item uses self.__map to find the link which is
+ # then removed by updating the links in the predecessor and successor
+ # nodes.
+ dict_delitem(self, key)
+ link_prev, link_next, key = self.__map.pop(key)
+ link_prev[1] = link_next
+ link_next[0] = link_prev
+
+ def __iter__(self):
+ 'od.__iter__() <==> iter(od)'
+ root = self.__root
+ curr = root[1]
+ while curr is not root:
+ yield curr[2]
+ curr = curr[1]
+
+ def __reversed__(self):
+ 'od.__reversed__() <==> reversed(od)'
+ root = self.__root
+ curr = root[0]
+ while curr is not root:
+ yield curr[2]
+ curr = curr[0]
+
+ def clear(self):
+ 'od.clear() -> None. Remove all items from od.'
+ try:
+ for node in self.__map.itervalues():
+ del node[:]
+ root = self.__root
+ root[:] = [root, root, None]
+ self.__map.clear()
+ except AttributeError:
+ pass
+ dict.clear(self)
+
+ def popitem(self, last=True):
+ '''od.popitem() -> (k, v), return and remove a (key, value) pair.
+ Pairs are returned in LIFO order if last is true or FIFO order if
+ false.
+ '''
+ if not self:
+ raise KeyError('dictionary is empty')
+ root = self.__root
+ if last:
+ link = root[0]
+ link_prev = link[0]
+ link_prev[1] = root
+ root[0] = link_prev
+ else:
+ link = root[1]
+ link_next = link[1]
+ root[1] = link_next
+ link_next[0] = root
+ key = link[2]
+ del self.__map[key]
+ value = dict.pop(self, key)
+ return key, value
+
+ # -- the following methods do not depend on the internal structure --
+
+ def keys(self):
+ 'od.keys() -> list of keys in od'
+ return list(self)
+
+ def values(self):
+ 'od.values() -> list of values in od'
+ return [self[key] for key in self]
+
+ def items(self):
+ 'od.items() -> list of (key, value) pairs in od'
+ return [(key, self[key]) for key in self]
+
+ def iterkeys(self):
+ 'od.iterkeys() -> an iterator over the keys in od'
+ return iter(self)
+
+ def itervalues(self):
+ 'od.itervalues -> an iterator over the values in od'
+ for k in self:
+ yield self[k]
+
+ def iteritems(self):
+ 'od.iteritems -> an iterator over the (key, value) items in od'
+ for k in self:
+ yield (k, self[k])
+
+ def update(*args, **kwds):
+ '''od.update(E, **F) -> None. Update od from dict/iterable E and F.
+
+ If E is a dict instance, does: for k in E: od[k] = E[k]
+ If E has a .keys() method, does: for k in E.keys():
+ od[k] = E[k]
+ Or if E is an iterable of items, does: for k, v in E: od[k] = v
+ In either case, this is followed by: for k, v in F.items():
+ od[k] = v
+ '''
+
+ if len(args) > 2:
+ raise TypeError('update() takes at most 2 positional '
+ 'arguments (%d given)' % (len(args),))
+ elif not args:
+ raise TypeError('update() takes at least 1 argument (0 given)')
+ self = args[0]
+ # Make progressively weaker assumptions about "other"
+ other = ()
+ if len(args) == 2:
+ other = args[1]
+ if isinstance(other, dict):
+ for key in other:
+ self[key] = other[key]
+ elif hasattr(other, 'keys'):
+ for key in other.keys():
+ self[key] = other[key]
+ else:
+ for key, value in other:
+ self[key] = value
+ for key, value in kwds.items():
+ self[key] = value
+
+ __update = update # let subclasses override update
+ # without breaking __init__
+
+ __marker = object()
+
+ def pop(self, key, default=__marker):
+ '''od.pop(k[,d]) -> v
+ remove specified key and return the corresponding value.
+ If key is not found, d is returned if given,
+ otherwise KeyError is raised.
+
+ '''
+ if key in self:
+ result = self[key]
+ del self[key]
+ return result
+ if default is self.__marker:
+ raise KeyError(key)
+ return default
+
+ def setdefault(self, key, default=None):
+ 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
+ if key in self:
+ return self[key]
+ self[key] = default
+ return default
+
+ def __repr__(self, _repr_running={}):
+ 'od.__repr__() <==> repr(od)'
+ call_key = id(self), _get_ident()
+ if call_key in _repr_running:
+ return '...'
+ _repr_running[call_key] = 1
+ try:
+ if not self:
+ return '%s()' % (self.__class__.__name__,)
+ return '%s(%r)' % (self.__class__.__name__, self.items())
+ finally:
+ del _repr_running[call_key]
+
+ def __reduce__(self):
+ 'Return state information for pickling'
+ items = [[k, self[k]] for k in self]
+ inst_dict = vars(self).copy()
+ for k in vars(OrderedDict()):
+ inst_dict.pop(k, None)
+ if inst_dict:
+ return (self.__class__, (items,), inst_dict)
+ return self.__class__, (items,)
+
+ def copy(self):
+ 'od.copy() -> a shallow copy of od'
+ return self.__class__(self)
+
+ @classmethod
+ def fromkeys(cls, iterable, value=None):
+ '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
+ and values equal to v (which defaults to None).
+
+ '''
+ d = cls()
+ for key in iterable:
+ d[key] = value
+ return d
+
+ def __eq__(self, other):
+ '''od.__eq__(y) <==> od==y.
+ Comparison to another OD is order-sensitive
+ while comparison to a regular mapping is order-insensitive.
+ '''
+ if isinstance(other, OrderedDict):
+ return len(self) == len(other) and self.items() == other.items()
+ return dict.__eq__(self, other)
+
+ def __ne__(self, other):
+ return not self == other
+
+ # -- the following methods are only used in Python 2.7 --
+
+ def viewkeys(self):
+ "od.viewkeys() -> a set-like object providing a view on od's keys"
+ return KeysView(self)
+
+ def viewvalues(self):
+ "od.viewvalues() -> an object providing a view on od's values"
+ return ValuesView(self)
+
+ def viewitems(self):
+ "od.viewitems() -> a set-like object providing a view on od's items"
+ return ItemsView(self)
diff --git a/src/leap/util/fileutil.py b/src/leap/util/fileutil.py
new file mode 100644
index 00000000..aef4cfe0
--- /dev/null
+++ b/src/leap/util/fileutil.py
@@ -0,0 +1,115 @@
+import errno
+from itertools import chain
+import logging
+import os
+import platform
+import stat
+
+
+logger = logging.getLogger()
+
+
+def is_user_executable(fpath):
+ st = os.stat(fpath)
+ return bool(st.st_mode & stat.S_IXUSR)
+
+
+def extend_path():
+ ourplatform = platform.system()
+ if ourplatform == "Linux":
+ return "/usr/local/sbin:/usr/sbin"
+ # XXX add mac / win extended search paths?
+
+
+def which(program, path=None):
+ """
+ an implementation of which
+ that extends the path with
+ other locations, like sbin
+ (f.i., openvpn binary is likely to be there)
+ @param program: a string representing the binary we're looking for.
+ """
+ def is_exe(fpath):
+ """
+ check that path exists,
+ it's a file,
+ and is executable by the owner
+ """
+ # we would check for access,
+ # but it's likely that we're
+ # using uid 0 + polkitd
+
+ return os.path.isfile(fpath)\
+ and is_user_executable(fpath)
+
+ def ext_candidates(fpath):
+ yield fpath
+ for ext in os.environ.get("PATHEXT", "").split(os.pathsep):
+ yield fpath + ext
+
+ def iter_path(pathset):
+ """
+ returns iterator with
+ full path for a given path list
+ and the current target bin.
+ """
+ for path in pathset.split(os.pathsep):
+ exe_file = os.path.join(path, program)
+ #print 'file=%s' % exe_file
+ for candidate in ext_candidates(exe_file):
+ if is_exe(candidate):
+ yield candidate
+
+ fpath, fname = os.path.split(program)
+ if fpath:
+ if is_exe(program):
+ return program
+ else:
+ # extended iterator
+ # with extra path
+ if path is None:
+ path = os.environ['PATH']
+ extended_path = chain(
+ iter_path(path),
+ iter_path(extend_path()))
+ for candidate in extended_path:
+ if candidate is not None:
+ return candidate
+
+ # sorry bro.
+ return None
+
+
+def mkdir_p(path):
+ """
+ implements mkdir -p functionality
+ """
+ try:
+ os.makedirs(path)
+ except OSError as exc:
+ if exc.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+
+
+def check_and_fix_urw_only(_file):
+ """
+ test for 600 mode and try
+ to set it if anything different found
+ """
+ mode = stat.S_IMODE(
+ os.stat(_file).st_mode)
+
+ if mode != int('600', 8):
+ try:
+ logger.warning(
+ 'bad permission on %s '
+ 'attempting to set 600',
+ _file)
+ os.chmod(_file, stat.S_IRUSR | stat.S_IWUSR)
+ except OSError:
+ logger.error(
+ 'error while trying to chmod 600 %s',
+ _file)
+ raise
diff --git a/src/leap/util/leap_argparse.py b/src/leap/util/leap_argparse.py
new file mode 100644
index 00000000..2f996a31
--- /dev/null
+++ b/src/leap/util/leap_argparse.py
@@ -0,0 +1,41 @@
+import argparse
+
+
+def build_parser():
+ """
+ all the options for the leap arg parser
+ Some of these could be switched on only if debug flag is present!
+ """
+ epilog = "Copyright 2012 The Leap Project"
+ parser = argparse.ArgumentParser(description="""
+Launches main LEAP Client""", epilog=epilog)
+ parser.add_argument('-d', '--debug', action="store_true",
+ help='launches in debug mode')
+ parser.add_argument('-c', '--config', metavar="CONFIG FILE", nargs='?',
+ action="store", dest="config_file",
+ type=argparse.FileType('r'),
+ help='optional config file')
+ parser.add_argument('--logfile', metavar="LOG FILE", nargs='?',
+ action="store", dest="log_file",
+ #type=argparse.FileType('w'),
+ help='optional log file')
+ parser.add_argument('--openvpn-verbosity', nargs='?',
+ type=int,
+ action="store", dest="openvpn_verb",
+ help='verbosity level for openvpn logs [1-6]')
+ parser.add_argument('-l', '--no-provider-checks',
+ action="store_true", default=False,
+ help="skips download of provider config files. gets "
+ "config from local files only. Will fail if cannot "
+ "find any")
+ parser.add_argument('-k', '--no-ca-verify',
+ action="store_true", default=False,
+ help="(insecure). Skips verification of the server "
+ "certificate used in TLS handshake.")
+ return parser
+
+
+def init_leapc_args():
+ parser = build_parser()
+ opts = parser.parse_args()
+ return parser, opts
diff --git a/src/leap/util/tests/__init__.py b/src/leap/util/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/util/tests/__init__.py
diff --git a/src/leap/util/tests/test_fileutil.py b/src/leap/util/tests/test_fileutil.py
new file mode 100644
index 00000000..f5131b3d
--- /dev/null
+++ b/src/leap/util/tests/test_fileutil.py
@@ -0,0 +1,100 @@
+import os
+import platform
+import shutil
+import stat
+import tempfile
+import unittest
+
+from leap.util import fileutil
+
+
+class FileUtilTest(unittest.TestCase):
+ """
+ test our file utils
+ """
+
+ def setUp(self):
+ self.system = platform.system()
+ self.create_temp_dir()
+
+ def tearDown(self):
+ self.remove_temp_dir()
+
+ #
+ # helpers
+ #
+
+ def create_temp_dir(self):
+ self.tmpdir = tempfile.mkdtemp()
+
+ def remove_temp_dir(self):
+ shutil.rmtree(self.tmpdir)
+
+ def get_file_path(self, filename):
+ return os.path.join(
+ self.tmpdir,
+ filename)
+
+ def touch_exec_file(self):
+ fp = self.get_file_path('testexec')
+ open(fp, 'w').close()
+ os.chmod(
+ fp,
+ stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
+ return fp
+
+ def get_mode(self, fp):
+ return stat.S_IMODE(os.stat(fp).st_mode)
+
+ #
+ # tests
+ #
+
+ def test_is_user_executable(self):
+ """
+ touch_exec_file creates in mode 700?
+ """
+ # XXX could check access X_OK
+
+ fp = self.touch_exec_file()
+ mode = self.get_mode(fp)
+ self.assertEqual(mode, int('700', 8))
+
+ def test_which(self):
+ """
+ which implementation ok?
+ not a very reliable test,
+ but I cannot think of anything smarter now
+ I guess it's highly improbable that copy
+ """
+ # XXX yep, we can change the syspath
+ # for the test... !
+
+ if self.system == "Linux":
+ self.assertEqual(
+ fileutil.which('cp'),
+ '/bin/cp')
+
+ def test_mkdir_p(self):
+ """
+ our own mkdir -p implementation ok?
+ """
+ testdir = self.get_file_path(
+ os.path.join('test', 'foo', 'bar'))
+ self.assertEqual(os.path.isdir(testdir), False)
+ fileutil.mkdir_p(testdir)
+ self.assertEqual(os.path.isdir(testdir), True)
+
+ def test_check_and_fix_urw_only(self):
+ """
+ ensure check_and_fix_urx_only ok?
+ """
+ fp = self.touch_exec_file()
+ mode = self.get_mode(fp)
+ self.assertEqual(mode, int('700', 8))
+ fileutil.check_and_fix_urw_only(fp)
+ mode = self.get_mode(fp)
+ self.assertEqual(mode, int('600', 8))
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/util/tests/test_leap_argparse.py b/src/leap/util/tests/test_leap_argparse.py
new file mode 100644
index 00000000..082919b7
--- /dev/null
+++ b/src/leap/util/tests/test_leap_argparse.py
@@ -0,0 +1,35 @@
+from argparse import Namespace
+import unittest
+
+from leap.util import leap_argparse
+
+
+class LeapArgParseTest(unittest.TestCase):
+ """
+ Test argparse options for eip client
+ """
+
+ def setUp(self):
+ """
+ get the parser
+ """
+ self.parser = leap_argparse.build_parser()
+
+ def test_debug_mode(self):
+ """
+ test debug mode option
+ """
+ opts = self.parser.parse_args(
+ ['--debug'])
+ self.assertEqual(
+ opts,
+ Namespace(
+ config_file=None,
+ debug=True,
+ log_file=None,
+ no_provider_checks=False,
+ no_ca_verify=False,
+ openvpn_verb=None))
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/util/web.py b/src/leap/util/web.py
new file mode 100644
index 00000000..b2aef058
--- /dev/null
+++ b/src/leap/util/web.py
@@ -0,0 +1,39 @@
+"""
+web related utilities
+"""
+
+
+class UsageError(Exception):
+ """ """
+
+
+def get_https_domain_and_port(full_domain):
+ """
+ returns a tuple with domain and port
+ from a full_domain string that can
+ contain a colon
+ """
+ if full_domain is None:
+ return None, None
+
+ https_sch = "https://"
+ http_sch = "http://"
+
+ if full_domain.startswith(https_sch):
+ full_domain = full_domain.lstrip(https_sch)
+ elif full_domain.startswith(http_sch):
+ raise UsageError(
+ "cannot be called with a domain "
+ "that begins with 'http://'")
+
+ domain_split = full_domain.split(':')
+ _len = len(domain_split)
+ if _len == 1:
+ domain, port = full_domain, 443
+ elif _len == 2:
+ domain, port = domain_split
+ else:
+ raise UsageError(
+ "must be called with one only parameter"
+ "in the form domain[:port]")
+ return domain, port
diff --git a/src/leap/utils/leap_argparse.py b/src/leap/utils/leap_argparse.py
deleted file mode 100644
index 9c355134..00000000
--- a/src/leap/utils/leap_argparse.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import argparse
-
-
-def build_parser():
- epilog = "Copyright 2012 The Leap Project"
- parser = argparse.ArgumentParser(description="""
-Launches main LEAP Client""", epilog=epilog)
- parser.add_argument('--debug', action="store_true",
- help='launches in debug mode')
- parser.add_argument('--config', metavar="CONFIG FILE", nargs='?',
- action="store", dest="config_file",
- type=argparse.FileType('r'),
- help='optional config file')
- return parser
-
-
-def init_leapc_args():
- parser = build_parser()
- opts = parser.parse_args()
- return parser, opts