summaryrefslogtreecommitdiff
path: root/src/leap
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap')
-rw-r--r--src/leap/__init__.py3
-rw-r--r--src/leap/app.py42
-rw-r--r--src/leap/base/auth.py355
-rw-r--r--src/leap/base/checks.py179
-rw-r--r--src/leap/base/config.py136
-rw-r--r--src/leap/base/connection.py10
-rw-r--r--src/leap/base/constants.py36
-rw-r--r--src/leap/base/exceptions.py39
-rw-r--r--src/leap/base/network.py37
-rw-r--r--src/leap/base/pluggableconfig.py38
-rw-r--r--src/leap/base/providers.py14
-rw-r--r--src/leap/base/specs.py22
-rw-r--r--src/leap/base/tests/__init__.py0
-rw-r--r--src/leap/base/tests/test_auth.py58
-rw-r--r--src/leap/base/tests/test_checks.py116
-rw-r--r--src/leap/base/tests/test_providers.py33
-rw-r--r--src/leap/baseapp/dialogs.py9
-rw-r--r--src/leap/baseapp/eip.py44
-rw-r--r--src/leap/baseapp/leap_app.py25
-rw-r--r--src/leap/baseapp/log.py20
-rw-r--r--src/leap/baseapp/mainwindow.py144
-rw-r--r--src/leap/baseapp/network.py45
-rw-r--r--src/leap/baseapp/systray.py109
-rw-r--r--src/leap/crypto/certs.py112
-rw-r--r--src/leap/crypto/certs_gnutls.py112
-rw-r--r--src/leap/crypto/leapkeyring.py2
-rw-r--r--src/leap/crypto/tests/__init__.py0
-rw-r--r--src/leap/crypto/tests/test_certs.py22
-rw-r--r--src/leap/eip/checks.py309
-rw-r--r--src/leap/eip/config.py190
-rw-r--r--src/leap/eip/eipconnection.py271
-rw-r--r--src/leap/eip/exceptions.py64
-rw-r--r--src/leap/eip/openvpnconnection.py521
-rw-r--r--src/leap/eip/specs.py69
-rw-r--r--src/leap/eip/tests/data.py40
-rw-r--r--src/leap/eip/tests/test_checks.py43
-rw-r--r--src/leap/eip/tests/test_config.py188
-rw-r--r--src/leap/eip/tests/test_eipconnection.py39
-rw-r--r--src/leap/eip/tests/test_openvpnconnection.py34
-rw-r--r--src/leap/gui/__init__.py11
-rw-r--r--src/leap/gui/constants.py13
-rw-r--r--src/leap/gui/firstrun/__init__.py28
-rw-r--r--src/leap/gui/firstrun/connect.py214
-rw-r--r--src/leap/gui/firstrun/constants.py0
-rw-r--r--src/leap/gui/firstrun/intro.py68
-rw-r--r--src/leap/gui/firstrun/last.py109
-rw-r--r--src/leap/gui/firstrun/login.py332
-rw-r--r--src/leap/gui/firstrun/mixins.py18
-rw-r--r--src/leap/gui/firstrun/providerinfo.py106
-rw-r--r--src/leap/gui/firstrun/providerselect.py471
-rw-r--r--src/leap/gui/firstrun/providersetup.py157
-rw-r--r--src/leap/gui/firstrun/register.py387
-rwxr-xr-xsrc/leap/gui/firstrun/tests/integration/fake_provider.py302
-rwxr-xr-xsrc/leap/gui/firstrun/wizard.py309
-rwxr-xr-xsrc/leap/gui/firstrunwizard.py507
-rw-r--r--src/leap/gui/locale_rc.py132
-rw-r--r--src/leap/gui/mainwindow_rc.py284
-rw-r--r--src/leap/gui/progress.py488
-rw-r--r--src/leap/gui/styles.py16
-rw-r--r--src/leap/gui/tests/__init__.py0
-rw-r--r--src/leap/gui/tests/integration/fake_user_signup.py6
-rw-r--r--src/leap/gui/tests/test_firstrun_login.py212
-rw-r--r--src/leap/gui/tests/test_firstrun_providerselect.py203
-rw-r--r--src/leap/gui/tests/test_firstrun_register.py244
-rw-r--r--src/leap/gui/tests/test_firstrun_wizard.py137
-rw-r--r--src/leap/gui/tests/test_mainwindow_rc.py (renamed from src/leap/gui/test_mainwindow_rc.py)12
-rw-r--r--src/leap/gui/tests/test_progress.py449
-rw-r--r--src/leap/gui/tests/test_threads.py27
-rw-r--r--src/leap/gui/threads.py21
-rw-r--r--src/leap/gui/utils.py34
-rw-r--r--src/leap/soledad/README37
-rw-r--r--src/leap/soledad/__init__.py212
-rw-r--r--src/leap/soledad/backends/__init__.py5
-rw-r--r--src/leap/soledad/backends/couch.py217
-rw-r--r--src/leap/soledad/backends/leap_backend.py198
-rw-r--r--src/leap/soledad/backends/objectstore.py109
-rw-r--r--src/leap/soledad/backends/openstack.py98
-rw-r--r--src/leap/soledad/backends/sqlcipher.py159
-rw-r--r--src/leap/soledad/tests/__init__.py0
-rw-r--r--src/leap/soledad/tests/test_couch.py201
-rw-r--r--src/leap/soledad/tests/test_encrypted.py205
-rw-r--r--src/leap/soledad/tests/test_leap_backend.py379
-rw-r--r--src/leap/soledad/tests/test_logs.py96
-rw-r--r--src/leap/soledad/tests/test_sqlcipher.py374
-rw-r--r--src/leap/soledad/tests/u1db_tests/README34
-rw-r--r--src/leap/soledad/tests/u1db_tests/__init__.py421
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_backends.py1907
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_document.py150
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_http_app.py1135
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_http_client.py363
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_http_database.py260
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_https.py117
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_open.py69
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_remote_sync_target.py317
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_sqlite_backend.py494
-rw-r--r--src/leap/soledad/tests/u1db_tests/test_sync.py1242
-rw-r--r--src/leap/soledad/util.py55
-rw-r--r--src/leap/testing/pyqt.py52
-rw-r--r--src/leap/testing/qunittest.py302
-rw-r--r--src/leap/util/__init__.py9
-rw-r--r--src/leap/util/certs.py18
-rw-r--r--src/leap/util/dicts.py268
-rw-r--r--src/leap/util/fileutil.py5
-rw-r--r--src/leap/util/geo.py32
-rw-r--r--src/leap/util/leap_argparse.py2
-rw-r--r--src/leap/util/misc.py37
-rw-r--r--src/leap/util/tests/test_translations.py22
-rw-r--r--src/leap/util/translations.py82
-rw-r--r--src/leap/util/web.py40
109 files changed, 17156 insertions, 1394 deletions
diff --git a/src/leap/__init__.py b/src/leap/__init__.py
index 5e003931..2adbb34a 100644
--- a/src/leap/__init__.py
+++ b/src/leap/__init__.py
@@ -6,8 +6,9 @@ website: U{https://leap.se/}
from leap import eip
from leap import baseapp
from leap import util
+from leap import soledad
-__all__ = [eip, baseapp, util]
+__all__ = [eip, baseapp, util, soledad]
__version__ = "unknown"
try:
diff --git a/src/leap/app.py b/src/leap/app.py
index a1251ca8..eb38751c 100644
--- a/src/leap/app.py
+++ b/src/leap/app.py
@@ -1,13 +1,25 @@
# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4
+from functools import partial
import logging
+import signal
+
# This is only needed for Python v2 but is harmless for Python v3.
import sip
sip.setapi('QVariant', 2)
sip.setapi('QString', 2)
from PyQt4.QtGui import (QApplication, QSystemTrayIcon, QMessageBox)
+from PyQt4 import QtCore
from leap import __version__ as VERSION
from leap.baseapp.mainwindow import LeapWindow
+from leap.gui import locale_rc
+
+
+def sigint_handler(*args, **kwargs):
+ logger = kwargs.get('logger', None)
+ logger.debug('SIGINT catched. shutting down...')
+ mainwindow = args[0]
+ mainwindow.shutdownSignal.emit()
def main():
@@ -36,7 +48,6 @@ def main():
console.setFormatter(formatter)
logger.addHandler(console)
- #logger.debug(opts)
logger.info('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
logger.info('LEAP client version %s', VERSION)
logger.info('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
@@ -51,6 +62,16 @@ def main():
logger.info('Starting app')
app = QApplication(sys.argv)
+ # To test:
+ # $ LANG=es ./app.py
+ locale = QtCore.QLocale.system().name()
+ qtTranslator = QtCore.QTranslator()
+ if qtTranslator.load("qt_%s" % locale, ":/translations"):
+ app.installTranslator(qtTranslator)
+ appTranslator = QtCore.QTranslator()
+ if appTranslator.load("leap_client_%s" % locale, ":/translations"):
+ app.installTranslator(appTranslator)
+
# needed for initializing qsettings
# it will write .config/leap/leap.conf
# top level app settings
@@ -59,6 +80,10 @@ def main():
app.setApplicationName("leap")
app.setOrganizationDomain("leap.se")
+ # XXX we could check here
+ # if leap-client is already running, and abort
+ # gracefully in that case.
+
if not QSystemTrayIcon.isSystemTrayAvailable():
QMessageBox.critical(None, "Systray",
"I couldn't detect"
@@ -68,12 +93,27 @@ def main():
QApplication.setQuitOnLastWindowClosed(False)
window = LeapWindow(opts)
+
+ # this dummy timer ensures that
+ # control is given to the outside loop, so we
+ # can hook our sigint handler.
+ timer = QtCore.QTimer()
+ timer.start(500)
+ timer.timeout.connect(lambda: None)
+
+ sigint_window = partial(sigint_handler, window, logger=logger)
+ signal.signal(signal.SIGINT, sigint_window)
+
if debug:
# we only show the main window
# if debug mode active.
# if not, it will be set visible
# from the systray menu.
window.show()
+ if sys.platform == "darwin":
+ window.raise_()
+
+ # run main loop
sys.exit(app.exec_())
if __name__ == "__main__":
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py
new file mode 100644
index 00000000..c2d3f424
--- /dev/null
+++ b/src/leap/base/auth.py
@@ -0,0 +1,355 @@
+import binascii
+import json
+import logging
+#import urlparse
+
+import requests
+import srp
+
+from PyQt4 import QtCore
+
+from leap.base import constants as baseconstants
+from leap.crypto import leapkeyring
+from leap.util.misc import null_check
+from leap.util.web import get_https_domain_and_port
+
+logger = logging.getLogger(__name__)
+
+SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5)
+
+"""
+Registration and authentication classes for the
+SRP auth mechanism used in the leap platform.
+
+We're using the srp library which uses a c-based implementation
+of the protocol if the c extension is available, and a python-based
+one if not.
+"""
+
+
+class SRPAuthenticationError(Exception):
+ """
+ exception raised
+ for authentication errors
+ """
+
+
+safe_unhexlify = lambda x: binascii.unhexlify(x) \
+ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
+
+
+class LeapSRPRegister(object):
+
+ def __init__(self,
+ schema="https",
+ provider=None,
+ verify=True,
+ register_path="1/users",
+ method="POST",
+ fetcher=requests,
+ srp=srp,
+ hashfun=srp.SHA256,
+ ng_constant=srp.NG_1024):
+
+ null_check(provider, "provider")
+
+ self.schema = schema
+
+ domain, port = get_https_domain_and_port(provider)
+ self.provider = domain
+ self.port = port
+
+ self.verify = verify
+ self.register_path = register_path
+ self.method = method
+ self.fetcher = fetcher
+ self.srp = srp
+ self.HASHFUN = hashfun
+ self.NG = ng_constant
+
+ self.init_session()
+
+ def init_session(self):
+ self.session = self.fetcher.session()
+
+ def get_registration_uri(self):
+ # XXX assert is https!
+ # use urlparse
+ if self.port:
+ uri = "%s://%s:%s/%s" % (
+ self.schema,
+ self.provider,
+ self.port,
+ self.register_path)
+ else:
+ uri = "%s://%s/%s" % (
+ self.schema,
+ self.provider,
+ self.register_path)
+
+ return uri
+
+ def register_user(self, username, password, keep=False):
+ """
+ @rtype: tuple
+ @rparam: (ok, request)
+ """
+ salt, vkey = self.srp.create_salted_verification_key(
+ username,
+ password,
+ self.HASHFUN,
+ self.NG)
+
+ user_data = {
+ 'user[login]': username,
+ 'user[password_verifier]': binascii.hexlify(vkey),
+ 'user[password_salt]': binascii.hexlify(salt)}
+
+ uri = self.get_registration_uri()
+ logger.debug('post to uri: %s' % uri)
+
+ # XXX get self.method
+ req = self.session.post(
+ uri, data=user_data,
+ timeout=SIGNUP_TIMEOUT,
+ verify=self.verify)
+ # we catch it in the form
+ #req.raise_for_status()
+ return (req.ok, req)
+
+
+class SRPAuth(requests.auth.AuthBase):
+
+ def __init__(self, username, password, server=None, verify=None):
+ # sanity check
+ null_check(server, 'server')
+ self.username = username
+ self.password = password
+ self.server = server
+ self.verify = verify
+
+ logger.debug('SRPAuth. verify=%s' % verify)
+ logger.debug('server: %s. username=%s' % (server, username))
+
+ self.init_data = None
+ self.session = requests.session()
+
+ self.init_srp()
+
+ def init_srp(self):
+ usr = srp.User(
+ self.username,
+ self.password,
+ srp.SHA256,
+ srp.NG_1024)
+ uname, A = usr.start_authentication()
+
+ self.srp_usr = usr
+ self.A = A
+
+ def get_auth_data(self):
+ return {
+ 'login': self.username,
+ 'A': binascii.hexlify(self.A)
+ }
+
+ def get_init_data(self):
+ try:
+ init_session = self.session.post(
+ self.server + '/1/sessions/',
+ data=self.get_auth_data(),
+ verify=self.verify)
+ except requests.exceptions.ConnectionError:
+ raise SRPAuthenticationError(
+ "No connection made (salt).")
+ except:
+ raise SRPAuthenticationError(
+ "Unknown error (salt).")
+ if init_session.status_code not in (200, ):
+ raise SRPAuthenticationError(
+ "No valid response (salt).")
+
+ self.init_data = init_session.json
+ return self.init_data
+
+ def get_server_proof_data(self):
+ try:
+ auth_result = self.session.put(
+ #self.server + '/1/sessions.json/' + self.username,
+ self.server + '/1/sessions/' + self.username,
+ data={'client_auth': binascii.hexlify(self.M)},
+ verify=self.verify)
+ except requests.exceptions.ConnectionError:
+ raise SRPAuthenticationError(
+ "No connection made (HAMK).")
+
+ if auth_result.status_code not in (200, ):
+ raise SRPAuthenticationError(
+ "No valid response (HAMK).")
+
+ self.auth_data = auth_result.json
+ return self.auth_data
+
+ def authenticate(self):
+ logger.debug('start authentication...')
+
+ init_data = self.get_init_data()
+ salt = init_data.get('salt', None)
+ B = init_data.get('B', None)
+
+ # XXX refactor this function
+ # move checks and un-hex
+ # to routines
+
+ if not salt or not B:
+ raise SRPAuthenticationError(
+ "Server did not send initial data.")
+
+ try:
+ unhex_salt = safe_unhexlify(salt)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (salt)")
+ try:
+ unhex_B = safe_unhexlify(B)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (B)")
+
+ self.M = self.srp_usr.process_challenge(
+ unhex_salt,
+ unhex_B
+ )
+
+ proof_data = self.get_server_proof_data()
+
+ HAMK = proof_data.get("M2", None)
+ if not HAMK:
+ errors = proof_data.get('errors', None)
+ if errors:
+ logger.error(errors)
+ raise SRPAuthenticationError("Server did not send HAMK.")
+
+ try:
+ unhex_HAMK = safe_unhexlify(HAMK)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (HAMK)")
+
+ self.srp_usr.verify_session(
+ unhex_HAMK)
+
+ try:
+ assert self.srp_usr.authenticated()
+ logger.debug('user is authenticated!')
+ except (AssertionError):
+ raise SRPAuthenticationError(
+ "Auth verification failed.")
+
+ def __call__(self, req):
+ self.authenticate()
+ req.cookies = self.session.cookies
+ return req
+
+
+def srpauth_protected(user=None, passwd=None, server=None, verify=True):
+ """
+ decorator factory that accepts
+ user and password keyword arguments
+ and add those to the decorated request
+ """
+ def srpauth(fn):
+ def wrapper(*args, **kwargs):
+ if user and passwd:
+ auth = SRPAuth(user, passwd, server, verify)
+ kwargs['auth'] = auth
+ kwargs['verify'] = verify
+ if not args:
+ logger.warning('attempting to get from empty uri!')
+ return fn(*args, **kwargs)
+ return wrapper
+ return srpauth
+
+
+def get_leap_credentials():
+ settings = QtCore.QSettings()
+ full_username = settings.value('username')
+ username, domain = full_username.split('@')
+ seed = settings.value('%s_seed' % domain, None)
+ password = leapkeyring.leap_get_password(full_username, seed=seed)
+ return (username, password)
+
+
+# XXX TODO
+# Pass verify as single argument,
+# in srpauth_protected style
+
+def magick_srpauth(fn):
+ """
+ decorator that gets user and password
+ from the config file and adds those to
+ the decorated request
+ """
+ logger.debug('magick srp auth decorator called')
+
+ def wrapper(*args, **kwargs):
+ #uri = args[0]
+ # XXX Ugh!
+ # Problem with this approach.
+ # This won't work when we're using
+ # api.foo.bar
+ # Unless we keep a table with the
+ # equivalencies...
+ user, passwd = get_leap_credentials()
+
+ # XXX pass verify and server too
+ # (pop)
+ auth = SRPAuth(user, passwd)
+ kwargs['auth'] = auth
+ return fn(*args, **kwargs)
+ return wrapper
+
+
+if __name__ == "__main__":
+ """
+ To test against test_provider (twisted version)
+ Register an user: (will be valid during the session)
+ >>> python auth.py add test password
+
+ Test login with that user:
+ >>> python auth.py login test password
+ """
+
+ import sys
+
+ if len(sys.argv) not in (4, 5):
+ print 'Usage: auth <add|login> <user> <pass> [server]'
+ sys.exit(0)
+
+ action = sys.argv[1]
+ user = sys.argv[2]
+ passwd = sys.argv[3]
+
+ if len(sys.argv) == 5:
+ SERVER = sys.argv[4]
+ else:
+ SERVER = "https://localhost:8443"
+
+ if action == "login":
+
+ @srpauth_protected(
+ user=user, passwd=passwd, server=SERVER, verify=False)
+ def test_srp_protected_get(*args, **kwargs):
+ req = requests.get(*args, **kwargs)
+ req.raise_for_status
+ return req
+
+ #req = test_srp_protected_get('https://localhost:8443/1/cert')
+ req = test_srp_protected_get('%s/1/cert' % SERVER)
+ #print 'cert :', req.content[:200] + "..."
+ print req.content
+ sys.exit(0)
+
+ if action == "add":
+ auth = LeapSRPRegister(provider=SERVER, verify=False)
+ auth.register_user(user, passwd)
diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py
index 7285e74f..0bf44f59 100644
--- a/src/leap/base/checks.py
+++ b/src/leap/base/checks.py
@@ -1,15 +1,22 @@
# -*- coding: utf-8 -*-
import logging
import platform
+import re
+import socket
import netifaces
-import ping
-import requests
+import sh
from leap.base import constants
from leap.base import exceptions
logger = logging.getLogger(name=__name__)
+_platform = platform.system()
+
+#EVENTS OF NOTE
+EVENT_CONNECT_REFUSED = "[ECONNREFUSED]: Connection refused (code=111)"
+
+ICMP_TARGET = "8.8.8.8"
class LeapNetworkChecker(object):
@@ -23,7 +30,7 @@ class LeapNetworkChecker(object):
def run_all(self, checker=None):
if not checker:
checker = self
- self.error = None # ?
+ #self.error = None # ?
# for MVS
checker.check_tunnel_default_interface()
@@ -33,79 +40,132 @@ class LeapNetworkChecker(object):
if self.provider_gateway:
checker.ping_gateway(self.provider_gateway)
+ checker.parse_log_and_react([], ())
+
def check_internet_connection(self):
- try:
- # XXX remove this hardcoded random ip
- # ping leap.se or eip provider instead...?
- requests.get('http://216.172.161.165')
-
- except (requests.HTTPError, requests.RequestException) as e:
- raise exceptions.NoInternetConnection(e.message)
- except requests.ConnectionError as e:
- error = "Unidentified Connection Error"
- if e.message == "[Errno 113] No route to host":
+ if _platform == "Linux":
+ try:
+ output = sh.ping("-c", "5", "-w", "5", ICMP_TARGET)
+ # XXX should redirect this to netcheck logger.
+ # and don't clutter main log.
+ logger.debug('Network appears to be up.')
+ except sh.ErrorReturnCode_1 as e:
+ packet_loss = re.findall("\d+% packet loss", e.message)[0]
+ logger.debug("Unidentified Connection Error: " + packet_loss)
if not self.is_internet_up():
error = "No valid internet connection found."
else:
error = "Provider server appears to be down."
- logger.error(error)
- raise exceptions.NoInternetConnection(error)
- logger.debug('Network appears to be up.')
- def is_internet_up(self):
- iface, gateway = self.get_default_interface_gateway()
- self.ping_gateway(self.provider_gateway)
+ logger.error(error)
+ raise exceptions.NoInternetConnection(error)
- def check_tunnel_default_interface(self):
- """
- Raises an TunnelNotDefaultRouteError
- (including when no routes are present)
- """
- if not platform.system() == "Linux":
+ else:
raise NotImplementedError
+ def is_internet_up(self):
+ iface, gateway = self.get_default_interface_gateway()
+ try:
+ self.ping_gateway(self.provider_gateway)
+ except exceptions.NoConnectionToGateway:
+ return False
+ return True
+
+ def _get_route_table_linux(self):
+ # do not use context manager, tests pass a StringIO
f = open("/proc/net/route")
route_table = f.readlines()
f.close()
#toss out header
route_table.pop(0)
-
if not route_table:
- raise exceptions.TunnelNotDefaultRouteError()
+ raise exceptions.NoDefaultInterfaceFoundError
+ return route_table
+ def _get_def_iface_osx(self):
+ default_iface = None
+ #gateway = None
+ routes = list(sh.route('-n', 'get', ICMP_TARGET, _iter=True))
+ iface = filter(lambda l: "interface" in l, routes)
+ if not iface:
+ return None, None
+ def_ifacel = re.findall('\w+\d', iface[0])
+ default_iface = def_ifacel[0] if def_ifacel else None
+ if not default_iface:
+ return None, None
+ _gw = filter(lambda l: "gateway" in l, routes)
+ gw = re.findall('\d+\.\d+\.\d+\.\d+', _gw[0])[0]
+ return default_iface, gw
+
+ def _get_tunnel_iface_linux(self):
+ # XXX review.
+ # valid also when local router has a default entry?
+ route_table = self._get_route_table_linux()
line = route_table.pop(0)
iface, destination = line.split('\t')[0:2]
if not destination == '00000000' or not iface == 'tun0':
raise exceptions.TunnelNotDefaultRouteError()
+ return True
- def get_default_interface_gateway(self):
- """only impletemented for linux so far."""
- if not platform.system() == "Linux":
+ def check_tunnel_default_interface(self):
+ """
+ Raises an TunnelNotDefaultRouteError
+ if tun0 is not the chosen default route
+ (including when no routes are present)
+ """
+ #logger.debug('checking tunnel default interface...')
+
+ if _platform == "Linux":
+ valid = self._get_tunnel_iface_linux()
+ return valid
+ elif _platform == "Darwin":
+ default_iface, gw = self._get_def_iface_osx()
+ #logger.debug('iface: %s', default_iface)
+ if default_iface != "tun0":
+ logger.debug('tunnel not default route! gw: %s', default_iface)
+ # XXX should catch this and act accordingly...
+ # but rather, this test should only be launched
+ # when we have successfully completed a connection
+ # ... TRIGGER: Connection stablished (or whatever it is)
+ # in the logs
+ raise exceptions.TunnelNotDefaultRouteError
+ else:
+ #logger.debug('PLATFORM !!! %s', _platform)
raise NotImplementedError
- # XXX use psutil
- f = open("/proc/net/route")
- route_table = f.readlines()
- f.close()
- #toss out header
- route_table.pop(0)
-
+ def _get_def_iface_linux(self):
default_iface = None
gateway = None
+
+ route_table = self._get_route_table_linux()
while route_table:
line = route_table.pop(0)
iface, destination, gateway = line.split('\t')[0:3]
if destination == '00000000':
default_iface = iface
break
+ return default_iface, gateway
+
+ def get_default_interface_gateway(self):
+ """
+ gets the interface we are going thru.
+ (this should be merged with check tunnel default interface,
+ imo...)
+ """
+ if _platform == "Linux":
+ default_iface, gw = self._get_def_iface_linux()
+ elif _platform == "Darwin":
+ default_iface, gw = self._get_def_iface_osx()
+ else:
+ raise NotImplementedError
if not default_iface:
raise exceptions.NoDefaultInterfaceFoundError
if default_iface not in netifaces.interfaces():
raise exceptions.InterfaceNotFoundError
-
- return default_iface, gateway
+ logger.debug('-- default iface %s', default_iface)
+ return default_iface, gw
def ping_gateway(self, gateway):
# TODO: Discuss how much packet loss (%) is acceptable.
@@ -114,15 +174,40 @@ class LeapNetworkChecker(object):
# -- is it a valid ip? (there's something in util)
# -- is it a domain?
# -- can we resolve? -- raise NoDNSError if not.
- packet_loss = ping.quiet_ping(gateway)[0]
+
+ # XXX -- sh.ping implemtation needs review!
+ try:
+ output = sh.ping("-c", "10", gateway).stdout
+ except sh.ErrorReturnCode_1 as e:
+ output = e.message
+ finally:
+ packet_loss = int(re.findall("(\d+)% packet loss", output)[0])
+
+ logger.debug('packet loss %s%%' % packet_loss)
if packet_loss > constants.MAX_ICMP_PACKET_LOSS:
raise exceptions.NoConnectionToGateway
- # XXX check for name resolution servers
- # dunno what's the best way to do this...
- # check for etc/resolv entries or similar?
- # just try to resolve?
- # is there something in psutil?
+ def check_name_resolution(self, domain_name):
+ try:
+ socket.gethostbyname(domain_name)
+ return True
+ except socket.gaierror:
+ raise exceptions.CannotResolveDomainError
- # def check_name_resolution(self):
- # pass
+ def parse_log_and_react(self, log, error_matrix=None):
+ """
+ compares the recent openvpn status log to
+ strings passed in and executes the callbacks passed in.
+ @param log: openvpn log
+ @type log: list of strings
+ @param error_matrix: tuples of strings and tuples of callbacks
+ @type error_matrix: tuples strings and call backs
+ """
+ for line in log:
+ # we could compile a regex here to save some cycles up -- kali
+ for each in error_matrix:
+ error, callbacks = each
+ if error in line:
+ for cb in callbacks:
+ if callable(cb):
+ cb()
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
index cf01d1aa..e2f0beba 100644
--- a/src/leap/base/config.py
+++ b/src/leap/base/config.py
@@ -4,12 +4,15 @@ Configuration Base Class
import grp
import json
import logging
+import re
import socket
-import tempfile
+import time
import os
logger = logging.getLogger(name=__name__)
+from dateutil import parser as dateparser
+from dirspec import basedir
import requests
from leap.base import exceptions
@@ -118,23 +121,50 @@ class JSONLeapConfig(BaseLeapConfig):
" derived class")
assert issubclass(self.spec, PluggableConfig)
+ self.domain = kwargs.pop('domain', None)
self._config = self.spec(format="json")
self._config.load()
self.fetcher = kwargs.pop('fetcher', requests)
# mandatory baseconfig interface
- def save(self, to=None):
- if to is None:
- to = self.filename
- folder, filename = os.path.split(to)
- if folder and not os.path.isdir(folder):
- mkdir_p(folder)
- self._config.serialize(to)
+ def save(self, to=None, force=False):
+ """
+ force param will skip the dirty check.
+ :type force: bool
+ """
+ # XXX this force=True does not feel to right
+ # but still have to look for a better way
+ # of dealing with dirtiness and the
+ # trick of loading remote config only
+ # when newer.
+
+ if force:
+ do_save = True
+ else:
+ do_save = self._config.is_dirty()
+
+ if do_save:
+ if to is None:
+ to = self.filename
+ folder, filename = os.path.split(to)
+ if folder and not os.path.isdir(folder):
+ mkdir_p(folder)
+ self._config.serialize(to)
+ return True
+
+ else:
+ return False
+
+ def load(self, fromfile=None, from_uri=None, fetcher=None,
+ force_download=False, verify=True):
- def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False):
if from_uri is not None:
- fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify)
+ fetched = self.fetch(
+ from_uri,
+ fetcher=fetcher,
+ verify=verify,
+ force_dl=force_download)
if fetched:
return
if fromfile is None:
@@ -145,33 +175,68 @@ class JSONLeapConfig(BaseLeapConfig):
logger.error('tried to load config from non-existent path')
logger.error('Not Found: %s', fromfile)
- def fetch(self, uri, fetcher=None, verify=True):
+ def fetch(self, uri, fetcher=None, verify=True, force_dl=False):
if not fetcher:
fetcher = self.fetcher
- logger.debug('verify: %s', verify)
- logger.debug('uri: %s', uri)
- request = fetcher.get(uri, verify=verify)
- # XXX should send a if-modified-since header
- # XXX get 404, ...
- # and raise a UnableToFetch...
+ logger.debug('uri: %s (verify: %s)' % (uri, verify))
+
+ rargs = (uri, )
+ rkwargs = {'verify': verify}
+ headers = {}
+
+ curmtime = self.get_mtime() if not force_dl else None
+ if curmtime:
+ logger.debug('requesting with if-modified-since %s' % curmtime)
+ headers['if-modified-since'] = curmtime
+ rkwargs['headers'] = headers
+
+ #request = fetcher.get(uri, verify=verify)
+ request = fetcher.get(*rargs, **rkwargs)
request.raise_for_status()
- fd, fname = tempfile.mkstemp(suffix=".json")
- if request.json:
- self._config.load(json.dumps(request.json))
+ if request.status_code == 304:
+ logger.debug('...304 Not Changed')
+ # On this point, we have to assume that
+ # we HAD the filename. If that filename is corruct,
+ # we should enforce a force_download in the load
+ # method above.
+ self._config.load(fromfile=self.filename)
+ return True
+ if request.json:
+ mtime = None
+ last_modified = request.headers.get('last-modified', None)
+ if last_modified:
+ _mtime = dateparser.parse(last_modified)
+ mtime = int(_mtime.strftime("%s"))
+ if callable(request.json):
+ _json = request.json()
+ else:
+ # back-compat
+ _json = request.json
+ self._config.load(json.dumps(_json), mtime=mtime)
+ self._config.set_dirty()
else:
# not request.json
# might be server did not announce content properly,
# let's try deserializing all the same.
try:
self._config.load(request.content)
+ self._config.set_dirty()
except ValueError:
raise eipexceptions.LeapBadConfigFetchedError
return True
+ def get_mtime(self):
+ try:
+ _mtime = os.stat(self.filename)[8]
+ mtime = time.strftime("%c GMT", time.gmtime(_mtime))
+ return mtime
+ except OSError:
+ return None
+
def get_config(self):
return self._config.config
@@ -216,15 +281,13 @@ def get_config_dir():
@rparam: config path
@rtype: string
"""
- # TODO
- # check for $XDG_CONFIG_HOME var?
- # get a more sensible path for win/mac
- # kclair: opinion? ^^
-
- return os.path.expanduser(
- os.path.join('~',
- '.config',
- 'leap'))
+ home = os.path.expanduser("~")
+ if re.findall("leap_tests-[a-zA-Z0-9]{6}", home):
+ # we're inside a test! :)
+ return os.path.join(home, ".config/leap")
+ else:
+ return os.path.join(basedir.default_config_home,
+ 'leap')
def get_config_file(filename, folder=None):
@@ -252,6 +315,15 @@ def get_default_provider_path():
return default_provider_path
+def get_provider_path(domain):
+ # XXX if not domain, return get_default_provider_path
+ default_subpath = os.path.join("providers", domain)
+ provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return provider_path
+
+
def validate_ip(ip_str):
"""
raises exception if the ip_str is
@@ -261,7 +333,11 @@ def validate_ip(ip_str):
def get_username():
- return os.getlogin()
+ try:
+ return os.getlogin()
+ except OSError as e:
+ import pwd
+ return pwd.getpwuid(os.getuid())[0]
def get_groupname():
diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py
index e478538d..41d13935 100644
--- a/src/leap/base/connection.py
+++ b/src/leap/base/connection.py
@@ -37,11 +37,11 @@ class Connection(Authentication):
"""
pass
- def shutdown(self):
- """
- shutdown and quit
- """
- self.desired_con_state = self.status.DISCONNECTED
+ #def shutdown(self):
+ #"""
+ #shutdown and quit
+ #"""
+ #self.desired_con_state = self.status.DISCONNECTED
def connection_state(self):
"""
diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py
index f7be8d98..f5665e5f 100644
--- a/src/leap/base/constants.py
+++ b/src/leap/base/constants.py
@@ -1,6 +1,7 @@
"""constants to be used in base module"""
from leap import __branding
-APP_NAME = __branding.get("short_name", "leap")
+APP_NAME = __branding.get("short_name", "leap-client")
+OPENVPN_BIN = "openvpn"
# default provider placeholder
# using `example.org` we make sure that this
@@ -14,18 +15,27 @@ DEFAULT_PROVIDER = __branding.get(
DEFINITION_EXPECTED_PATH = "provider.json"
DEFAULT_PROVIDER_DEFINITION = {
- u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER,
- u'api_version': u'0.1.0',
- u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e',
- u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER,
- u'description': {u'en': u'This is a test provider'},
- u'display_name': {u'en': u'Test Provider'},
- u'domain': u'%s' % DEFAULT_PROVIDER,
- u'enrollment_policy': u'open',
- u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21',
- u'serial': 1,
- u'services': [u'eip'],
- u'version': u'0.1.0'}
+ u"api_uri": "https://api.%s/" % DEFAULT_PROVIDER,
+ u"api_version": u"1",
+ u"ca_cert_fingerprint": "SHA256: fff",
+ u"ca_cert_uri": u"https://%s/ca.crt" % DEFAULT_PROVIDER,
+ u"default_language": u"en",
+ u"description": {
+ u"en": u"A demonstration service provider using the LEAP platform"
+ },
+ u"domain": "%s" % DEFAULT_PROVIDER,
+ u"enrollment_policy": u"open",
+ u"languages": [
+ u"en"
+ ],
+ u"name": {
+ u"en": u"Test Provider"
+ },
+ u"services": [
+ "openvpn"
+ ]
+}
+
MAX_ICMP_PACKET_LOSS = 10
diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py
index f12a49d5..2e31b33b 100644
--- a/src/leap/base/exceptions.py
+++ b/src/leap/base/exceptions.py
@@ -14,6 +14,7 @@ Exception attributes and their meaning/uses
* usermessage: the message that will be passed to user in ErrorDialogs
in Qt-land.
"""
+from leap.util.translations import translate
class LeapException(Exception):
@@ -22,6 +23,7 @@ class LeapException(Exception):
sets some parameters that we will check
during error checking routines
"""
+
critical = False
failfirst = False
warning = False
@@ -46,27 +48,50 @@ class ImproperlyConfigured(Exception):
pass
-class NoDefaultInterfaceFoundError(LeapException):
- message = "no default interface found"
- usermessage = "Looks like your computer is not connected to the internet"
+# NOTE: "Errors" (context) has to be a explicit string!
class InterfaceNotFoundError(LeapException):
# XXX should take iface arg on init maybe?
message = "interface not found"
+ usermessage = translate(
+ "Errors",
+ "Interface not found")
+
+
+class NoDefaultInterfaceFoundError(LeapException):
+ message = "no default interface found"
+ usermessage = translate(
+ "Errors",
+ "Looks like your computer "
+ "is not connected to the internet")
class NoConnectionToGateway(CriticalError):
message = "no connection to gateway"
- usermessage = "Looks like there are problems with your internet connection"
+ usermessage = translate(
+ "Errors",
+ "Looks like there are problems "
+ "with your internet connection")
class NoInternetConnection(CriticalError):
message = "No Internet connection found"
- usermessage = "It looks like there is no internet connection."
+ usermessage = translate(
+ "Errors",
+ "It looks like there is no internet connection.")
# and now we try to connect to our web to troubleshoot LOL :P
-class TunnelNotDefaultRouteError(CriticalError):
+class CannotResolveDomainError(LeapException):
+ message = "Cannot resolve domain"
+ usermessage = translate(
+ "Errors",
+ "Domain cannot be found")
+
+
+class TunnelNotDefaultRouteError(LeapException):
message = "Tunnel connection dissapeared. VPN down?"
- usermessage = "The Encrypted Connection was lost. Shutting down..."
+ usermessage = translate(
+ "Errors",
+ "The Encrypted Connection was lost.")
diff --git a/src/leap/base/network.py b/src/leap/base/network.py
index 3891b00a..d841e692 100644
--- a/src/leap/base/network.py
+++ b/src/leap/base/network.py
@@ -3,10 +3,11 @@ from __future__ import (print_function)
import logging
import threading
-from leap.eip.config import get_eip_gateway
+from leap.eip import config as eipconfig
from leap.base.checks import LeapNetworkChecker
from leap.base.constants import ROUTE_CHECK_INTERVAL
from leap.base.exceptions import TunnelNotDefaultRouteError
+from leap.util.misc import null_check
from leap.util.coroutines import (launch_thread, process_events)
from time import sleep
@@ -20,24 +21,34 @@ class NetworkCheckerThread(object):
connection.
"""
def __init__(self, *args, **kwargs):
+
self.status_signals = kwargs.pop('status_signals', None)
- #self.watcher_cb = kwargs.pop('status_signals', None)
self.error_cb = kwargs.pop(
'error_cb',
lambda exc: logger.error("%s", exc.message))
self.shutdown = threading.Event()
- # XXX get provider_gateway and pass it to checker
- # see in eip.config for function
- # #718
+ # XXX get provider passed here
+ provider = kwargs.pop('provider', None)
+ null_check(provider, 'provider')
+
+ eipconf = eipconfig.EIPConfig(domain=provider)
+ eipconf.load()
+ eipserviceconf = eipconfig.EIPServiceConfig(domain=provider)
+ eipserviceconf.load()
+
+ gw = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipserviceconf)
self.checker = LeapNetworkChecker(
- provider_gw = get_eip_gateway())
+ provider_gw=gw)
def start(self):
self.process_handle = self._launch_recurrent_network_checks(
(self.error_cb,))
def stop(self):
+ self.process_handle.join(timeout=0.1)
self.shutdown.set()
logger.debug("network checked stopped.")
@@ -49,6 +60,7 @@ class NetworkCheckerThread(object):
#here all the observers in fail_callbacks expect one positional argument,
#which is exception so we can try by passing a lambda with logger to
#check it works.
+
def _network_checks_thread(self, fail_callbacks):
#TODO: replace this with waiting for a signal from openvpn
while True:
@@ -59,11 +71,15 @@ class NetworkCheckerThread(object):
# XXX ??? why do we sleep here???
# aa: If the openvpn isn't up and running yet,
# let's give it a moment to breath.
+ #logger.error('NOT DEFAULT ROUTE!----')
+ # Instead of this, we should flag when the
+ # iface IS SUPPOSED to be up imo. -- kali
sleep(1)
fail_observer_dict = dict(((
observer,
process_events(observer)) for observer in fail_callbacks))
+
while not self.shutdown.is_set():
try:
self.checker.check_tunnel_default_interface()
@@ -73,11 +89,18 @@ class NetworkCheckerThread(object):
for obs in fail_observer_dict:
fail_observer_dict[obs].send(exc)
sleep(ROUTE_CHECK_INTERVAL)
+
#reset event
+ # I see a problem with this. You cannot stop it, it
+ # resets itself forever. -- kali
+
+ # XXX use QTimer for the recurrent triggers,
+ # and ditch the sleeps.
+ logger.debug('resetting event')
self.shutdown.clear()
def _launch_recurrent_network_checks(self, fail_callbacks):
- #we need to wrap the fail callback in a tuple
+ # XXX reimplement using QTimer -- kali
watcher = launch_thread(
self._network_checks_thread,
(fail_callbacks,))
diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py
index b8615ad8..3517db6b 100644
--- a/src/leap/base/pluggableconfig.py
+++ b/src/leap/base/pluggableconfig.py
@@ -10,6 +10,8 @@ import urlparse
import jsonschema
+from leap.util.translations import LEAPTranslatable
+
logger = logging.getLogger(__name__)
@@ -118,7 +120,6 @@ adaptors['json'] = JSONAdaptor()
# to proper python types.
# TODO:
-# - multilingual object.
# - HTTPS uri
@@ -132,6 +133,20 @@ class DateType(object):
return time.strftime(self.fmt, data)
+class TranslatableType(object):
+ """
+ a type that casts to LEAPTranslatable objects.
+ Used for labels we get from providers and stuff.
+ """
+
+ def to_python(self, data):
+ return LEAPTranslatable(data)
+
+ # needed? we already have an extended dict...
+ #def get_prep_value(self, data):
+ #return dict(data)
+
+
class URIType(object):
def to_python(self, data):
@@ -164,6 +179,7 @@ types = {
'date': DateType(),
'uri': URIType(),
'https-uri': HTTPSURIType(),
+ 'translatable': TranslatableType(),
}
@@ -180,6 +196,8 @@ class PluggableConfig(object):
self.adaptors = adaptors
self.types = types
self._format = format
+ self.mtime = None
+ self.dirty = False
@property
def option_dict(self):
@@ -319,6 +337,13 @@ class PluggableConfig(object):
serializable = self.prep_value(config)
adaptor.write(serializable, filename)
+ if self.mtime:
+ self.touch_mtime(filename)
+
+ def touch_mtime(self, filename):
+ mtime = self.mtime
+ os.utime(filename, (mtime, mtime))
+
def deserialize(self, string=None, fromfile=None, format=None):
"""
load configuration from a file or string
@@ -364,6 +389,12 @@ class PluggableConfig(object):
content = _try_deserialize()
return content
+ def set_dirty(self):
+ self.dirty = True
+
+ def is_dirty(self):
+ return self.dirty
+
def load(self, *args, **kwargs):
"""
load from string or file
@@ -373,6 +404,8 @@ class PluggableConfig(object):
"""
string = args[0] if args else None
fromfile = kwargs.get("fromfile", None)
+ mtime = kwargs.pop("mtime", None)
+ self.mtime = mtime
content = None
# start with defaults, so we can
@@ -402,7 +435,8 @@ class PluggableConfig(object):
return True
-def testmain():
+def testmain(): # pragma: no cover
+
from tests import test_validation as t
import pprint
diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py
index 7b219cc7..d41f3695 100644
--- a/src/leap/base/providers.py
+++ b/src/leap/base/providers.py
@@ -7,20 +7,20 @@ class LeapProviderDefinition(baseconfig.JSONLeapConfig):
spec = specs.leap_provider_spec
def _get_slug(self):
- provider_path = baseconfig.get_default_provider_path()
+ domain = getattr(self, 'domain', None)
+ if domain:
+ path = baseconfig.get_provider_path(domain)
+ else:
+ path = baseconfig.get_default_provider_path()
+
return baseconfig.get_config_file(
- 'provider.json',
- folder=provider_path)
+ 'provider.json', folder=path)
def _set_slug(self, *args, **kwargs):
raise AttributeError("you cannot set slug")
slug = property(_get_slug, _set_slug)
- # TODO (MVS+)
- # we will construct slug from providers/%s/definition.json
- # where %s is domain name. we can get that on __init__
-
class LeapProviderSet(object):
# we gather them from the filesystem
diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py
index b4bb8dcf..f57d7e9c 100644
--- a/src/leap/base/specs.py
+++ b/src/leap/base/specs.py
@@ -2,28 +2,36 @@ leap_provider_spec = {
'description': 'provider definition',
'type': 'object',
'properties': {
- 'serial': {
- 'type': int,
- 'default': 1,
- 'required': True,
- },
+ #'serial': {
+ #'type': int,
+ #'default': 1,
+ #'required': True,
+ #},
'version': {
'type': unicode,
'default': '0.1.0'
#'required': True
},
+ "default_language": {
+ 'type': unicode,
+ 'default': 'en'
+ },
'domain': {
'type': unicode, # XXX define uri type
'default': 'testprovider.example.org'
#'required': True,
},
- 'display_name': {
- 'type': dict, # XXX multilingual object?
+ 'name': {
+ #'type': LEAPTranslatable,
+ 'type': dict,
+ 'format': 'translatable',
'default': {u'en': u'Test Provider'}
#'required': True
},
'description': {
+ #'type': LEAPTranslatable,
'type': dict,
+ 'format': 'translatable',
'default': {u'en': u'Test provider'}
},
'enrollment_policy': {
diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/base/tests/__init__.py
diff --git a/src/leap/base/tests/test_auth.py b/src/leap/base/tests/test_auth.py
new file mode 100644
index 00000000..b3009a9b
--- /dev/null
+++ b/src/leap/base/tests/test_auth.py
@@ -0,0 +1,58 @@
+from BaseHTTPServer import BaseHTTPRequestHandler
+import urlparse
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+import requests
+#from mock import Mock
+
+from leap.base import auth
+#from leap.base import exceptions
+from leap.eip.tests.test_checks import NoLogRequestHandler
+from leap.testing.basetest import BaseLeapTest
+from leap.testing.https_server import BaseHTTPSServerTestCase
+
+
+class LeapSRPRegisterTests(BaseHTTPSServerTestCase, BaseLeapTest):
+ __name__ = "leap_srp_register_test"
+ provider = "testprovider.example.org"
+
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ responses = {
+ '/': ['OK', '']}
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_srp_auth_should_implement_check_methods(self):
+ SERVER = "https://localhost:8443"
+ srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False)
+
+ self.assertTrue(hasattr(srp_auth, "init_session"),
+ "missing meth")
+ self.assertTrue(hasattr(srp_auth, "get_registration_uri"),
+ "missing meth")
+ self.assertTrue(hasattr(srp_auth, "register_user"),
+ "missing meth")
+
+ def test_srp_auth_basic_functionality(self):
+ SERVER = "https://localhost:8443"
+ srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False)
+
+ self.assertIsInstance(srp_auth.session, requests.sessions.Session)
+ self.assertEqual(
+ srp_auth.get_registration_uri(),
+ "https://localhost:8443/1/users")
diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py
index bec09ce6..8126755b 100644
--- a/src/leap/base/tests/test_checks.py
+++ b/src/leap/base/tests/test_checks.py
@@ -3,13 +3,11 @@ try:
except ImportError:
import unittest
import os
+import sh
from mock import (patch, Mock)
from StringIO import StringIO
-import ping
-import requests
-
from leap.base import checks
from leap.base import exceptions
from leap.testing.basetest import BaseLeapTest
@@ -21,6 +19,7 @@ class LeapNetworkCheckTest(BaseLeapTest):
__name__ = "leap_network_check_tests"
def setUp(self):
+ os.environ['PATH'] += ':/bin'
pass
def tearDown(self):
@@ -37,16 +36,27 @@ class LeapNetworkCheckTest(BaseLeapTest):
"missing meth")
self.assertTrue(hasattr(checker, "ping_gateway"),
"missing meth")
+ self.assertTrue(hasattr(checker, "parse_log_and_react"),
+ "missing meth")
def test_checker_should_actually_call_all_tests(self):
checker = checks.LeapNetworkChecker()
+ mc = Mock()
+ checker.run_all(checker=mc)
+ self.assertTrue(mc.check_internet_connection.called, "not called")
+ self.assertTrue(mc.check_tunnel_default_interface.called, "not called")
+ self.assertTrue(mc.is_internet_up.called, "not called")
+ self.assertTrue(mc.parse_log_and_react.called, "not called")
+ # ping gateway only called if we pass provider_gw
+ checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0")
mc = Mock()
checker.run_all(checker=mc)
self.assertTrue(mc.check_internet_connection.called, "not called")
self.assertTrue(mc.check_tunnel_default_interface.called, "not called")
self.assertTrue(mc.ping_gateway.called, "not called")
self.assertTrue(mc.is_internet_up.called, "not called")
+ self.assertTrue(mc.parse_log_and_react.called, "not called")
def test_get_default_interface_no_interface(self):
checker = checks.LeapNetworkChecker()
@@ -65,14 +75,6 @@ class LeapNetworkCheckTest(BaseLeapTest):
mock_open.return_value = StringIO(
"Iface\tDestination Gateway\t"
"Flags\tRefCntd\tUse\tMetric\t"
- "Mask\tMTU\tWindow\tIRTT")
- checker.check_tunnel_default_interface()
-
- with patch('leap.base.checks.open', create=True) as mock_open:
- with self.assertRaises(exceptions.TunnelNotDefaultRouteError):
- mock_open.return_value = StringIO(
- "Iface\tDestination Gateway\t"
- "Flags\tRefCntd\tUse\tMetric\t"
"Mask\tMTU\tWindow\tIRTT\n"
"wlan0\t00000000\t0102A8C0\t"
"0003\t0\t0\t0\t00000000\t0\t0\t0")
@@ -88,30 +90,88 @@ class LeapNetworkCheckTest(BaseLeapTest):
def test_ping_gateway_fail(self):
checker = checks.LeapNetworkChecker()
- with patch.object(ping, "quiet_ping") as mocked_ping:
+ with patch.object(sh, "ping") as mocked_ping:
with self.assertRaises(exceptions.NoConnectionToGateway):
- mocked_ping.return_value = [11, "", ""]
+ mocked_ping.return_value = Mock
+ mocked_ping.return_value.stdout = "11% packet loss"
checker.ping_gateway("4.2.2.2")
- def test_check_internet_connection_failures(self):
+ def test_ping_gateway(self):
checker = checks.LeapNetworkChecker()
- with patch.object(requests, "get") as mocked_get:
- mocked_get.side_effect = requests.HTTPError
- with self.assertRaises(exceptions.NoInternetConnection):
- checker.check_internet_connection()
+ with patch.object(sh, "ping") as mocked_ping:
+ mocked_ping.return_value = Mock
+ mocked_ping.return_value.stdout = """
+PING 4.2.2.2 (4.2.2.2) 56(84) bytes of data.
+64 bytes from 4.2.2.2: icmp_req=1 ttl=54 time=33.8 ms
+64 bytes from 4.2.2.2: icmp_req=2 ttl=54 time=30.6 ms
+64 bytes from 4.2.2.2: icmp_req=3 ttl=54 time=31.4 ms
+64 bytes from 4.2.2.2: icmp_req=4 ttl=54 time=36.1 ms
+64 bytes from 4.2.2.2: icmp_req=5 ttl=54 time=30.8 ms
+64 bytes from 4.2.2.2: icmp_req=6 ttl=54 time=30.4 ms
+64 bytes from 4.2.2.2: icmp_req=7 ttl=54 time=30.7 ms
+64 bytes from 4.2.2.2: icmp_req=8 ttl=54 time=32.7 ms
+64 bytes from 4.2.2.2: icmp_req=9 ttl=54 time=31.4 ms
+64 bytes from 4.2.2.2: icmp_req=10 ttl=54 time=33.3 ms
+
+--- 4.2.2.2 ping statistics ---
+10 packets transmitted, 10 received, 0% packet loss, time 9016ms
+rtt min/avg/max/mdev = 30.497/32.172/36.161/1.755 ms"""
+ checker.ping_gateway("4.2.2.2")
- with patch.object(requests, "get") as mocked_get:
- mocked_get.side_effect = requests.RequestException
+ def test_check_internet_connection_failures(self):
+ checker = checks.LeapNetworkChecker()
+ TimeoutError = get_ping_timeout_error()
+ with patch.object(sh, "ping") as mocked_ping:
+ mocked_ping.side_effect = TimeoutError
with self.assertRaises(exceptions.NoInternetConnection):
- checker.check_internet_connection()
+ with patch.object(checker, "ping_gateway") as mock_gateway:
+ mock_gateway.side_effect = exceptions.NoConnectionToGateway
+ checker.check_internet_connection()
- #TODO: Mock possible errors that can be raised by is_internet_up
- with patch.object(requests, "get") as mocked_get:
- mocked_get.side_effect = requests.ConnectionError
+ with patch.object(sh, "ping") as mocked_ping:
+ mocked_ping.side_effect = TimeoutError
with self.assertRaises(exceptions.NoInternetConnection):
- checker.check_internet_connection()
+ with patch.object(checker, "ping_gateway") as mock_gateway:
+ mock_gateway.return_value = True
+ checker.check_internet_connection()
- @unittest.skipUnless(_uid == 0, "root only")
- def test_ping_gateway(self):
+ def test_parse_log_and_react(self):
checker = checks.LeapNetworkChecker()
- checker.ping_gateway("4.2.2.2")
+ to_call = Mock()
+ log = [("leap.openvpn - INFO - Mon Nov 19 13:36:24 2012 "
+ "read UDPv4 [ECONNREFUSED]: Connection refused (code=111)")]
+ err_matrix = [(checks.EVENT_CONNECT_REFUSED, (to_call, ))]
+ checker.parse_log_and_react(log, err_matrix)
+ self.assertTrue(to_call.called)
+
+ log = [("2012-11-19 13:36:26,177 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command "
+ "failed: external program exited"),
+ ("2012-11-19 13:36:26,178 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command "
+ "failed: external program exited"),
+ ("2012-11-19 13:36:26,180 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command "
+ "failed: external program exited"),
+ ("2012-11-19 13:36:26,181 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 /sbin/ifconfig tun0 0.0.0.0"),
+ ("2012-11-19 13:36:26,182 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 Linux ip addr del failed: external "
+ "program exited with error stat"),
+ ("2012-11-19 13:36:26,183 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:26 2012 SIGTERM[hard,] received, process"
+ "exiting"), ]
+ to_call.reset_mock()
+ checker.parse_log_and_react(log, err_matrix)
+ self.assertFalse(to_call.called)
+
+ to_call.reset_mock()
+ checker.parse_log_and_react([], err_matrix)
+ self.assertFalse(to_call.called)
+
+
+def get_ping_timeout_error():
+ try:
+ sh.ping("-c", "1", "-w", "1", "8.8.7.7")
+ except Exception as e:
+ return e
diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py
index 8d3b8847..f257f54d 100644
--- a/src/leap/base/tests/test_providers.py
+++ b/src/leap/base/tests/test_providers.py
@@ -8,18 +8,22 @@ import os
import jsonschema
-from leap import __branding as BRANDING
+#from leap import __branding as BRANDING
from leap.testing.basetest import BaseLeapTest
from leap.base import providers
EXPECTED_DEFAULT_CONFIG = {
u"api_version": u"0.1.0",
- u"description": {u'en': u"Test provider"},
- u"display_name": {u'en': u"Test Provider"},
+ #u"description": "LEAPTranslatable<{u'en': u'Test provider'}>",
+ u"description": {u'en': u'Test provider'},
+ u"default_language": u"en",
+ #u"display_name": {u'en': u"Test Provider"},
u"domain": u"testprovider.example.org",
+ #u'name': "LEAPTranslatable<{u'en': u'Test Provider'}>",
+ u'name': {u'en': u'Test Provider'},
u"enrollment_policy": u"open",
- u"serial": 1,
+ #u"serial": 1,
u"services": [
u"eip"
],
@@ -30,9 +34,11 @@ EXPECTED_DEFAULT_CONFIG = {
class TestLeapProviderDefinition(BaseLeapTest):
def setUp(self):
- self.definition = providers.LeapProviderDefinition()
- self.definition.save()
- self.definition.load()
+ self.domain = "testprovider.example.org"
+ self.definition = providers.LeapProviderDefinition(
+ domain=self.domain)
+ self.definition.save(force=True)
+ self.definition.load() # why have to load after save??
self.config = self.definition.config
def tearDown(self):
@@ -51,7 +57,7 @@ class TestLeapProviderDefinition(BaseLeapTest):
os.path.join(
self.home,
'.config', 'leap', 'providers',
- '%s' % BRANDING.get('provider_domain'),
+ '%s' % self.domain,
'provider.json'))
with self.assertRaises(AttributeError):
self.definition.slug = 23
@@ -59,9 +65,10 @@ class TestLeapProviderDefinition(BaseLeapTest):
def test_provider_dump(self):
# check a good provider definition is dumped to disk
self.testfile = self.get_tempfile('test.json')
- self.definition.save(to=self.testfile)
+ self.definition.save(to=self.testfile, force=True)
deserialized = json.load(open(self.testfile, 'rb'))
self.maxDiff = None
+ #import ipdb;ipdb.set_trace()
self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)
def test_provider_dump_to_slug(self):
@@ -80,13 +87,15 @@ class TestLeapProviderDefinition(BaseLeapTest):
with open(self.testfile, 'w') as wf:
wf.write(json.dumps(EXPECTED_DEFAULT_CONFIG))
self.definition.load(fromfile=self.testfile)
- self.assertDictEqual(self.config,
- EXPECTED_DEFAULT_CONFIG)
+ #self.assertDictEqual(self.config,
+ #EXPECTED_DEFAULT_CONFIG)
+ self.assertItemsEqual(self.config, EXPECTED_DEFAULT_CONFIG)
def test_provider_validation(self):
self.definition.validate(self.config)
_config = copy.deepcopy(self.config)
- _config['serial'] = 'aaa'
+ # bad type, raise validation error
+ _config['domain'] = 111
with self.assertRaises(jsonschema.ValidationError):
self.definition.validate(_config)
diff --git a/src/leap/baseapp/dialogs.py b/src/leap/baseapp/dialogs.py
index 3cb539cf..d256fc99 100644
--- a/src/leap/baseapp/dialogs.py
+++ b/src/leap/baseapp/dialogs.py
@@ -23,7 +23,8 @@ class ErrorDialog(QDialog):
def warningMessage(self, msg, label):
msgBox = QMessageBox(QMessageBox.Warning,
- "QMessageBox.warning()", msg,
+ "LEAP Client Error",
+ msg,
QMessageBox.NoButton, self)
msgBox.addButton("&Ok", QMessageBox.AcceptRole)
if msgBox.exec_() == QMessageBox.AcceptRole:
@@ -34,7 +35,8 @@ class ErrorDialog(QDialog):
def criticalMessage(self, msg, label):
msgBox = QMessageBox(QMessageBox.Critical,
- "QMessageBox.critical()", msg,
+ "LEAP Client Error",
+ msg,
QMessageBox.NoButton, self)
msgBox.addButton("&Ok", QMessageBox.AcceptRole)
msgBox.exec_()
@@ -49,7 +51,8 @@ class ErrorDialog(QDialog):
def confirmMessage(self, msg, label, action):
msgBox = QMessageBox(QMessageBox.Critical,
- "QMessageBox.critical()", msg,
+ self.tr("LEAP Client Error"),
+ msg,
QMessageBox.NoButton, self)
msgBox.addButton("&Ok", QMessageBox.AcceptRole)
msgBox.addButton("&Cancel", QMessageBox.RejectRole)
diff --git a/src/leap/baseapp/eip.py b/src/leap/baseapp/eip.py
index 93dce3ac..adc9ba68 100644
--- a/src/leap/baseapp/eip.py
+++ b/src/leap/baseapp/eip.py
@@ -9,6 +9,8 @@ from leap.baseapp.dialogs import ErrorDialog
from leap.baseapp import constants
from leap.eip import exceptions as eip_exceptions
from leap.eip.eipconnection import EIPConnection
+from leap.base.checks import EVENT_CONNECT_REFUSED
+from leap.util import geo
logger = logging.getLogger(name=__name__)
@@ -21,10 +23,12 @@ class EIPConductorAppMixin(object):
Connects the eip connect/disconnect logic
to the switches in the app (buttons/menu items).
"""
+ ERR_DIALOG = False
def __init__(self, *args, **kwargs):
opts = kwargs.pop('opts')
config_file = getattr(opts, 'config_file', None)
+ provider = kwargs.pop('provider')
self.eip_service_started = False
@@ -36,10 +40,11 @@ class EIPConductorAppMixin(object):
self.conductor = EIPConnection(
watcher_cb=self.newLogLine.emit,
config_file=config_file,
- checker_signals=(self.changeLeapStatus.emit, ),
- status_signals=(self.statusChange.emit, ),
+ checker_signals=(self.eipStatusChange.emit, ),
+ status_signals=(self.openvpnStatusChange.emit, ),
debug=self.debugmode,
- ovpn_verbosity=opts.openvpn_verb)
+ ovpn_verbosity=opts.openvpn_verb,
+ provider=provider)
self.skip_download = opts.no_provider_checks
self.skip_verify = opts.no_ca_verify
@@ -91,6 +96,15 @@ class EIPConductorAppMixin(object):
in the future we plan to derive errors to
our log viewer.
"""
+ if self.ERR_DIALOG:
+ logger.warning('another error dialog suppressed')
+ return
+
+ # XXX this is actually a one-shot.
+ # On the dialog there should be
+ # a reset signal binded to the ok button
+ # or something like that.
+ self.ERR_DIALOG = True
if getattr(error, 'usermessage', None):
message = error.usermessage
@@ -110,6 +124,7 @@ class EIPConductorAppMixin(object):
ErrorDialog(errtype="critical",
msg=message,
label="critical error")
+
elif error.warning:
logger.warning(error.message)
@@ -137,14 +152,14 @@ class EIPConductorAppMixin(object):
# is not ready yet.
return
- if self.conductor.with_errors:
+ #if self.conductor.with_errors:
#XXX how to wait on pkexec???
#something better that this workaround, plz!!
#I removed the pkexec pass authentication at all.
#time.sleep(5)
#logger.debug('timeout')
- logger.error('errors. disconnect')
- self.start_or_stopVPN() # is stop
+ #logger.error('errors. disconnect')
+ #self.start_or_stopVPN() # is stop
state = self.conductor.poll_connection_state()
if not state:
@@ -160,6 +175,8 @@ class EIPConductorAppMixin(object):
self.status_label.setText(con_status)
self.ip_label.setText(ip)
self.remote_label.setText(remote)
+ self.remote_country.setText(
+ geo.get_country_name(remote))
# status i/o
@@ -172,19 +189,27 @@ class EIPConductorAppMixin(object):
self.tun_read_bytes.setText(tun_read)
self.tun_write_bytes.setText(tun_write)
+ # connection information via management interface
+ log = self.conductor.get_log()
+ error_matrix = [(EVENT_CONNECT_REFUSED, (self.start_or_stopVPN, ))]
+ if hasattr(self.network_checker, 'checker'):
+ self.network_checker.checker.parse_log_and_react(log, error_matrix)
+
@QtCore.pyqtSlot()
- def start_or_stopVPN(self):
+ def start_or_stopVPN(self, **kwargs):
"""
stub for running child process with vpn
"""
if self.conductor.has_errors():
logger.debug('not starting vpn; conductor has errors')
+ return
if self.eip_service_started is False:
try:
self.conductor.connect()
except eip_exceptions.EIPNoCommandError as exc:
+ logger.error('tried to run openvpn but no command is set')
self.triggerEIPError.emit(exc)
except Exception as err:
@@ -193,7 +218,7 @@ class EIPConductorAppMixin(object):
else:
# no errors, so go on.
if self.debugmode:
- self.startStopButton.setText('&Disconnect')
+ self.startStopButton.setText(self.tr('&Disconnect'))
self.eip_service_started = True
self.toggleEIPAct()
@@ -201,14 +226,13 @@ class EIPConductorAppMixin(object):
# we could bring Timer Init to this Mixin
# or to its own Mixin.
self.timer.start(constants.TIMER_MILLISECONDS)
- self.network_checker.start()
return
if self.eip_service_started is True:
self.network_checker.stop()
self.conductor.disconnect()
if self.debugmode:
- self.startStopButton.setText('&Connect')
+ self.startStopButton.setText(self.tr('&Connect'))
self.eip_service_started = False
self.toggleEIPAct()
self.timer.stop()
diff --git a/src/leap/baseapp/leap_app.py b/src/leap/baseapp/leap_app.py
index 6ffb08a8..4d3aebd6 100644
--- a/src/leap/baseapp/leap_app.py
+++ b/src/leap/baseapp/leap_app.py
@@ -52,7 +52,7 @@ class MainWindowMixin(object):
self.firstRunWizardAct = QtGui.QAction(
"&First run wizard...", self,
- triggered=self.launch_first_run_wizard)
+ triggered=self.stop_connection_and_launch_first_run_wizard)
self.aboutAct = QtGui.QAction("&About", self, triggered=self.about)
#self.aboutQtAct = QtGui.QAction("About &Qt", self,
@@ -74,16 +74,21 @@ class MainWindowMixin(object):
self.menuBar().addMenu(self.settingsMenu)
self.menuBar().addMenu(self.helpMenu)
- def launch_first_run_wizard(self):
+ def stop_connection_and_launch_first_run_wizard(self):
settings = QtCore.QSettings()
settings.setValue('FirstRunWizardDone', False)
logger.debug('should run first run wizard again...')
- from leap.gui.firstrunwizard import FirstRunWizard
- wizard = FirstRunWizard(
- parent=self,
- success_cb=self.initReady.emit)
- wizard.show()
+ status = self.conductor.get_icon_name()
+ if status != "disconnected":
+ self.start_or_stopVPN()
+
+ self.launch_first_run_wizard()
+ #from leap.gui.firstrunwizard import FirstRunWizard
+ #wizard = FirstRunWizard(
+ #parent=self,
+ #success_cb=self.initReady.emit)
+ #wizard.show()
def set_app_icon(self):
icon = QtGui.QIcon(APP_LOGO)
@@ -127,8 +132,8 @@ class MainWindowMixin(object):
"context menu of the system tray entry.")
self.hide()
event.ignore()
- if self.debugmode:
- self.cleanupAndQuit()
+ return
+ self.cleanupAndQuit()
def cleanupAndQuit(self):
"""
@@ -143,6 +148,6 @@ class MainWindowMixin(object):
# in conductor
# XXX send signal instead?
logger.info('Shutting down')
- self.conductor.cleanup()
+ self.conductor.disconnect(shutdown=True)
logger.info('Exiting. Bye.')
QtGui.qApp.quit()
diff --git a/src/leap/baseapp/log.py b/src/leap/baseapp/log.py
index 8a7f81c3..636e5bae 100644
--- a/src/leap/baseapp/log.py
+++ b/src/leap/baseapp/log.py
@@ -11,6 +11,7 @@ class LogPaneMixin(object):
a simple log pane
that writes new lines as they come
"""
+ EXCLUDES = ('MANAGEMENT',)
def createLogBrowser(self):
"""
@@ -21,7 +22,7 @@ class LogPaneMixin(object):
logging_layout = QtGui.QVBoxLayout()
self.logbrowser = QtGui.QTextBrowser()
- startStopButton = QtGui.QPushButton("&Connect")
+ startStopButton = QtGui.QPushButton(self.tr("&Connect"))
self.startStopButton = startStopButton
logging_layout.addWidget(self.logbrowser)
@@ -34,9 +35,10 @@ class LogPaneMixin(object):
grid = QtGui.QGridLayout()
self.updateTS = QtGui.QLabel('')
- self.status_label = QtGui.QLabel('Disconnected')
+ self.status_label = QtGui.QLabel(self.tr('Disconnected'))
self.ip_label = QtGui.QLabel('')
self.remote_label = QtGui.QLabel('')
+ self.remote_country = QtGui.QLabel('')
tun_read_label = QtGui.QLabel("tun read")
self.tun_read_bytes = QtGui.QLabel("0")
@@ -47,10 +49,11 @@ class LogPaneMixin(object):
grid.addWidget(self.status_label, 0, 1)
grid.addWidget(self.ip_label, 1, 0)
grid.addWidget(self.remote_label, 1, 1)
- grid.addWidget(tun_read_label, 2, 0)
- grid.addWidget(self.tun_read_bytes, 2, 1)
- grid.addWidget(tun_write_label, 3, 0)
- grid.addWidget(self.tun_write_bytes, 3, 1)
+ grid.addWidget(self.remote_country, 2, 1)
+ grid.addWidget(tun_read_label, 3, 0)
+ grid.addWidget(self.tun_read_bytes, 3, 1)
+ grid.addWidget(tun_write_label, 4, 0)
+ grid.addWidget(self.tun_write_bytes, 4, 1)
self.statusBox.setLayout(grid)
@@ -60,6 +63,7 @@ class LogPaneMixin(object):
simple slot: writes new line to logger Pane.
"""
msg = line[:-1]
- if self.debugmode:
+ if self.debugmode and all(map(lambda w: w not in msg,
+ LogPaneMixin.EXCLUDES)):
self.logbrowser.append(msg)
- vpnlogger.info(msg)
+ vpnlogger.info(msg)
diff --git a/src/leap/baseapp/mainwindow.py b/src/leap/baseapp/mainwindow.py
index 3b6cb544..91b0dc61 100644
--- a/src/leap/baseapp/mainwindow.py
+++ b/src/leap/baseapp/mainwindow.py
@@ -2,6 +2,10 @@
#!/usr/bin/env python
import logging
+import sip
+sip.setapi('QString', 2)
+sip.setapi('QVariant', 2)
+
from PyQt4 import QtCore
from PyQt4 import QtGui
@@ -10,6 +14,8 @@ from leap.baseapp.log import LogPaneMixin
from leap.baseapp.systray import StatusAwareTrayIconMixin
from leap.baseapp.network import NetworkCheckerAppMixin
from leap.baseapp.leap_app import MainWindowMixin
+from leap.eip.checks import ProviderCertChecker
+from leap.gui.threads import FunThread
logger = logging.getLogger(name=__name__)
@@ -34,12 +40,13 @@ class LeapWindow(QtGui.QMainWindow,
networkError = QtCore.pyqtSignal([object])
triggerEIPError = QtCore.pyqtSignal([object])
start_eipconnection = QtCore.pyqtSignal([])
+ shutdownSignal = QtCore.pyqtSignal([])
+ initNetworkChecker = QtCore.pyqtSignal([])
- # XXX fix nomenclature here
- # this is eip status change got from vpn management
- statusChange = QtCore.pyqtSignal([object])
- # this is global leap status
- changeLeapStatus = QtCore.pyqtSignal([str])
+ # this is status change got from openvpn management
+ openvpnStatusChange = QtCore.pyqtSignal([object])
+ # this is global eip status
+ eipStatusChange = QtCore.pyqtSignal([str])
def __init__(self, opts):
logger.debug('init leap window')
@@ -48,26 +55,37 @@ class LeapWindow(QtGui.QMainWindow,
if self.debugmode:
self.createLogBrowser()
- EIPConductorAppMixin.__init__(self, opts=opts)
+ settings = QtCore.QSettings()
+ self.provider_domain = settings.value("provider_domain", None)
+ self.username = settings.value("username", None)
+
+ logger.debug('provider: %s', self.provider_domain)
+ logger.debug('username: %s', self.username)
+
+ provider = self.provider_domain
+ EIPConductorAppMixin.__init__(
+ self, opts=opts, provider=provider)
StatusAwareTrayIconMixin.__init__(self)
- NetworkCheckerAppMixin.__init__(self)
- MainWindowMixin.__init__(self)
- settings = QtCore.QSettings()
+ # XXX network checker should probably not
+ # trigger run_checks on init... but wait
+ # for ready signal instead...
+ NetworkCheckerAppMixin.__init__(self, provider=provider)
+ MainWindowMixin.__init__(self)
geom_key = "DebugGeometry" if self.debugmode else "Geometry"
geom = settings.value(geom_key)
-
- geom = settings.value("Geometry")
if geom:
self.restoreGeometry(geom)
+
+ # XXX check for wizard
self.wizard_done = settings.value("FirstRunWizardDone")
- self.initchecks = InitChecksThread(self.run_eip_checks)
+ self.initchecks = FunThread(self.run_eip_checks)
# bind signals
self.initchecks.finished.connect(
- lambda: logger.debug('Initial checks finished'))
+ lambda: logger.debug('Initial checks thread finished'))
self.trayIcon.activated.connect(self.iconActivated)
self.newLogLine.connect(
lambda line: self.onLoggerNewLine(line))
@@ -82,48 +100,92 @@ class LeapWindow(QtGui.QMainWindow,
self.startStopButton.clicked.connect(
lambda: self.start_or_stopVPN())
self.start_eipconnection.connect(
- lambda: self.start_or_stopVPN())
+ self.do_start_eipconnection)
+ self.shutdownSignal.connect(
+ self.cleanupAndQuit)
+ self.initNetworkChecker.connect(
+ lambda: self.init_network_checker(self.conductor.provider))
# status change.
# TODO unify
- self.statusChange.connect(
- lambda status: self.onStatusChange(status))
- self.changeLeapStatus.connect(
- lambda newstatus: self.onChangeLeapConnStatus(newstatus))
-
- # do frwizard and init signals
+ self.openvpnStatusChange.connect(
+ lambda status: self.onOpenVPNStatusChange(status))
+ self.eipStatusChange.connect(
+ lambda newstatus: self.onEIPConnStatusChange(newstatus))
+ self.eipStatusChange.connect(
+ lambda newstatus: self.toggleEIPAct())
+
+ # do first run wizard and init signals
self.mainappReady.connect(self.do_first_run_wizard_check)
self.initReady.connect(self.runchecks_and_eipconnect)
# ... all ready. go!
- # calls do_first_run_wizard_check
+ # connected to do_first_run_wizard_check
self.mainappReady.emit()
def do_first_run_wizard_check(self):
+ """
+ checks whether first run wizard needs to be run
+ launches it if needed
+ and emits initReady signal if not.
+ """
+
logger.debug('first run wizard check...')
- if self.wizard_done:
- self.initReady.emit()
- else:
- # need to run first-run-wizard
- logger.debug('running first run wizard')
- from leap.gui.firstrunwizard import FirstRunWizard
- wizard = FirstRunWizard(
- parent=self,
- success_cb=self.initReady.emit)
- wizard.show()
+ need_wizard = False
- def runchecks_and_eipconnect(self):
- self.initchecks.begin()
+ # do checks (can overlap if wizard was interrupted)
+ if not self.wizard_done:
+ need_wizard = True
+ if not self.provider_domain:
+ need_wizard = True
+ else:
+ pcertchecker = ProviderCertChecker(domain=self.provider_domain)
+ if not pcertchecker.is_cert_valid(do_raise=False):
+ logger.warning('missing valid client cert. need wizard')
+ need_wizard = True
-class InitChecksThread(QtCore.QThread):
+ # launch wizard if needed
+ if need_wizard:
+ logger.debug('running first run wizard')
+ self.launch_first_run_wizard()
+ else: # no wizard needed
+ self.initReady.emit()
- def __init__(self, fun, parent=None):
- QtCore.QThread.__init__(self, parent)
- self.fun = fun
+ def launch_first_run_wizard(self):
+ """
+ launches wizard and blocks
+ """
+ from leap.gui.firstrun.wizard import FirstRunWizard
+ wizard = FirstRunWizard(
+ self.conductor,
+ parent=self,
+ username=self.username,
+ start_eipconnection_signal=self.start_eipconnection,
+ eip_statuschange_signal=self.eipStatusChange,
+ quitcallback=self.onWizardCancel)
+ wizard.show()
+
+ def onWizardCancel(self):
+ if not self.wizard_done:
+ logger.debug(
+ 'clicked on Cancel during first '
+ 'run wizard. shutting down')
+ self.cleanupAndQuit()
- def run(self):
- self.fun()
+ def runchecks_and_eipconnect(self):
+ """
+ shows icon and run init checks
+ """
+ self.show_systray_icon()
+ self.initchecks.begin()
- def begin(self):
- self.start()
+ def do_start_eipconnection(self):
+ """
+ shows icon and init eip connection
+ called from the end of wizard
+ """
+ self.show_systray_icon()
+ # this will setup the command
+ self.conductor.run_openvpn_checks()
+ self.start_or_stopVPN()
diff --git a/src/leap/baseapp/network.py b/src/leap/baseapp/network.py
index 077d5164..dc5182a4 100644
--- a/src/leap/baseapp/network.py
+++ b/src/leap/baseapp/network.py
@@ -9,19 +9,34 @@ from PyQt4 import QtCore
from leap.baseapp.dialogs import ErrorDialog
from leap.base.network import NetworkCheckerThread
+from leap.util.misc import null_check
+
class NetworkCheckerAppMixin(object):
"""
initialize an instance of the Network Checker,
which gathers error and passes them on.
"""
+ ERR_NETERR = False
def __init__(self, *args, **kwargs):
- self.network_checker = NetworkCheckerThread(
- error_cb=self.networkError.emit,
- debug=self.debugmode)
+ provider = kwargs.pop('provider', None)
+ self.network_checker = None
+ if provider:
+ self.init_network_checker(provider)
+
+ def init_network_checker(self, provider):
+ null_check(provider, "provider_domain")
+ if not self.network_checker:
+ self.network_checker = NetworkCheckerThread(
+ error_cb=self.networkError.emit,
+ debug=self.debugmode,
+ provider=provider)
+ self.network_checker.start()
- # XXX move run_checks to slot
+ @QtCore.pyqtSlot(object)
+ def runNetworkChecks(self):
+ logger.debug('running checks (from NetworkChecker Mixin slot)')
self.network_checker.run_checks()
@QtCore.pyqtSlot(object)
@@ -30,11 +45,19 @@ class NetworkCheckerAppMixin(object):
slot that receives a network exceptions
and raises a user error message
"""
- logger.debug('handling network exception')
- logger.error(exc.message)
- dialog = ErrorDialog(parent=self)
+ # FIXME this should not HANDLE anything after
+ # the network check thread has been stopped.
- if exc.critical:
- dialog.criticalMessage(exc.usermessage, "network error")
- else:
- dialog.warningMessage(exc.usermessage, "network error")
+ logger.debug('handling network exception')
+ if not self.ERR_NETERR:
+ self.ERR_NETERR = True
+
+ logger.error(exc.message)
+ dialog = ErrorDialog(parent=self)
+ if exc.critical:
+ dialog.criticalMessage(exc.usermessage, "network error")
+ else:
+ dialog.warningMessage(exc.usermessage, "network error")
+
+ self.start_or_stopVPN()
+ self.network_checker.stop()
diff --git a/src/leap/baseapp/systray.py b/src/leap/baseapp/systray.py
index cc5d89df..77eb3fe9 100644
--- a/src/leap/baseapp/systray.py
+++ b/src/leap/baseapp/systray.py
@@ -1,4 +1,9 @@
import logging
+import sys
+
+import sip
+sip.setapi('QString', 2)
+sip.setapi('QVariant', 2)
from PyQt4 import QtCore
from PyQt4 import QtGui
@@ -41,12 +46,14 @@ class StatusAwareTrayIconMixin(object):
self.createIconGroupBox()
self.createActions()
self.createTrayIcon()
- #logger.debug('showing tray icon................')
- self.trayIcon.show()
# not sure if this really belongs here, but...
self.timer = QtCore.QTimer()
+ def show_systray_icon(self):
+ #logger.debug('showing tray icon................')
+ self.trayIcon.show()
+
def createIconGroupBox(self):
"""
dummy icongroupbox
@@ -68,7 +75,8 @@ class StatusAwareTrayIconMixin(object):
self.iconpath['connected'])),
self.ConnectionWidgets = con_widgets
- self.statusIconBox = QtGui.QGroupBox("EIP Connection Status")
+ self.statusIconBox = QtGui.QGroupBox(
+ self.tr("EIP Connection Status"))
statusIconLayout = QtGui.QHBoxLayout()
statusIconLayout.addWidget(self.ConnectionWidgets['disconnected'])
statusIconLayout.addWidget(self.ConnectionWidgets['connecting'])
@@ -76,7 +84,8 @@ class StatusAwareTrayIconMixin(object):
statusIconLayout.itemAt(1).widget().hide()
statusIconLayout.itemAt(2).widget().hide()
- self.leapConnStatus = QtGui.QLabel("<b>disconnected</b>")
+ self.leapConnStatus = QtGui.QLabel(
+ self.tr("<b>disconnected</b>"))
statusIconLayout.addWidget(self.leapConnStatus)
self.statusIconBox.setLayout(statusIconLayout)
@@ -92,7 +101,9 @@ class StatusAwareTrayIconMixin(object):
self.trayIconMenu.addAction(self.detailsAct)
self.trayIconMenu.addSeparator()
self.trayIconMenu.addAction(self.aboutAct)
- self.trayIconMenu.addAction(self.aboutQtAct)
+ # we should get this hidden inside the "about" dialog
+ # (as a little button maybe)
+ #self.trayIconMenu.addAction(self.aboutQtAct)
self.trayIconMenu.addSeparator()
self.trayIconMenu.addAction(self.quitAction)
@@ -104,35 +115,52 @@ class StatusAwareTrayIconMixin(object):
#self.trayIconMenu.customContextMenuRequested.connect(
#self.on_context_menu)
- def bad(self):
- logger.error('this should not be called')
+ #def bad(self):
+ #logger.error('this should not be called')
def createActions(self):
"""
creates actions to be binded to tray icon
"""
# XXX change action name on (dis)connect
- self.connAct = QtGui.QAction("Encryption ON turn &off", self,
- triggered=lambda: self.start_or_stopVPN())
-
- self.detailsAct = QtGui.QAction("&Details...",
- self,
- triggered=self.detailsWin)
- self.aboutAct = QtGui.QAction("&About", self,
- triggered=self.about)
- self.aboutQtAct = QtGui.QAction("About Q&t", self,
- triggered=QtGui.qApp.aboutQt)
- self.quitAction = QtGui.QAction("&Quit", self,
- triggered=self.cleanupAndQuit)
+ self.connAct = QtGui.QAction(
+ self.tr("Encryption ON turn &off"),
+ self,
+ triggered=lambda: self.start_or_stopVPN())
+
+ self.detailsAct = QtGui.QAction(
+ self.tr("&Details..."),
+ self,
+ triggered=self.detailsWin)
+ self.aboutAct = QtGui.QAction(
+ self.tr("&About"), self,
+ triggered=self.about)
+ self.aboutQtAct = QtGui.QAction(
+ self.tr("About Q&t"), self,
+ triggered=QtGui.qApp.aboutQt)
+ self.quitAction = QtGui.QAction(
+ self.tr("&Quit"), self,
+ triggered=self.cleanupAndQuit)
def toggleEIPAct(self):
# this is too simple by now.
- # XXX We need to get the REAL info for Encryption state.
- # (now is ON as soon as vpn launched)
- if self.eip_service_started is True:
- self.connAct.setText('Encryption ON turn o&ff')
- else:
- self.connAct.setText('Encryption OFF turn &on')
+ # XXX get STATUS CONSTANTS INSTEAD
+
+ icon_status = self.conductor.get_icon_name()
+ if icon_status == "connected":
+ self.connAct.setEnabled(True)
+ self.connAct.setText(
+ self.tr('Encryption ON turn o&ff'))
+ return
+ if icon_status == "disconnected":
+ self.connAct.setEnabled(True)
+ self.connAct.setText(
+ self.tr('Encryption OFF turn &on'))
+ return
+ if icon_status == "connecting":
+ self.connAct.setDisabled(True)
+ self.connAct.setText(self.tr('connecting...'))
+ return
def detailsWin(self):
visible = self.isVisible()
@@ -140,18 +168,21 @@ class StatusAwareTrayIconMixin(object):
self.hide()
else:
self.show()
+ if sys.platform == "darwin":
+ self.raise_()
def about(self):
# move to widget
flavor = BRANDING.get('short_name', None)
- content = ("LEAP client<br>"
- "(version <b>%s</b>)<br>" % VERSION)
+ content = self.tr(
+ ("LEAP client<br>"
+ "(version <b>%s</b>)<br>" % VERSION))
if flavor:
content = content + ('<br>Flavor: <i>%s</i><br>' % flavor)
content = content + (
"<br><a href='https://leap.se/'>"
"https://leap.se</a>")
- QtGui.QMessageBox.about(self, "About", content)
+ QtGui.QMessageBox.about(self, self.tr("About"), content)
def setConnWidget(self, icon_name):
oldlayout = self.statusIconBox.layout()
@@ -189,6 +220,7 @@ class StatusAwareTrayIconMixin(object):
# is failing in a way beyond my understanding.
# (not working the first time it's clicked).
# this works however.
+ # XXX in osx it shows some glitches.
context_menu.exec_(self.trayIcon.geometry().center())
@QtCore.pyqtSlot()
@@ -196,31 +228,38 @@ class StatusAwareTrayIconMixin(object):
self.statusUpdate()
@QtCore.pyqtSlot(object)
- def onStatusChange(self, status):
+ def onOpenVPNStatusChange(self, status):
"""
- updates icon
+ updates icon, according to the openvpn status change.
"""
icon_name = self.conductor.get_icon_name()
+ if not icon_name:
+ return
# XXX refactor. Use QStateMachine
if icon_name in ("disconnected", "connected"):
- self.changeLeapStatus.emit(icon_name)
+ self.eipStatusChange.emit(icon_name)
if icon_name in ("connecting"):
# let's see how it matches
leap_status_name = self.conductor.get_leap_status()
- self.changeLeapStatus.emit(leap_status_name)
+ self.eipStatusChange.emit(leap_status_name)
+
+ if icon_name == "connected":
+ # When we change to "connected', we launch
+ # the network checker.
+ self.initNetworkChecker.emit()
self.setIcon(icon_name)
# change connection pixmap widget
self.setConnWidget(icon_name)
@QtCore.pyqtSlot(str)
- def onChangeLeapConnStatus(self, newstatus):
+ def onEIPConnStatusChange(self, newstatus):
"""
- slot for LEAP status changes
- not to be confused with onStatusChange.
+ slot for EIP status changes
+ not to be confused with onOpenVPNStatusChange.
this only updates the non-debug LEAP Status line
next to the connection icon.
"""
diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py
new file mode 100644
index 00000000..cbb5725a
--- /dev/null
+++ b/src/leap/crypto/certs.py
@@ -0,0 +1,112 @@
+import logging
+import os
+from StringIO import StringIO
+import ssl
+import time
+
+from dateutil.parser import parse
+from OpenSSL import crypto
+
+from leap.util.misc import null_check
+
+logger = logging.getLogger(__name__)
+
+
+class BadCertError(Exception):
+ """
+ raised for malformed certs
+ """
+
+
+class NoCertError(Exception):
+ """
+ raised for cert not found in given path
+ """
+
+
+def get_https_cert_from_domain(domain, port=443):
+ """
+ @param domain: a domain name to get a certificate from.
+ """
+ cert = ssl.get_server_certificate((domain, port))
+ x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
+ return x509
+
+
+def get_cert_from_file(_file):
+ null_check(_file, "pem file")
+ if isinstance(_file, (str, unicode)):
+ if not os.path.isfile(_file):
+ raise NoCertError
+ with open(_file) as f:
+ cert = f.read()
+ else:
+ cert = _file.read()
+ x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
+ return x509
+
+
+def get_pkey_from_file(_file):
+ getkey = lambda f: crypto.load_privatekey(
+ crypto.FILETYPE_PEM, f.read())
+
+ if isinstance(_file, str):
+ with open(_file) as f:
+ key = getkey(f)
+ else:
+ key = getkey(_file)
+ return key
+
+
+def can_load_cert_and_pkey(string):
+ """
+ loads certificate and private key from
+ a buffer
+ """
+ try:
+ f = StringIO(string)
+ cert = get_cert_from_file(f)
+
+ f = StringIO(string)
+ key = get_pkey_from_file(f)
+
+ null_check(cert, 'certificate')
+ null_check(key, 'private key')
+ except Exception as exc:
+ logger.error(type(exc), exc.message)
+ raise BadCertError
+ else:
+ return True
+
+
+def get_cert_fingerprint(domain=None, port=443, filepath=None,
+ hash_type="SHA256", sep=":"):
+ """
+ @param domain: a domain name to get a fingerprint from
+ @type domain: str
+ @param filepath: path to a file containing a PEM file
+ @type filepath: str
+ @param hash_type: the hash function to be used in the fingerprint.
+ must be one of SHA1, SHA224, SHA256, SHA384, SHA512
+ @type hash_type: str
+ @rparam: hex_fpr, a hexadecimal representation of a bytestring
+ containing the fingerprint.
+ @rtype: string
+ """
+ if domain:
+ cert = get_https_cert_from_domain(domain, port=port)
+ if filepath:
+ cert = get_cert_from_file(filepath)
+ hex_fpr = cert.digest(hash_type)
+ return hex_fpr
+
+
+def get_time_boundaries(certfile):
+ cert = get_cert_from_file(certfile)
+ null_check(cert, 'certificate')
+
+ fromts, tots = (cert.get_notBefore(), cert.get_notAfter())
+ from_, to_ = map(
+ lambda ts: time.gmtime(time.mktime(parse(ts).timetuple())),
+ (fromts, tots))
+ return from_, to_
diff --git a/src/leap/crypto/certs_gnutls.py b/src/leap/crypto/certs_gnutls.py
new file mode 100644
index 00000000..20c0e043
--- /dev/null
+++ b/src/leap/crypto/certs_gnutls.py
@@ -0,0 +1,112 @@
+'''
+We're using PyOpenSSL now
+
+import ctypes
+from StringIO import StringIO
+import socket
+
+import gnutls.connection
+import gnutls.crypto
+import gnutls.library
+
+from leap.util.misc import null_check
+
+
+class BadCertError(Exception):
+ """raised for malformed certs"""
+
+
+def get_https_cert_from_domain(domain):
+ """
+ @param domain: a domain name to get a certificate from.
+ """
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ cred = gnutls.connection.X509Credentials()
+
+ session = gnutls.connection.ClientSession(sock, cred)
+ session.connect((domain, 443))
+ session.handshake()
+ cert = session.peer_certificate
+ return cert
+
+
+def get_cert_from_file(_file):
+ getcert = lambda f: gnutls.crypto.X509Certificate(f.read())
+ if isinstance(_file, str):
+ with open(_file) as f:
+ cert = getcert(f)
+ else:
+ cert = getcert(_file)
+ return cert
+
+
+def get_pkey_from_file(_file):
+ getkey = lambda f: gnutls.crypto.X509PrivateKey(f.read())
+ if isinstance(_file, str):
+ with open(_file) as f:
+ key = getkey(f)
+ else:
+ key = getkey(_file)
+ return key
+
+
+def can_load_cert_and_pkey(string):
+ try:
+ f = StringIO(string)
+ cert = get_cert_from_file(f)
+
+ f = StringIO(string)
+ key = get_pkey_from_file(f)
+
+ null_check(cert, 'certificate')
+ null_check(key, 'private key')
+ except:
+ # XXX catch GNUTLSError?
+ raise BadCertError
+ else:
+ return True
+
+def get_cert_fingerprint(domain=None, filepath=None,
+ hash_type="SHA256", sep=":"):
+ """
+ @param domain: a domain name to get a fingerprint from
+ @type domain: str
+ @param filepath: path to a file containing a PEM file
+ @type filepath: str
+ @param hash_type: the hash function to be used in the fingerprint.
+ must be one of SHA1, SHA224, SHA256, SHA384, SHA512
+ @type hash_type: str
+ @rparam: hex_fpr, a hexadecimal representation of a bytestring
+ containing the fingerprint.
+ @rtype: string
+ """
+ if domain:
+ cert = get_https_cert_from_domain(domain)
+ if filepath:
+ cert = get_cert_from_file(filepath)
+
+ _buffer = ctypes.create_string_buffer(64)
+ buffer_length = ctypes.c_size_t(64)
+
+ SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512")
+ if hash_type in SUPPORTED_DIGEST_FUN:
+ digestfunction = getattr(
+ gnutls.library.constants,
+ "GNUTLS_DIG_%s" % hash_type)
+ else:
+ # XXX improperlyconfigured or something
+ raise Exception("digest function not supported")
+
+ gnutls.library.functions.gnutls_x509_crt_get_fingerprint(
+ cert._c_object, digestfunction,
+ ctypes.byref(_buffer), ctypes.byref(buffer_length))
+
+ # deinit
+ #server_cert._X509Certificate__deinit(server_cert._c_object)
+ # needed? is segfaulting
+
+ fpr = ctypes.string_at(_buffer, buffer_length.value)
+ hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr)
+
+ return hex_fpr
+'''
diff --git a/src/leap/crypto/leapkeyring.py b/src/leap/crypto/leapkeyring.py
index bceadc75..c241d0bc 100644
--- a/src/leap/crypto/leapkeyring.py
+++ b/src/leap/crypto/leapkeyring.py
@@ -53,12 +53,14 @@ class LeapCryptedFileKeyring(keyring.backend.CryptedFileKeyring):
def leap_set_password(key, value, seed="xxx"):
+ key, value = map(unicode, (key, value))
keyring.set_keyring(LeapCryptedFileKeyring(seed=seed))
keyring.set_password('leap', key, value)
def leap_get_password(key, seed="xxx"):
keyring.set_keyring(LeapCryptedFileKeyring(seed=seed))
+ #import ipdb;ipdb.set_trace()
return keyring.get_password('leap', key)
diff --git a/src/leap/crypto/tests/__init__.py b/src/leap/crypto/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/crypto/tests/__init__.py
diff --git a/src/leap/crypto/tests/test_certs.py b/src/leap/crypto/tests/test_certs.py
new file mode 100644
index 00000000..e476b630
--- /dev/null
+++ b/src/leap/crypto/tests/test_certs.py
@@ -0,0 +1,22 @@
+import unittest
+
+from leap.testing.https_server import where
+from leap.crypto import certs
+
+
+class CertTestCase(unittest.TestCase):
+
+ def test_can_load_client_and_pkey(self):
+ with open(where('leaptestscert.pem')) as cf:
+ cs = cf.read()
+ with open(where('leaptestskey.pem')) as kf:
+ ks = kf.read()
+ certs.can_load_cert_and_pkey(cs + ks)
+
+ with self.assertRaises(certs.BadCertError):
+ # screw header
+ certs.can_load_cert_and_pkey(cs.replace("BEGIN", "BEGINN") + ks)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index f739c3e8..9a34a428 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -1,23 +1,24 @@
import logging
-import ssl
-#import platform
import time
import os
+import sys
-from gnutls import crypto
-#import netifaces
-#import ping
import requests
from leap import __branding as BRANDING
-from leap import certs
+from leap import certs as leapcerts
+from leap.base.auth import srpauth_protected, magick_srpauth
+from leap.base import config as baseconfig
from leap.base import constants as baseconstants
from leap.base import providers
+from leap.crypto import certs
from leap.eip import config as eipconfig
from leap.eip import constants as eipconstants
from leap.eip import exceptions as eipexceptions
from leap.eip import specs as eipspecs
+from leap.util.certs import get_mac_cabundle
from leap.util.fileutil import mkdir_p
+from leap.util.web import get_https_domain_and_port
logger = logging.getLogger(name=__name__)
@@ -42,10 +43,11 @@ reachable and testable as a whole.
"""
-def get_ca_cert():
+def get_branding_ca_cert(domain):
+ # deprecated
ca_file = BRANDING.get('provider_ca_file')
if ca_file:
- return certs.where(ca_file)
+ return leapcerts.where(ca_file)
class ProviderCertChecker(object):
@@ -54,18 +56,29 @@ class ProviderCertChecker(object):
client certs and checking tls connection
with provider.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests,
+ domain=None):
+
self.fetcher = fetcher
- self.cacert = get_ca_cert()
+ self.domain = domain
+ #XXX needs some kind of autoinit
+ #right now we set by hand
+ #by loading and reading provider config
+ self.apidomain = None
+ self.cacert = eipspecs.provider_ca_path(domain)
+
+ def run_all(
+ self, checker=None,
+ skip_download=False, skip_verify=False):
- def run_all(self, checker=None, skip_download=False, skip_verify=False):
if not checker:
checker = self
do_verify = not skip_verify
logger.debug('do_verify: %s', do_verify)
- # For MVS+
# checker.download_ca_cert()
+
+ # For MVS+
# checker.download_ca_signature()
# checker.get_ca_signatures()
# checker.is_there_trust_path()
@@ -73,13 +86,44 @@ class ProviderCertChecker(object):
# For MVS
checker.is_there_provider_ca()
- # XXX FAKE IT!!!
- checker.is_https_working(verify=do_verify)
+ checker.is_https_working(verify=do_verify, autocacert=False)
checker.check_new_cert_needed(verify=do_verify)
- def download_ca_cert(self):
- # MVS+
- raise NotImplementedError
+ def download_ca_cert(self, uri=None, verify=True):
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+
+ # should check domain exists
+ capath = self._get_ca_cert_path(self.domain)
+ with open(capath, 'w') as f:
+ f.write(req.content)
+
+ def check_ca_cert_fingerprint(
+ self, hash_type="SHA256",
+ fingerprint=None):
+ """
+ compares the fingerprint in
+ the ca cert with a string
+ we are passed
+ returns True if they are equal, False if not.
+ @param hash_type: digest function
+ @type hash_type: str
+ @param fingerprint: the fingerprint to compare with.
+ @type fingerprint: str (with : separator)
+ @rtype bool
+ """
+ ca_cert_path = self.ca_cert_path
+ ca_cert_fpr = certs.get_cert_fingerprint(
+ filepath=ca_cert_path)
+ return ca_cert_fpr == fingerprint
+
+ def verify_api_https(self, uri):
+ assert uri.startswith('https://')
+ cacert = self.ca_cert_path
+ verify = cacert and cacert or True
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+ return True
def download_ca_signature(self):
# MVS+
@@ -94,80 +138,110 @@ class ProviderCertChecker(object):
raise NotImplementedError
def is_there_provider_ca(self):
- from leap import certs
- logger.debug('do we have provider_ca?')
- cacert_path = BRANDING.get('provider_ca_file', None)
- if not cacert_path:
- logger.debug('False')
+ if not self.cacert:
return False
- self.cacert = certs.where(cacert_path)
- logger.debug('True')
- return True
+ cacert_exists = os.path.isfile(self.cacert)
+ if cacert_exists:
+ logger.debug('True')
+ return True
+ logger.debug('False!')
+ return False
- def is_https_working(self, uri=None, verify=True):
+ def is_https_working(
+ self, uri=None, verify=True,
+ autocacert=False):
if uri is None:
uri = self._get_root_uri()
# XXX raise InsecureURI or something better
- assert uri.startswith('https')
- if verify is True and self.cacert is not None:
+ try:
+ assert uri.startswith('https')
+ except AssertionError:
+ raise AssertionError(
+ "uri passed should start with https")
+ if autocacert and verify is True and self.cacert is not None:
logger.debug('verify cert: %s', self.cacert)
verify = self.cacert
- logger.debug('is https working?')
+ if sys.platform == "darwin":
+ verify = get_mac_cabundle()
+ logger.debug('checking https connection')
logger.debug('uri: %s (verify:%s)', uri, verify)
+
try:
self.fetcher.get(uri, verify=verify)
+
except requests.exceptions.SSLError as exc:
- logger.warning('False! CERT VERIFICATION FAILED! '
- '(this should be CRITICAL)')
- logger.warning('SSLError: %s', exc.message)
- # XXX RAISE! See #638
- #raise eipexceptions.EIPBadCertError
- # XXX get requests.exceptions.ConnectionError Errno 110
- # Connection timed out, and raise ours.
+ raise eipexceptions.HttpsBadCertError
+
+ except requests.exceptions.ConnectionError:
+ logger.error('ConnectionError')
+ raise eipexceptions.HttpsNotSupported
+
else:
- logger.debug('True')
return True
def check_new_cert_needed(self, skip_download=False, verify=True):
- logger.debug('is new cert needed?')
+ # XXX add autocacert
if not self.is_cert_valid(do_raise=False):
- logger.debug('True')
+ logger.debug('cert needed: true')
self.download_new_client_cert(
skip_download=skip_download,
verify=verify)
return True
- logger.debug('False')
+ logger.debug('cert needed: false')
return False
def download_new_client_cert(self, uri=None, verify=True,
- skip_download=False):
+ skip_download=False,
+ credentials=None):
logger.debug('download new client cert')
if skip_download:
return True
if uri is None:
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
- assert uri.startswith('https')
+ #assert uri.startswith('https')
+
if verify is True and self.cacert is not None:
verify = self.cacert
+ logger.debug('verify = %s', verify)
+
+ fgetfn = self.fetcher.get
+
+ if credentials:
+ user, passwd = credentials
+ logger.debug('apidomain = %s', self.apidomain)
+
+ @srpauth_protected(user, passwd,
+ server="https://%s" % self.apidomain,
+ verify=verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+
+ else:
+ # XXX FIXME fix decorated args
+ @magick_srpauth(verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
try:
- # XXX FIXME!!!!
- # verify=verify
- # Workaround for #638. return to verification
- # when That's done!!!
-
- # XXX HOOK SRP here...
- # will have to be more generic in the future.
- req = self.fetcher.get(uri, verify=False)
+
+ req = getfn(uri, verify=verify)
req.raise_for_status()
+
except requests.exceptions.SSLError:
logger.warning('SSLError while fetching cert. '
'Look below for stack trace.')
# XXX raise better exception
- raise
+ return self.fail("SSLError")
+ except Exception as exc:
+ return self.fail(exc.message)
+
try:
+ logger.debug('validating cert...')
pemfile_content = req.content
- self.is_valid_pemfile(pemfile_content)
+ valid = self.is_valid_pemfile(pemfile_content)
+ if not valid:
+ logger.warning('invalid cert')
+ return False
cert_path = self._get_client_cert_path()
self.write_cert(pemfile_content, to=cert_path)
except:
@@ -196,11 +270,8 @@ class ProviderCertChecker(object):
def is_cert_not_expired(self, certfile=None, now=time.gmtime):
if certfile is None:
certfile = self._get_client_cert_path()
- with open(certfile) as cf:
- cert_s = cf.read()
- cert = crypto.X509Certificate(cert_s)
- from_ = time.gmtime(cert.activation_time)
- to_ = time.gmtime(cert.expiration_time)
+ from_, to_ = certs.get_time_boundaries(certfile)
+
return from_ < now() < to_
def is_valid_pemfile(self, cert_s=None):
@@ -216,33 +287,39 @@ class ProviderCertChecker(object):
with open(certfile) as cf:
cert_s = cf.read()
try:
- # XXX get a real cert validation
- # so far this is only checking begin/end
- # delimiters :)
- # XXX use gnutls for get proper
- # validation.
- # crypto.X509Certificate(cert_s)
- sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5
- # we might have private key and cert in the same file
- certparts = cert_s.split(sep)
- if len(certparts) > 1:
- cert_s = sep + certparts[1]
- ssl.PEM_cert_to_DER_cert(cert_s)
- except:
- # XXX raise proper exception
- raise
- return True
+ valid = certs.can_load_cert_and_pkey(cert_s)
+ except certs.BadCertError:
+ logger.warning("Not valid pemfile")
+ valid = False
+ return valid
+
+ @property
+ def ca_cert_path(self):
+ return self._get_ca_cert_path(self.domain)
def _get_root_uri(self):
- return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
+ return u"https://%s/" % self.domain
def _get_client_cert_uri(self):
- # XXX get the whole thing from constants
- return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)
+ return "https://%s/1/cert" % self.apidomain
def _get_client_cert_path(self):
- # MVS+ : get provider path
- return eipspecs.client_cert_path()
+ return eipspecs.client_cert_path(domain=self.domain)
+
+ def _get_ca_cert_path(self, domain):
+ # XXX this folder path will be broken for win
+ # and this should be moved to eipspecs.ca_path
+
+ # XXX use baseconfig.get_provider_path(folder=Foo)
+ # !!!
+
+ capath = baseconfig.get_config_file(
+ 'cacert.pem',
+ folder='providers/%s/keys/ca' % domain)
+ folder, fname = os.path.split(capath)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ return capath
def write_cert(self, pemfile_content, to=None):
folder, filename = os.path.split(to)
@@ -251,6 +328,9 @@ class ProviderCertChecker(object):
with open(to, 'w') as cert_f:
cert_f.write(pemfile_content)
+ def set_api_domain(self, domain):
+ self.apidomain = domain
+
class EIPConfigChecker(object):
"""
@@ -260,16 +340,25 @@ class EIPConfigChecker(object):
use run_all to run all checks.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests, domain=None):
# we do not want to accept too many
# argument on init.
# we want tests
# to be explicitely run.
+
self.fetcher = fetcher
- self.eipconfig = eipconfig.EIPConfig()
- self.defaultprovider = providers.LeapProviderDefinition()
- self.eipserviceconfig = eipconfig.EIPServiceConfig()
+ # if not domain, get from config
+ self.domain = domain
+ self.apidomain = None
+ self.cacert = eipspecs.provider_ca_path(domain)
+
+ self.defaultprovider = providers.LeapProviderDefinition(domain=domain)
+ self.defaultprovider.load()
+ self.eipconfig = eipconfig.EIPConfig(domain=domain)
+ self.set_api_domain()
+ self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)
+ self.eipserviceconfig.load()
def run_all(self, checker=None, skip_download=False):
"""
@@ -330,7 +419,9 @@ class EIPConfigChecker(object):
return True
def fetch_definition(self, skip_download=False,
- config=None, uri=None):
+ force_download=False,
+ config=None, uri=None,
+ domain=None):
"""
fetches a definition file from server
"""
@@ -347,27 +438,45 @@ class EIPConfigChecker(object):
if config is None:
config = self.defaultprovider.config
if uri is None:
- domain = config.get('provider', None)
+ if not domain:
+ domain = config.get('provider', None)
uri = self._get_provider_definition_uri(domain=domain)
- # FIXME! Pass ca path verify!!!
+ if sys.platform == "darwin":
+ verify = get_mac_cabundle()
+ else:
+ verify = True
+
self.defaultprovider.load(
from_uri=uri,
fetcher=self.fetcher,
- verify=False)
+ verify=verify)
self.defaultprovider.save()
def fetch_eip_service_config(self, skip_download=False,
- config=None, uri=None):
+ force_download=False,
+ config=None, uri=None, # domain=None,
+ autocacert=True, verify=True):
if skip_download:
return True
if config is None:
+ self.eipserviceconfig.load()
config = self.eipserviceconfig.config
if uri is None:
- domain = config.get('provider', None)
- uri = self._get_eip_service_uri(domain=domain)
+ #XXX
+ #if not domain:
+ #domain = self.domain or config.get('provider', None)
+ uri = self._get_eip_service_uri(
+ domain=self.apidomain)
+
+ if autocacert and self.cacert is not None:
+ verify = self.cacert
- self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher)
+ self.eipserviceconfig.load(
+ from_uri=uri,
+ fetcher=self.fetcher,
+ force_download=force_download,
+ verify=verify)
self.eipserviceconfig.save()
def check_complete_eip_config(self, config=None):
@@ -375,7 +484,6 @@ class EIPConfigChecker(object):
if config is None:
config = self.eipconfig.config
try:
- 'trying assertions'
assert 'provider' in config
assert config['provider'] is not None
# XXX assert there is gateway !!
@@ -395,11 +503,11 @@ class EIPConfigChecker(object):
return self.eipconfig.exists()
def _dump_default_eipconfig(self):
- self.eipconfig.save()
+ self.eipconfig.save(force=True)
def _get_provider_definition_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_PROVIDER
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
if path is None:
path = baseconstants.DEFINITION_EXPECTED_PATH
uri = u"https://%s/%s" % (domain, path)
@@ -408,9 +516,22 @@ class EIPConfigChecker(object):
def _get_eip_service_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_PROVIDER
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
if path is None:
path = eipconstants.EIP_SERVICE_EXPECTED_PATH
uri = "https://%s/%s" % (domain, path)
logger.debug('getting eip service file from %s', uri)
return uri
+
+ def set_api_domain(self):
+ """sets api domain from defaultprovider config object"""
+ api = self.defaultprovider.config.get('api_uri', None)
+ # the caller is responsible for having loaded the config
+ # object at this point
+ if api:
+ api_dom = get_https_domain_and_port(api)
+ self.apidomain = "%s:%s" % api_dom
+
+ def get_api_domain(self):
+ """gets api domain"""
+ return self.apidomain
diff --git a/src/leap/eip/config.py b/src/leap/eip/config.py
index ef0f52b4..917871da 100644
--- a/src/leap/eip/config.py
+++ b/src/leap/eip/config.py
@@ -1,10 +1,12 @@
import logging
import os
import platform
+import re
import tempfile
from leap import __branding as BRANDING
from leap import certs
+from leap.util.misc import null_check
from leap.util.fileutil import (which, mkdir_p, check_and_fix_urw_only)
from leap.base import config as baseconfig
@@ -16,6 +18,8 @@ from leap.eip import specs as eipspecs
logger = logging.getLogger(name=__name__)
provider_ca_file = BRANDING.get('provider_ca_file', None)
+_platform = platform.system()
+
class EIPConfig(baseconfig.JSONLeapConfig):
spec = eipspecs.eipconfig_spec
@@ -35,9 +39,13 @@ class EIPServiceConfig(baseconfig.JSONLeapConfig):
spec = eipspecs.eipservice_config_spec
def _get_slug(self):
+ domain = getattr(self, 'domain', None)
+ if domain:
+ path = baseconfig.get_provider_path(domain)
+ else:
+ path = baseconfig.get_default_provider_path()
return baseconfig.get_config_file(
- 'eip-service.json',
- folder=baseconfig.get_default_provider_path())
+ 'eip-service.json', folder=path)
def _set_slug(self):
raise AttributeError("you cannot set slug")
@@ -49,45 +57,96 @@ def get_socket_path():
socket_path = os.path.join(
tempfile.mkdtemp(prefix="leap-tmp"),
'openvpn.socket')
- logger.debug('socket path: %s', socket_path)
+ #logger.debug('socket path: %s', socket_path)
return socket_path
-def get_eip_gateway():
+def get_eip_gateway(eipconfig=None, eipserviceconfig=None):
"""
return the first host in eip service config
that matches the name defined in the eip.json config
file.
"""
- placeholder = "testprovider.example.org"
- eipconfig = EIPConfig()
- #import ipdb;ipdb.set_trace()
- eipconfig.load()
+ # XXX eventually we should move to a more clever
+ # gateway selection. maybe we could return
+ # all gateways that match our cluster.
+
+ null_check(eipconfig, "eipconfig")
+ null_check(eipserviceconfig, "eipserviceconfig")
+ PLACEHOLDER = "testprovider.example.org"
+
conf = eipconfig.config
+ eipsconf = eipserviceconfig.config
primary_gateway = conf.get('primary_gateway', None)
if not primary_gateway:
- return placeholder
+ return PLACEHOLDER
- eipserviceconfig = EIPServiceConfig()
- eipserviceconfig.load()
- eipsconf = eipserviceconfig.get_config()
gateways = eipsconf.get('gateways', None)
if not gateways:
logger.error('missing gateways in eip service config')
- return placeholder
+ return PLACEHOLDER
+
if len(gateways) > 0:
for gw in gateways:
- if gw['name'] == primary_gateway:
- hosts = gw['hosts']
- if len(hosts) > 0:
- return hosts[0]
- else:
- logger.error('no hosts')
+ clustername = gw.get('cluster', None)
+ if not clustername:
+ logger.error('no cluster name')
+ return
+
+ if clustername == primary_gateway:
+ # XXX at some moment, we must
+ # make this a more generic function,
+ # and return ports, protocols...
+ ipaddress = gw.get('ip_address', None)
+ if not ipaddress:
+ logger.error('no ip_address')
+ return
+ return ipaddress
logger.error('could not find primary gateway in provider'
'gateway list')
+def get_cipher_options(eipserviceconfig=None):
+ """
+ gathers optional cipher options from eip-service config.
+ :param eipserviceconfig: EIPServiceConfig instance
+ """
+ null_check(eipserviceconfig, 'eipserviceconfig')
+ eipsconf = eipserviceconfig.get_config()
+
+ ALLOWED_KEYS = ("auth", "cipher", "tls-cipher")
+ CIPHERS_REGEX = re.compile("[A-Z0-9\-]+")
+ opts = []
+ if 'openvpn_configuration' in eipsconf:
+ config = eipserviceconfig.config.get(
+ "openvpn_configuration", {})
+ for key, value in config.items():
+ if key in ALLOWED_KEYS and value is not None:
+ sanitized_val = CIPHERS_REGEX.findall(value)
+ if len(sanitized_val) != 0:
+ _val = sanitized_val[0]
+ opts.append('--%s' % key)
+ opts.append('%s' % _val)
+ return opts
+
+LINUX_UP_DOWN_SCRIPT = "/etc/leap/resolv-update"
+OPENVPN_DOWN_ROOT = "/usr/lib/openvpn/openvpn-down-root.so"
+
+
+def has_updown_scripts():
+ """
+ checks the existence of the up/down scripts
+ """
+ # XXX should check permissions too
+ is_file = os.path.isfile(LINUX_UP_DOWN_SCRIPT)
+ if not is_file:
+ logger.warning(
+ "Could not find up/down scripts at %s! "
+ "Risk of DNS Leaks!!!")
+ return is_file
+
+
def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
"""
build a list of options
@@ -103,6 +162,12 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
# since we will need to take some
# things from there if present.
+ provider = kwargs.pop('provider', None)
+ eipconfig = EIPConfig(domain=provider)
+ eipconfig.load()
+ eipserviceconfig = EIPServiceConfig(domain=provider)
+ eipserviceconfig.load()
+
# get user/group name
# also from config.
user = baseconfig.get_username()
@@ -123,18 +188,32 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
opts.append('--verb')
opts.append("%s" % verbosity)
- # remote
+ # remote ##############################
+ # (server, port, protocol)
+
opts.append('--remote')
- gw = get_eip_gateway()
+
+ gw = get_eip_gateway(eipconfig=eipconfig,
+ eipserviceconfig=eipserviceconfig)
logger.debug('setting eip gateway to %s', gw)
opts.append(str(gw))
+
+ # get port/protocol from eipservice too
opts.append('1194')
+ #opts.append('80')
opts.append('udp')
opts.append('--tls-client')
opts.append('--remote-cert-tls')
opts.append('server')
+ # get ciphers #######################
+
+ ciphers = get_cipher_options(
+ eipserviceconfig=eipserviceconfig)
+ for cipheropt in ciphers:
+ opts.append(str(cipheropt))
+
# set user and group
opts.append('--user')
opts.append('%s' % user)
@@ -149,8 +228,13 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
# interface. unix sockets or telnet interface for win.
# XXX take them from the config object.
- ourplatform = platform.system()
- if ourplatform in ("Linux", "Mac"):
+ if _platform == "Windows":
+ opts.append('--management')
+ opts.append('localhost')
+ # XXX which is a good choice?
+ opts.append('7777')
+
+ if _platform in ("Linux", "Darwin"):
opts.append('--management')
if socket_path is None:
@@ -158,19 +242,30 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
opts.append(socket_path)
opts.append('unix')
- if ourplatform == "Windows":
- opts.append('--management')
- opts.append('localhost')
- # XXX which is a good choice?
- opts.append('7777')
+ opts.append('--script-security')
+ opts.append('2')
+
+ if _platform == "Linux":
+ if has_updown_scripts():
+ opts.append("--up")
+ opts.append(LINUX_UP_DOWN_SCRIPT)
+ opts.append("--down")
+ opts.append(LINUX_UP_DOWN_SCRIPT)
+ opts.append("--plugin")
+ opts.append(OPENVPN_DOWN_ROOT)
+ opts.append("'script_type=down %s'" % LINUX_UP_DOWN_SCRIPT)
# certs
+ client_cert_path = eipspecs.client_cert_path(provider)
+ ca_cert_path = eipspecs.provider_ca_path(provider)
+
+ # XXX FIX paths for MAC
opts.append('--cert')
- opts.append(eipspecs.client_cert_path())
+ opts.append(client_cert_path)
opts.append('--key')
- opts.append(eipspecs.client_cert_path())
+ opts.append(client_cert_path)
opts.append('--ca')
- opts.append(eipspecs.provider_ca_path())
+ opts.append(ca_cert_path)
# we cannot run in daemon mode
# with the current subp setting.
@@ -178,7 +273,7 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs):
#if daemon is True:
#opts.append('--daemon')
- logger.debug('vpn options: %s', opts)
+ logger.debug('vpn options: %s', ' '.join(opts))
return opts
@@ -198,7 +293,7 @@ def build_ovpn_command(debug=False, do_pkexec_check=True, vpnbin=None,
# XXX get use_pkexec from config instead.
- if platform.system() == "Linux" and use_pkexec and do_pkexec_check:
+ if _platform == "Linux" and use_pkexec and do_pkexec_check:
# check for both pkexec
# AND a suitable authentication
@@ -218,8 +313,16 @@ def build_ovpn_command(debug=False, do_pkexec_check=True, vpnbin=None,
raise eip_exceptions.EIPNoPolkitAuthAgentAvailable
command.append('pkexec')
+
if vpnbin is None:
- ovpn = which('openvpn')
+ if _platform == "Darwin":
+ # XXX Should hardcode our installed path
+ # /Applications/LEAPClient.app/Contents/Resources/openvpn.leap
+ openvpn_bin = "openvpn.leap"
+ else:
+ openvpn_bin = "openvpn"
+ #XXX hardcode for darwin
+ ovpn = which(openvpn_bin)
else:
ovpn = vpnbin
if ovpn:
@@ -235,10 +338,20 @@ def build_ovpn_command(debug=False, do_pkexec_check=True, vpnbin=None,
# XXX check len and raise proper error
- return [command[0], command[1:]]
+ if _platform == "Darwin":
+ OSX_ASADMIN = 'do shell script "%s" with administrator privileges'
+ # XXX fix workaround for Nones
+ _command = [x if x else " " for x in command]
+ # XXX debugging!
+ # XXX get openvpn log path from debug flags
+ _command.append('--log')
+ _command.append('/tmp/leap_openvpn.log')
+ return ["osascript", ["-e", OSX_ASADMIN % ' '.join(_command)]]
+ else:
+ return [command[0], command[1:]]
-def check_vpn_keys():
+def check_vpn_keys(provider=None):
"""
performs an existance and permission check
over the openvpn keys file.
@@ -246,8 +359,9 @@ def check_vpn_keys():
per provider, containing the CA cert,
the provider key, and our client certificate
"""
- provider_ca = eipspecs.provider_ca_path()
- client_cert = eipspecs.client_cert_path()
+ assert provider is not None
+ provider_ca = eipspecs.provider_ca_path(provider)
+ client_cert = eipspecs.client_cert_path(provider)
logger.debug('provider ca = %s', provider_ca)
logger.debug('client cert = %s', client_cert)
diff --git a/src/leap/eip/eipconnection.py b/src/leap/eip/eipconnection.py
index f0e7861e..d012c567 100644
--- a/src/leap/eip/eipconnection.py
+++ b/src/leap/eip/eipconnection.py
@@ -5,6 +5,9 @@ from __future__ import (absolute_import,)
import logging
import Queue
import sys
+import time
+
+from dateutil.parser import parse as dateparse
from leap.eip.checks import ProviderCertChecker
from leap.eip.checks import EIPConfigChecker
@@ -15,20 +18,149 @@ from leap.eip.openvpnconnection import OpenVPNConnection
logger = logging.getLogger(name=__name__)
-class EIPConnection(OpenVPNConnection):
+class StatusMixIn(object):
+
+ # a bunch of methods related with querying the connection
+ # state/status and displaying useful info.
+ # Needs to get clear on what is what, and
+ # separate functions.
+ # Should separate EIPConnectionStatus (self.status)
+ # from the OpenVPN state/status command and parsing.
+
+ ERR_CONNREFUSED = False
+
+ def connection_state(self):
+ """
+ returns the current connection state
+ """
+ return self.status.current
+
+ def get_icon_name(self):
+ """
+ get icon name from status object
+ """
+ return self.status.get_state_icon()
+
+ def get_leap_status(self):
+ return self.status.get_leap_status()
+
+ def poll_connection_state(self):
+ """
+ """
+ try:
+ state = self.get_connection_state()
+ except eip_exceptions.ConnectionRefusedError:
+ # connection refused. might be not ready yet.
+ if not self.ERR_CONNREFUSED:
+ logger.warning('connection refused')
+ self.ERR_CONNREFUSED = True
+ return
+ if not state:
+ #logger.debug('no state')
+ return
+ (ts, status_step,
+ ok, ip, remote) = state
+ self.status.set_vpn_state(status_step)
+ status_step = self.status.get_readable_status()
+ return (ts, status_step, ok, ip, remote)
+
+ def make_error(self):
+ """
+ capture error and wrap it in an
+ understandable format
+ """
+ # mostly a hack to display errors in the debug UI
+ # w/o breaking the polling.
+ #XXX get helpful error codes
+ self.with_errors = True
+ now = int(time.time())
+ return '%s,LAUNCHER ERROR,ERROR,-,-' % now
+
+ def state(self):
+ """
+ Sends OpenVPN command: state
+ """
+ state = self._send_command("state")
+ if not state:
+ return None
+ if isinstance(state, str):
+ return state
+ if isinstance(state, list):
+ if len(state) == 1:
+ return state[0]
+ else:
+ return state[-1]
+
+ def vpn_status(self):
+ """
+ OpenVPN command: status
+ """
+ status = self._send_command("status")
+ return status
+
+ def vpn_status2(self):
+ """
+ OpenVPN command: last 2 statuses
+ """
+ return self._send_command("status 2")
+
+ #
+ # parse info as the UI expects
+ #
+
+ def get_status_io(self):
+ status = self.vpn_status()
+ if isinstance(status, str):
+ lines = status.split('\n')
+ if isinstance(status, list):
+ lines = status
+ try:
+ (header, when, tun_read, tun_write,
+ tcp_read, tcp_write, auth_read) = tuple(lines)
+ except ValueError:
+ return None
+
+ when_ts = dateparse(when.split(',')[1]).timetuple()
+ sep = ','
+ # XXX clean up this!
+ tun_read = tun_read.split(sep)[1]
+ tun_write = tun_write.split(sep)[1]
+ tcp_read = tcp_read.split(sep)[1]
+ tcp_write = tcp_write.split(sep)[1]
+ auth_read = auth_read.split(sep)[1]
+
+ # XXX this could be a named tuple. prettier.
+ return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read)
+
+ def get_connection_state(self):
+ state = self.state()
+ if state is not None:
+ ts, status_step, ok, ip, remote = state.split(',')
+ ts = time.gmtime(float(ts))
+ # XXX this could be a named tuple. prettier.
+ return ts, status_step, ok, ip, remote
+
+
+class EIPConnection(OpenVPNConnection, StatusMixIn):
"""
+ Aka conductor.
Manages the execution of the OpenVPN process, auto starts, monitors the
network connection, handles configuration, fixes leaky hosts, handles
errors, etc.
Status updates (connected, bandwidth, etc) are signaled to the GUI.
"""
+ # XXX change name to EIPConductor ??
+
def __init__(self,
provider_cert_checker=ProviderCertChecker,
config_checker=EIPConfigChecker,
*args, **kwargs):
- self.settingsfile = kwargs.get('settingsfile', None)
- self.logfile = kwargs.get('logfile', None)
+ #self.settingsfile = kwargs.get('settingsfile', None)
+ #self.logfile = kwargs.get('logfile', None)
+ self.provider = kwargs.pop('provider', None)
+ self._providercertchecker = provider_cert_checker
+ self._configchecker = config_checker
self.error_queue = Queue.Queue()
@@ -38,17 +170,51 @@ class EIPConnection(OpenVPNConnection):
checker_signals = kwargs.pop('checker_signals', None)
self.checker_signals = checker_signals
- self.provider_cert_checker = provider_cert_checker()
- self.config_checker = config_checker()
+ self.init_checkers()
host = eipconfig.get_socket_path()
kwargs['host'] = host
super(EIPConnection, self).__init__(*args, **kwargs)
+ def connect(self, **kwargs):
+ """
+ entry point for connection process
+ """
+ # in OpenVPNConnection
+ self.try_openvpn_connection()
+
+ def disconnect(self, shutdown=False):
+ """
+ disconnects client
+ """
+ self.terminate_openvpn_connection(shutdown=shutdown)
+ self.status.change_to(self.status.DISCONNECTED)
+
def has_errors(self):
return True if self.error_queue.qsize() != 0 else False
+ def init_checkers(self):
+ """
+ initialize checkers
+ """
+ self.provider_cert_checker = self._providercertchecker(
+ domain=self.provider)
+ self.config_checker = self._configchecker(domain=self.provider)
+
+ def set_provider_domain(self, domain):
+ """
+ sets the provider domain.
+ used from the first run wizard when we launch the run_checks
+ and connect process after having initialized the conductor.
+ """
+ # This looks convoluted, right.
+ # We have to reinstantiate checkers cause we're passing
+ # the domain param that we did not know at the beginning
+ # (only for the firstrunwizard case)
+ self.provider = domain
+ self.init_checkers()
+
def run_checks(self, skip_download=False, skip_verify=False):
"""
run all eip checks previous to attempting a connection
@@ -80,100 +246,6 @@ class EIPConnection(OpenVPNConnection):
except Exception as exc:
push_err(exc)
- def connect(self):
- """
- entry point for connection process
- """
- #self.forget_errors()
- self._try_connection()
-
- def disconnect(self):
- """
- disconnects client
- """
- self.cleanup()
- logger.debug("disconnect: clicked.")
- self.status.change_to(self.status.DISCONNECTED)
-
- def shutdown(self):
- """
- shutdown and quit
- """
- self.desired_con_state = self.status.DISCONNECTED
-
- def connection_state(self):
- """
- returns the current connection state
- """
- return self.status.current
-
- def poll_connection_state(self):
- """
- """
- # XXX this separation does not
- # make sense anymore after having
- # merged Connection and Manager classes.
- # XXX GET RID OF THIS FUNCTION HERE!
- try:
- state = self.get_connection_state()
- except eip_exceptions.ConnectionRefusedError:
- # connection refused. might be not ready yet.
- logger.warning('connection refused')
- return
- if not state:
- #logger.debug('no state')
- return
- (ts, status_step,
- ok, ip, remote) = state
- self.status.set_vpn_state(status_step)
- status_step = self.status.get_readable_status()
- return (ts, status_step, ok, ip, remote)
-
- def get_icon_name(self):
- """
- get icon name from status object
- """
- return self.status.get_state_icon()
-
- def get_leap_status(self):
- return self.status.get_leap_status()
-
- #
- # private methods
- #
-
- #def _disconnect(self):
- # """
- # private method for disconnecting
- # """
- # if self.subp is not None:
- # logger.debug('disconnecting...')
- # self.subp.terminate()
- # self.subp = None
-
- #def _is_alive(self):
- #"""
- #don't know yet
- #"""
- #pass
-
- def _connect(self):
- """
- entry point for connection cascade methods.
- """
- try:
- conn_result = self._try_connection()
- except eip_exceptions.UnrecoverableError as except_msg:
- logger.error("FATAL: %s" % unicode(except_msg))
- conn_result = self.status.UNRECOVERABLE
-
- # XXX enqueue exceptions themselves instead?
- except Exception as except_msg:
- self.error_queue.append(except_msg)
- logger.error("Failed Connection: %s" %
- unicode(except_msg))
- return conn_result
-
class EIPConnectionStatus(object):
"""
@@ -247,6 +319,7 @@ class EIPConnectionStatus(object):
def get_leap_status(self):
# XXX improve nomenclature
leap_status = {
+ 0: 'disconnected',
1: 'connecting to gateway',
2: 'connecting to gateway',
3: 'authenticating',
diff --git a/src/leap/eip/exceptions.py b/src/leap/eip/exceptions.py
index 11bfd620..b7d398c3 100644
--- a/src/leap/eip/exceptions.py
+++ b/src/leap/eip/exceptions.py
@@ -32,8 +32,11 @@ TODO:
* gettext / i18n for user messages.
"""
+from leap.base.exceptions import LeapException
+from leap.util.translations import translate
+# This should inherit from LeapException
class EIPClientError(Exception):
"""
base EIPClient exception
@@ -60,45 +63,70 @@ class Warning(EIPClientError):
class EIPNoPolkitAuthAgentAvailable(CriticalError):
message = "No polkit authentication agent could be found"
- usermessage = ("We could not find any authentication "
- "agent in your system.<br/>"
- "Make sure you have "
- "<b>polkit-gnome-authentication-agent-1</b> "
- "running and try again.")
+ usermessage = translate(
+ "EIPErrors",
+ "We could not find any authentication "
+ "agent in your system.<br/>"
+ "Make sure you have "
+ "<b>polkit-gnome-authentication-agent-1</b> "
+ "running and try again.")
class EIPNoPkexecAvailable(Warning):
message = "No pkexec binary found"
- usermessage = ("We could not find <b>pkexec</b> in your "
- "system.<br/> Do you want to try "
- "<b>setuid workaround</b>? "
- "(<i>DOES NOTHING YET</i>)")
+ usermessage = translate(
+ "EIPErrors",
+ "We could not find <b>pkexec</b> in your "
+ "system.<br/> Do you want to try "
+ "<b>setuid workaround</b>? "
+ "(<i>DOES NOTHING YET</i>)")
failfirst = True
class EIPNoCommandError(EIPClientError):
message = "no suitable openvpn command found"
- usermessage = ("No suitable openvpn command found. "
- "<br/>(Might be a permissions problem)")
+ usermessage = translate(
+ "EIPErrors",
+ "No suitable openvpn command found. "
+ "<br/>(Might be a permissions problem)")
class EIPBadCertError(Warning):
# XXX this should be critical and fail close
message = "cert verification failed"
- usermessage = "there is a problem with provider certificate"
+ usermessage = translate(
+ "EIPErrors",
+ "there is a problem with provider certificate")
class LeapBadConfigFetchedError(Warning):
message = "provider sent a malformed json file"
- usermessage = "an error occurred during configuratio of leap services"
+ usermessage = translate(
+ "EIPErrors",
+ "an error occurred during configuratio of leap services")
-class OpenVPNAlreadyRunning(EIPClientError):
+class OpenVPNAlreadyRunning(CriticalError):
message = "Another OpenVPN Process is already running."
- usermessage = ("Another OpenVPN Process has been detected."
- "Please close it before starting leap-client")
+ usermessage = translate(
+ "EIPErrors",
+ "Another OpenVPN Process has been detected. "
+ "Please close it before starting leap-client")
+class HttpsNotSupported(LeapException):
+ message = "connection refused while accessing via https"
+ usermessage = translate(
+ "EIPErrors",
+ "Server does not allow secure connections")
+
+
+class HttpsBadCertError(LeapException):
+ message = "verification error on cert"
+ usermessage = translate(
+ "EIPErrors",
+ "Server certificate could not be verified")
+
#
# errors still needing some love
#
@@ -106,7 +134,9 @@ class OpenVPNAlreadyRunning(EIPClientError):
class EIPInitNoKeyFileError(CriticalError):
message = "No vpn keys found in the expected path"
- usermessage = "We could not find your eip certs in the expected path"
+ usermessage = translate(
+ "EIPErrors",
+ "We could not find your eip certs in the expected path")
class EIPInitBadKeyFilePermError(Warning):
diff --git a/src/leap/eip/openvpnconnection.py b/src/leap/eip/openvpnconnection.py
index d93bc40f..455735c8 100644
--- a/src/leap/eip/openvpnconnection.py
+++ b/src/leap/eip/openvpnconnection.py
@@ -2,30 +2,150 @@
OpenVPN Connection
"""
from __future__ import (print_function)
+from functools import partial
import logging
+import os
import psutil
+import shutil
+import select
import socket
-import time
-from functools import partial
+from time import sleep
logger = logging.getLogger(name=__name__)
from leap.base.connection import Connection
+from leap.base.constants import OPENVPN_BIN
from leap.util.coroutines import spawn_and_watch_process
+from leap.util.misc import get_openvpn_pids
from leap.eip.udstelnet import UDSTelnet
from leap.eip import config as eip_config
from leap.eip import exceptions as eip_exceptions
-class OpenVPNConnection(Connection):
+class OpenVPNManagement(object):
+
+ # TODO explain a little bit how management interface works
+ # and our telnet interface with support for unix sockets.
+
+ """
+ for more information, read openvpn management notes.
+ zcat `dpkg -L openvpn | grep management`
+ """
+
+ def _connect_to_management(self):
+ """
+ Connect to openvpn management interface
+ """
+ if hasattr(self, 'tn'):
+ self._close_management_socket()
+ self.tn = UDSTelnet(self.host, self.port)
+
+ # XXX make password optional
+ # specially for win. we should generate
+ # the pass on the fly when invoking manager
+ # from conductor
+
+ #self.tn.read_until('ENTER PASSWORD:', 2)
+ #self.tn.write(self.password + '\n')
+ #self.tn.read_until('SUCCESS:', 2)
+ if self.tn:
+ self._seek_to_eof()
+ return True
+
+ def _close_management_socket(self, announce=True):
+ """
+ Close connection to openvpn management interface
+ """
+ logger.debug('closing socket')
+ if announce:
+ self.tn.write("quit\n")
+ self.tn.read_all()
+ self.tn.get_socket().close()
+ del self.tn
+
+ def _seek_to_eof(self):
+ """
+ Read as much as available. Position seek pointer to end of stream
+ """
+ try:
+ b = self.tn.read_eager()
+ except EOFError:
+ logger.debug("Could not read from socket. Assuming it died.")
+ return
+ while b:
+ try:
+ b = self.tn.read_eager()
+ except EOFError:
+ logger.debug("Could not read from socket. Assuming it died.")
+
+ def _send_command(self, cmd):
+ """
+ Send a command to openvpn and return response as list
+ """
+ if not self.connected():
+ try:
+ self._connect_to_management()
+ except eip_exceptions.MissingSocketError:
+ #logger.warning('missing management socket')
+ return []
+ try:
+ if hasattr(self, 'tn'):
+ self.tn.write(cmd + "\n")
+ except socket.error:
+ logger.error('socket error')
+ self._close_management_socket(announce=False)
+ return []
+ try:
+ buf = self.tn.read_until(b"END", 2)
+ self._seek_to_eof()
+ blist = buf.split('\r\n')
+ if blist[-1].startswith('END'):
+ del blist[-1]
+ return blist
+ else:
+ return []
+ except socket.error as exc:
+ logger.debug('socket error: %s' % exc.message)
+ except select.error as exc:
+ logger.debug('select error: %s' % exc.message)
+
+ def _send_short_command(self, cmd):
+ """
+ parse output from commands that are
+ delimited by "success" instead
+ """
+ if not self.connected():
+ self.connect()
+ self.tn.write(cmd + "\n")
+ # XXX not working?
+ buf = self.tn.read_until(b"SUCCESS", 2)
+ self._seek_to_eof()
+ blist = buf.split('\r\n')
+ return blist
+
+ #
+ # random maybe useful vpn commands
+ #
+
+ def pid(self):
+ #XXX broken
+ return self._send_short_command("pid")
+
+
+class OpenVPNConnection(Connection, OpenVPNManagement):
"""
All related to invocation
- of the openvpn binary
+ of the openvpn binary.
+ It's extended by EIPConnection.
"""
+ # XXX Inheriting from Connection was an early design idea
+ # but currently that's an empty class.
+ # We can get rid of that if we don't use it for sharing
+ # state with other leap modules.
+
def __init__(self,
- #config_file=None,
watcher_cb=None,
debug=False,
host=None,
@@ -33,24 +153,21 @@ class OpenVPNConnection(Connection):
password=None,
*args, **kwargs):
"""
- :param config_file: configuration file to read from
:param watcher_cb: callback to be \
called for each line in watched stdout
:param signal_map: dictionary of signal names and callables \
to be triggered for each one of them.
- :type config_file: str
:type watcher_cb: function
:type signal_map: dict
"""
#XXX FIXME
#change watcher_cb to line_observer
+ # XXX if not host: raise ImproperlyConfigured
logger.debug('init openvpn connection')
self.debug = debug
- # XXX if not host: raise ImproperlyConfigured
self.ovpn_verbosity = kwargs.get('ovpn_verbosity', None)
- #self.config_file = config_file
self.watcher_cb = watcher_cb
#self.signal_maps = signal_maps
@@ -61,21 +178,13 @@ to be triggered for each one of them.
self.port = None
self.proto = None
- #XXX workaround for signaling
- #the ui that we don't know how to
- #manage a connection error
- self.with_errors = False
-
self.command = None
self.args = None
# XXX get autostart from config
self.autostart = True
- #
- # management init methods
- #
-
+ # management interface init
self.host = host
if isinstance(port, str) and port.isdigit():
port = int(port)
@@ -87,15 +196,108 @@ to be triggered for each one of them.
self.password = password
def run_openvpn_checks(self):
+ """
+ runs check needed before launching
+ openvpn subprocess. will raise if errors found.
+ """
logger.debug('running openvpn checks')
+ # XXX I think that "check_if_running" should be called
+ # from try openvpn connection instead. -- kali.
+ # let's prepare tests for that before changing it...
self._check_if_running_instance()
self._set_ovpn_command()
self._check_vpn_keys()
+ def try_openvpn_connection(self):
+ """
+ attempts to connect
+ """
+ # XXX should make public method
+ if self.command is None:
+ raise eip_exceptions.EIPNoCommandError
+ if self.subp is not None:
+ logger.debug('cowardly refusing to launch subprocess again')
+ # XXX this is not returning ???!!
+ # FIXME -- so it's calling it all the same!!
+
+ self._launch_openvpn()
+
+ def connected(self):
+ """
+ Returns True if connected
+ rtype: bool
+ """
+ # XXX make a property
+ return hasattr(self, 'tn')
+
+ def terminate_openvpn_connection(self, shutdown=False):
+ """
+ terminates openvpn child subprocess
+ """
+ if self.subp:
+ try:
+ self._stop_openvpn()
+ except eip_exceptions.ConnectionRefusedError:
+ logger.warning(
+ 'unable to send sigterm signal to openvpn: '
+ 'connection refused.')
+
+ # XXX kali --
+ # XXX review-me
+ # I think this will block if child process
+ # does not return.
+ # Maybe we can .poll() for a given
+ # interval and exit in any case.
+
+ RETCODE = self.subp.wait()
+ if RETCODE:
+ logger.error(
+ 'cannot terminate subprocess! Retcode %s'
+ '(We might have left openvpn running)' % RETCODE)
+
+ if shutdown:
+ self._cleanup_tempfiles()
+
+ def _cleanup_tempfiles(self):
+ """
+ remove all temporal files
+ we might have left behind
+ """
+ # if self.port is 'unix', we have
+ # created a temporal socket path that, under
+ # normal circumstances, we should be able to
+ # delete
+
+ if self.port == "unix":
+ logger.debug('cleaning socket file temp folder')
+
+ tempfolder = os.path.split(self.host)[0]
+ if os.path.isdir(tempfolder):
+ try:
+ shutil.rmtree(tempfolder)
+ except OSError:
+ logger.error('could not delete tmpfolder %s' % tempfolder)
+
+ # checks
+
+ def _check_if_running_instance(self):
+ """
+ check if openvpn is already running
+ """
+ openvpn_pids = get_openvpn_pids()
+ if openvpn_pids:
+ logger.debug('an openvpn instance is already running.')
+ logger.debug('attempting to stop openvpn instance.')
+ if not self._stop_openvpn():
+ raise eip_exceptions.OpenVPNAlreadyRunning
+ return
+ else:
+ logger.debug('no openvpn instance found.')
+
def _set_ovpn_command(self):
- # XXX check also for command-line --command flag
try:
command, args = eip_config.build_ovpn_command(
+ provider=self.provider,
debug=self.debug,
socket_path=self.host,
ovpn_verbosity=self.ovpn_verbosity)
@@ -115,12 +317,14 @@ to be triggered for each one of them.
checks for correct permissions on vpn keys
"""
try:
- eip_config.check_vpn_keys()
+ eip_config.check_vpn_keys(provider=self.provider)
except eip_exceptions.EIPInitBadKeyFilePermError:
logger.error('Bad VPN Keys permission!')
# do nothing now
# and raise the rest ...
+ # starting and stopping openvpn subprocess
+
def _launch_openvpn(self):
"""
invocation of openvpn binaries in a subprocess.
@@ -129,12 +333,13 @@ to be triggered for each one of them.
#deprecate watcher_cb,
#use _only_ signal_maps instead
- logger.debug('_launch_openvpn called')
+ #logger.debug('_launch_openvpn called')
if self.watcher_cb is not None:
linewrite_callback = self.watcher_cb
else:
#XXX get logger instead
- linewrite_callback = lambda line: print('watcher: %s' % line)
+ linewrite_callback = lambda line: logger.debug(
+ 'watcher: %s' % line)
# the partial is not
# being applied now because we're not observing the process
@@ -142,7 +347,8 @@ to be triggered for each one of them.
# here since it will be handy for observing patterns in the
# thru-the-manager updates (with regex)
observers = (linewrite_callback,
- partial(lambda con_status, line: None, self.status))
+ partial(lambda con_status,
+ line: linewrite_callback, self.status))
subp, watcher = spawn_and_watch_process(
self.command,
self.args,
@@ -150,264 +356,55 @@ to be triggered for each one of them.
self.subp = subp
self.watcher = watcher
- def _try_connection(self):
- """
- attempts to connect
- """
- if self.command is None:
- raise eip_exceptions.EIPNoCommandError
- if self.subp is not None:
- logger.debug('cowardly refusing to launch subprocess again')
-
- self._launch_openvpn()
-
- def _check_if_running_instance(self):
- """
- check if openvpn is already running
- """
- for process in psutil.get_process_list():
- if process.name == "openvpn":
- logger.debug('an openvpn instance is already running.')
- logger.debug('attempting to stop openvpn instance.')
- if not self._stop():
- raise eip_exceptions.OpenVPNAlreadyRunning
-
- logger.debug('no openvpn instance found.')
-
- def cleanup(self):
- """
- terminates openvpn child subprocess
- """
- if self.subp:
- self._stop()
- RETCODE = self.subp.wait()
- if RETCODE:
- logger.error('cannot terminate subprocess! '
- '(maybe openvpn still running?)')
-
- def _stop(self):
+ def _stop_openvpn(self):
"""
stop openvpn process
+ by sending SIGTERM to the management
+ interface
"""
- logger.debug("disconnecting...")
- self._send_command("signal SIGTERM\n")
-
- if self.subp:
- return True
+ # XXX method a bit too long, split
+ logger.debug("atempting to terminate openvpn process...")
+ if self.connected():
+ try:
+ self._send_command("signal SIGTERM\n")
+ sleep(1)
+ if not self.subp: # XXX ???
+ return True
+ except socket.error:
+ logger.warning('management socket died')
+ return
#shutting openvpn failured
#try patching in old openvpn host and trying again
+ # XXX could be more than one!
process = self._get_openvpn_process()
if process:
- self.host = \
- process.cmdline[process.cmdline.index("--management") + 1]
- self._send_command("signal SIGTERM\n")
+ logger.debug('process: %s' % process.name)
+ cmdline = process.cmdline
+
+ manag_flag = "--management"
+ if isinstance(cmdline, list) and manag_flag in cmdline:
+ _index = cmdline.index(manag_flag)
+ self.host = cmdline[_index + 1]
+ self._send_command("signal SIGTERM\n")
#make sure the process was terminated
process = self._get_openvpn_process()
if not process:
- logger.debug("Exisiting OpenVPN Process Terminated")
+ logger.debug("Existing OpenVPN Process Terminated")
return True
else:
- logger.error("Unable to terminate exisiting OpenVPN Process.")
+ logger.error("Unable to terminate existing OpenVPN Process.")
return False
return True
def _get_openvpn_process(self):
- for process in psutil.get_process_list():
- if process.name == "openvpn":
+ for process in psutil.process_iter():
+ if OPENVPN_BIN in process.name:
return process
return None
- # management methods
- #
- # XXX REVIEW-ME
- # REFACTOR INFO: (former "manager".
- # Can we move to another
- # base class to test independently?)
- #
-
- #def forget_errors(self):
- #logger.debug('forgetting errors')
- #self.with_errors = False
-
- def connect_to_management(self):
- """Connect to openvpn management interface"""
- #logger.debug('connecting socket')
- if hasattr(self, 'tn'):
- self.close()
- self.tn = UDSTelnet(self.host, self.port)
-
- # XXX make password optional
- # specially for win. we should generate
- # the pass on the fly when invoking manager
- # from conductor
-
- #self.tn.read_until('ENTER PASSWORD:', 2)
- #self.tn.write(self.password + '\n')
- #self.tn.read_until('SUCCESS:', 2)
-
- self._seek_to_eof()
- return True
-
- def _seek_to_eof(self):
- """
- Read as much as available. Position seek pointer to end of stream
- """
- try:
- b = self.tn.read_eager()
- except EOFError:
- logger.debug("Could not read from socket. Assuming it died.")
- return
- while b:
- try:
- b = self.tn.read_eager()
- except EOFError:
- logger.debug("Could not read from socket. Assuming it died.")
-
- def connected(self):
- """
- Returns True if connected
- rtype: bool
- """
- return hasattr(self, 'tn')
-
- def close(self, announce=True):
- """
- Close connection to openvpn management interface
- """
- logger.debug('closing socket')
- if announce:
- self.tn.write("quit\n")
- self.tn.read_all()
- self.tn.get_socket().close()
- del self.tn
-
- def _send_command(self, cmd):
- """
- Send a command to openvpn and return response as list
- """
- if not self.connected():
- try:
- self.connect_to_management()
- except eip_exceptions.MissingSocketError:
- logger.warning('missing management socket')
- # This should only happen briefly during
- # the first invocation. Race condition make
- # the polling begin before management socket
- # is ready
- return []
- #return self.make_error()
- try:
- if hasattr(self, 'tn'):
- self.tn.write(cmd + "\n")
- except socket.error:
- logger.error('socket error')
- self.close(announce=False)
- return []
- buf = self.tn.read_until(b"END", 2)
- self._seek_to_eof()
- blist = buf.split('\r\n')
- if blist[-1].startswith('END'):
- del blist[-1]
- return blist
- else:
- return []
-
- def _send_short_command(self, cmd):
- """
- parse output from commands that are
- delimited by "success" instead
- """
- if not self.connected():
- self.connect()
- self.tn.write(cmd + "\n")
- # XXX not working?
- buf = self.tn.read_until(b"SUCCESS", 2)
- self._seek_to_eof()
- blist = buf.split('\r\n')
- return blist
-
- #
- # useful vpn commands
- #
-
- def pid(self):
- #XXX broken
- return self._send_short_command("pid")
-
- def make_error(self):
- """
- capture error and wrap it in an
- understandable format
- """
- #XXX get helpful error codes
- self.with_errors = True
- now = int(time.time())
- return '%s,LAUNCHER ERROR,ERROR,-,-' % now
-
- def state(self):
- """
- OpenVPN command: state
- """
- state = self._send_command("state")
- if not state:
- return None
- if isinstance(state, str):
- return state
- if isinstance(state, list):
- if len(state) == 1:
- return state[0]
- else:
- return state[-1]
-
- def vpn_status(self):
- """
- OpenVPN command: status
- """
- #logger.debug('status called')
- status = self._send_command("status")
- return status
-
- def vpn_status2(self):
- """
- OpenVPN command: last 2 statuses
- """
- return self._send_command("status 2")
-
- #
- # parse info
- #
-
- def get_status_io(self):
- status = self.vpn_status()
- if isinstance(status, str):
- lines = status.split('\n')
- if isinstance(status, list):
- lines = status
- try:
- (header, when, tun_read, tun_write,
- tcp_read, tcp_write, auth_read) = tuple(lines)
- except ValueError:
- return None
-
- when_ts = time.strptime(when.split(',')[1], "%a %b %d %H:%M:%S %Y")
- sep = ','
- # XXX cleanup!
- tun_read = tun_read.split(sep)[1]
- tun_write = tun_write.split(sep)[1]
- tcp_read = tcp_read.split(sep)[1]
- tcp_write = tcp_write.split(sep)[1]
- auth_read = auth_read.split(sep)[1]
-
- # XXX this could be a named tuple. prettier.
- return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read)
-
- def get_connection_state(self):
- state = self.state()
- if state is not None:
- ts, status_step, ok, ip, remote = state.split(',')
- ts = time.gmtime(float(ts))
- # XXX this could be a named tuple. prettier.
- return ts, status_step, ok, ip, remote
+ def get_log(self, lines=1):
+ log = self._send_command("log %s" % lines)
+ return log
diff --git a/src/leap/eip/specs.py b/src/leap/eip/specs.py
index 1a670b0e..c41fd29b 100644
--- a/src/leap/eip/specs.py
+++ b/src/leap/eip/specs.py
@@ -4,11 +4,20 @@ import os
from leap import __branding
from leap.base import config as baseconfig
+# XXX move provider stuff to base config
+
PROVIDER_CA_CERT = __branding.get(
'provider_ca_file',
- 'testprovider-ca-cert.pem')
+ 'cacert.pem')
+
+provider_ca_path = lambda domain: str(os.path.join(
+ #baseconfig.get_default_provider_path(),
+ baseconfig.get_provider_path(domain),
+ 'keys', 'ca',
+ 'cacert.pem'
+)) if domain else None
-provider_ca_path = lambda: str(os.path.join(
+default_provider_ca_path = lambda: str(os.path.join(
baseconfig.get_default_provider_path(),
'keys', 'ca',
PROVIDER_CA_CERT
@@ -17,7 +26,13 @@ provider_ca_path = lambda: str(os.path.join(
PROVIDER_DOMAIN = __branding.get('provider_domain', 'testprovider.example.org')
-client_cert_path = lambda: unicode(os.path.join(
+client_cert_path = lambda domain: unicode(os.path.join(
+ baseconfig.get_provider_path(domain),
+ 'keys', 'client',
+ 'openvpn.pem'
+)) if domain else None
+
+default_client_cert_path = lambda: unicode(os.path.join(
baseconfig.get_default_provider_path(),
'keys', 'client',
'openvpn.pem'
@@ -46,11 +61,11 @@ eipconfig_spec = {
},
'openvpn_ca_certificate': {
'type': unicode, # path
- 'default': provider_ca_path
+ 'default': default_provider_ca_path
},
'openvpn_client_certificate': {
'type': unicode, # path
- 'default': client_cert_path
+ 'default': default_client_cert_path
},
'connect_on_login': {
'type': bool,
@@ -62,12 +77,12 @@ eipconfig_spec = {
},
'primary_gateway': {
'type': unicode,
- 'default': u"turkey",
+ 'default': u"location_unknown",
#'required': True
},
'secondary_gateway': {
'type': unicode,
- 'default': u"france"
+ 'default': u"location_unknown2"
},
'management_password': {
'type': unicode
@@ -85,25 +100,37 @@ eipservice_config_spec = {
'default': 1
},
'version': {
- 'type': unicode,
+ 'type': int,
'required': True,
- 'default': "0.1.0"
+ 'default': 1
},
- 'capabilities': {
- 'type': dict,
- 'default': {
- "transport": ["openvpn"],
- "ports": ["80", "53"],
- "protocols": ["udp", "tcp"],
- "static_ips": True,
- "adblock": True}
+ 'clusters': {
+ 'type': list,
+ 'default': [
+ {"label": {
+ "en": "Location Unknown"},
+ "name": "location_unknown"}]
},
'gateways': {
'type': list,
- 'default': [{"country_code": "us",
- "label": {"en":"west"},
- "capabilities": {},
- "hosts": ["1.2.3.4", "1.2.3.5"]}]
+ 'default': [
+ {"capabilities": {
+ "adblock": True,
+ "filter_dns": True,
+ "ports": ["80", "53", "443", "1194"],
+ "protocols": ["udp", "tcp"],
+ "transport": ["openvpn"],
+ "user_ips": False},
+ "cluster": "location_unknown",
+ "host": "location.example.org",
+ "ip_address": "127.0.0.1"}]
+ },
+ 'openvpn_configuration': {
+ 'type': dict,
+ 'default': {
+ "auth": None,
+ "cipher": None,
+ "tls-cipher": None}
}
}
}
diff --git a/src/leap/eip/tests/data.py b/src/leap/eip/tests/data.py
index 43df2013..a7fe1853 100644
--- a/src/leap/eip/tests/data.py
+++ b/src/leap/eip/tests/data.py
@@ -1,11 +1,12 @@
from __future__ import unicode_literals
import os
-from leap import __branding
+#from leap import __branding
# sample data used in tests
-PROVIDER = __branding.get('provider_domain')
+#PROVIDER = __branding.get('provider_domain')
+PROVIDER = "testprovider.example.org"
EIP_SAMPLE_CONFIG = {
"provider": "%s" % PROVIDER,
@@ -15,33 +16,36 @@ EIP_SAMPLE_CONFIG = {
"openvpn_ca_certificate": os.path.expanduser(
"~/.config/leap/providers/"
"%s/"
- "keys/ca/testprovider-ca-cert.pem" % PROVIDER),
+ "keys/ca/cacert.pem" % PROVIDER),
"openvpn_client_certificate": os.path.expanduser(
"~/.config/leap/providers/"
"%s/"
"keys/client/openvpn.pem" % PROVIDER),
"connect_on_login": True,
"block_cleartext_traffic": True,
- "primary_gateway": "turkey",
- "secondary_gateway": "france",
+ "primary_gateway": "location_unknown",
+ "secondary_gateway": "location_unknown2",
#"management_password": "oph7Que1othahwiech6J"
}
EIP_SAMPLE_SERVICE = {
"serial": 1,
- "version": "0.1.0",
- "capabilities": {
- "transport": ["openvpn"],
- "ports": ["80", "53"],
- "protocols": ["udp", "tcp"],
- "static_ips": True,
- "adblock": True
- },
+ "version": 1,
+ "clusters": [
+ {"label": {
+ "en": "Location Unknown"},
+ "name": "location_unknown"}
+ ],
"gateways": [
- {"country_code": "tr",
- "name": "turkey",
- "label": {"en":"Ankara, Turkey"},
- "capabilities": {},
- "hosts": ["94.103.43.4"]}
+ {"capabilities": {
+ "adblock": True,
+ "filter_dns": True,
+ "ports": ["80", "53", "443", "1194"],
+ "protocols": ["udp", "tcp"],
+ "transport": ["openvpn"],
+ "user_ips": False},
+ "cluster": "location_unknown",
+ "host": "location.example.org",
+ "ip_address": "192.0.43.10"}
]
}
diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py
index 58ce473f..ab11037a 100644
--- a/src/leap/eip/tests/test_checks.py
+++ b/src/leap/eip/tests/test_checks.py
@@ -25,6 +25,7 @@ from leap.eip.tests import data as testdata
from leap.testing.basetest import BaseLeapTest
from leap.testing.https_server import BaseHTTPSServerTestCase
from leap.testing.https_server import where as where_cert
+from leap.util.fileutil import mkdir_f
class NoLogRequestHandler:
@@ -39,6 +40,8 @@ class NoLogRequestHandler:
class EIPCheckTest(BaseLeapTest):
__name__ = "eip_check_tests"
+ provider = "testprovider.example.org"
+ maxDiff = None
def setUp(self):
pass
@@ -49,7 +52,7 @@ class EIPCheckTest(BaseLeapTest):
# test methods are there, and can be called from run_all
def test_checker_should_implement_check_methods(self):
- checker = eipchecks.EIPConfigChecker()
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
self.assertTrue(hasattr(checker, "check_default_eipconfig"),
"missing meth")
@@ -62,7 +65,7 @@ class EIPCheckTest(BaseLeapTest):
"missing meth")
def test_checker_should_actually_call_all_tests(self):
- checker = eipchecks.EIPConfigChecker()
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
mc = Mock()
checker.run_all(checker=mc)
@@ -79,7 +82,7 @@ class EIPCheckTest(BaseLeapTest):
# test individual check methods
def test_check_default_eipconfig(self):
- checker = eipchecks.EIPConfigChecker()
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
# no eip config (empty home)
eipconfig_path = checker.eipconfig.filename
self.assertFalse(os.path.isfile(eipconfig_path))
@@ -93,15 +96,15 @@ class EIPCheckTest(BaseLeapTest):
# small workaround for evaluating home dirs correctly
EIP_SAMPLE_CONFIG = copy.copy(testdata.EIP_SAMPLE_CONFIG)
EIP_SAMPLE_CONFIG['openvpn_client_certificate'] = \
- eipspecs.client_cert_path()
+ eipspecs.client_cert_path(self.provider)
EIP_SAMPLE_CONFIG['openvpn_ca_certificate'] = \
- eipspecs.provider_ca_path()
+ eipspecs.provider_ca_path(self.provider)
self.assertEqual(deserialized, EIP_SAMPLE_CONFIG)
# TODO: shold ALSO run validation methods.
def test_check_is_there_default_provider(self):
- checker = eipchecks.EIPConfigChecker()
+ checker = eipchecks.EIPConfigChecker(domain=self.provider)
# we do dump a sample eip config, but lacking a
# default provider entry.
# This error will be possible catched in a different
@@ -116,6 +119,7 @@ class EIPCheckTest(BaseLeapTest):
sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG)
sampleconfig['provider'] = None
eipcfg_path = checker.eipconfig.filename
+ mkdir_f(eipcfg_path)
with open(eipcfg_path, 'w') as fp:
json.dump(sampleconfig, fp)
#with self.assertRaises(eipexceptions.EIPMissingDefaultProvider):
@@ -136,6 +140,8 @@ class EIPCheckTest(BaseLeapTest):
def test_fetch_definition(self):
with patch.object(requests, "get") as mocked_get:
mocked_get.return_value.status_code = 200
+ mocked_get.return_value.headers = {
+ 'last-modified': "Wed Dec 12 12:12:12 GMT 2012"}
mocked_get.return_value.json = DEFAULT_PROVIDER_DEFINITION
checker = eipchecks.EIPConfigChecker(fetcher=requests)
sampleconfig = testdata.EIP_SAMPLE_CONFIG
@@ -154,6 +160,8 @@ class EIPCheckTest(BaseLeapTest):
def test_fetch_eip_service_config(self):
with patch.object(requests, "get") as mocked_get:
mocked_get.return_value.status_code = 200
+ mocked_get.return_value.headers = {
+ 'last-modified': "Wed Dec 12 12:12:12 GMT 2012"}
mocked_get.return_value.json = testdata.EIP_SAMPLE_SERVICE
checker = eipchecks.EIPConfigChecker(fetcher=requests)
sampleconfig = testdata.EIP_SAMPLE_CONFIG
@@ -178,6 +186,7 @@ class EIPCheckTest(BaseLeapTest):
class ProviderCertCheckerTest(BaseLeapTest):
__name__ = "provider_cert_checker_tests"
+ provider = "testprovider.example.org"
def setUp(self):
pass
@@ -226,13 +235,20 @@ class ProviderCertCheckerTest(BaseLeapTest):
# test individual check methods
+ @unittest.skip
def test_is_there_provider_ca(self):
+ # XXX commenting out this test.
+ # With the generic client this does not make sense,
+ # we should dump one there.
+ # or test conductor logic.
checker = eipchecks.ProviderCertChecker()
self.assertTrue(
checker.is_there_provider_ca())
class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase, BaseLeapTest):
+ provider = "testprovider.example.org"
+
class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
responses = {
'/': ['OK', ''],
@@ -292,12 +308,19 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase, BaseLeapTest):
# same, but get cacert from leap.custom
# XXX TODO!
+ @unittest.skip
def test_download_new_client_cert(self):
+ # FIXME
+ # Magick srp decorator broken right now...
+ # Have to mock the decorator and inject something that
+ # can bypass the authentication
+
uri = "https://%s/client.cert" % (self.get_server())
cacert = where_cert('cacert.pem')
- checker = eipchecks.ProviderCertChecker()
+ checker = eipchecks.ProviderCertChecker(domain=self.provider)
+ credentials = "testuser", "testpassword"
self.assertTrue(checker.download_new_client_cert(
- uri=uri, verify=cacert))
+ credentials=credentials, uri=uri, verify=cacert))
# now download a malformed cert
uri = "https://%s/badclient.cert" % (self.get_server())
@@ -305,7 +328,7 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase, BaseLeapTest):
checker = eipchecks.ProviderCertChecker()
with self.assertRaises(ValueError):
self.assertTrue(checker.download_new_client_cert(
- uri=uri, verify=cacert))
+ credentials=credentials, uri=uri, verify=cacert))
# did we write cert to its path?
clientcertfile = eipspecs.client_cert_path()
@@ -339,7 +362,7 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase, BaseLeapTest):
def test_check_new_cert_needed(self):
# check: missing cert
- checker = eipchecks.ProviderCertChecker()
+ checker = eipchecks.ProviderCertChecker(domain=self.provider)
self.assertTrue(checker.check_new_cert_needed(skip_download=True))
# TODO check: malformed cert
# TODO check: expired cert
diff --git a/src/leap/eip/tests/test_config.py b/src/leap/eip/tests/test_config.py
index 6759b522..72ab3c8e 100644
--- a/src/leap/eip/tests/test_config.py
+++ b/src/leap/eip/tests/test_config.py
@@ -1,3 +1,4 @@
+from collections import OrderedDict
import json
import os
import platform
@@ -10,21 +11,24 @@ except ImportError:
#from leap.base import constants
#from leap.eip import config as eip_config
-from leap import __branding as BRANDING
+#from leap import __branding as BRANDING
from leap.eip import config as eipconfig
from leap.eip.tests.data import EIP_SAMPLE_CONFIG, EIP_SAMPLE_SERVICE
from leap.testing.basetest import BaseLeapTest
-from leap.util.fileutil import mkdir_p
+from leap.util.fileutil import mkdir_p, mkdir_f
_system = platform.system()
-PROVIDER = BRANDING.get('provider_domain')
-PROVIDER_SHORTNAME = BRANDING.get('short_name')
+#PROVIDER = BRANDING.get('provider_domain')
+#PROVIDER_SHORTNAME = BRANDING.get('short_name')
class EIPConfigTest(BaseLeapTest):
__name__ = "eip_config_tests"
+ provider = "testprovider.example.org"
+
+ maxDiff = None
def setUp(self):
pass
@@ -46,11 +50,22 @@ class EIPConfigTest(BaseLeapTest):
open(tfile, 'wb').close()
os.chmod(tfile, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
- def write_sample_eipservice(self):
+ def write_sample_eipservice(self, vpnciphers=False, extra_vpnopts=None,
+ gateways=None):
conf = eipconfig.EIPServiceConfig()
- folder, f = os.path.split(conf.filename)
- if not os.path.isdir(folder):
- mkdir_p(folder)
+ mkdir_f(conf.filename)
+ if gateways:
+ EIP_SAMPLE_SERVICE['gateways'] = gateways
+ if vpnciphers:
+ openvpnconfig = OrderedDict({
+ "auth": "SHA1",
+ "cipher": "AES-128-CBC",
+ "tls-cipher": "DHE-RSA-AES128-SHA"})
+ if extra_vpnopts:
+ for k, v in extra_vpnopts.items():
+ openvpnconfig[k] = v
+ EIP_SAMPLE_SERVICE['openvpn_configuration'] = openvpnconfig
+
with open(conf.filename, 'w') as fd:
fd.write(json.dumps(EIP_SAMPLE_SERVICE))
@@ -62,8 +77,17 @@ class EIPConfigTest(BaseLeapTest):
with open(conf.filename, 'w') as fd:
fd.write(json.dumps(EIP_SAMPLE_CONFIG))
- def get_expected_openvpn_args(self):
+ def get_expected_openvpn_args(self, with_openvpn_ciphers=False):
+ """
+ yeah, this is almost as duplicating the
+ code for building the command
+ """
args = []
+ eipconf = eipconfig.EIPConfig(domain=self.provider)
+ eipconf.load()
+ eipsconf = eipconfig.EIPServiceConfig(domain=self.provider)
+ eipsconf.load()
+
username = self.get_username()
groupname = self.get_groupname()
@@ -74,7 +98,10 @@ class EIPConfigTest(BaseLeapTest):
args.append('--persist-tun')
args.append('--persist-key')
args.append('--remote')
- args.append('%s' % eipconfig.get_eip_gateway())
+
+ args.append('%s' % eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipsconf))
# XXX get port!?
args.append('1194')
# XXX get proto
@@ -83,6 +110,14 @@ class EIPConfigTest(BaseLeapTest):
args.append('--remote-cert-tls')
args.append('server')
+ if with_openvpn_ciphers:
+ CIPHERS = [
+ "--tls-cipher", "DHE-RSA-AES128-SHA",
+ "--cipher", "AES-128-CBC",
+ "--auth", "SHA1"]
+ for opt in CIPHERS:
+ args.append(opt)
+
args.append('--user')
args.append(username)
args.append('--group')
@@ -97,29 +132,43 @@ class EIPConfigTest(BaseLeapTest):
args.append('/tmp/test.socket')
args.append('unix')
+ args.append('--script-security')
+ args.append('2')
+
+ if _system == "Linux":
+ UPDOWN_SCRIPT = "/etc/leap/resolv-update"
+ if os.path.isfile(UPDOWN_SCRIPT):
+ args.append('--up')
+ args.append('/etc/leap/resolv-update')
+ args.append('--down')
+ args.append('/etc/leap/resolv-update')
+ args.append('--plugin')
+ args.append('/usr/lib/openvpn/openvpn-down-root.so')
+ args.append("'script_type=down /etc/leap/resolv-update'")
+
# certs
# XXX get values from specs?
args.append('--cert')
args.append(os.path.join(
self.home,
'.config', 'leap', 'providers',
- '%s' % PROVIDER,
+ '%s' % self.provider,
'keys', 'client',
'openvpn.pem'))
args.append('--key')
args.append(os.path.join(
self.home,
'.config', 'leap', 'providers',
- '%s' % PROVIDER,
+ '%s' % self.provider,
'keys', 'client',
'openvpn.pem'))
args.append('--ca')
args.append(os.path.join(
self.home,
'.config', 'leap', 'providers',
- '%s' % PROVIDER,
+ '%s' % self.provider,
'keys', 'ca',
- '%s-cacert.pem' % PROVIDER_SHORTNAME))
+ 'cacert.pem'))
return args
# build command string
@@ -128,6 +177,55 @@ class EIPConfigTest(BaseLeapTest):
# params in the function call, to disable
# some checks.
+ def test_get_eip_gateway(self):
+ self.write_sample_eipconfig()
+ eipconf = eipconfig.EIPConfig(domain=self.provider)
+
+ # default eipservice
+ self.write_sample_eipservice()
+ eipsconf = eipconfig.EIPServiceConfig(domain=self.provider)
+
+ gateway = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipsconf)
+
+ # in spec is local gateway by default
+ self.assertEqual(gateway, '127.0.0.1')
+
+ # change eipservice
+ # right now we only check that cluster == selected primary gw in
+ # eip.json, and pick first matching ip
+ eipconf._config.config['primary_gateway'] = "foo_provider"
+ newgateways = [{"cluster": "foo_provider",
+ "ip_address": "127.0.0.99"}]
+ self.write_sample_eipservice(gateways=newgateways)
+ eipsconf = eipconfig.EIPServiceConfig(domain=self.provider)
+ # load from disk file
+ eipsconf.load()
+
+ gateway = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipsconf)
+ self.assertEqual(gateway, '127.0.0.99')
+
+ # change eipservice, several gateways
+ # right now we only check that cluster == selected primary gw in
+ # eip.json, and pick first matching ip
+ eipconf._config.config['primary_gateway'] = "bar_provider"
+ newgateways = [{"cluster": "foo_provider",
+ "ip_address": "127.0.0.99"},
+ {'cluster': "bar_provider",
+ "ip_address": "127.0.0.88"}]
+ self.write_sample_eipservice(gateways=newgateways)
+ eipsconf = eipconfig.EIPServiceConfig(domain=self.provider)
+ # load from disk file
+ eipsconf.load()
+
+ gateway = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipsconf)
+ self.assertEqual(gateway, '127.0.0.88')
+
def test_build_ovpn_command_empty_config(self):
self.touch_exec()
self.write_sample_eipservice()
@@ -137,13 +235,63 @@ class EIPConfigTest(BaseLeapTest):
from leap.util.fileutil import which
path = os.environ['PATH']
vpnbin = which('openvpn', path=path)
- print 'path =', path
- print 'vpnbin = ', vpnbin
- command, args = eipconfig.build_ovpn_command(
+ #print 'path =', path
+ #print 'vpnbin = ', vpnbin
+ vpncommand, vpnargs = eipconfig.build_ovpn_command(
+ do_pkexec_check=False, vpnbin=vpnbin,
+ socket_path="/tmp/test.socket",
+ provider=self.provider)
+ self.assertEqual(vpncommand, self.home + '/bin/openvpn')
+ self.assertEqual(vpnargs, self.get_expected_openvpn_args())
+
+ def test_build_ovpn_command_openvpnoptions(self):
+ self.touch_exec()
+
+ from leap.eip import config as eipconfig
+ from leap.util.fileutil import which
+ path = os.environ['PATH']
+ vpnbin = which('openvpn', path=path)
+
+ self.write_sample_eipconfig()
+
+ # regular run, everything normal
+ self.write_sample_eipservice(vpnciphers=True)
+ vpncommand, vpnargs = eipconfig.build_ovpn_command(
+ do_pkexec_check=False, vpnbin=vpnbin,
+ socket_path="/tmp/test.socket",
+ provider=self.provider)
+ self.assertEqual(vpncommand, self.home + '/bin/openvpn')
+ expected = self.get_expected_openvpn_args(
+ with_openvpn_ciphers=True)
+ self.assertEqual(vpnargs, expected)
+
+ # bad options -- illegal options
+ self.write_sample_eipservice(
+ vpnciphers=True,
+ # WE ONLY ALLOW vpn options in auth, cipher, tls-cipher
+ extra_vpnopts={"notallowedconfig": "badvalue"})
+ vpncommand, vpnargs = eipconfig.build_ovpn_command(
+ do_pkexec_check=False, vpnbin=vpnbin,
+ socket_path="/tmp/test.socket",
+ provider=self.provider)
+ self.assertEqual(vpncommand, self.home + '/bin/openvpn')
+ expected = self.get_expected_openvpn_args(
+ with_openvpn_ciphers=True)
+ self.assertEqual(vpnargs, expected)
+
+ # bad options -- illegal chars
+ self.write_sample_eipservice(
+ vpnciphers=True,
+ # WE ONLY ALLOW A-Z09\-
+ extra_vpnopts={"cipher": "AES-128-CBC;FOOTHING"})
+ vpncommand, vpnargs = eipconfig.build_ovpn_command(
do_pkexec_check=False, vpnbin=vpnbin,
- socket_path="/tmp/test.socket")
- self.assertEqual(command, self.home + '/bin/openvpn')
- self.assertEqual(args, self.get_expected_openvpn_args())
+ socket_path="/tmp/test.socket",
+ provider=self.provider)
+ self.assertEqual(vpncommand, self.home + '/bin/openvpn')
+ expected = self.get_expected_openvpn_args(
+ with_openvpn_ciphers=True)
+ self.assertEqual(vpnargs, expected)
if __name__ == "__main__":
diff --git a/src/leap/eip/tests/test_eipconnection.py b/src/leap/eip/tests/test_eipconnection.py
index bb643ae0..163f8d45 100644
--- a/src/leap/eip/tests/test_eipconnection.py
+++ b/src/leap/eip/tests/test_eipconnection.py
@@ -1,6 +1,8 @@
+import glob
import logging
import platform
-import os
+#import os
+import shutil
logging.basicConfig()
logger = logging.getLogger(name=__name__)
@@ -19,6 +21,8 @@ from leap.testing.basetest import BaseLeapTest
_system = platform.system()
+PROVIDER = "testprovider.example.org"
+
class NotImplementedError(Exception):
pass
@@ -27,6 +31,7 @@ class NotImplementedError(Exception):
@patch('OpenVPNConnection._get_or_create_config')
@patch('OpenVPNConnection._set_ovpn_command')
class MockedEIPConnection(EIPConnection):
+
def _set_ovpn_command(self):
self.command = "mock_command"
self.args = [1, 2, 3]
@@ -35,6 +40,7 @@ class MockedEIPConnection(EIPConnection):
class EIPConductorTest(BaseLeapTest):
__name__ = "eip_conductor_tests"
+ provider = PROVIDER
def setUp(self):
# XXX there's a conceptual/design
@@ -51,8 +57,8 @@ class EIPConductorTest(BaseLeapTest):
# XXX change to keys_checker invocation
# (see config_checker)
- keyfiles = (eipspecs.provider_ca_path(),
- eipspecs.client_cert_path())
+ keyfiles = (eipspecs.provider_ca_path(domain=self.provider),
+ eipspecs.client_cert_path(domain=self.provider))
for filepath in keyfiles:
self.touch(filepath)
self.chmod600(filepath)
@@ -61,11 +67,27 @@ class EIPConductorTest(BaseLeapTest):
# some methods mocked
self.manager = Mock(name="openvpnmanager_mock")
self.con = MockedEIPConnection()
+ self.con.provider = self.provider
+
+ # XXX watch out. This sometimes is throwing the following error:
+ # NoSuchProcess: process no longer exists (pid=6571)
+ # because of a bad implementation of _check_if_running_instance
+
self.con.run_openvpn_checks()
def tearDown(self):
+ pass
+
+ def doCleanups(self):
+ super(BaseLeapTest, self).doCleanups()
+ self.cleanupSocketDir()
del self.con
+ def cleanupSocketDir(self):
+ ptt = ('/tmp/leap-tmp*')
+ for tmpdir in glob.glob(ptt):
+ shutil.rmtree(tmpdir)
+
#
# tests
#
@@ -76,6 +98,7 @@ class EIPConductorTest(BaseLeapTest):
"""
con = self.con
self.assertEqual(con.autostart, True)
+ # XXX moar!
def test_ovpn_command(self):
"""
@@ -93,6 +116,7 @@ class EIPConductorTest(BaseLeapTest):
# needed to run tests. (roughly 3 secs for this only)
# We should modularize and inject Mocks on more places.
+ oldcon = self.con
del(self.con)
config_checker = Mock()
self.con = MockedEIPConnection(config_checker=config_checker)
@@ -102,6 +126,7 @@ class EIPConductorTest(BaseLeapTest):
skip_download=False)
# XXX test for cert_checker also
+ self.con = oldcon
# connect/disconnect calls
@@ -118,8 +143,14 @@ class EIPConductorTest(BaseLeapTest):
self.con.status.CONNECTED)
# disconnect
+ self.con.terminate_openvpn_connection = Mock()
self.con.disconnect()
- self.con._disconnect.assert_called_once_with()
+ self.con.terminate_openvpn_connection.assert_called_once_with(
+ shutdown=False)
+ self.con.terminate_openvpn_connection = Mock()
+ self.con.disconnect(shutdown=True)
+ self.con.terminate_openvpn_connection.assert_called_once_with(
+ shutdown=True)
# new status should be disconnected
# XXX this should evolve and check no errors
diff --git a/src/leap/eip/tests/test_openvpnconnection.py b/src/leap/eip/tests/test_openvpnconnection.py
index 61769f04..95bfb2f0 100644
--- a/src/leap/eip/tests/test_openvpnconnection.py
+++ b/src/leap/eip/tests/test_openvpnconnection.py
@@ -58,16 +58,27 @@ class OpenVPNConnectionTest(BaseLeapTest):
def setUp(self):
# XXX this will have to change for win, host=localhost
host = eipconfig.get_socket_path()
+ self.host = host
self.manager = MockedOpenVPNConnection(host=host)
def tearDown(self):
+ pass
+
+ def doCleanups(self):
+ super(BaseLeapTest, self).doCleanups()
+ self.cleanupSocketDir()
+
+ def cleanupSocketDir(self):
# remove the socket folder.
# XXX only if posix. in win, host is localhost, so nothing
# has to be done.
- if self.manager.host:
- folder, fpath = os.path.split(self.manager.host)
- assert folder.startswith('/tmp/leap-tmp') # safety check
- shutil.rmtree(folder)
+ if self.host:
+ folder, fpath = os.path.split(self.host)
+ try:
+ assert folder.startswith('/tmp/leap-tmp') # safety check
+ shutil.rmtree(folder)
+ except:
+ self.fail("could not remove temp file")
del self.manager
@@ -76,13 +87,18 @@ class OpenVPNConnectionTest(BaseLeapTest):
#
def test_detect_vpn(self):
+ # XXX review, not sure if captured all the logic
+ # while fixing. kali.
openvpn_connection = openvpnconnection.OpenVPNConnection()
- with patch.object(psutil, "get_process_list") as mocked_psutil:
+
+ with patch.object(psutil, "process_iter") as mocked_psutil:
+ mocked_process = Mock()
+ mocked_process.name = "openvpn"
+ mocked_process.cmdline = ["openvpn", "-foo", "-bar", "-gaaz"]
+ mocked_psutil.return_value = [mocked_process]
with self.assertRaises(eipexceptions.OpenVPNAlreadyRunning):
- mocked_process = Mock()
- mocked_process.name = "openvpn"
- mocked_psutil.return_value = [mocked_process]
openvpn_connection._check_if_running_instance()
+
openvpn_connection._check_if_running_instance()
@unittest.skipIf(_system == "Windows", "lin/mac only")
@@ -104,12 +120,14 @@ class OpenVPNConnectionTest(BaseLeapTest):
self.assertEqual(self.manager.port, 7777)
def test_port_types_init(self):
+ oldmanager = self.manager
self.manager = MockedOpenVPNConnection(port="42")
self.assertEqual(self.manager.port, 42)
self.manager = MockedOpenVPNConnection()
self.assertEqual(self.manager.port, "unix")
self.manager = MockedOpenVPNConnection(port="bad")
self.assertEqual(self.manager.port, None)
+ self.manager = oldmanager
def test_uds_telnet_called_on_connect(self):
self.manager.connect_to_management()
diff --git a/src/leap/gui/__init__.py b/src/leap/gui/__init__.py
index e69de29b..804bfbc1 100644
--- a/src/leap/gui/__init__.py
+++ b/src/leap/gui/__init__.py
@@ -0,0 +1,11 @@
+try:
+ import sip
+ sip.setapi('QString', 2)
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
+
+import firstrun
+import firstrun.wizard
+
+__all__ = ['firstrun', 'firstrun.wizard']
diff --git a/src/leap/gui/constants.py b/src/leap/gui/constants.py
new file mode 100644
index 00000000..277f3540
--- /dev/null
+++ b/src/leap/gui/constants.py
@@ -0,0 +1,13 @@
+import time
+
+APP_LOGO = ':/images/leap-color-small.png'
+
+# bare is the username portion of a JID
+# full includes the "at" and some extra chars
+# that can be allowed for fqdn
+
+BARE_USERNAME_REGEX = r"^[A-Za-z\d_]+$"
+FULL_USERNAME_REGEX = r"^[A-Za-z\d_@.-]+$"
+
+GUI_PAUSE_FOR_USER_SECONDS = 1
+pause_for_user = lambda: time.sleep(GUI_PAUSE_FOR_USER_SECONDS)
diff --git a/src/leap/gui/firstrun/__init__.py b/src/leap/gui/firstrun/__init__.py
new file mode 100644
index 00000000..2a523d6a
--- /dev/null
+++ b/src/leap/gui/firstrun/__init__.py
@@ -0,0 +1,28 @@
+try:
+ import sip
+ sip.setapi('QString', 2)
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
+
+import intro
+import connect
+import last
+import login
+import mixins
+import providerinfo
+import providerselect
+import providersetup
+import register
+
+__all__ = [
+ 'intro',
+ 'connect',
+ 'last',
+ 'login',
+ 'mixins',
+ 'providerinfo',
+ 'providerselect',
+ 'providersetup',
+ 'register',
+ ] # ,'wizard']
diff --git a/src/leap/gui/firstrun/connect.py b/src/leap/gui/firstrun/connect.py
new file mode 100644
index 00000000..ad7bb13a
--- /dev/null
+++ b/src/leap/gui/firstrun/connect.py
@@ -0,0 +1,214 @@
+"""
+Provider Setup Validation Page,
+used in First Run Wizard
+"""
+import logging
+
+from PyQt4 import QtGui
+
+#import requests
+
+from leap.gui.progress import ValidationPage
+from leap.util.web import get_https_domain_and_port
+
+from leap.base import auth
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class ConnectionPage(ValidationPage):
+
+ def __init__(self, parent=None):
+ super(ConnectionPage, self).__init__(parent)
+ self.current_page = "connect"
+
+ title = self.tr("Connecting...")
+ subtitle = self.tr("Setting up a encrypted "
+ "connection with the provider")
+
+ self.setTitle(title)
+ self.setSubTitle(subtitle)
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ def _do_checks(self, update_signal=None):
+ """
+ executes actual checks in a separate thread
+
+ we initialize the srp protocol register
+ and try to register user.
+ """
+ wizard = self.wizard()
+ full_domain = self.field('provider_domain')
+ domain, port = get_https_domain_and_port(full_domain)
+
+ pconfig = wizard.eipconfigchecker(domain=domain)
+ # this should be persisted...
+ pconfig.defaultprovider.load()
+ pconfig.set_api_domain()
+
+ pCertChecker = wizard.providercertchecker(
+ domain=domain)
+ pCertChecker.set_api_domain(pconfig.apidomain)
+
+ ###########################################
+ # Set Credentials.
+ # username and password are in different fields
+ # if they were stored in log_in or sign_up pages.
+ from_login = wizard.from_login
+
+ unamek_base = 'userName'
+ passwk_base = 'userPassword'
+ unamek = 'login_%s' % unamek_base if from_login else unamek_base
+ passwk = 'login_%s' % passwk_base if from_login else passwk_base
+
+ username = self.field(unamek)
+ password = self.field(passwk)
+ credentials = username, password
+
+ yield(("head_sentinel", 0), lambda: None)
+
+ ##################################################
+ # 1) fetching eip service config
+ ##################################################
+ def fetcheipconf():
+ try:
+ pconfig.fetch_eip_service_config()
+
+ # XXX get specific exception
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ yield((self.tr("Getting EIP configuration files"), 40),
+ fetcheipconf)
+
+ ##################################################
+ # 2) getting client certificate
+ ##################################################
+
+ def fetcheipcert():
+ try:
+ downloaded = pCertChecker.download_new_client_cert(
+ credentials=credentials)
+ if not downloaded:
+ logger.error('Could not download client cert')
+ return False
+
+ except auth.SRPAuthenticationError as exc:
+ return self.fail(self.tr(
+ "Authentication error: %s" % exc.message))
+
+ except Exception as exc:
+ return self.fail(exc.message)
+ else:
+ return True
+
+ yield((self.tr("Getting EIP certificate"), 80),
+ fetcheipcert)
+
+ ################
+ # end !
+ ################
+ self.set_done()
+ yield(("end_sentinel", 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ called after _do_checks has finished
+ (connected to checker thread finished signal)
+ """
+ # here we go! :)
+ if self.is_done():
+ nextbutton = self.wizard().button(QtGui.QWizard.NextButton)
+ nextbutton.setFocus()
+
+ full_domain = self.field('provider_domain')
+ domain, port = get_https_domain_and_port(full_domain)
+ _domain = u"%s:%s" % (
+ domain, port) if port != 443 else unicode(domain)
+ self.run_eip_checks_for_provider_and_connect(_domain)
+
+ def run_eip_checks_for_provider_and_connect(self, domain):
+ wizard = self.wizard()
+ conductor = wizard.conductor
+ start_eip_signal = getattr(
+ wizard,
+ 'start_eipconnection_signal', None)
+
+ if conductor:
+ conductor.set_provider_domain(domain)
+ # we could run some of the checks to be
+ # sure everything is in order, but
+ # I see no point in doing it, we assume
+ # we've gone thru all checks during the wizard.
+ #conductor.run_checks()
+ #self.conductor = conductor
+ #errors = self.eip_error_check()
+ #if not errors and start_eip_signal:
+ if start_eip_signal:
+ start_eip_signal.emit()
+
+ else:
+ logger.warning(
+ "No conductor found. This means that "
+ "probably the wizard has been launched "
+ "in an stand-alone way.")
+
+ self.set_done()
+
+ #def eip_error_check(self):
+ #"""
+ #a version of the main app error checker,
+ #but integrated within the connecting page of the wizard.
+ #consumes the conductor error queue.
+ #pops errors, and add those to the wizard page
+ #"""
+ # TODO handle errors.
+ # We should redirect them to the log viewer
+ # with a brief message.
+ # XXX move to LAST PAGE instead.
+ #logger.debug('eip error check from connecting page')
+ #errq = self.conductor.error_queue
+
+ #def _do_validation(self):
+ #"""
+ #called after _do_checks has finished
+ #(connected to checker thread finished signal)
+ #"""
+ #from_login = self.wizard().from_login
+ #prevpage = "login" if from_login else "signup"
+
+ #wizard = self.wizard()
+ #if self.errors:
+ #logger.debug('going back with errors')
+ #logger.error(self.errors)
+ #name, first_error = self.pop_first_error()
+ #wizard.set_validation_error(
+ #prevpage,
+ #first_error)
+ #self.go_back()
+
+ def nextId(self):
+ wizard = self.wizard()
+ return wizard.get_page_index('lastpage')
+
+ def initializePage(self):
+ super(ConnectionPage, self).initializePage()
+ self.set_undone()
+ cancelbutton = self.wizard().button(QtGui.QWizard.CancelButton)
+ cancelbutton.hide()
+ self.completeChanged.emit()
+
+ wizard = self.wizard()
+ eip_statuschange_signal = wizard.eip_statuschange_signal
+ if eip_statuschange_signal:
+ eip_statuschange_signal.connect(
+ lambda status: self.send_status(
+ status))
+
+ def send_status(self, status):
+ wizard = self.wizard()
+ wizard.openvpn_status.append(status)
diff --git a/src/leap/gui/firstrun/constants.py b/src/leap/gui/firstrun/constants.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/gui/firstrun/constants.py
diff --git a/src/leap/gui/firstrun/intro.py b/src/leap/gui/firstrun/intro.py
new file mode 100644
index 00000000..b519362f
--- /dev/null
+++ b/src/leap/gui/firstrun/intro.py
@@ -0,0 +1,68 @@
+"""
+Intro page used in first run wizard
+"""
+
+from PyQt4 import QtGui
+
+from leap.gui.constants import APP_LOGO
+
+
+class IntroPage(QtGui.QWizardPage):
+ def __init__(self, parent=None):
+ super(IntroPage, self).__init__(parent)
+
+ self.setTitle(self.tr("First run wizard"))
+
+ #self.setPixmap(
+ #QtGui.QWizard.WatermarkPixmap,
+ #QtGui.QPixmap(':/images/watermark1.png'))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ label = QtGui.QLabel(self.tr(
+ "Now we will guide you through "
+ "some configuration that is needed before you "
+ "can connect for the first time.<br><br>"
+ "If you ever need to modify these options again, "
+ "you can find the wizard in the '<i>Settings</i>' menu from the "
+ "main window.<br><br>"
+ "Do you want to <b>sign up</b> for a new account, or <b>log "
+ "in</b> with an already existing username?<br>"))
+ label.setWordWrap(True)
+
+ radiobuttonGroup = QtGui.QGroupBox()
+
+ self.sign_up = QtGui.QRadioButton(
+ self.tr("Sign up for a new account"))
+ self.sign_up.setChecked(True)
+ self.log_in = QtGui.QRadioButton(
+ self.tr("Log In with my credentials"))
+
+ radiobLayout = QtGui.QVBoxLayout()
+ radiobLayout.addWidget(self.sign_up)
+ radiobLayout.addWidget(self.log_in)
+ radiobuttonGroup.setLayout(radiobLayout)
+
+ layout = QtGui.QVBoxLayout()
+ layout.addWidget(label)
+ layout.addWidget(radiobuttonGroup)
+ self.setLayout(layout)
+
+ #self.registerField('is_signup', self.sign_up)
+
+ def validatePage(self):
+ return True
+
+ def nextId(self):
+ """
+ returns next id
+ in a non-linear wizard
+ """
+ if self.sign_up.isChecked():
+ next_ = 'providerselection'
+ if self.log_in.isChecked():
+ next_ = 'login'
+ wizard = self.wizard()
+ return wizard.get_page_index(next_)
diff --git a/src/leap/gui/firstrun/last.py b/src/leap/gui/firstrun/last.py
new file mode 100644
index 00000000..32d98acc
--- /dev/null
+++ b/src/leap/gui/firstrun/last.py
@@ -0,0 +1,109 @@
+"""
+Last Page, used in First Run Wizard
+"""
+import logging
+
+from PyQt4 import QtGui
+
+from leap.util.coroutines import coroutine
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class LastPage(QtGui.QWizardPage):
+ def __init__(self, parent=None):
+ super(LastPage, self).__init__(parent)
+
+ self.setTitle(self.tr(
+ "Connecting to Encrypted Internet Proxy service..."))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ #self.setPixmap(
+ #QtGui.QWizard.WatermarkPixmap,
+ #QtGui.QPixmap(':/images/watermark2.png'))
+
+ self.label = QtGui.QLabel()
+ self.label.setWordWrap(True)
+
+ # XXX REFACTOR to a Validating Page...
+ self.status_line_1 = QtGui.QLabel()
+ self.status_line_2 = QtGui.QLabel()
+ self.status_line_3 = QtGui.QLabel()
+ self.status_line_4 = QtGui.QLabel()
+ self.status_line_5 = QtGui.QLabel()
+
+ layout = QtGui.QVBoxLayout()
+ layout.addWidget(self.label)
+
+ # make loop
+ layout.addWidget(self.status_line_1)
+ layout.addWidget(self.status_line_2)
+ layout.addWidget(self.status_line_3)
+ layout.addWidget(self.status_line_4)
+ layout.addWidget(self.status_line_5)
+
+ self.setLayout(layout)
+
+ def set_status_line(self, line, status):
+ statusline = getattr(self, 'status_line_%s' % line)
+ if statusline:
+ statusline.setText(status)
+
+ def set_finished_status(self):
+ self.setTitle(self.tr('You are now using an encrypted connection!'))
+ finishText = self.wizard().buttonText(
+ QtGui.QWizard.FinishButton)
+ finishText = finishText.replace('&', '')
+ self.label.setText(self.tr(
+ "Click '<i>%s</i>' to end the wizard and "
+ "save your settings." % finishText))
+ # XXX init network checker
+ # trigger signal
+
+ @coroutine
+ def eip_status_handler(self):
+ # XXX this can be changed to use
+ # signals. See progress.py
+ logger.debug('logging status in last page')
+ self.validation_done = False
+ status_count = 1
+ try:
+ while True:
+ status = (yield)
+ status_count += 1
+ # XXX add to line...
+ logger.debug('status --> %s', status)
+ self.set_status_line(status_count, status)
+ if status == "connected":
+ self.set_finished_status()
+ break
+ except GeneratorExit:
+ pass
+ except StopIteration:
+ pass
+
+ def initializePage(self):
+ super(LastPage, self).initializePage()
+ wizard = self.wizard()
+ handler = self.eip_status_handler()
+
+ # get statuses done in prev page
+ for st in wizard.openvpn_status:
+ self.send_status(handler.send, st)
+
+ # bind signal for events yet to come
+ eip_statuschange_signal = wizard.eip_statuschange_signal
+ if eip_statuschange_signal:
+ eip_statuschange_signal.connect(
+ lambda status: self.send_status(
+ handler.send, status))
+
+ def send_status(self, cb, status):
+ try:
+ cb(status)
+ except StopIteration:
+ pass
diff --git a/src/leap/gui/firstrun/login.py b/src/leap/gui/firstrun/login.py
new file mode 100644
index 00000000..3707d3ff
--- /dev/null
+++ b/src/leap/gui/firstrun/login.py
@@ -0,0 +1,332 @@
+"""
+LogIn Page, used inf First Run Wizard
+"""
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+import requests
+
+from leap.base import auth
+from leap.gui.firstrun.mixins import UserFormMixIn
+from leap.gui.progress import InlineValidationPage
+from leap.gui import styles
+
+from leap.gui.constants import APP_LOGO, FULL_USERNAME_REGEX
+
+
+class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage
+
+ def __init__(self, parent=None):
+
+ super(LogInPage, self).__init__(parent)
+ self.current_page = "login"
+
+ self.setTitle(self.tr("Log In"))
+ self.setSubTitle(self.tr("Log in with your credentials"))
+ self.current_page = "login"
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.setupSteps()
+ self.setupUI()
+
+ self.do_confirm_next = False
+
+ def setupUI(self):
+ userNameLabel = QtGui.QLabel(self.tr("User &name:"))
+ userNameLineEdit = QtGui.QLineEdit()
+ userNameLineEdit.cursorPositionChanged.connect(
+ self.reset_validation_status)
+ userNameLabel.setBuddy(userNameLineEdit)
+
+ # let's add regex validator
+ usernameRe = QtCore.QRegExp(FULL_USERNAME_REGEX)
+ userNameLineEdit.setValidator(
+ QtGui.QRegExpValidator(usernameRe, self))
+
+ #userNameLineEdit.setPlaceholderText(
+ #'username@provider.example.org')
+ self.userNameLineEdit = userNameLineEdit
+
+ userPasswordLabel = QtGui.QLabel(self.tr("&Password:"))
+ self.userPasswordLineEdit = QtGui.QLineEdit()
+ self.userPasswordLineEdit.setEchoMode(
+ QtGui.QLineEdit.Password)
+ userPasswordLabel.setBuddy(self.userPasswordLineEdit)
+
+ self.registerField('login_userName*', self.userNameLineEdit)
+ self.registerField('login_userPassword*', self.userPasswordLineEdit)
+
+ layout = QtGui.QGridLayout()
+ layout.setColumnMinimumWidth(0, 20)
+
+ validationMsg = QtGui.QLabel("")
+ validationMsg.setStyleSheet(styles.ErrorLabelStyleSheet)
+ self.validationMsg = validationMsg
+
+ layout.addWidget(validationMsg, 0, 3)
+ layout.addWidget(userNameLabel, 1, 0)
+ layout.addWidget(self.userNameLineEdit, 1, 3)
+ layout.addWidget(userPasswordLabel, 2, 0)
+ layout.addWidget(self.userPasswordLineEdit, 2, 3)
+
+ # add validation frame
+ self.setupValidationFrame()
+ layout.addWidget(self.valFrame, 4, 2, 4, 2)
+ self.valFrame.hide()
+
+ self.nextText(self.tr("Log in"))
+ self.setLayout(layout)
+
+ #self.registerField('is_login_wizard')
+
+ # actual checks
+
+ def _do_checks(self):
+
+ full_username = self.userNameLineEdit.text()
+ ###########################
+ # 0) check user@domain form
+ ###########################
+
+ def checkusername():
+ if full_username.count('@') != 1:
+ return self.fail(
+ self.tr(
+ "Username must be in the username@provider form."))
+ else:
+ return True
+
+ yield(("head_sentinel", 0), checkusername)
+
+ username, domain = full_username.split('@')
+ password = self.userPasswordLineEdit.text()
+
+ # We try a call to an authenticated
+ # page here as a mean to catch
+ # srp authentication errors while
+ wizard = self.wizard()
+ eipconfigchecker = wizard.eipconfigchecker(domain=domain)
+
+ ########################
+ # 1) try name resolution
+ ########################
+ # show the frame before going on...
+ QtCore.QMetaObject.invokeMethod(
+ self, "showStepsFrame")
+
+ # Able to contact domain?
+ # can get definition?
+ # two-by-one
+ def resolvedomain():
+ try:
+ eipconfigchecker.fetch_definition(domain=domain)
+
+ # we're using requests here for all
+ # the possible error cases that it catches.
+ except requests.exceptions.ConnectionError as exc:
+ return self.fail(exc.message[1])
+ except requests.exceptions.HTTPError as exc:
+ return self.fail(exc.message)
+ except Exception as exc:
+ # XXX get catchall error msg
+ return self.fail(
+ exc.message)
+ else:
+ return True
+
+ yield((self.tr("Resolving domain name"), 20), resolvedomain)
+
+ wizard.set_providerconfig(
+ eipconfigchecker.defaultprovider.config)
+
+ ########################
+ # 2) do authentication
+ ########################
+ credentials = username, password
+ pCertChecker = wizard.providercertchecker(
+ domain=domain)
+
+ def validate_credentials():
+ #################
+ # FIXME #BUG #638
+ verify = False
+
+ try:
+ pCertChecker.download_new_client_cert(
+ credentials=credentials,
+ verify=verify)
+
+ except auth.SRPAuthenticationError as exc:
+ return self.fail(
+ self.tr("Authentication error: %s" % exc.message))
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ yield(('Validating credentials', 60), validate_credentials)
+
+ self.set_done()
+ yield(("end_sentinel", 100), lambda: None)
+
+ def green_validation_status(self):
+ val = self.validationMsg
+ val.setText(self.tr('Credentials validated.'))
+ val.setStyleSheet(styles.GreenLineEdit)
+
+ def on_checks_validation_ready(self):
+ """
+ after checks
+ """
+ if self.is_done():
+ self.disableFields()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+ # make the user confirm the transition
+ # to next page.
+ self.nextText('&Next')
+ self.nextFocus()
+ self.green_validation_status()
+ self.do_confirm_next = True
+
+ # ui update
+
+ def nextText(self, text):
+ self.setButtonText(
+ QtGui.QWizard.NextButton, text)
+
+ def nextFocus(self):
+ self.wizard().button(
+ QtGui.QWizard.NextButton).setFocus()
+
+ def disableNextButton(self):
+ self.wizard().button(
+ QtGui.QWizard.NextButton).setDisabled(True)
+
+ def onUserNamePositionChanged(self, *args):
+ if self.initial_username_sample:
+ self.userNameLineEdit.setText('')
+ # XXX set regular color
+ self.initial_username_sample = None
+
+ def onUserNameTextChanged(self, *args):
+ if self.initial_username_sample:
+ k = args[0][-1]
+ self.initial_username_sample = None
+ self.userNameLineEdit.setText(k)
+
+ def disableFields(self):
+ for field in (self.userNameLineEdit,
+ self.userPasswordLineEdit):
+ field.setDisabled(True)
+
+ def populateErrors(self):
+ # XXX could move this to ValidationMixin
+ # used in providerselect and register too
+
+ errors = self.wizard().get_validation_error(
+ self.current_page)
+ showerr = self.validationMsg.setText
+
+ if errors:
+ bad_str = getattr(self, 'bad_string', None)
+ cur_str = self.userNameLineEdit.text()
+
+ if bad_str is None:
+ # first time we fall here.
+ # save the current bad_string value
+ self.bad_string = cur_str
+ showerr(errors)
+ else:
+ # not the first time
+ if cur_str == bad_str:
+ showerr(errors)
+ else:
+ self.focused_field = False
+ showerr('')
+
+ def cleanup_errormsg(self):
+ """
+ we reset bad_string to None
+ should be called before leaving the page
+ """
+ self.bad_string = None
+
+ def paintEvent(self, event):
+ """
+ we hook our populate errors
+ on paintEvent because we need it to catch
+ when user enters the page coming from next,
+ and initializePage does not cover that case.
+ Maybe there's a better event to hook upon.
+ """
+ super(LogInPage, self).paintEvent(event)
+ self.populateErrors()
+
+ def set_prevalidation_error(self, error):
+ self.prevalidation_error = error
+
+ # pagewizard methods
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ if wizard.is_provider_setup is False:
+ next_ = 'providersetupvalidation'
+ if wizard.is_provider_setup is True:
+ # XXX bad name, ok, gonna change that
+ next_ = 'signupvalidation'
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ super(LogInPage, self).initializePage()
+ username = self.userNameLineEdit
+ username.setText('username@provider.example.org')
+ username.cursorPositionChanged.connect(
+ self.onUserNamePositionChanged)
+ username.textChanged.connect(
+ self.onUserNameTextChanged)
+ self.initial_username_sample = True
+ self.validationMsg.setText('')
+ self.valFrame.hide()
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ and clean the inline validation widget.
+ """
+ self.validationMsg.setText('')
+ self.steps.removeAllSteps()
+ self.clearTable()
+
+ def validatePage(self):
+ """
+ if not register done, do checks.
+ if done, wait for click.
+ """
+ self.disableNextButton()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+
+ if self.do_confirm_next:
+ full_username = self.userNameLineEdit.text()
+ password = self.userPasswordLineEdit.text()
+ username, domain = full_username.split('@')
+ self.setField('provider_domain', domain)
+ self.setField('login_userName', username)
+ self.setField('login_userPassword', password)
+ self.wizard().from_login = True
+
+ return True
+
+ if not self.is_done():
+ self.reset_validation_status()
+ self.do_checks()
+
+ return self.is_done()
diff --git a/src/leap/gui/firstrun/mixins.py b/src/leap/gui/firstrun/mixins.py
new file mode 100644
index 00000000..c4731893
--- /dev/null
+++ b/src/leap/gui/firstrun/mixins.py
@@ -0,0 +1,18 @@
+"""
+mixins used in First Run Wizard
+"""
+
+
+class UserFormMixIn(object):
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ """
+ self.validationMsg.setText('')
+
+ def set_validation_status(self, msg):
+ """
+ set generic validation status
+ """
+ self.validationMsg.setText(msg)
diff --git a/src/leap/gui/firstrun/providerinfo.py b/src/leap/gui/firstrun/providerinfo.py
new file mode 100644
index 00000000..cff4caca
--- /dev/null
+++ b/src/leap/gui/firstrun/providerinfo.py
@@ -0,0 +1,106 @@
+"""
+Provider Info Page, used in First run Wizard
+"""
+import logging
+
+from PyQt4 import QtGui
+
+from leap.gui.constants import APP_LOGO
+from leap.util.translations import translate
+
+logger = logging.getLogger(__name__)
+
+
+class ProviderInfoPage(QtGui.QWizardPage):
+
+ def __init__(self, parent=None):
+ super(ProviderInfoPage, self).__init__(parent)
+
+ self.setTitle(self.tr("Provider Information"))
+ self.setSubTitle(self.tr(
+ "Services offered by this provider"))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.create_info_panel()
+
+ def create_info_panel(self):
+ # Use stacked widget instead
+ # of reparenting the layout.
+
+ infoWidget = QtGui.QStackedWidget()
+
+ info = QtGui.QWidget()
+ layout = QtGui.QVBoxLayout()
+
+ displayName = QtGui.QLabel("")
+ description = QtGui.QLabel("")
+ enrollment_policy = QtGui.QLabel("")
+
+ # XXX set stylesheet...
+ # prettify a little bit.
+ # bigger fonts and so on...
+
+ # We could use a QFrame here
+
+ layout.addWidget(displayName)
+ layout.addWidget(description)
+ layout.addWidget(enrollment_policy)
+ layout.addStretch(1)
+
+ info.setLayout(layout)
+ infoWidget.addWidget(info)
+
+ pageLayout = QtGui.QVBoxLayout()
+ pageLayout.addWidget(infoWidget)
+ self.setLayout(pageLayout)
+
+ # add refs to self to allow for
+ # updates.
+ # Watch out! Have to get rid of these references!
+ # this should be better handled with signals !!
+ self.displayName = displayName
+ self.description = description
+ self.description.setWordWrap(True)
+ self.enrollment_policy = enrollment_policy
+
+ def show_provider_info(self):
+
+ # XXX get multilingual objects
+ # directly from the config object
+
+ lang = "en"
+ pconfig = self.wizard().providerconfig
+
+ dn = pconfig.get('name')
+ display_name = dn[lang] if dn else ''
+ domain_name = self.field('provider_domain')
+
+ self.displayName.setText(
+ "<b>%s</b> https://%s" % (display_name, domain_name))
+
+ desc = pconfig.get('description')
+
+ #description_text = desc[lang] if desc else ''
+ description_text = translate(desc) if desc else ''
+
+ self.description.setText(
+ "<i>%s</i>" % description_text)
+
+ # XXX should translate this...
+ enroll = pconfig.get('enrollment_policy')
+ if enroll:
+ self.enrollment_policy.setText(
+ '<b>%s</b>: <em>%s</em>' % (
+ self.tr('enrollment policy'),
+ enroll))
+
+ def nextId(self):
+ wizard = self.wizard()
+ next_ = "providersetupvalidation"
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ self.show_provider_info()
diff --git a/src/leap/gui/firstrun/providerselect.py b/src/leap/gui/firstrun/providerselect.py
new file mode 100644
index 00000000..917b16fd
--- /dev/null
+++ b/src/leap/gui/firstrun/providerselect.py
@@ -0,0 +1,471 @@
+"""
+Select Provider Page, used in First Run Wizard
+"""
+import logging
+
+import requests
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.base import exceptions as baseexceptions
+#from leap.crypto import certs
+from leap.eip import exceptions as eipexceptions
+from leap.gui.progress import InlineValidationPage
+from leap.gui import styles
+from leap.gui.utils import delay
+from leap.util.web import get_https_domain_and_port
+
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class SelectProviderPage(InlineValidationPage):
+
+ launchChecks = QtCore.pyqtSignal()
+
+ def __init__(self, parent=None, providers=None):
+ super(SelectProviderPage, self).__init__(parent)
+ self.current_page = 'providerselection'
+
+ self.setTitle(self.tr("Enter Provider"))
+ self.setSubTitle(self.tr(
+ "Please enter the domain of the provider you want "
+ "to use for your connection")
+ )
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ self.did_cert_check = False
+
+ self.done = False
+
+ self.setupSteps()
+ self.setupUI()
+
+ self.launchChecks.connect(
+ self.launch_checks)
+
+ self.providerNameEdit.editingFinished.connect(
+ lambda: self.providerCheckButton.setFocus(True))
+
+ def setupUI(self):
+ """
+ initializes the UI
+ """
+ providerNameLabel = QtGui.QLabel("h&ttps://")
+ # note that we expect the bare domain name
+ # we will add the scheme later
+ providerNameEdit = QtGui.QLineEdit()
+ providerNameEdit.cursorPositionChanged.connect(
+ self.reset_validation_status)
+ providerNameLabel.setBuddy(providerNameEdit)
+
+ # add regex validator
+ providerDomainRe = QtCore.QRegExp(r"^[a-z1-9_\-\.]+$")
+ providerNameEdit.setValidator(
+ QtGui.QRegExpValidator(providerDomainRe, self))
+ self.providerNameEdit = providerNameEdit
+
+ # Eventually we will seed a list of
+ # well known providers here.
+
+ #providercombo = QtGui.QComboBox()
+ #if providers:
+ #for provider in providers:
+ #providercombo.addItem(provider)
+ #providerNameSelect = providercombo
+
+ self.registerField("provider_domain*", self.providerNameEdit)
+ #self.registerField('provider_name_index', providerNameSelect)
+
+ validationMsg = QtGui.QLabel("")
+ validationMsg.setStyleSheet(styles.ErrorLabelStyleSheet)
+ self.validationMsg = validationMsg
+ providerCheckButton = QtGui.QPushButton(self.tr("chec&k!"))
+ self.providerCheckButton = providerCheckButton
+
+ # cert info
+
+ # this is used in the callback
+ # for the checkbox changes.
+ # tricky, since the first time came
+ # from the exception message.
+ # should get string from exception too!
+ self.bad_cert_status = self.tr(
+ "Server certificate could not be verified.")
+
+ self.certInfo = QtGui.QLabel("")
+ self.certInfo.setWordWrap(True)
+ self.certWarning = QtGui.QLabel("")
+ self.trustProviderCertCheckBox = QtGui.QCheckBox(
+ self.tr("&Trust this provider certificate."))
+
+ self.trustProviderCertCheckBox.stateChanged.connect(
+ self.onTrustCheckChanged)
+ self.providerNameEdit.textChanged.connect(
+ self.onProviderChanged)
+ self.providerCheckButton.clicked.connect(
+ self.onCheckButtonClicked)
+
+ layout = QtGui.QGridLayout()
+ layout.addWidget(validationMsg, 0, 2)
+ layout.addWidget(providerNameLabel, 1, 1)
+ layout.addWidget(providerNameEdit, 1, 2)
+ layout.addWidget(providerCheckButton, 1, 3)
+
+ # add certinfo group
+ # XXX not shown now. should move to validation box.
+ #layout.addWidget(certinfoGroup, 4, 1, 4, 2)
+ #self.certinfoGroup = certinfoGroup
+ #self.certinfoGroup.hide()
+
+ # add validation frame
+ self.setupValidationFrame()
+ layout.addWidget(self.valFrame, 4, 2, 4, 2)
+ self.valFrame.hide()
+
+ self.setLayout(layout)
+
+ # certinfo
+
+ def setupCertInfoGroup(self): # pragma: no cover
+ # XXX not used now.
+ certinfoGroup = QtGui.QGroupBox(
+ self.tr("Certificate validation"))
+ certinfoLayout = QtGui.QVBoxLayout()
+ certinfoLayout.addWidget(self.certInfo)
+ certinfoLayout.addWidget(self.certWarning)
+ certinfoLayout.addWidget(self.trustProviderCertCheckBox)
+ certinfoGroup.setLayout(certinfoLayout)
+ self.certinfoGroup = self.certinfoGroup
+
+ # progress frame
+
+ def setupValidationFrame(self):
+ qframe = QtGui.QFrame
+ valFrame = qframe()
+ valFrame.setFrameStyle(qframe.NoFrame)
+ valframeLayout = QtGui.QVBoxLayout()
+ zeros = (0, 0, 0, 0)
+ valframeLayout.setContentsMargins(*zeros)
+
+ valframeLayout.addWidget(self.stepsTableWidget)
+ valFrame.setLayout(valframeLayout)
+ self.valFrame = valFrame
+
+ @QtCore.pyqtSlot()
+ def onDisableCheckButton(self):
+ #print 'CHECK BUTTON DISABLED!!!'
+ self.providerCheckButton.setDisabled(True)
+
+ @QtCore.pyqtSlot()
+ def launch_checks(self):
+ self.do_checks()
+
+ def onCheckButtonClicked(self):
+ QtCore.QMetaObject.invokeMethod(
+ self, "onDisableCheckButton")
+
+ QtCore.QMetaObject.invokeMethod(
+ self, "showStepsFrame")
+
+ delay(self, "launch_checks")
+
+ def _do_checks(self):
+ """
+ generator that yields actual checks
+ that are executed in a separate thread
+ """
+
+ wizard = self.wizard()
+ full_domain = self.providerNameEdit.text()
+
+ # we check if we have a port in the domain string.
+ domain, port = get_https_domain_and_port(full_domain)
+ _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain)
+
+ netchecker = wizard.netchecker()
+ providercertchecker = wizard.providercertchecker()
+ eipconfigchecker = wizard.eipconfigchecker(domain=_domain)
+
+ yield(("head_sentinel", 0), lambda: None)
+
+ ########################
+ # 1) try name resolution
+ ########################
+
+ def namecheck():
+ """
+ in which we check if
+ we are able to name resolve
+ this domain
+ """
+ try:
+ #import ipdb;ipdb.set_trace()
+ netchecker.check_name_resolution(
+ domain)
+
+ except baseexceptions.LeapException as exc:
+ logger.error(exc.message)
+ return self.fail(exc.usermessage)
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ logger.debug('checking name resolution')
+ yield((self.tr("Checking if it is a valid provider"), 20), namecheck)
+
+ #########################
+ # 2) try https connection
+ #########################
+
+ def httpscheck():
+ """
+ in which we check
+ if the provider
+ is offering service over
+ https
+ """
+ try:
+ providercertchecker.is_https_working(
+ "https://%s" % _domain,
+ verify=True)
+
+ except eipexceptions.HttpsBadCertError as exc:
+ logger.debug('exception')
+ return self.fail(exc.usermessage)
+ # XXX skipping for now...
+ ##############################################
+ # We had this validation logic
+ # in the provider selection page before
+ ##############################################
+ #if self.trustProviderCertCheckBox.isChecked():
+ #pass
+ #else:
+ #fingerprint = certs.get_cert_fingerprint(
+ #domain=domain, sep=" ")
+
+ # it's ok if we've trusted this fgprt before
+ #trustedcrts = wizard.trusted_certs
+ #if trustedcrts and \
+ # fingerprint.replace(' ', '') in trustedcrts:
+ #pass
+ #else:
+ # let your user face panick :P
+ #self.add_cert_info(fingerprint)
+ #self.did_cert_check = True
+ #self.completeChanged.emit()
+ #return False
+
+ except baseexceptions.LeapException as exc:
+ return self.fail(exc.usermessage)
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ logger.debug('checking https connection')
+ yield((self.tr("Checking for a secure connection"), 40), httpscheck)
+
+ ##################################
+ # 3) try download provider info...
+ ##################################
+
+ def fetchinfo():
+ try:
+ # XXX we already set _domain in the initialization
+ # so it should not be needed here.
+ eipconfigchecker.fetch_definition(domain=_domain)
+ wizard.set_providerconfig(
+ eipconfigchecker.defaultprovider.config)
+ except requests.exceptions.SSLError:
+ return self.fail(self.tr(
+ "Could not get info from provider."))
+ except requests.exceptions.ConnectionError:
+ return self.fail(self.tr(
+ "Could not download provider info "
+ "(refused conn.)."))
+
+ except Exception as exc:
+ return self.fail(
+ self.tr(exc.message))
+ else:
+ return True
+
+ yield((self.tr("Getting info from the provider"), 80), fetchinfo)
+
+ # done!
+
+ self.done = True
+ yield(("end_sentinel", 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ called after _do_checks has finished.
+ """
+ self.domain_checked = True
+ self.completeChanged.emit()
+ # let's set focus...
+ if self.is_done():
+ self.wizard().clean_validation_error(self.current_page)
+ nextbutton = self.wizard().button(QtGui.QWizard.NextButton)
+ nextbutton.setFocus()
+ else:
+ self.providerNameEdit.setFocus()
+
+ # cert trust verification
+ # (disabled for now)
+
+ def is_insecure_cert_trusted(self):
+ return self.trustProviderCertCheckBox.isChecked()
+
+ def onTrustCheckChanged(self, state): # pragma: no cover XXX
+ checked = False
+ if state == 2:
+ checked = True
+
+ if checked:
+ self.reset_validation_status()
+ else:
+ self.set_validation_status(self.bad_cert_status)
+
+ # trigger signal to redraw next button
+ self.completeChanged.emit()
+
+ def add_cert_info(self, certinfo): # pragma: no cover XXX
+ self.certWarning.setText(
+ self.tr("Do you want to <b>trust this provider certificate?</b>"))
+ # XXX Check if this needs to abstracted to remove certinfo
+ self.certInfo.setText(
+ self.tr('SHA-256 fingerprint: <i>%s</i><br>' % certinfo))
+ self.certInfo.setWordWrap(True)
+ self.certinfoGroup.show()
+
+ def onProviderChanged(self, text):
+ self.done = False
+ provider = self.providerNameEdit.text()
+ if provider:
+ self.providerCheckButton.setDisabled(False)
+ else:
+ self.providerCheckButton.setDisabled(True)
+ self.completeChanged.emit()
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ and clean the inline validation widget.
+ """
+ self.validationMsg.setText('')
+ self.steps.removeAllSteps()
+ self.clearTable()
+ self.domain_checked = False
+
+ # pagewizard methods
+
+ def isComplete(self):
+ provider = self.providerNameEdit.text()
+
+ if not self.is_done():
+ return False
+
+ if not provider:
+ return False
+ else:
+ if self.is_insecure_cert_trusted():
+ return True
+ if not self.did_cert_check:
+ if self.is_done():
+ # XXX sure?
+ return True
+ return False
+
+ def populateErrors(self):
+ # XXX could move this to ValidationMixin
+ # with some defaults for the validating fields
+ # (now it only allows one field, manually specified)
+
+ #logger.debug('getting errors')
+ errors = self.wizard().get_validation_error(
+ self.current_page)
+ if errors:
+ bad_str = getattr(self, 'bad_string', None)
+ cur_str = self.providerNameEdit.text()
+ showerr = self.validationMsg.setText
+ markred = lambda: self.providerNameEdit.setStyleSheet(
+ styles.ErrorLineEdit)
+ umarkrd = lambda: self.providerNameEdit.setStyleSheet(
+ styles.RegularLineEdit)
+ if bad_str is None:
+ # first time we fall here.
+ # save the current bad_string value
+ self.bad_string = cur_str
+ showerr(errors)
+ markred()
+ else:
+ # not the first time
+ # XXX hey, this is getting convoluted.
+ # roll out this.
+ # but be careful about all the possibilities
+ # with going back and forth once you
+ # enter a domain.
+ if cur_str == bad_str:
+ showerr(errors)
+ markred()
+ else:
+ if not getattr(self, 'domain_checked', None):
+ showerr('')
+ umarkrd()
+ else:
+ self.bad_string = cur_str
+ showerr(errors)
+
+ def cleanup_errormsg(self):
+ """
+ we reset bad_string to None
+ should be called before leaving the page
+ """
+ self.bad_string = None
+ self.domain_checked = False
+
+ def paintEvent(self, event):
+ """
+ we hook our populate errors
+ on paintEvent because we need it to catch
+ when user enters the page coming from next,
+ and initializePage does not cover that case.
+ Maybe there's a better event to hook upon.
+ """
+ super(SelectProviderPage, self).paintEvent(event)
+ self.populateErrors()
+
+ def initializePage(self):
+ self.validationMsg.setText('')
+ if hasattr(self, 'certinfoGroup'):
+ # XXX remove ?
+ self.certinfoGroup.hide()
+ self.done = False
+ self.providerCheckButton.setDisabled(True)
+ self.valFrame.hide()
+ self.steps.removeAllSteps()
+ self.clearTable()
+
+ def validatePage(self):
+ # some cleanup before we leave the page
+ self.cleanup_errormsg()
+
+ # go
+ return True
+
+ def nextId(self):
+ wizard = self.wizard()
+ if not wizard:
+ return
+ return wizard.get_page_index('providerinfo')
diff --git a/src/leap/gui/firstrun/providersetup.py b/src/leap/gui/firstrun/providersetup.py
new file mode 100644
index 00000000..47060f6e
--- /dev/null
+++ b/src/leap/gui/firstrun/providersetup.py
@@ -0,0 +1,157 @@
+"""
+Provider Setup Validation Page,
+used if First Run Wizard
+"""
+import logging
+
+import requests
+
+from PyQt4 import QtGui
+
+from leap.base import exceptions as baseexceptions
+from leap.gui.progress import ValidationPage
+
+from leap.gui.constants import APP_LOGO
+
+logger = logging.getLogger(__name__)
+
+
+class ProviderSetupValidationPage(ValidationPage):
+ def __init__(self, parent=None):
+ super(ProviderSetupValidationPage, self).__init__(parent)
+ self.current_page = "providersetupvalidation"
+
+ # XXX needed anymore?
+ #is_signup = self.field("is_signup")
+ #self.is_signup = is_signup
+
+ self.setTitle(self.tr("Provider setup"))
+ self.setSubTitle(
+ self.tr("Gathering configuration options for this provider"))
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ def _do_checks(self):
+ """
+ generator that yields actual checks
+ that are executed in a separate thread
+ """
+
+ full_domain = self.field('provider_domain')
+ wizard = self.wizard()
+ pconfig = wizard.providerconfig
+
+ #pCertChecker = wizard.providercertchecker
+ #certchecker = pCertChecker(domain=full_domain)
+ pCertChecker = wizard.providercertchecker(
+ domain=full_domain)
+
+ yield(("head_sentinel", 0), lambda: None)
+
+ ########################
+ # 1) fetch ca cert
+ ########################
+
+ def fetchcacert():
+ if pconfig:
+ ca_cert_uri = pconfig.get('ca_cert_uri').geturl()
+ else:
+ ca_cert_uri = None
+
+ # XXX check scheme == "https"
+ # XXX passing verify == False because
+ # we have trusted right before.
+ # We should check it's the same domain!!!
+ # (Check with the trusted fingerprints dict
+ # or something smart)
+ try:
+ pCertChecker.download_ca_cert(
+ uri=ca_cert_uri,
+ verify=False)
+
+ except baseexceptions.LeapException as exc:
+ logger.error(exc.message)
+ # XXX this should be _ method
+ return self.fail(self.tr(exc.usermessage))
+
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ else:
+ return True
+
+ yield((self.tr('Fetching CA certificate'), 30),
+ fetchcacert)
+
+ #########################
+ # 2) check CA fingerprint
+ #########################
+
+ def checkcafingerprint():
+ # XXX get the real thing!!!
+ pass
+ #ca_cert_fingerprint = pconfig.get('ca_cert_fingerprint', None)
+
+ # XXX get fingerprint dict (types)
+ #sha256_fpr = ca_cert_fingerprint.split('=')[1]
+
+ #validate_fpr = pCertChecker.check_ca_cert_fingerprint(
+ #fingerprint=sha256_fpr)
+ #if not validate_fpr:
+ # XXX update validationMsg
+ # should catch exception
+ #return False
+
+ yield((self.tr("Checking CA fingerprint"), 60),
+ checkcafingerprint)
+
+ #########################
+ # 2) check CA fingerprint
+ #########################
+
+ def validatecacert():
+ api_uri = pconfig.get('api_uri', None)
+ try:
+ pCertChecker.verify_api_https(api_uri)
+ except requests.exceptions.SSLError as exc:
+ return self.fail("Validation Error")
+ except Exception as exc:
+ return self.fail(exc.msg)
+ else:
+ return True
+
+ yield((self.tr('Validating api certificate'), 90), validatecacert)
+
+ self.set_done()
+ yield(('end_sentinel', 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ called after _do_checks has finished
+ (connected to checker thread finished signal)
+ """
+ wizard = self.wizard()
+ prevpage = "login" if wizard.from_login else "providerselection"
+
+ if self.errors:
+ logger.debug('going back with errors')
+ name, first_error = self.pop_first_error()
+ wizard.set_validation_error(
+ prevpage,
+ first_error)
+
+ def nextId(self):
+ wizard = self.wizard()
+ from_login = wizard.from_login
+ if from_login:
+ next_ = 'connect'
+ else:
+ next_ = 'signup'
+ return wizard.get_page_index(next_)
+
+ def initializePage(self):
+ super(ProviderSetupValidationPage, self).initializePage()
+ self.set_undone()
+ self.completeChanged.emit()
diff --git a/src/leap/gui/firstrun/register.py b/src/leap/gui/firstrun/register.py
new file mode 100644
index 00000000..15278330
--- /dev/null
+++ b/src/leap/gui/firstrun/register.py
@@ -0,0 +1,387 @@
+"""
+Register User Page, used in First Run Wizard
+"""
+import json
+import logging
+import socket
+
+import requests
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.gui.firstrun.mixins import UserFormMixIn
+
+logger = logging.getLogger(__name__)
+
+from leap.base import auth
+from leap.gui import styles
+from leap.gui.constants import APP_LOGO, BARE_USERNAME_REGEX
+from leap.gui.progress import InlineValidationPage
+from leap.gui.styles import ErrorLabelStyleSheet
+
+
+class RegisterUserPage(InlineValidationPage, UserFormMixIn):
+
+ def __init__(self, parent=None):
+
+ super(RegisterUserPage, self).__init__(parent)
+ self.current_page = "signup"
+
+ self.setTitle(self.tr("Sign Up"))
+ # subtitle is set in the initializePage
+
+ self.setPixmap(
+ QtGui.QWizard.LogoPixmap,
+ QtGui.QPixmap(APP_LOGO))
+
+ # commit page means there's no way back after this...
+ # XXX should change the text on the "commit" button...
+ self.setCommitPage(True)
+
+ self.setupSteps()
+ self.setupUI()
+ self.do_confirm_next = False
+ self.focused_field = False
+
+ def setupUI(self):
+ userNameLabel = QtGui.QLabel(self.tr("User &name:"))
+ userNameLineEdit = QtGui.QLineEdit()
+ userNameLineEdit.cursorPositionChanged.connect(
+ self.reset_validation_status)
+ userNameLabel.setBuddy(userNameLineEdit)
+
+ # let's add regex validator
+ usernameRe = QtCore.QRegExp(BARE_USERNAME_REGEX)
+ userNameLineEdit.setValidator(
+ QtGui.QRegExpValidator(usernameRe, self))
+ self.userNameLineEdit = userNameLineEdit
+
+ userPasswordLabel = QtGui.QLabel(self.tr("&Password:"))
+ self.userPasswordLineEdit = QtGui.QLineEdit()
+ self.userPasswordLineEdit.setEchoMode(
+ QtGui.QLineEdit.Password)
+ userPasswordLabel.setBuddy(self.userPasswordLineEdit)
+
+ userPassword2Label = QtGui.QLabel(self.tr("Password (again):"))
+ self.userPassword2LineEdit = QtGui.QLineEdit()
+ self.userPassword2LineEdit.setEchoMode(
+ QtGui.QLineEdit.Password)
+ userPassword2Label.setBuddy(self.userPassword2LineEdit)
+
+ rememberPasswordCheckBox = QtGui.QCheckBox(
+ self.tr("&Remember username and password."))
+ rememberPasswordCheckBox.setChecked(True)
+
+ self.registerField('userName*', self.userNameLineEdit)
+ self.registerField('userPassword*', self.userPasswordLineEdit)
+ self.registerField('userPassword2*', self.userPassword2LineEdit)
+
+ # XXX missing password confirmation
+ # XXX validator!
+
+ self.registerField('rememberPassword', rememberPasswordCheckBox)
+
+ layout = QtGui.QGridLayout()
+ layout.setColumnMinimumWidth(0, 20)
+
+ validationMsg = QtGui.QLabel("")
+ validationMsg.setStyleSheet(ErrorLabelStyleSheet)
+
+ self.validationMsg = validationMsg
+
+ layout.addWidget(validationMsg, 0, 3)
+ layout.addWidget(userNameLabel, 1, 0)
+ layout.addWidget(self.userNameLineEdit, 1, 3)
+ layout.addWidget(userPasswordLabel, 2, 0)
+ layout.addWidget(userPassword2Label, 3, 0)
+ layout.addWidget(self.userPasswordLineEdit, 2, 3)
+ layout.addWidget(self.userPassword2LineEdit, 3, 3)
+ layout.addWidget(rememberPasswordCheckBox, 4, 3, 4, 4)
+
+ # add validation frame
+ self.setupValidationFrame()
+ layout.addWidget(self.valFrame, 5, 2, 5, 2)
+ self.valFrame.hide()
+
+ self.setLayout(layout)
+ self.commitText("Sign up!")
+
+ # commit button
+
+ def commitText(self, text):
+ # change "commit" button text
+ self.setButtonText(
+ QtGui.QWizard.CommitButton, text)
+
+ @property
+ def commitButton(self):
+ return self.wizard().button(QtGui.QWizard.CommitButton)
+
+ def commitFocus(self):
+ self.commitButton.setFocus()
+
+ def disableCommitButton(self):
+ self.commitButton.setDisabled(True)
+
+ def disableFields(self):
+ for field in (self.userNameLineEdit,
+ self.userPasswordLineEdit,
+ self.userPassword2LineEdit):
+ field.setDisabled(True)
+
+ # error painting
+ def paintEvent(self, event):
+ """
+ we hook our populate errors
+ on paintEvent because we need it to catch
+ when user enters the page coming from next,
+ and initializePage does not cover that case.
+ Maybe there's a better event to hook upon.
+ """
+ super(RegisterUserPage, self).paintEvent(event)
+ self.populateErrors()
+
+ def markRedAndGetFocus(self, field):
+ field.setStyleSheet(styles.ErrorLineEdit)
+ if not self.focused_field:
+ self.focused_field = True
+ field.setFocus(QtCore.Qt.OtherFocusReason)
+
+ def markRegular(self, field):
+ field.setStyleSheet(styles.RegularLineEdit)
+
+ def populateErrors(self):
+ def showerr(text):
+ self.validationMsg.setText(text)
+ err_lower = text.lower()
+ if "username" in err_lower:
+ self.markRedAndGetFocus(
+ self.userNameLineEdit)
+ if "password" in err_lower:
+ self.markRedAndGetFocus(
+ self.userPasswordLineEdit)
+
+ def unmarkred():
+ for field in (self.userNameLineEdit,
+ self.userPasswordLineEdit,
+ self.userPassword2LineEdit):
+ self.markRegular(field)
+
+ errors = self.wizard().get_validation_error(
+ self.current_page)
+ if errors:
+ bad_str = getattr(self, 'bad_string', None)
+ cur_str = self.userNameLineEdit.text()
+ #prev_er = getattr(self, 'prevalidation_error', None)
+
+ if bad_str is None:
+ # first time we fall here.
+ # save the current bad_string value
+ self.bad_string = cur_str
+ showerr(errors)
+ else:
+ #if prev_er:
+ #showerr(prev_er)
+ #return
+ # not the first time
+ if cur_str == bad_str:
+ showerr(errors)
+ else:
+ self.focused_field = False
+ showerr('')
+ unmarkred()
+ else:
+ # no errors
+ self.focused_field = False
+ unmarkred()
+
+ def cleanup_errormsg(self):
+ """
+ we reset bad_string to None
+ should be called before leaving the page
+ """
+ self.bad_string = None
+
+ def green_validation_status(self):
+ val = self.validationMsg
+ val.setText(self.tr('Registration succeeded!'))
+ val.setStyleSheet(styles.GreenLineEdit)
+
+ def reset_validation_status(self):
+ """
+ empty the validation msg
+ and clean the inline validation widget.
+ """
+ self.validationMsg.setText('')
+ self.steps.removeAllSteps()
+ self.clearTable()
+
+ # actual checks
+
+ def _do_checks(self):
+ """
+ generator that yields actual checks
+ that are executed in a separate thread
+ """
+ wizard = self.wizard()
+
+ provider = self.field('provider_domain')
+ username = self.userNameLineEdit.text()
+ password = self.userPasswordLineEdit.text()
+ password2 = self.userPassword2LineEdit.text()
+
+ pconfig = wizard.eipconfigchecker(domain=provider)
+ pconfig.defaultprovider.load()
+ pconfig.set_api_domain()
+
+ def checkpass():
+ # we better have here
+ # some call to a password checker...
+ # to assess strenght and avoid silly stuff.
+
+ if password != password2:
+ return self.fail(self.tr('Password does not match..'))
+
+ if len(password) < 6:
+ #self.set_prevalidation_error('Password too short.')
+ return self.fail(self.tr('Password too short.'))
+
+ if password == "123456":
+ # joking, but not too much.
+ #self.set_prevalidation_error('Password too obvious.')
+ return self.fail(self.tr('Password too obvious.'))
+
+ # go
+ return True
+
+ yield(("head_sentinel", 0), checkpass)
+
+ # XXX should emit signal for .show the frame!
+ # XXX HERE!
+
+ ##################################################
+ # 1) register user
+ ##################################################
+
+ # show the frame before going on...
+ QtCore.QMetaObject.invokeMethod(
+ self, "showStepsFrame")
+
+ def register():
+
+ signup = auth.LeapSRPRegister(
+ schema="https",
+ provider=pconfig.apidomain,
+ verify=pconfig.cacert)
+ try:
+ ok, req = signup.register_user(
+ username, password)
+
+ except socket.timeout:
+ return self.fail(
+ self.tr("Error connecting to provider (timeout)"))
+
+ except requests.exceptions.ConnectionError as exc:
+ logger.error(exc.message)
+ return self.fail(
+ self.tr('Error Connecting to provider (connerr).'))
+ except Exception as exc:
+ return self.fail(exc.message)
+
+ # XXX check for != OK instead???
+
+ if req.status_code in (404, 500):
+ return self.fail(
+ self.tr(
+ "Error during registration (%s)") % req.status_code)
+
+ try:
+ validation_msgs = json.loads(req.content)
+ errors = validation_msgs.get('errors', None)
+ logger.debug('validation errors: %s' % validation_msgs)
+ except ValueError:
+ # probably bad json returned
+ return self.fail(
+ self.tr(
+ "Could not register (bad response)"))
+
+ if errors and errors.get('login', None):
+ # XXX this sometimes catch the blank username
+ # but we're not allowing that (soon)
+ return self.fail(
+ self.tr('Username not available.'))
+
+ return True
+
+ logger.debug('registering user')
+ yield(("Registering username", 40), register)
+
+ self.set_done()
+ yield(("end_sentinel", 100), lambda: None)
+
+ def on_checks_validation_ready(self):
+ """
+ after checks
+ """
+ if self.is_done():
+ self.disableFields()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+ # make the user confirm the transition
+ # to next page.
+ self.commitText('Connect!')
+ self.commitFocus()
+ self.green_validation_status()
+ self.do_confirm_next = True
+
+ # pagewizard methods
+
+ def validatePage(self):
+ """
+ if not register done, do checks.
+ if done, wait for click.
+ """
+ self.disableCommitButton()
+ self.cleanup_errormsg()
+ self.clean_wizard_errors(self.current_page)
+
+ # After a successful validation
+ # (ie, success register with server)
+ # we change the commit button text
+ # and set this flag to True.
+ if self.do_confirm_next:
+ return True
+
+ if not self.is_done():
+ # calls checks, which after successful
+ # execution will call on_checks_validation_ready
+ self.reset_validation_status()
+ self.do_checks()
+
+ return self.is_done()
+
+ def initializePage(self):
+ """
+ inits wizard page
+ """
+ provider = unicode(self.field('provider_domain'))
+ if provider:
+ # here we should have provider
+ # but in tests we might not.
+
+ # XXX this error causes a segfault on free()
+ # that we might want to get fixed ...
+ #self.setSubTitle(
+ #self.tr("Register a new user with provider %s.") %
+ #provider)
+ self.setSubTitle(
+ self.tr("Register a new user with provider <em>%s</em>" %
+ provider))
+ self.validationMsg.setText('')
+ self.userPassword2LineEdit.setText('')
+ self.valFrame.hide()
+
+ def nextId(self):
+ wizard = self.wizard()
+ return wizard.get_page_index('connect')
diff --git a/src/leap/gui/firstrun/tests/integration/fake_provider.py b/src/leap/gui/firstrun/tests/integration/fake_provider.py
new file mode 100755
index 00000000..668db5d1
--- /dev/null
+++ b/src/leap/gui/firstrun/tests/integration/fake_provider.py
@@ -0,0 +1,302 @@
+#!/usr/bin/env python
+"""A server faking some of the provider resources and apis,
+used for testing Leap Client requests
+
+It needs that you create a subfolder named 'certs',
+and that you place the following files:
+
+[ ] certs/leaptestscert.pem
+[ ] certs/leaptestskey.pem
+[ ] certs/cacert.pem
+[ ] certs/openvpn.pem
+
+[ ] provider.json
+[ ] eip-service.json
+"""
+# XXX NOTE: intended for manual debug.
+# I intend to include this as a regular test after 0.2.0 release
+# (so we can add twisted as a dep there)
+import binascii
+import json
+import os
+import sys
+
+# python SRP LIB (! important MUST be >=1.0.1 !)
+import srp
+
+# GnuTLS Example -- is not working as expected
+#from gnutls import crypto
+#from gnutls.constants import COMP_LZO, COMP_DEFLATE, COMP_NULL
+#from gnutls.interfaces.twisted import X509Credentials
+
+# Going with OpenSSL as a workaround instead
+# But we DO NOT want to introduce this dependency.
+from OpenSSL import SSL
+
+from zope.interface import Interface, Attribute, implements
+
+from twisted.web.server import Site
+from twisted.web.static import File
+from twisted.web.resource import Resource
+from twisted.internet import reactor
+
+from leap.testing.https_server import where
+
+# See
+# http://twistedmatrix.com/documents/current/web/howto/web-in-60/index.htmln
+# for more examples
+
+"""
+Testing the FAKE_API:
+#####################
+
+ 1) register an user
+ >> curl -d "user[login]=me" -d "user[password_salt]=foo" \
+ -d "user[password_verifier]=beef" http://localhost:8000/1/users.json
+ << {"errors": null}
+
+ 2) check that if you try to register again, it will fail:
+ >> curl -d "user[login]=me" -d "user[password_salt]=foo" \
+ -d "user[password_verifier]=beef" http://localhost:8000/1/users.json
+ << {"errors": {"login": "already taken!"}}
+
+"""
+
+# Globals to mock user/sessiondb
+
+USERDB = {}
+SESSIONDB = {}
+
+
+safe_unhexlify = lambda x: binascii.unhexlify(x) \
+ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
+
+
+class IUser(Interface):
+ login = Attribute("User login.")
+ salt = Attribute("Password salt.")
+ verifier = Attribute("Password verifier.")
+ session = Attribute("Session.")
+ svr = Attribute("Server verifier.")
+
+
+class User(object):
+ implements(IUser)
+
+ def __init__(self, login, salt, verifier):
+ self.login = login
+ self.salt = salt
+ self.verifier = verifier
+ self.session = None
+
+ def set_server_verifier(self, svr):
+ self.svr = svr
+
+ def set_session(self, session):
+ SESSIONDB[session] = self
+ self.session = session
+
+
+class FakeUsers(Resource):
+ def __init__(self, name):
+ self.name = name
+
+ def render_POST(self, request):
+ args = request.args
+
+ login = args['user[login]'][0]
+ salt = args['user[password_salt]'][0]
+ verifier = args['user[password_verifier]'][0]
+
+ if login in USERDB:
+ return "%s\n" % json.dumps(
+ {'errors': {'login': 'already taken!'}})
+
+ print login, verifier, salt
+ user = User(login, salt, verifier)
+ USERDB[login] = user
+ return json.dumps({'errors': None})
+
+
+def get_user(request):
+ login = request.args.get('login')
+ if login:
+ user = USERDB.get(login[0], None)
+ if user:
+ return user
+
+ session = request.getSession()
+ user = SESSIONDB.get(session, None)
+ return user
+
+
+class FakeSession(Resource):
+ def __init__(self, name):
+ self.name = name
+
+ def render_GET(self, request):
+ return "%s\n" % json.dumps({'errors': None})
+
+ def render_POST(self, request):
+
+ user = get_user(request)
+
+ if not user:
+ # XXX get real error from demo provider
+ return json.dumps({'errors': 'no such user'})
+
+ A = request.args['A'][0]
+
+ _A = safe_unhexlify(A)
+ _salt = safe_unhexlify(user.salt)
+ _verifier = safe_unhexlify(user.verifier)
+
+ svr = srp.Verifier(
+ user.login,
+ _salt,
+ _verifier,
+ _A,
+ hash_alg=srp.SHA256,
+ ng_type=srp.NG_1024)
+
+ s, B = svr.get_challenge()
+
+ _B = binascii.hexlify(B)
+
+ print 'login = %s' % user.login
+ print 'salt = %s' % user.salt
+ print 'len(_salt) = %s' % len(_salt)
+ print 'vkey = %s' % user.verifier
+ print 'len(vkey) = %s' % len(_verifier)
+ print 's = %s' % binascii.hexlify(s)
+ print 'B = %s' % _B
+ print 'len(B) = %s' % len(_B)
+
+ session = request.getSession()
+ user.set_session(session)
+ user.set_server_verifier(svr)
+
+ # yep, this is tricky.
+ # some things are *already* unhexlified.
+ data = {
+ 'salt': user.salt,
+ 'B': _B,
+ 'errors': None}
+
+ return json.dumps(data)
+
+ def render_PUT(self, request):
+
+ # XXX check session???
+ user = get_user(request)
+
+ if not user:
+ print 'NO USER'
+ return json.dumps({'errors': 'no such user'})
+
+ data = request.content.read()
+ auth = data.split("client_auth=")
+ M = auth[1] if len(auth) > 1 else None
+ # if not H, return
+ if not M:
+ return json.dumps({'errors': 'no M proof passed by client'})
+
+ svr = user.svr
+ HAMK = svr.verify_session(binascii.unhexlify(M))
+ if HAMK is None:
+ print 'verification failed!!!'
+ raise Exception("Authentication failed!")
+ #import ipdb;ipdb.set_trace()
+
+ assert svr.authenticated()
+ print "***"
+ print 'server authenticated user SRP!'
+ print "***"
+
+ return json.dumps(
+ {'M2': binascii.hexlify(HAMK), 'errors': None})
+
+
+class API_Sessions(Resource):
+ def getChild(self, name, request):
+ return FakeSession(name)
+
+
+def get_certs_path():
+ script_path = os.path.realpath(os.path.dirname(sys.argv[0]))
+ certs_path = os.path.join(script_path, 'certs')
+ return certs_path
+
+
+def get_TLS_credentials():
+ # XXX this is giving errors
+ # XXX REview! We want to use gnutls!
+
+ cert = crypto.X509Certificate(
+ open(where('leaptestscert.pem')).read())
+ key = crypto.X509PrivateKey(
+ open(where('leaptestskey.pem')).read())
+ ca = crypto.X509Certificate(
+ open(where('cacert.pem')).read())
+ #crl = crypto.X509CRL(open(certs_path + '/crl.pem').read())
+ #cred = crypto.X509Credentials(cert, key, [ca], [crl])
+ cred = X509Credentials(cert, key, [ca])
+ cred.verify_peer = True
+ cred.session_params.compressions = (COMP_LZO, COMP_DEFLATE, COMP_NULL)
+ return cred
+
+
+class OpenSSLServerContextFactory:
+ # XXX workaround for broken TLS interface
+ # from gnuTLS.
+
+ def getContext(self):
+ """Create an SSL context.
+ This is a sample implementation that loads a certificate from a file
+ called 'server.pem'."""
+
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ #certs_path = get_certs_path()
+ #ctx.use_certificate_file(certs_path + '/leaptestscert.pem')
+ #ctx.use_privatekey_file(certs_path + '/leaptestskey.pem')
+ ctx.use_certificate_file(where('leaptestscert.pem'))
+ ctx.use_privatekey_file(where('leaptestskey.pem'))
+ return ctx
+
+
+def serve_fake_provider():
+ root = Resource()
+ root.putChild("provider.json", File("./provider.json"))
+ config = Resource()
+ config.putChild(
+ "eip-service.json",
+ File("./eip-service.json"))
+ apiv1 = Resource()
+ apiv1.putChild("config", config)
+ apiv1.putChild("sessions.json", API_Sessions())
+ apiv1.putChild("users.json", FakeUsers(None))
+ apiv1.putChild("cert", File(get_certs_path() + '/openvpn.pem'))
+ root.putChild("1", apiv1)
+
+ cred = get_TLS_credentials()
+
+ factory = Site(root)
+
+ # regular http (for debugging with curl)
+ reactor.listenTCP(8000, factory)
+
+ # TLS with gnutls --- seems broken :(
+ #reactor.listenTLS(8003, factory, cred)
+
+ # OpenSSL
+ reactor.listenSSL(8443, factory, OpenSSLServerContextFactory())
+
+ reactor.run()
+
+
+if __name__ == "__main__":
+
+ from twisted.python import log
+ log.startLogging(sys.stdout)
+
+ serve_fake_provider()
diff --git a/src/leap/gui/firstrun/wizard.py b/src/leap/gui/firstrun/wizard.py
new file mode 100755
index 00000000..f198dca0
--- /dev/null
+++ b/src/leap/gui/firstrun/wizard.py
@@ -0,0 +1,309 @@
+#!/usr/bin/env python
+import logging
+
+import sip
+try:
+ sip.setapi('QString', 2)
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.base import checks as basechecks
+from leap.crypto import leapkeyring
+from leap.eip import checks as eipchecks
+
+from leap.gui import firstrun
+
+from leap.gui import mainwindow_rc
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+logger = logging.getLogger(__name__)
+
+"""
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+Work in progress!
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+This wizard still needs to be refactored out.
+
+TODO-ish:
+
+[X] Break file in wizard / pages files (and its own folder).
+[ ] Separate presentation from logic.
+[ ] Have a "manager" class for connections, that can be
+ dep-injected for testing.
+[ ] Document signals used / expected.
+[ ] Separate style from widgets.
+[ ] Fix TOFU Widget for provider cert.
+[X] Refactor widgets out.
+[ ] Follow more MVC style.
+[ ] Maybe separate "first run wizard" into different wizards
+ that share some of the pages?
+"""
+
+
+def get_pages_dict():
+ return OrderedDict((
+ ('intro', firstrun.intro.IntroPage),
+ ('providerselection',
+ firstrun.providerselect.SelectProviderPage),
+ ('login', firstrun.login.LogInPage),
+ ('providerinfo', firstrun.providerinfo.ProviderInfoPage),
+ ('providersetupvalidation',
+ firstrun.providersetup.ProviderSetupValidationPage),
+ ('signup', firstrun.register.RegisterUserPage),
+ ('connect',
+ firstrun.connect.ConnectionPage),
+ ('lastpage', firstrun.last.LastPage)
+ ))
+
+
+class FirstRunWizard(QtGui.QWizard):
+
+ def __init__(
+ self,
+ conductor_instance,
+ parent=None,
+ pages_dict=None,
+ username=None,
+ providers=None,
+ success_cb=None, is_provider_setup=False,
+ trusted_certs=None,
+ netchecker=basechecks.LeapNetworkChecker,
+ providercertchecker=eipchecks.ProviderCertChecker,
+ eipconfigchecker=eipchecks.EIPConfigChecker,
+ start_eipconnection_signal=None,
+ eip_statuschange_signal=None,
+ debug_server=None,
+ quitcallback=None):
+ super(FirstRunWizard, self).__init__(
+ parent,
+ QtCore.Qt.WindowStaysOnTopHint)
+
+ # we keep a reference to the conductor
+ # to be able to launch eip checks and connection
+ # in the connection page, before the wizard has ended.
+ self.conductor = conductor_instance
+
+ self.username = username
+ self.providers = providers
+
+ # success callback
+ self.success_cb = success_cb
+
+ # is provider setup?
+ self.is_provider_setup = is_provider_setup
+
+ # a dict with trusted fingerprints
+ # in the form {'nospacesfingerprint': ['host1', 'host2']}
+ self.trusted_certs = trusted_certs
+
+ # Checkers
+ self.netchecker = netchecker
+ self.providercertchecker = providercertchecker
+ self.eipconfigchecker = eipconfigchecker
+
+ # debug server
+ self.debug_server = debug_server
+
+ # Signals
+ # will be emitted in connecting page
+ self.start_eipconnection_signal = start_eipconnection_signal
+ self.eip_statuschange_signal = eip_statuschange_signal
+
+ if quitcallback is not None:
+ self.button(
+ QtGui.QWizard.CancelButton).clicked.connect(
+ quitcallback)
+
+ self.providerconfig = None
+ # previously registered
+ # if True, jumps to LogIn page.
+ # by setting 1st page??
+ #self.is_previously_registered = is_previously_registered
+ # XXX ??? ^v
+ self.is_previously_registered = bool(self.username)
+ self.from_login = False
+
+ pages_dict = pages_dict or get_pages_dict()
+ self.add_pages_from_dict(pages_dict)
+
+ self.validation_errors = {}
+ self.openvpn_status = []
+
+ self.setPixmap(
+ QtGui.QWizard.BannerPixmap,
+ QtGui.QPixmap(':/images/banner.png'))
+ self.setPixmap(
+ QtGui.QWizard.BackgroundPixmap,
+ QtGui.QPixmap(':/images/background.png'))
+
+ # set options
+ self.setOption(QtGui.QWizard.IndependentPages, on=False)
+ self.setOption(QtGui.QWizard.NoBackButtonOnStartPage, on=True)
+
+ self.setWindowTitle("First Run Wizard")
+
+ # TODO: set style for MAC / windows ...
+ #self.setWizardStyle()
+
+ #
+ # setup pages in wizard
+ #
+
+ def add_pages_from_dict(self, pages_dict):
+ """
+ @param pages_dict: the dictionary with pages, where
+ values are a tuple of InstanceofWizardPage, kwargs.
+ @type pages_dict: dict
+ """
+ for name, page in pages_dict.items():
+ # XXX check for is_previously registered
+ # and skip adding the signup branch if so
+ self.addPage(page())
+ self.pages_dict = pages_dict
+
+ def get_page_index(self, page_name):
+ """
+ returns the index of the given page
+ @param page_name: the name of the desired page
+ @type page_name: str
+ @rparam: index of page in wizard
+ @rtype: int
+ """
+ return self.pages_dict.keys().index(page_name)
+
+ #
+ # validation errors
+ #
+
+ def set_validation_error(self, pagename, error):
+ self.validation_errors[pagename] = error
+
+ def clean_validation_error(self, pagename):
+ vald = self.validation_errors
+ if pagename in vald:
+ del vald[pagename]
+
+ def get_validation_error(self, pagename):
+ return self.validation_errors.get(pagename, None)
+
+ def accept(self):
+ """
+ final step in the wizard.
+ gather the info, update settings
+ and call the success callback if any has been passed.
+ """
+ super(FirstRunWizard, self).accept()
+
+ # username and password are in different fields
+ # if they were stored in log_in or sign_up pages.
+ from_login = self.from_login
+ unamek_base = 'userName'
+ passwk_base = 'userPassword'
+ unamek = 'login_%s' % unamek_base if from_login else unamek_base
+ passwk = 'login_%s' % passwk_base if from_login else passwk_base
+
+ username = self.field(unamek)
+ password = self.field(passwk)
+ provider = self.field('provider_domain')
+ remember_pass = self.field('rememberPassword')
+
+ logger.debug('chosen provider: %s', provider)
+ logger.debug('username: %s', username)
+ logger.debug('remember password: %s', remember_pass)
+
+ # we are assuming here that we only remember one username
+ # in the form username@provider.domain
+ # We probably could extend this to support some form of
+ # profiles.
+
+ settings = QtCore.QSettings()
+
+ settings.setValue("FirstRunWizardDone", True)
+ settings.setValue("provider_domain", provider)
+ full_username = "%s@%s" % (username, provider)
+
+ settings.setValue("remember_user_and_pass", remember_pass)
+
+ if remember_pass:
+ settings.setValue("username", full_username)
+ seed = self.get_random_str(10)
+ settings.setValue("%s_seed" % provider, seed)
+
+ # XXX #744: comment out for 0.2.0 release
+ # if we need to have a version of python-keyring < 0.9
+ leapkeyring.leap_set_password(
+ full_username, password, seed=seed)
+
+ logger.debug('First Run Wizard Done.')
+ cb = self.success_cb
+ if cb and callable(cb):
+ self.success_cb()
+
+ # misc helpers
+
+ def get_random_str(self, n):
+ """
+ returns a random string
+ :param n: the length of the desired string
+ :rvalue: str
+ """
+ from string import (ascii_uppercase, ascii_lowercase, digits)
+ from random import choice
+ return ''.join(choice(
+ ascii_uppercase +
+ ascii_lowercase +
+ digits) for x in range(n))
+
+ def set_providerconfig(self, providerconfig):
+ """
+ sets a providerconfig attribute
+ used when we fetch and parse a json configuration
+ """
+ self.providerconfig = providerconfig
+
+ def get_provider_by_index(self): # pragma: no cover
+ """
+ returns the value of a provider given its index.
+ this was used in the select provider page,
+ in the case where we were preseeding providers in a combobox
+ """
+ # Leaving it here for the moment when we go back at the
+ # option of preseeding with known provider values.
+ provider = self.field('provider_index')
+ return self.providers[provider]
+
+
+if __name__ == '__main__':
+ # standalone test
+ # it can be (somehow) run against
+ # gui/tests/integration/fake_user_signup.py
+
+ import sys
+ import logging
+ logging.basicConfig()
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG)
+
+ app = QtGui.QApplication(sys.argv)
+ server = sys.argv[1] if len(sys.argv) > 1 else None
+
+ trusted_certs = {
+ "3DF83F316BFA0186"
+ "0A11A5C9C7FC24B9"
+ "18C62B941192CC1A"
+ "49AE62218B2A4B7C": ['springbok']}
+
+ wizard = FirstRunWizard(
+ None, trusted_certs=trusted_certs,
+ debug_server=server)
+ wizard.show()
+ sys.exit(app.exec_())
diff --git a/src/leap/gui/firstrunwizard.py b/src/leap/gui/firstrunwizard.py
deleted file mode 100755
index a76865fd..00000000
--- a/src/leap/gui/firstrunwizard.py
+++ /dev/null
@@ -1,507 +0,0 @@
-#!/usr/bin/env python
-import logging
-import json
-import socket
-
-import sip
-sip.setapi('QString', 2)
-sip.setapi('QVariant', 2)
-
-from PyQt4 import QtCore
-from PyQt4 import QtGui
-
-from leap.crypto import leapkeyring
-from leap.gui import mainwindow_rc
-
-logger = logging.getLogger(__name__)
-
-APP_LOGO = ':/images/leap-color-small.png'
-
-# registration ######################
-# move to base/
-import binascii
-
-import requests
-import srp
-
-from leap.base import constants as baseconstants
-
-SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5)
-
-
-class LeapSRPRegister(object):
-
- def __init__(self,
- schema="https",
- provider=None,
- port=None,
- register_path="1/users.json",
- method="POST",
- fetcher=requests,
- srp=srp,
- hashfun=srp.SHA256,
- ng_constant=srp.NG_1024):
-
- self.schema = schema
- self.provider = provider
- self.port = port
- self.register_path = register_path
- self.method = method
- self.fetcher = fetcher
- self.srp = srp
- self.HASHFUN = hashfun
- self.NG = ng_constant
-
- self.init_session()
-
- def init_session(self):
- self.session = self.fetcher.session()
-
- def get_registration_uri(self):
- # XXX assert is https!
- # use urlparse
- if self.port:
- uri = "%s://%s:%s/%s" % (
- self.schema,
- self.provider,
- self.port,
- self.register_path)
- else:
- uri = "%s://%s/%s" % (
- self.schema,
- self.provider,
- self.register_path)
-
- return uri
-
- def register_user(self, username, password, keep=False):
- """
- @rtype: tuple
- @rvalue: (ok, request)
- """
- salt, vkey = self.srp.create_salted_verification_key(
- username,
- password,
- self.HASHFUN,
- self.NG)
-
- user_data = {
- 'user[login]': username,
- 'user[password_verifier]': binascii.hexlify(vkey),
- 'user[password_salt]': binascii.hexlify(salt)}
-
- uri = self.get_registration_uri()
- logger.debug('post to uri: %s' % uri)
-
- # XXX get self.method
- req = self.session.post(
- uri, data=user_data,
- timeout=SIGNUP_TIMEOUT)
- logger.debug(req)
- logger.debug('user_data: %s', user_data)
- #logger.debug('response: %s', req.text)
- # we catch it in the form
- #req.raise_for_status()
- return (req.ok, req)
-
-######################################
-
-ErrorLabelStyleSheet = """
-QLabel { color: red;
- font-weight: bold}
-"""
-
-
-class FirstRunWizard(QtGui.QWizard):
-
- def __init__(
- self, parent=None, providers=None,
- success_cb=None):
- super(FirstRunWizard, self).__init__(
- parent,
- QtCore.Qt.WindowStaysOnTopHint)
-
- # XXX hardcoded for tests
- if not providers:
- providers = ('springbok',)
- self.providers = providers
-
- # success callback
- self.success_cb = success_cb
-
- self.addPage(IntroPage())
- self.addPage(SelectProviderPage(providers=providers))
-
- self.addPage(RegisterUserPage(wizard=self))
- #self.addPage(GlobalEIPSettings())
- self.addPage(LastPage())
-
- self.setPixmap(
- QtGui.QWizard.BannerPixmap,
- QtGui.QPixmap(':/images/banner.png'))
- self.setPixmap(
- QtGui.QWizard.BackgroundPixmap,
- QtGui.QPixmap(':/images/background.png'))
-
- self.setWindowTitle("First Run Wizard")
-
- # TODO: set style for MAC / windows ...
- #self.setWizardStyle()
-
- def setWindowFlags(self, flags):
- logger.debug('setting window flags')
- QtGui.QWizard.setWindowFlags(self, flags)
-
- def focusOutEvent(self, event):
- # needed ?
- self.setFocus(True)
- self.activateWindow()
- self.raise_()
- self.show()
-
- def accept(self):
- """
- final step in the wizard.
- gather the info, update settings
- and call the success callback.
- """
- provider = self.get_provider()
- username = self.field('userName')
- #password = self.field('userPassword')
- remember_pass = self.field('rememberPassword')
-
- logger.debug('chosen provider: %s', provider)
- logger.debug('username: %s', username)
- logger.debug('remember password: %s', remember_pass)
- super(FirstRunWizard, self).accept()
-
- settings = QtCore.QSettings()
- settings.setValue("FirstRunWizardDone", True)
- settings.setValue(
- "eip_%s_username" % provider,
- username)
- settings.setValue("%s_remember_pass" % provider, remember_pass)
-
- seed = self.get_random_str(10)
- settings.setValue("%s_seed" % provider, seed)
-
- # Commenting out for 0.2.0 release
- # since we did not fix #744 on time.
-
- #leapkeyring.leap_set_password(username, password, seed=seed)
-
- logger.debug('First Run Wizard Done.')
- cb = self.success_cb
- if cb and callable(cb):
- self.success_cb()
-
- def get_provider(self):
- provider = self.field('provider_index')
- return self.providers[provider]
-
- def get_random_str(self, n):
- from string import (ascii_uppercase, ascii_lowercase, digits)
- from random import choice
- return ''.join(choice(
- ascii_uppercase +
- ascii_lowercase +
- digits) for x in range(n))
-
-
-class IntroPage(QtGui.QWizardPage):
- def __init__(self, parent=None):
- super(IntroPage, self).__init__(parent)
-
- self.setTitle("First run wizard.")
-
- #self.setPixmap(
- #QtGui.QWizard.WatermarkPixmap,
- #QtGui.QPixmap(':/images/watermark1.png'))
-
- label = QtGui.QLabel(
- "Now we will guide you through "
- "some configuration that is needed before you "
- "can connect for the first time.<br><br>"
- "If you ever need to modify these options again, "
- "you can find the wizard in the '<i>Settings</i>' menu from the "
- "main window of the Leap App.")
-
- label.setWordWrap(True)
-
- layout = QtGui.QVBoxLayout()
- layout.addWidget(label)
- self.setLayout(layout)
-
-
-class SelectProviderPage(QtGui.QWizardPage):
- def __init__(self, parent=None, providers=None):
- super(SelectProviderPage, self).__init__(parent)
-
- self.setTitle("Select Provider")
- self.setSubTitle(
- "Please select which provider do you want "
- "to use for your connection."
- )
- self.setPixmap(
- QtGui.QWizard.LogoPixmap,
- QtGui.QPixmap(APP_LOGO))
-
- providerNameLabel = QtGui.QLabel("&Provider:")
-
- providercombo = QtGui.QComboBox()
- if providers:
- for provider in providers:
- providercombo.addItem(provider)
- providerNameSelect = providercombo
-
- providerNameLabel.setBuddy(providerNameSelect)
-
- self.registerField('provider_index', providerNameSelect)
-
- layout = QtGui.QGridLayout()
- layout.addWidget(providerNameLabel, 0, 0)
- layout.addWidget(providerNameSelect, 0, 1)
- self.setLayout(layout)
-
-
-class RegisterUserPage(QtGui.QWizardPage):
- setSigningUpStatus = QtCore.pyqtSignal([])
-
- def __init__(self, parent=None, wizard=None):
- super(RegisterUserPage, self).__init__(parent)
-
- # bind wizard page signals
- self.setSigningUpStatus.connect(
- self.set_status_validating)
-
- # XXX check for no wizard pased
- # getting provider from previous step
- provider = wizard.get_provider()
-
- self.setTitle("User registration")
- self.setSubTitle(
- "Register a new user with provider %s." %
- provider)
- self.setPixmap(
- QtGui.QWizard.LogoPixmap,
- QtGui.QPixmap(APP_LOGO))
-
- rememberPasswordCheckBox = QtGui.QCheckBox(
- "&Remember password.")
- rememberPasswordCheckBox.setChecked(True)
-
- userNameLabel = QtGui.QLabel("User &name:")
- userNameLineEdit = QtGui.QLineEdit()
- userNameLineEdit.cursorPositionChanged.connect(
- self.reset_validation_status)
- userNameLabel.setBuddy(userNameLineEdit)
-
- # add regex validator
- usernameRe = QtCore.QRegExp(r"^[A-Za-z\d_]+$")
- userNameLineEdit.setValidator(
- QtGui.QRegExpValidator(usernameRe, self))
- self.userNameLineEdit = userNameLineEdit
-
- userPasswordLabel = QtGui.QLabel("&Password:")
- self.userPasswordLineEdit = QtGui.QLineEdit()
- self.userPasswordLineEdit.setEchoMode(
- QtGui.QLineEdit.Password)
-
- userPasswordLabel.setBuddy(self.userPasswordLineEdit)
-
- self.registerField('userName', self.userNameLineEdit)
- self.registerField('userPassword', self.userPasswordLineEdit)
- self.registerField('rememberPassword', rememberPasswordCheckBox)
-
- layout = QtGui.QGridLayout()
- layout.setColumnMinimumWidth(0, 20)
-
- validationMsg = QtGui.QLabel("")
- validationMsg.setStyleSheet(ErrorLabelStyleSheet)
-
- self.validationMsg = validationMsg
-
- layout.addWidget(validationMsg, 0, 3)
-
- layout.addWidget(userNameLabel, 1, 0)
- layout.addWidget(self.userNameLineEdit, 1, 3)
-
- layout.addWidget(userPasswordLabel, 2, 0)
- layout.addWidget(self.userPasswordLineEdit, 2, 3)
-
- layout.addWidget(rememberPasswordCheckBox, 3, 3, 3, 4)
- self.setLayout(layout)
-
- def reset_validation_status(self):
- """
- empty the validation msg
- """
- self.validationMsg.setText('')
-
- def set_status_validating(self):
- """
- set validation msg to 'registering...'
- """
- # XXX this is NOT WORKING.
- # My guess is that, even if we are using
- # signals to trigger this, it does
- # not show until the validate function
- # returns.
- # I guess it is because there is no delay...
- logger.debug('registering........')
- self.validationMsg.setText('registering...')
- # need to call update somehow???
-
- def set_status_invalid_username(self):
- """
- set validation msg to
- not available user
- """
- self.validationMsg.setText('Username not available.')
-
- def set_status_server_500(self):
- """
- set validation msg to
- internal server error
- """
- self.validationMsg.setText("Error during registration (500)")
-
- def set_status_timeout(self):
- """
- set validation msg to
- timeout
- """
- self.validationMsg.setText("Error connecting to provider (timeout)")
-
- def set_status_connerror(self):
- """
- set validation msg to
- connection refused
- """
- self.validationMsg.setText(
- "Error connecting to provider "
- "(connection error)")
-
- def set_status_unknown_error(self):
- """
- set validation msg to
- unknown error
- """
- self.validationMsg.setText("Error during signup")
-
- # overwritten methods
-
- def initializePage(self):
- """
- inits wizard page
- """
- self.validationMsg.setText('')
-
- def validatePage(self):
- """
- validation
- we initialize the srp protocol register
- and try to register user. if error
- returned we write validation error msg
- above the form.
- """
- # the slot for this signal is not doing
- # what's expected. Investigate why,
- # right now we're not giving any feedback
- # to the user re. what's going on. The only
- # thing I can see as a workaround is setting
- # a low timeout.
- self.setSigningUpStatus.emit()
-
- username = self.userNameLineEdit.text()
- password = self.userPasswordLineEdit.text()
-
- # XXX TODO -- remove debug info
- # XXX get from provider info
- # XXX enforce https
- # and pass a verify value
-
- signup = LeapSRPRegister(
- schema="http",
- provider="springbok",
-
- # debug -----
- #provider="localhost",
- #register_path="timeout",
- #port=8000
- )
- try:
- ok, req = signup.register_user(username, password)
- except socket.timeout:
- self.set_status_timeout()
- return False
-
- except requests.exceptions.ConnectionError as exc:
- logger.error(exc)
- self.set_status_connerror()
- return False
-
- if ok:
- return True
-
- # something went wrong.
- # not registered, let's catch what.
- # get timeout
- # ...
- if req.status_code == 500:
- self.set_status_server_500()
- return False
-
- validation_msgs = json.loads(req.content)
- logger.debug('validation errors: %s' % validation_msgs)
- errors = validation_msgs.get('errors', None)
- if errors and errors.get('login', None):
- self.set_status_invalid_username()
- else:
- self.set_status_unknown_error()
- return False
-
-
-class GlobalEIPSettings(QtGui.QWizardPage):
- def __init__(self, parent=None):
- super(GlobalEIPSettings, self).__init__(parent)
-
-
-class LastPage(QtGui.QWizardPage):
- def __init__(self, parent=None):
- super(LastPage, self).__init__(parent)
-
- self.setTitle("Ready to go!")
-
- #self.setPixmap(
- #QtGui.QWizard.WatermarkPixmap,
- #QtGui.QPixmap(':/images/watermark2.png'))
-
- self.label = QtGui.QLabel()
- self.label.setWordWrap(True)
-
- layout = QtGui.QVBoxLayout()
- layout.addWidget(self.label)
- self.setLayout(layout)
-
- def initializePage(self):
- finishText = self.wizard().buttonText(
- QtGui.QWizard.FinishButton)
- finishText = finishText.replace('&', '')
- self.label.setText(
- "Click '<i>%s</i>' to end the wizard and start "
- "encrypting your connection." % finishText)
-
-
-if __name__ == '__main__':
- # standalone test
- import sys
- import logging
- logging.basicConfig()
- logger = logging.getLogger()
- logger.setLevel(logging.DEBUG)
-
- app = QtGui.QApplication(sys.argv)
- wizard = FirstRunWizard()
- wizard.show()
- sys.exit(app.exec_())
diff --git a/src/leap/gui/locale_rc.py b/src/leap/gui/locale_rc.py
new file mode 100644
index 00000000..f165ff8e
--- /dev/null
+++ b/src/leap/gui/locale_rc.py
@@ -0,0 +1,132 @@
+# -*- coding: utf-8 -*-
+
+# Resource object code
+#
+# Created: vie nov 16 22:33:33 2012
+# by: The Resource Compiler for PyQt (Qt v4.8.2)
+#
+# WARNING! All changes made in this file will be lost!
+
+from PyQt4 import QtCore
+
+qt_resource_data = "\
+\x00\x00\x05\xaa\
+\x3c\
+\xb8\x64\x18\xca\xef\x9c\x95\xcd\x21\x1c\xbf\x60\xa1\xbd\xdd\x42\
+\x00\x00\x00\x20\x09\xfc\x2c\x8e\x00\x00\x04\xfb\x0a\x74\xb8\x1e\
+\x00\x00\x00\xd6\x0a\xfd\x99\xfe\x00\x00\x00\x51\x0c\x44\x41\xbe\
+\x00\x00\x00\x00\x69\x00\x00\x05\x69\x03\x00\x00\x00\x22\x00\x50\
+\x00\x72\x00\x69\x00\x6d\x00\x65\x00\x72\x00\x61\x00\x20\x00\x63\
+\x00\x6f\x00\x6e\x00\x65\x00\x78\x00\x69\x00\x6f\x00\x6e\x00\x2e\
+\x08\x00\x00\x00\x00\x06\x00\x00\x00\x11\x46\x69\x72\x73\x74\x20\
+\x72\x75\x6e\x20\x77\x69\x7a\x61\x72\x64\x2e\x07\x00\x00\x00\x09\
+\x49\x6e\x74\x72\x6f\x50\x61\x67\x65\x01\x03\x00\x00\x00\x4c\x00\
+\x4c\x00\x6f\x00\x67\x00\x75\x00\x65\x00\x61\x00\x72\x00\x6d\x00\
+\x65\x00\x20\x00\x63\x00\x6f\x00\x6e\x00\x20\x00\x6d\x00\x69\x00\
+\x20\x00\x75\x00\x73\x00\x75\x00\x61\x00\x72\x00\x69\x00\x6f\x00\
+\x20\x00\x79\x00\x20\x00\x63\x00\x6f\x00\x6e\x00\x74\x00\x72\x00\
+\x61\x00\x73\x00\x65\x00\x6e\x00\x61\x00\x2e\x08\x00\x00\x00\x00\
+\x06\x00\x00\x00\x1b\x4c\x6f\x67\x20\x49\x6e\x20\x77\x69\x74\x68\
+\x20\x6d\x79\x20\x63\x72\x65\x64\x65\x6e\x74\x69\x61\x6c\x73\x2e\
+\x07\x00\x00\x00\x09\x49\x6e\x74\x72\x6f\x50\x61\x67\x65\x01\x03\
+\x00\x00\x02\xaa\x00\x56\x00\x61\x00\x6d\x00\x6f\x00\x73\x00\x20\
+\x00\x61\x00\x20\x00\x72\x00\x65\x00\x75\x00\x6e\x00\x69\x00\x72\
+\x00\x20\x00\x6c\x00\x61\x00\x20\x00\x69\x00\x6e\x00\x66\x00\x6f\
+\x00\x72\x00\x6d\x00\x61\x00\x63\x00\x69\x00\x6f\x00\x6e\x00\x20\
+\x00\x71\x00\x75\x00\x65\x00\x20\x00\x6e\x00\x65\x00\x63\x00\x65\
+\x00\x73\x00\x69\x00\x74\x00\x61\x00\x73\x00\x20\x00\x61\x00\x6e\
+\x00\x74\x00\x65\x00\x73\x00\x20\x00\x64\x00\x65\x00\x20\x00\x6c\
+\x00\x61\x00\x20\x00\x70\x00\x72\x00\x69\x00\x6d\x00\x65\x00\x72\
+\x00\x61\x00\x20\x00\x63\x00\x6f\x00\x6e\x00\x65\x00\x78\x00\x69\
+\x00\x6f\x00\x6e\x00\x2e\x00\x3c\x00\x62\x00\x72\x00\x3e\x00\x3c\
+\x00\x62\x00\x72\x00\x3e\x00\x53\x00\x69\x00\x20\x00\x61\x00\x6c\
+\x00\x67\x00\x75\x00\x6e\x00\x61\x00\x20\x00\x76\x00\x65\x00\x7a\
+\x00\x20\x00\x6e\x00\x65\x00\x63\x00\x65\x00\x73\x00\x69\x00\x74\
+\x00\x61\x00\x73\x00\x20\x00\x6d\x00\x6f\x00\x64\x00\x69\x00\x66\
+\x00\x69\x00\x63\x00\x61\x00\x72\x00\x20\x00\x65\x00\x73\x00\x74\
+\x00\x61\x00\x73\x00\x20\x00\x6f\x00\x70\x00\x63\x00\x69\x00\x6f\
+\x00\x6e\x00\x65\x00\x73\x00\x20\x00\x64\x00\x65\x00\x20\x00\x6e\
+\x00\x75\x00\x65\x00\x76\x00\x6f\x00\x2c\x00\x20\x00\x70\x00\x75\
+\x00\x65\x00\x64\x00\x65\x00\x73\x00\x20\x00\x65\x00\x6e\x00\x63\
+\x00\x6f\x00\x6e\x00\x74\x00\x72\x00\x61\x00\x72\x00\x20\x00\x65\
+\x00\x73\x00\x74\x00\x65\x00\x20\x00\x61\x00\x73\x00\x69\x00\x73\
+\x00\x74\x00\x65\x00\x6e\x00\x74\x00\x65\x00\x20\x00\x65\x00\x6e\
+\x00\x20\x00\x65\x00\x6c\x00\x20\x00\x6d\x00\x65\x00\x6e\x00\x75\
+\x00\x20\x00\x3c\x00\x69\x00\x3e\x00\x4f\x00\x70\x00\x63\x00\x69\
+\x00\x6f\x00\x6e\x00\x65\x00\x73\x00\x3c\x00\x2f\x00\x69\x00\x3e\
+\x00\x20\x00\x65\x00\x6e\x00\x20\x00\x6c\x00\x61\x00\x20\x00\x76\
+\x00\x65\x00\x6e\x00\x74\x00\x61\x00\x6e\x00\x61\x00\x20\x00\x70\
+\x00\x72\x00\x69\x00\x6e\x00\x63\x00\x69\x00\x70\x00\x61\x00\x6c\
+\x00\x2e\x00\x3c\x00\x62\x00\x72\x00\x3e\x00\x3c\x00\x62\x00\x72\
+\x00\x3e\x00\x51\x00\x75\x00\x65\x00\x20\x00\x64\x00\x65\x00\x73\
+\x00\x65\x00\x61\x00\x73\x00\x20\x00\x68\x00\x61\x00\x63\x00\x65\
+\x00\x72\x00\x20\x00\x61\x00\x68\x00\x6f\x00\x72\x00\x61\x00\x3f\
+\x00\x20\x00\x50\x00\x75\x00\x65\x00\x64\x00\x65\x00\x73\x00\x20\
+\x00\x3c\x00\x62\x00\x3e\x00\x72\x00\x65\x00\x67\x00\x69\x00\x73\
+\x00\x74\x00\x72\x00\x61\x00\x72\x00\x3c\x00\x2f\x00\x62\x00\x3e\
+\x00\x20\x00\x75\x00\x6e\x00\x61\x00\x20\x00\x6e\x00\x75\x00\x65\
+\x00\x76\x00\x61\x00\x20\x00\x63\x00\x75\x00\x65\x00\x6e\x00\x74\
+\x00\x61\x00\x20\x00\x6f\x00\x20\x00\x3c\x00\x62\x00\x3e\x00\x6c\
+\x00\x6f\x00\x67\x00\x75\x00\x65\x00\x61\x00\x72\x00\x74\x00\x65\
+\x00\x3c\x00\x2f\x00\x62\x00\x3e\x00\x20\x00\x63\x00\x6f\x00\x6e\
+\x00\x20\x00\x75\x00\x6e\x00\x61\x00\x20\x00\x71\x00\x75\x00\x65\
+\x00\x20\x00\x79\x00\x61\x00\x20\x00\x74\x00\x69\x00\x65\x00\x6e\
+\x00\x65\x00\x73\x00\x3f\x00\x3c\x00\x62\x00\x72\x00\x3e\x08\x00\
+\x00\x00\x00\x06\x00\x00\x01\x5d\x4e\x6f\x77\x20\x77\x65\x20\x77\
+\x69\x6c\x6c\x20\x67\x75\x69\x64\x65\x20\x79\x6f\x75\x20\x74\x68\
+\x72\x6f\x75\x67\x68\x20\x73\x6f\x6d\x65\x20\x63\x6f\x6e\x66\x69\
+\x67\x75\x72\x61\x74\x69\x6f\x6e\x20\x74\x68\x61\x74\x20\x69\x73\
+\x20\x6e\x65\x65\x64\x65\x64\x20\x62\x65\x66\x6f\x72\x65\x20\x79\
+\x6f\x75\x20\x63\x61\x6e\x20\x63\x6f\x6e\x6e\x65\x63\x74\x20\x66\
+\x6f\x72\x20\x74\x68\x65\x20\x66\x69\x72\x73\x74\x20\x74\x69\x6d\
+\x65\x2e\x3c\x62\x72\x3e\x3c\x62\x72\x3e\x49\x66\x20\x79\x6f\x75\
+\x20\x65\x76\x65\x72\x20\x6e\x65\x65\x64\x20\x74\x6f\x20\x6d\x6f\
+\x64\x69\x66\x79\x20\x74\x68\x65\x73\x65\x20\x6f\x70\x74\x69\x6f\
+\x6e\x73\x20\x61\x67\x61\x69\x6e\x2c\x20\x79\x6f\x75\x20\x63\x61\
+\x6e\x20\x66\x69\x6e\x64\x20\x74\x68\x65\x20\x77\x69\x7a\x61\x72\
+\x64\x20\x69\x6e\x20\x74\x68\x65\x20\x27\x3c\x69\x3e\x53\x65\x74\
+\x74\x69\x6e\x67\x73\x3c\x2f\x69\x3e\x27\x20\x6d\x65\x6e\x75\x20\
+\x66\x72\x6f\x6d\x20\x74\x68\x65\x20\x6d\x61\x69\x6e\x20\x77\x69\
+\x6e\x64\x6f\x77\x2e\x3c\x62\x72\x3e\x3c\x62\x72\x3e\x44\x6f\x20\
+\x79\x6f\x75\x20\x77\x61\x6e\x74\x20\x74\x6f\x20\x3c\x62\x3e\x73\
+\x69\x67\x6e\x20\x75\x70\x3c\x2f\x62\x3e\x20\x66\x6f\x72\x20\x61\
+\x20\x6e\x65\x77\x20\x61\x63\x63\x6f\x75\x6e\x74\x2c\x20\x6f\x72\
+\x20\x3c\x62\x3e\x6c\x6f\x67\x20\x69\x6e\x3c\x2f\x62\x3e\x20\x77\
+\x69\x74\x68\x20\x61\x6e\x20\x61\x6c\x72\x65\x61\x64\x79\x20\x65\
+\x78\x69\x73\x74\x69\x6e\x67\x20\x75\x73\x65\x72\x6e\x61\x6d\x65\
+\x3f\x3c\x62\x72\x3e\x07\x00\x00\x00\x09\x49\x6e\x74\x72\x6f\x50\
+\x61\x67\x65\x01\x03\x00\x00\x00\x36\x00\x52\x00\x65\x00\x67\x00\
+\x69\x00\x73\x00\x74\x00\x72\x00\x61\x00\x72\x00\x20\x00\x75\x00\
+\x6e\x00\x61\x00\x20\x00\x63\x00\x75\x00\x65\x00\x6e\x00\x74\x00\
+\x61\x00\x20\x00\x6e\x00\x75\x00\x65\x00\x76\x00\x61\x00\x2e\x08\
+\x00\x00\x00\x00\x06\x00\x00\x00\x1a\x53\x69\x67\x6e\x20\x75\x70\
+\x20\x66\x6f\x72\x20\x61\x20\x6e\x65\x77\x20\x61\x63\x63\x6f\x75\
+\x6e\x74\x2e\x07\x00\x00\x00\x09\x49\x6e\x74\x72\x6f\x50\x61\x67\
+\x65\x01\x88\x00\x00\x00\x02\x01\x01\
+"
+
+qt_resource_name = "\
+\x00\x0c\
+\x0d\xfc\x11\x13\
+\x00\x74\
+\x00\x72\x00\x61\x00\x6e\x00\x73\x00\x6c\x00\x61\x00\x74\x00\x69\x00\x6f\x00\x6e\x00\x73\
+\x00\x14\
+\x08\xa9\x0f\x1d\
+\x00\x6c\
+\x00\x65\x00\x61\x00\x70\x00\x5f\x00\x63\x00\x6c\x00\x69\x00\x65\x00\x6e\x00\x74\x00\x5f\x00\x65\x00\x73\x00\x5f\x00\x45\x00\x53\
+\x00\x2e\x00\x71\x00\x6d\
+"
+
+qt_resource_struct = "\
+\x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\
+\x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x02\
+\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\
+"
+
+def qInitResources():
+ QtCore.qRegisterResourceData(0x01, qt_resource_struct, qt_resource_name, qt_resource_data)
+
+def qCleanupResources():
+ QtCore.qUnregisterResourceData(0x01, qt_resource_struct, qt_resource_name, qt_resource_data)
+
+qInitResources()
diff --git a/src/leap/gui/mainwindow_rc.py b/src/leap/gui/mainwindow_rc.py
index be575159..5bee35c7 100644
--- a/src/leap/gui/mainwindow_rc.py
+++ b/src/leap/gui/mainwindow_rc.py
@@ -2,7 +2,7 @@
# Resource object code
#
-# Created: Thu Sep 13 16:12:58 2012
+# Created: Wed Nov 21 04:25:36 2012
# by: The Resource Compiler for PyQt (Qt v4.8.2)
#
# WARNING! All changes made in this file will be lost!
@@ -236,6 +236,87 @@ qt_resource_data = "\
\x71\xa4\x40\xda\x14\x7a\xd1\x73\x1f\xf4\x7f\xb7\xf9\x1f\xc2\x26\
\x56\xd5\x70\x45\xfc\x8a\x00\x00\x00\x00\x49\x45\x4e\x44\xae\x42\
\x60\x82\
+\x00\x00\x04\xec\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x06\xec\x00\x00\x06\xec\
+\x01\x1e\x75\x38\x35\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x00\x13\x74\x45\
+\x58\x74\x41\x75\x74\x68\x6f\x72\x00\x52\x6f\x64\x6e\x65\x79\x20\
+\x44\x61\x77\x65\x73\x0e\xd8\x7e\x1d\x00\x00\x04\x4a\x49\x44\x41\
+\x54\x48\x89\x8d\x96\x5d\x6c\x53\x65\x18\xc7\x7f\xef\x39\x6b\xbb\
+\x7e\x9c\x75\x65\xad\x2b\x9b\xfb\xd0\x31\xdd\x14\xb6\x8c\x19\x44\
+\x90\x44\x63\x82\x42\x88\x5e\x90\x98\xcc\x19\x15\x13\xd4\x18\x76\
+\x61\xd4\x18\xe3\x85\x57\xca\x05\xe1\xc2\x0c\xa3\xa8\x51\xd0\x4c\
+\x12\xe3\x85\x31\x80\x26\x6a\xe2\x85\x23\xb0\x38\xb6\xc1\x1c\xce\
+\xb1\x40\x59\xf6\xe5\xca\xda\xae\xed\xfa\x75\x7a\x5e\x2f\x4e\xd7\
+\x59\xd6\x32\xfe\xc9\x7b\xf3\x9e\xe7\xf9\xff\x9f\xe7\xff\x9e\xf3\
+\x9c\x57\x48\x29\x59\x0f\xbd\x7b\x85\x0d\x17\xed\x1e\xbb\xb2\x07\
+\x20\x94\x30\x7e\x22\xc6\x48\xcf\x59\x99\x5a\x2f\x57\x94\x12\xf8\
+\xec\x55\x61\x71\x65\x6d\x47\xfc\xbe\xda\x47\x9d\x5a\xa5\xbf\xda\
+\x69\xaf\xda\xe0\x28\x2f\x07\x58\x5c\x4e\x26\xe7\xe3\x89\x9b\xf1\
+\x68\x78\x6e\x6e\x61\xfa\x8f\x98\x9a\x7a\xfb\x95\xe3\x32\x73\xc7\
+\x02\x9f\x76\x89\x8e\xba\xda\xda\x2f\xb7\x37\xdf\xdf\xe6\x2a\x13\
+\x8a\x94\x06\x82\xc2\x38\x89\x40\x08\x85\x98\x2e\x8d\xf3\x13\xe3\
+\x97\xa6\xa6\xa7\x5f\x7e\xed\x94\x1c\x5a\x57\xa0\xef\xa0\xfd\x70\
+\x5b\xf3\x96\x03\xcd\xde\x8a\x6a\x61\x64\xd7\x73\xc0\x14\x53\x54\
+\x26\x82\x4b\xf3\x97\x26\x2e\x7f\xd5\xfd\x79\xe2\xdd\x92\x02\x27\
+\x5f\x2a\x7b\xe1\x89\xce\x1d\xc7\xbc\x76\x55\x13\xc5\x98\xac\x4e\
+\x10\x0a\xa4\xa2\x6b\x45\x80\x60\x22\x1b\xfd\x6d\xf0\xdc\xa1\x17\
+\x4f\xe8\x5f\xaf\x11\x38\xfa\x9c\xf0\x6e\xdb\xf4\xc0\xf9\x6d\xf5\
+\xfe\x26\x30\xf2\x89\xca\xc6\x76\xd4\x07\xf7\xa3\xd4\x74\x80\xd5\
+\x65\x6e\xa6\xe3\x64\x03\xfd\x64\x2f\x9e\x40\x46\x67\xff\x27\xa3\
+\x30\x70\x63\x6e\x72\xe0\xea\xd8\xf6\x37\xbf\x95\x41\x73\x27\x87\
+\x06\x8f\xa7\x6f\x6b\x7d\x4d\x01\x39\x80\x52\xff\x08\x4a\xe3\xae\
+\x55\xf2\x5c\x27\x6a\xf3\x6e\x2c\x7b\x8f\x9a\x5d\xe5\x61\xb0\xb5\
+\xbe\xa6\xa9\xc1\xe3\xe9\x5b\x95\x04\x7a\xbb\x44\x47\x5b\x53\xcb\
+\x4e\x15\xbd\x98\x31\xc8\x70\x00\xfd\xfc\xc7\x64\xce\xbc\x81\x7e\
+\xe1\x13\xc8\x75\x2d\xb4\x8d\x28\xb5\x0f\x15\xc4\xaa\xe8\xb4\x35\
+\xb5\xec\xec\xed\x12\x1d\x00\x65\x00\xee\x72\x65\x9f\x5f\x73\x38\
+\x05\x6b\x0f\x35\x3b\xf6\x03\xfa\xc0\xf1\x3c\x29\xb3\xc3\xa8\xf7\
+\x3e\x8e\xf0\xb5\x98\x22\xf6\x0d\x05\xf1\x02\xf0\x6b\x0e\xa7\xbb\
+\x5c\xd9\x07\x0c\x29\x00\x9a\xc3\xd5\x69\x55\xd5\xe2\xd5\x47\xe7\
+\x56\xc9\x01\xe1\xbe\x1b\xe1\xb9\x67\xf5\x79\x70\x7c\x4d\x8e\x55\
+\x55\xd1\x1c\xae\xce\xbc\x45\x15\x6e\x5f\x9d\x90\xc5\xed\x29\xa8\
+\xae\xa2\x06\xcb\x53\x47\xa0\xcc\x66\x76\x37\xfa\x3d\xd9\xa9\x81\
+\xb5\x71\x52\xa7\xc2\xed\xab\x83\x9c\x45\x76\xbb\x56\x25\xa5\xa4\
+\xe8\xab\xb9\x02\x9b\x86\x65\xf7\x87\x08\xcd\x6f\x92\x8f\x9f\x21\
+\xf5\xdd\xf3\xa0\xa7\x10\xe5\x6e\x44\x45\x2d\x38\x7d\x08\x21\x90\
+\xd2\xe4\xcc\x0b\x24\x12\xd1\x9b\x42\xbd\xab\x81\x6c\xba\x28\xb7\
+\x94\x06\x65\xcd\x4f\x22\x2a\x1b\x00\x30\xa6\xff\x24\xd5\xb7\x1f\
+\x74\x73\x14\xc9\x64\x04\x99\x8c\x80\xc5\x8e\xe2\xae\x03\xab\x93\
+\x44\x22\x7a\x33\x6f\xd1\x52\x64\x61\x0a\xb5\xbc\x28\xb1\xb1\x34\
+\x83\x91\xb3\xc1\x98\x1d\xc1\x98\x1d\x41\x3f\xd7\x9b\x27\x2f\x40\
+\x26\x81\x11\xfc\x07\x99\x8a\x99\x9c\x2b\x1d\x44\x97\x63\x83\xc9\
+\xe8\xfc\x33\x36\x23\x05\xaa\x05\xd2\xcb\xc8\x74\xcc\xfc\x88\x72\
+\x5d\xa5\x7f\x3c\x74\x3b\x03\x0b\x90\x52\xed\x44\x97\x63\x83\x79\
+\x81\x48\xd2\x38\x3d\x1b\xcf\xbc\x53\x1f\xb9\xe4\x44\x1a\x45\x93\
+\xac\xcf\x7e\x83\xda\xb8\xcb\x2c\xf4\xd7\xf7\xd1\x2f\x9e\x2c\xce\
+\x2e\x14\xe6\xd2\x65\xf1\x48\xd2\x38\x0d\x39\x8b\x7a\x4e\xc9\xa1\
+\xd1\xc0\xb5\xfe\xac\xb7\xb5\x64\x55\xc2\xe5\x47\x54\x36\x98\xe7\
+\x60\xd3\x4a\xc6\x65\xbd\xad\x8c\x06\xae\xf5\xf7\xe4\x26\x6b\x7e\
+\x54\x04\x42\xa1\xee\xe1\x90\x31\x29\x1c\xde\xd2\xbd\xaf\x03\xe1\
+\xf0\x32\x1c\x32\x26\x03\xa1\x50\x77\x7e\xef\xd6\x69\xfa\x58\x7b\
+\xe7\x31\x5f\x78\x54\x23\xb3\x5c\x90\xac\xf8\xdb\x10\x0e\xf3\xab\
+\x35\x82\x13\xc8\xa5\xe9\x42\x76\x8b\x83\x85\xca\xcd\xd1\xdf\x47\
+\x06\x8b\x4f\xd3\x15\xf4\x1d\xb4\x1f\xde\xd2\xd4\x7a\x60\x93\x1a\
+\xaa\x26\x74\xfd\xce\x4a\xf7\x34\x72\x35\xeb\x99\xbf\x3c\x79\xe5\
+\xf6\xff\x83\x15\x7c\xf0\xb4\xd8\xbe\xb9\xa9\xe6\x8b\x1d\x0d\xd5\
+\xad\xae\xd8\x94\x22\x13\x21\x90\xb7\xcc\x29\xa1\x22\xec\x1e\x62\
+\xae\x3a\xa3\xff\xfa\xfc\xdf\xe7\xc6\x66\x5e\x3f\xf2\x0b\xfd\x52\
+\x16\x8e\x84\x02\x01\x21\x84\x0a\x54\x01\x95\x9a\x1d\xdf\x7b\x7b\
+\xac\x6f\xdd\x57\xb7\xb1\x6d\x83\xbb\xd2\x53\xe3\x10\x2e\x9f\xcd\
+\xb0\x00\xfc\x9b\x54\xf4\x99\x84\x8c\x2d\x86\xc3\xe1\x2b\x81\xd9\
+\xbf\x0e\xff\x9c\xfe\x28\x9e\x22\x08\x84\x80\xb0\x94\x32\x5c\xb2\
+\x03\x21\x84\x13\xf0\x00\xee\xdc\xd2\x5c\x56\x3c\x5b\xeb\x69\x79\
+\xb8\x51\x74\x18\x12\xe5\xc2\x75\x39\x3c\x74\x83\xc9\x78\x86\x10\
+\x10\x03\x96\x80\x48\x6e\x2d\x4a\xb9\x7a\x01\x28\x79\xab\xc8\x89\
+\x59\x00\x2b\x60\xcb\x2d\x0b\xa0\x02\x3a\x90\x02\xd2\x40\x12\xc8\
+\x48\x79\xab\x87\x26\xfe\x03\x26\x93\xd5\x41\x51\x76\x98\xdb\x00\
+\x00\x00\x00\x49\x45\x4e\x44\xae\x42\x60\x82\
\x00\x00\x0b\xd7\
\x89\
\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
@@ -1491,6 +1572,180 @@ qt_resource_data = "\
\xc3\x25\x0d\x25\x35\x01\xd7\x0f\x5b\xb5\x7e\x8e\x93\x83\xff\x0f\
\x92\x04\x28\x92\xfd\x58\xc9\xac\x00\x00\x00\x00\x49\x45\x4e\x44\
\xae\x42\x60\x82\
+\x00\x00\x05\x24\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x06\xec\x00\x00\x06\xec\
+\x01\x1e\x75\x38\x35\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x00\x13\x74\x45\
+\x58\x74\x41\x75\x74\x68\x6f\x72\x00\x52\x6f\x64\x6e\x65\x79\x20\
+\x44\x61\x77\x65\x73\x0e\xd8\x7e\x1d\x00\x00\x04\x82\x49\x44\x41\
+\x54\x48\x89\x8d\x96\x6d\x88\x54\x55\x18\xc7\x7f\xe7\xbe\xcd\xdc\
+\xb9\xb3\xb3\x33\xfb\x1a\xeb\xbe\xe9\x6e\xad\x2f\x69\xea\x4a\x68\
+\x8a\x22\x24\x94\x58\x7e\xaa\x48\x23\x30\x41\xa1\xec\x5b\x51\x41\
+\xf8\x31\xa3\x3e\x44\x94\x44\x92\x09\xa5\x46\x54\x84\x21\x59\x44\
+\x18\x96\x19\xe8\xb2\x5a\x9b\xba\xea\xa4\xfb\x66\xee\xce\xce\xec\
+\xec\xcc\xdc\xbd\xf3\x76\xef\xe9\x43\xb3\x63\xe6\xbe\x3d\xf0\x7c\
+\xfb\x9f\xff\xef\x3c\xcf\x3d\xf7\x39\x47\x48\x29\x99\x2d\x36\xbf\
+\x27\x7c\x6a\x40\x7f\xc0\x6f\xaa\x8f\x02\x64\x1d\xf7\x84\x3b\x51\
+\xb8\xf0\xed\x8b\x32\x37\xdb\x5a\x31\x1d\x60\xf7\x01\xa1\x67\x2a\
+\x42\x6f\x37\xd7\x76\xac\xaf\x0a\xd6\xd5\x07\x2c\x7f\xb5\x69\x05\
+\x7c\x52\xba\x38\xb6\x93\xb3\xed\x6c\x62\xdc\x8e\xdd\xea\x8b\x5d\
+\x39\x15\x4c\xa7\x5e\xfe\x70\x97\x2c\xcc\x19\xb0\xf5\xa0\xb1\xa2\
+\xb9\x7e\xfe\xa1\x65\x1d\xab\x96\x7a\x5a\x56\xc9\xbb\x0e\x92\x3b\
+\x75\x02\x81\xa1\x9a\x50\x50\xbd\x9e\x2b\xe7\xff\xe8\x1f\xbe\xb1\
+\xe3\xd8\xce\x7c\xf7\xac\x80\x6d\x9f\x87\xf7\x2d\x69\x59\xf9\x5c\
+\x5d\x7d\x4d\x9d\xe3\xa6\x67\xeb\x00\x00\x01\x35\x44\x7c\x78\x74\
+\xe4\x42\x5f\xd7\xc7\x47\x9f\x4a\xbe\x36\x2d\xe0\x89\xc3\xe6\xb3\
+\x6b\x96\xae\xdf\x6f\x04\xf5\xa0\x27\xdd\x39\x99\x87\xb4\x7a\xaa\
+\xb4\x05\xa4\x8a\x43\x8c\x24\xaf\x65\xce\xf4\xfc\xfa\xc2\x17\xcf\
+\x38\x9f\xdc\x05\x78\xfc\x80\xa8\x59\xdc\xb6\xfc\xb7\xc6\xd6\x79\
+\x6d\x73\x33\x17\x6c\xa8\xdc\x43\x83\xb1\x0c\x9f\x08\x72\xc9\xf9\
+\x9e\x73\x99\x4f\x19\xb8\xd1\x17\xbd\x18\xed\x59\xfd\xcd\x2e\x39\
+\x0a\xa0\x4c\xca\xab\xab\xea\x8e\x34\x35\x37\xcc\xd1\x1c\x56\x5a\
+\x4f\xd2\xea\x5b\x8d\x5f\x54\x10\x2b\x46\x39\x97\x39\x8c\x2b\x5d\
+\x5a\x9a\x5b\xdb\xaa\xab\x6a\x8f\x4c\xea\x14\xf8\xf7\xa3\xde\xd7\
+\xd2\xb1\xd6\x15\x73\x33\xaf\xd1\xdb\xb9\xd7\xdc\x88\x82\x4a\xda\
+\x1d\xe1\x97\xd4\x07\xb8\xb2\x08\x40\x51\xb8\xb4\x37\xcf\x5f\xbb\
+\xf5\xa0\xb1\xa2\x0c\xf0\xfb\xd4\x2d\x56\xd0\xb4\x40\x20\xc4\x54\
+\xa9\x94\x53\x53\xfc\xac\x0b\xed\x22\xa0\x86\xc9\x4b\x87\x4b\xce\
+\x09\xd2\xee\x2d\x14\xa1\xa2\x08\x0d\x21\x54\x02\x96\xdf\xf2\xfb\
+\xd4\x2d\x00\x1a\x80\x65\x55\x76\xaa\x86\x86\xbc\xdd\xb1\x52\x97\
+\xc5\x5d\xbb\x5f\x57\xb1\x9b\x88\xd6\x02\xc0\xad\x42\x0f\x97\xb3\
+\x3f\xa2\x08\xed\x0e\x8d\x66\xe8\x58\x66\xa8\xb3\x0c\x08\x9b\x91\
+\x26\x4f\x4a\x54\x45\x9d\xd1\x7c\x81\xff\x21\x1a\x7d\xcb\xf1\x3c\
+\x8f\x91\x42\x2f\xc7\xe3\x7b\x31\xd5\x10\x86\x12\xb8\x43\x57\x14\
+\x0a\xe1\x40\xa8\xa9\x0c\x30\x0d\xb3\x5a\x11\x3a\x0a\x0a\x4c\x61\
+\x0c\x10\x50\x23\x2c\x31\x1f\x43\xb8\x3a\x49\xf7\x26\xc7\xe3\x7b\
+\x89\xe5\xa3\x00\x58\x6a\x15\x11\xa3\x11\x53\x0d\x23\x10\xe4\xbd\
+\x1c\xa6\xe1\xab\x2e\x03\x9c\xfc\x44\x5c\x57\x8c\x96\xff\x9a\x77\
+\x98\x9b\x08\x28\x61\xba\xed\x2f\xf1\x64\x91\x4e\x73\x1b\x16\xb5\
+\x14\x65\x8e\x0b\xf6\x57\x0c\xe5\xcf\x97\xb5\xb6\x9b\xc0\x76\x12\
+\xf8\x14\x8b\xb0\xde\x48\xde\x4b\xe1\xe4\xb3\xf1\x32\x20\x69\x8f\
+\x0d\x08\xe9\xad\x14\xc2\x00\x60\x51\xe0\x11\x96\x04\x36\x23\x50\
+\x18\x2b\x0c\x52\x70\x73\xdc\xa3\x2d\x06\xa0\x3f\x77\x96\x53\xa9\
+\xfd\x53\x56\x99\xf3\x6c\x86\x73\xbd\x54\xa9\xb5\x24\x27\xc6\x07\
+\xca\x00\xdb\x49\x75\x25\x9d\xa1\xad\xe8\x3a\xa6\x5a\x89\x81\x85\
+\xf0\x34\x3c\xcf\xa3\xcd\xd8\x80\x2e\x4c\x14\xa1\x13\x2b\x5c\xe5\
+\x58\xe2\x55\x60\xe6\x09\xac\xba\x3a\xb6\x93\xe9\x2a\x03\xb2\x39\
+\xf7\x78\x2e\x53\x7c\x65\x3c\xd0\x67\x49\x3c\xbc\xa2\xa4\xb5\x76\
+\x0d\x9a\xf0\x53\x55\x3a\x31\xb6\x1b\xe7\x64\xea\x1d\x26\xbc\xc4\
+\x8c\xe6\x02\x85\x82\x2d\xed\x6c\xce\x3d\x0e\xa5\xff\xe0\xd8\xce\
+\x7c\x77\x74\xa0\xff\x74\x8d\xd2\x00\xc0\x60\xbe\x9b\x44\xb1\xbf\
+\xbc\xc8\x93\x2e\x97\x9c\xef\xb8\xea\xfc\x34\xa3\x39\x40\x8d\x52\
+\x47\x74\x70\xf0\xf4\xe4\x64\x2d\x1f\xfc\x78\x22\xb6\x3d\x31\xe4\
+\x44\x4d\xc5\xc2\xc3\xe5\x77\xfb\x6b\x3c\x59\xa4\x20\x1d\x06\xf2\
+\xe7\xf8\x21\xf9\xe6\xac\xe6\xa6\x12\x20\x31\xe4\x44\xe3\x89\xf8\
+\xf6\x72\x45\xff\x9f\xa6\x9d\x0b\xef\x7f\x3f\xe9\xff\xbb\xc2\x95\
+\x1e\x0b\xcd\x87\xc9\xca\x34\xd7\xb3\x67\x98\xad\xef\x9a\xd0\x08\
+\x3b\xb5\xe9\xae\xde\xde\x3d\x53\x4e\xd3\xc9\xd8\xf6\x59\x78\x5f\
+\xc7\xbc\x96\x1d\x6a\x4d\xa1\x3e\xe9\xc6\x67\xdd\x35\x40\x58\x8d\
+\xe0\x8e\x6a\xc3\xbd\x43\xfd\x87\x8e\x3e\x9d\x9a\xfe\x3e\x98\x8c\
+\x4d\x6f\x19\xab\xdb\xda\xeb\x3f\x5a\xd0\x5e\xb7\x28\xad\x24\x94\
+\xac\x9c\xfa\x46\xf3\x0b\x3f\x15\x6e\xd8\xfb\xeb\xda\xc8\xe5\x8b\
+\x5d\xb1\xe7\x7f\xde\x57\x3c\x2d\x65\x69\xea\x4d\x05\x10\x42\xa8\
+\x40\x35\x10\x36\x2b\xa8\xdd\xf8\x7a\xf0\xa5\x96\xe6\xea\x65\xa1\
+\x50\x20\x12\x08\x2b\x41\xa3\x42\xd1\x01\x72\x29\xaf\x38\x91\x2c\
+\x66\xd2\xa9\x6c\xf2\xfa\xf5\xd8\x9f\x27\xdf\x98\x78\x37\x67\x33\
+\x0a\x8c\x01\x49\x29\x65\x72\xda\x0a\x84\x10\x16\x10\x01\x2a\x4b\
+\x59\x61\x04\x89\x34\x3c\xa8\x2c\x6c\x5a\xa5\xad\x90\x12\x65\xe0\
+\x6c\xf1\xfc\xcd\xb3\x5e\xb4\x60\x33\x06\x64\x80\x14\x30\x5e\xca\
+\x84\x94\xb7\x1f\x00\xd3\xbe\x2a\x4a\x30\x1d\x30\x00\x5f\x29\x75\
+\x40\x05\x8a\x40\x0e\xc8\x03\x59\xa0\x20\xe5\xd4\x37\xd5\x3f\x13\
+\x05\x02\x8c\xec\xcf\x7e\xae\x00\x00\x00\x00\x49\x45\x4e\x44\xae\
+\x42\x60\x82\
+\x00\x00\x05\x64\
+\x89\
+\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52\x00\
+\x00\x00\x18\x00\x00\x00\x18\x08\x06\x00\x00\x00\xe0\x77\x3d\xf8\
+\x00\x00\x00\x04\x73\x42\x49\x54\x08\x08\x08\x08\x7c\x08\x64\x88\
+\x00\x00\x00\x09\x70\x48\x59\x73\x00\x00\x0d\xd7\x00\x00\x0d\xd7\
+\x01\x42\x28\x9b\x78\x00\x00\x00\x19\x74\x45\x58\x74\x53\x6f\x66\
+\x74\x77\x61\x72\x65\x00\x77\x77\x77\x2e\x69\x6e\x6b\x73\x63\x61\
+\x70\x65\x2e\x6f\x72\x67\x9b\xee\x3c\x1a\x00\x00\x04\xe1\x49\x44\
+\x41\x54\x48\x89\xb5\x95\x4b\x6c\x54\x55\x18\xc7\x7f\xe7\x9c\x3b\
+\xd3\x99\x32\xb5\xf4\x41\x1f\x38\x02\xa5\x8d\x4c\x54\xb0\x25\x48\
+\xa2\x12\x62\x14\x17\x46\x12\x64\x81\x26\xb0\x31\x18\x12\x36\x04\
+\xd2\x45\xe9\xca\x84\x5d\x29\x1a\x43\x80\x6e\x0c\x2b\xdc\x88\x09\
+\xa0\x21\x04\x35\x18\xa2\x26\x3e\xa3\x62\xa8\x4f\x7c\x61\x95\x81\
+\xb6\x4c\xcb\xbc\x67\xee\x39\x9f\x8b\xdb\xde\x4e\x2b\x10\x37\xde\
+\xe4\x4b\xee\xf3\xf7\xff\xfe\xdf\xf9\xee\x77\x94\x88\xf0\x7f\x1e\
+\xde\xdd\x1e\x9e\x53\xaa\x23\x6a\xcc\x33\xf1\x58\xec\xc9\xc6\xae\
+\x15\x2b\xeb\x97\x76\x2e\xb5\xc5\x62\x29\xfb\xd7\x5f\x63\xb9\x6b\
+\xd7\x7f\x2c\x55\x2a\xe7\x2b\xd6\x7e\xb0\x59\xa4\x70\x27\x86\xba\
+\xad\x03\xa5\xd4\x87\xb1\xd8\xcb\xc9\xc7\x1e\xdd\xdd\xd9\xde\xd6\
+\x11\xab\x8b\xa3\x4a\x45\xc8\xe5\xc1\x18\x48\x2c\xc2\x79\x86\x5c\
+\x21\xef\x7e\x1f\xbd\x7c\x65\xe2\xa7\x5f\xf7\x3c\x59\xad\xbe\xf7\
+\x9f\x04\xce\x2b\xb5\xa2\x75\x69\xe7\x1b\xa9\x0d\x8f\xaf\x4f\x54\
+\x6d\x84\x52\x29\x7c\x26\xc0\xec\xfb\x32\x13\x34\x24\xb8\x91\x9f\
+\x9e\xfa\xe5\xe3\x4f\xdf\xb6\xb7\xb2\xbb\x9f\x10\x29\xd5\xf2\xe6\
+\x09\xbc\xa7\x54\x6a\xc5\xda\xde\xf7\x7b\xba\xbb\x93\x3a\x57\x98\
+\x83\x2c\x80\x2e\x14\x11\xcf\xc3\xc6\x23\xf2\xdd\xa7\x9f\x7d\x91\
+\x1f\xbb\xf6\xf8\x13\x22\xfe\x2c\x53\xcf\x9e\x1c\x50\x4a\xb7\x2e\
+\x5f\x76\xa2\x67\xf9\xb2\x24\xd9\x3c\x4e\x24\x0c\x0b\xd8\x9a\x6b\
+\x07\xc1\x3d\xc0\x01\xce\xf7\x21\x5b\x54\x3d\x0f\xaf\x5e\xa7\xeb\
+\xeb\x5f\xab\x75\x10\x0a\x6c\x8a\xc7\x87\xee\x7f\xe8\x81\x3e\x29\
+\x56\x02\x80\x08\x95\xed\xdb\xa9\x6e\xd9\x32\x27\x52\x03\xb5\x9e\
+\x87\x1d\x18\xc0\xef\xed\x0d\xc5\x8d\x55\x3a\xd9\xf7\xe0\x0b\xef\
+\x46\xa3\xeb\x67\xb9\x1e\xc0\x05\xa5\xee\x5f\xf9\xd8\xfa\x17\xeb\
+\x7c\x67\x1c\x0a\x01\xec\x8e\x1d\xf8\xcf\x3d\x17\x64\xa1\x14\xfa\
+\xf4\xe9\xb0\x3c\x2e\x12\x81\xfd\xfb\xa1\xb7\x17\xfa\xfa\x70\x43\
+\x43\xf0\xf5\xd7\x88\x15\x9a\x1b\x9b\x96\x34\x25\xdb\x5f\x47\xa9\
+\x5e\x44\x44\x03\x78\x9e\xf7\x52\xc7\xe2\xa6\x25\x0e\x15\x64\xe9\
+\x79\xd8\x9e\x9e\xd0\xa6\xdb\xbe\x1d\x7f\xeb\xd6\x20\xd3\x68\x14\
+\x06\x07\x51\x7d\x7d\x28\xa5\x50\x91\x08\x74\x77\x63\x95\xc2\x29\
+\x85\xad\x38\x5a\xdb\x5a\x97\x9d\x85\x9e\xd0\xc1\xa2\xd6\x96\x07\
+\x11\x70\x5a\x07\x59\x3a\x07\x43\x43\xa8\xc1\x41\x58\xbd\x3a\x50\
+\xd9\xb1\x03\x17\x89\xa0\x52\x29\xd4\x9a\x35\xa1\xb8\x3d\x79\x12\
+\xff\xcc\x19\xc4\x18\x10\x41\xb4\xa6\x2e\x16\x5b\x1c\xd1\xfa\x29\
+\xe0\x67\x0d\x10\x6b\x6d\xba\xcf\x39\xc1\x69\x3d\x17\xd6\xe2\x86\
+\x87\x91\xcb\x97\xe7\x16\xec\xf9\xe7\x43\xb8\x88\xe0\xbf\xf9\x26\
+\xd5\x53\xa7\x82\xc4\xb4\xc6\x79\x1e\x62\x0c\x26\x12\xa5\x2e\x11\
+\xdf\x00\xa0\xdf\x52\x2a\x51\xb7\x28\xde\xee\xbc\xc8\x7c\x01\xad\
+\x71\xce\xe1\xbf\xf2\x0a\x32\x3a\x3a\x57\x2e\xe7\xf0\x7d\x9f\xe2\
+\xa1\x43\x64\x07\x06\x28\x5f\xb9\x82\x2d\x95\x70\xc6\x04\xdf\x18\
+\x83\xad\x5a\xe2\x4b\x9a\x97\x87\x25\xc2\x98\x60\xe1\x6a\x7b\x7b\
+\x06\x28\xc6\x20\x4a\x81\x13\x9c\xb8\xf0\x1f\x70\x80\xb5\x16\x3f\
+\x9d\x46\xd2\x69\x74\x4b\x0b\x5e\x32\x89\x6e\x6e\xc6\x96\xcb\x88\
+\xb8\xa0\x8b\xb6\x89\xe4\xbe\x79\x68\xd5\x75\xb9\xd7\xb4\x8b\x95\
+\x00\xae\x54\x00\x8f\x44\x88\xec\xdb\x07\xa9\x14\xd6\xd9\xda\xf6\
+\x26\xd6\xdf\x8f\x03\x0a\x87\x0f\x07\x5d\x37\x31\x41\x65\x7c\x1c\
+\xe2\x71\xbc\xce\x26\xf2\x13\x53\x7f\x84\xff\x41\x71\x32\xf3\xa7\
+\x3f\x39\x81\xd3\x26\xb4\xea\x97\xca\xe8\x5d\xbb\x90\x55\xab\x70\
+\x2e\xc8\xa6\xf0\xea\xab\x54\x3f\xfa\x28\x14\xa9\xef\xef\x27\xb6\
+\x77\x2f\x56\x24\x08\xc0\x2f\x14\xa8\x7a\x9a\x72\xae\xf8\x71\x58\
+\xa2\xc2\x44\x66\xb4\x9c\xcd\x3c\xeb\x46\x7f\x40\x25\x12\x38\xdf\
+\x27\x3e\x32\x82\x5e\xbb\x36\x84\xe5\x87\x87\x29\x1c\x3b\x86\x44\
+\xa3\x34\x1e\x3f\x4e\xdd\xc6\x8d\x00\x24\xfa\xfb\x71\xbe\x4f\xf6\
+\xc8\x91\xb0\xb4\x3e\x4c\x39\xe7\x2e\x84\x0e\xf0\xfd\xe3\x19\xe7\
+\xc6\x2d\x8e\xca\xf8\x38\xd5\x4c\x06\xff\xd2\xa5\x10\x9e\x3b\x78\
+\x90\xdc\xd1\xa3\x41\x96\xe5\x32\x93\x3b\x77\x52\xba\x78\x31\x28\
+\x63\xb9\x4c\xe9\xdb\x6f\xf1\x67\x5c\xe8\xf6\x26\xa6\x27\x32\x57\
+\x37\xc3\x15\xa8\x19\x76\x17\xea\xeb\x87\xbb\xfa\x52\xfd\xf6\xf2\
+\x2f\x46\x24\x58\x8b\xc4\xe0\x20\xf6\xd6\x2d\x72\x23\x23\x41\x8f\
+\xd7\x36\x41\x34\x4a\xeb\xc8\x08\xd9\x13\x27\x28\x5c\xbc\x18\x0c\
+\xc4\xa8\x87\xac\xec\x18\xff\xfb\xab\x9f\x36\x6f\x16\xf9\x7c\x9e\
+\xc0\x01\xa5\xf4\xa6\x64\xfb\x67\x1d\x0d\x0d\xeb\x2a\x63\xd7\x91\
+\x05\xc0\xd9\xde\x9f\x27\x52\x73\x0d\x50\x97\xba\xcf\xa5\xbf\xff\
+\x7d\xe4\xe9\x42\x69\xcf\xac\xfb\x79\xe3\xfa\xac\x52\xa9\xb6\x54\
+\xd7\xfb\x8d\xc6\x24\x2b\x7f\xa4\xff\x05\x58\x08\x0d\x47\xb7\x67\
+\x88\x76\x77\xca\x74\xfa\xe6\x17\x92\x9e\x9c\x37\xae\x6f\xbb\xe1\
+\xc4\xda\x9a\xdf\x58\xb2\x62\xe9\x23\xfe\x6f\xd7\xa2\x36\x5f\x9c\
+\x2f\xb0\x40\xc4\x6b\x6b\x44\x9a\x13\x99\xc9\x1f\xc7\xde\x89\x15\
+\x4a\x77\xdf\x70\x00\x94\x52\x0d\x31\x68\x39\x12\x31\x03\xa9\x55\
+\xcb\xb7\x35\x18\xd3\xaa\x7c\x8b\x9d\xce\x53\x9d\xce\x83\xd1\x98\
+\xc5\x09\x74\x43\x0c\xab\x95\x4c\x66\x72\x63\x17\xae\x5e\x3f\x78\
+\x08\xde\x05\xa6\x80\x29\xb9\x93\x03\xa5\x54\x04\x68\x01\x16\x03\
+\x4d\xbd\xd0\xb5\x49\xeb\x8d\x5d\x9e\xe9\x6b\x69\x69\x68\x6b\x6c\
+\x5c\xd4\x58\xf5\xad\xbd\x99\xc9\x4d\xdf\x98\xce\x5f\x1b\xf5\xed\
+\x97\xe7\xe0\x93\x71\x48\xcf\xc2\x81\x8c\x88\x64\xef\xe8\x60\x46\
+\xa8\x1e\xb8\x67\x26\xea\x81\x38\x50\x07\x18\x40\x11\xec\x3b\x15\
+\xa0\x0c\xe4\x81\x1c\x70\x0b\xc8\xca\xec\x8c\x98\x39\xfe\x01\x76\
+\x95\xba\xf1\x06\x3a\xff\x81\x00\x00\x00\x00\x49\x45\x4e\x44\xae\
+\x42\x60\x82\
"
qt_resource_name = "\
@@ -1502,6 +1757,11 @@ qt_resource_name = "\
\x05\xcd\xf4\xe7\
\x00\x63\
\x00\x6f\x00\x6e\x00\x6e\x00\x5f\x00\x65\x00\x72\x00\x72\x00\x6f\x00\x72\x00\x2e\x00\x70\x00\x6e\x00\x67\
+\x00\x13\
+\x09\xd2\x6c\x67\
+\x00\x45\
+\x00\x6d\x00\x62\x00\x6c\x00\x65\x00\x6d\x00\x2d\x00\x71\x00\x75\x00\x65\x00\x73\x00\x74\x00\x69\x00\x6f\x00\x6e\x00\x2e\x00\x70\
+\x00\x6e\x00\x67\
\x00\x12\
\x04\xe4\x91\x47\
\x00\x63\
@@ -1521,16 +1781,28 @@ qt_resource_name = "\
\x00\x6c\
\x00\x65\x00\x61\x00\x70\x00\x2d\x00\x63\x00\x6f\x00\x6c\x00\x6f\x00\x72\x00\x2d\x00\x73\x00\x6d\x00\x61\x00\x6c\x00\x6c\x00\x2e\
\x00\x70\x00\x6e\x00\x67\
+\x00\x11\
+\x06\x1a\x44\xa7\
+\x00\x44\
+\x00\x69\x00\x61\x00\x6c\x00\x6f\x00\x67\x00\x2d\x00\x61\x00\x63\x00\x63\x00\x65\x00\x70\x00\x74\x00\x2e\x00\x70\x00\x6e\x00\x67\
+\
+\x00\x10\
+\x0f\xc3\x90\x67\
+\x00\x44\
+\x00\x69\x00\x61\x00\x6c\x00\x6f\x00\x67\x00\x2d\x00\x65\x00\x72\x00\x72\x00\x6f\x00\x72\x00\x2e\x00\x70\x00\x6e\x00\x67\
"
qt_resource_struct = "\
\x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\
-\x00\x00\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x02\
-\x00\x00\x00\xa8\x00\x00\x00\x00\x00\x01\x00\x00\x2d\x4e\
-\x00\x00\x00\x34\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xf7\
+\x00\x00\x00\x00\x00\x02\x00\x00\x00\x08\x00\x00\x00\x02\
+\x00\x00\x00\xd4\x00\x00\x00\x00\x00\x01\x00\x00\x32\x3e\
+\x00\x00\x00\x60\x00\x00\x00\x00\x00\x01\x00\x00\x12\xe7\
\x00\x00\x00\x12\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\
-\x00\x00\x00\x5e\x00\x00\x00\x00\x00\x01\x00\x00\x19\xd2\
-\x00\x00\x00\x7c\x00\x00\x00\x00\x00\x01\x00\x00\x20\xbd\
+\x00\x00\x01\x02\x00\x00\x00\x00\x00\x01\x00\x00\x60\xc7\
+\x00\x00\x00\x8a\x00\x00\x00\x00\x00\x01\x00\x00\x1e\xc2\
+\x00\x00\x00\x34\x00\x00\x00\x00\x00\x01\x00\x00\x0d\xf7\
+\x00\x00\x00\xa8\x00\x00\x00\x00\x00\x01\x00\x00\x25\xad\
+\x00\x00\x01\x2a\x00\x00\x00\x00\x00\x01\x00\x00\x65\xef\
"
def qInitResources():
diff --git a/src/leap/gui/progress.py b/src/leap/gui/progress.py
new file mode 100644
index 00000000..ca4f6cc3
--- /dev/null
+++ b/src/leap/gui/progress.py
@@ -0,0 +1,488 @@
+"""
+classes used in progress pages
+from first run wizard
+"""
+try:
+ from collections import OrderedDict
+except ImportError: # pragma: no cover
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+import logging
+
+from PyQt4 import QtCore
+from PyQt4 import QtGui
+
+from leap.gui.threads import FunThread
+
+from leap.gui import mainwindow_rc
+
+ICON_CHECKMARK = ":/images/Dialog-accept.png"
+ICON_FAILED = ":/images/Dialog-error.png"
+ICON_WAITING = ":/images/Emblem-question.png"
+
+logger = logging.getLogger(__name__)
+
+
+class ImgWidget(QtGui.QWidget):
+
+ # XXX move to widgets
+
+ def __init__(self, parent=None, img=None):
+ super(ImgWidget, self).__init__(parent)
+ self.pic = QtGui.QPixmap(img)
+
+ def paintEvent(self, event):
+ painter = QtGui.QPainter(self)
+ painter.drawPixmap(0, 0, self.pic)
+
+
+class ProgressStep(object):
+ """
+ Data model for sequential steps
+ to be used in a progress page in
+ connection wizard
+ """
+ NAME = 0
+ DONE = 1
+
+ def __init__(self, stepname, done, index=None):
+ """
+ @param step: the name of the step
+ @type step: str
+ @param done: whether is completed or not
+ @type done: bool
+ """
+ self.index = int(index) if index else 0
+ self.name = unicode(stepname)
+ self.done = bool(done)
+
+ @classmethod
+ def columns(self):
+ return ('name', 'done')
+
+
+class ProgressStepContainer(object):
+ """
+ a container for ProgressSteps objects
+ access data in the internal dict
+ """
+
+ def __init__(self):
+ self.dirty = False
+ self.steps = {}
+
+ def step(self, identity):
+ return self.steps.get(identity, None)
+
+ def addStep(self, step):
+ self.steps[step.index] = step
+
+ def removeStep(self, step):
+ if step and self.steps.get(step.index, None):
+ del self.steps[step.index]
+ del step
+ self.dirty = True
+
+ def removeAllSteps(self):
+ for item in iter(self):
+ self.removeStep(item)
+
+ @property
+ def columns(self):
+ return ProgressStep.columns()
+
+ def __len__(self):
+ return len(self.steps)
+
+ def __iter__(self):
+ for step in self.steps.values():
+ yield step
+
+
+class StepsTableWidget(QtGui.QTableWidget):
+ """
+ initializes a TableWidget
+ suitable for our display purposes, like removing
+ header info and grid display
+ """
+
+ def __init__(self, parent=None):
+ super(StepsTableWidget, self).__init__(parent=parent)
+
+ # remove headers and all edit/select behavior
+ self.horizontalHeader().hide()
+ self.verticalHeader().hide()
+ self.setEditTriggers(
+ QtGui.QAbstractItemView.NoEditTriggers)
+ self.setSelectionMode(
+ QtGui.QAbstractItemView.NoSelection)
+ width = self.width()
+
+ # WTF? Here init width is 100...
+ # but on populating is 456... :(
+ #logger.debug('init table. width=%s' % width)
+
+ # XXX do we need this initial?
+ self.horizontalHeader().resizeSection(0, width * 0.7)
+
+ # this disables the table grid.
+ # we should add alignment to the ImgWidget (it's top-left now)
+ self.setShowGrid(False)
+ self.setFocusPolicy(QtCore.Qt.NoFocus)
+ #self.setStyleSheet("QTableView{outline: 0;}")
+
+ # XXX change image for done to rc
+
+ # Note about the "done" status painting:
+ #
+ # XXX currently we are setting the CellWidget
+ # for the whole table on a per-row basis
+ # (on add_status_line method on ValidationPage).
+ # However, a more generic solution might be
+ # to implement a custom Delegate that overwrites
+ # the paint method (so it paints a checked tickmark if
+ # done is True and some other thing if checking or false).
+ # What we have now is quick and works because
+ # I'm supposing that on first fail we will
+ # go back to previous wizard page to signal the failure.
+ # A more generic solution could be used for
+ # some failing tests if they are not critical.
+
+
+class WithStepsMixIn(object):
+ """
+ This Class is a mixin that can be inherited
+ by InlineValidation pages (which will display
+ a progress steps widget in the same page as the form)
+ or by Validation Pages (which will only display
+ the progress steps in the page, below a progress bar widget)
+ """
+ STEPS_TIMER_MS = 100
+
+ #
+ # methods related to worker threads
+ # launched for individual checks
+ #
+
+ def setupStepsProcessingQueue(self):
+ """
+ should be called from the init method
+ of the derived classes
+ """
+ self.steps_queue = Queue.Queue()
+ self.stepscheck_timer = QtCore.QTimer()
+ self.stepscheck_timer.timeout.connect(self.processStepsQueue)
+ self.stepscheck_timer.start(self.STEPS_TIMER_MS)
+ # we need to keep a reference to child threads
+ self.threads = []
+
+ def do_checks(self):
+ """
+ main entry point for checks.
+ it calls _do_checks in derived classes,
+ and it expects it to be a generator
+ yielding a tuple in the form (("message", progress_int), checkfunction)
+ """
+
+ # yo dawg, I heard you like checks
+ # so I put a __do_checks in your do_checks
+ # for calling others' _do_checks
+
+ def __do_checks(fun=None, queue=None):
+
+ for checkcase in fun(): # pragma: no cover
+ checkmsg, checkfun = checkcase
+
+ queue.put(checkmsg)
+ if checkfun() is False:
+ queue.put("failed")
+ break
+
+ t = FunThread(fun=partial(
+ __do_checks,
+ fun=self._do_checks,
+ queue=self.steps_queue))
+ if hasattr(self, 'on_checks_validation_ready'):
+ t.finished.connect(self.on_checks_validation_ready)
+ t.begin()
+ self.threads.append(t)
+
+ def processStepsQueue(self):
+ """
+ consume steps queue
+ and pass messages
+ to the ui updater functions
+ """
+ while self.steps_queue.qsize():
+ try:
+ status = self.steps_queue.get(0)
+ if status == "failed":
+ self.set_failed_icon()
+ else:
+ self.onStepStatusChanged(*status)
+ except Queue.Empty: # pragma: no cover
+ pass
+
+ def fail(self, err=None):
+ """
+ return failed state
+ and send error notification as
+ a nice side effect. this function is called from
+ the _do_checks check functions returned in the
+ generator.
+ """
+ wizard = self.wizard()
+ senderr = lambda err: wizard.set_validation_error(
+ self.current_page, err)
+ self.set_undone()
+ if err:
+ senderr(err)
+ return False
+
+ @QtCore.pyqtSlot()
+ def launch_checks(self):
+ self.do_checks()
+
+ # (gui) presentation stuff begins #####################
+
+ # slot
+ #@QtCore.pyqtSlot(str, int)
+ def onStepStatusChanged(self, status, progress=None):
+ status = unicode(status)
+ if status not in ("head_sentinel", "end_sentinel"):
+ self.add_status_line(status)
+ if status in ("end_sentinel"):
+ #self.checks_finished = True
+ self.set_checked_icon()
+ if progress and hasattr(self, 'progress'):
+ self.progress.setValue(progress)
+ self.progress.update()
+
+ def setupSteps(self):
+ self.steps = ProgressStepContainer()
+ # steps table widget
+ if isinstance(self, QtCore.QObject):
+ parent = self
+ else:
+ parent = None
+ self.stepsTableWidget = StepsTableWidget(parent=parent)
+ zeros = (0, 0, 0, 0)
+ self.stepsTableWidget.setContentsMargins(*zeros)
+ self.errors = OrderedDict()
+
+ def set_error(self, name, error):
+ self.errors[name] = error
+
+ def pop_first_error(self):
+ errkey, errval = list(reversed(self.errors.items())).pop()
+ del self.errors[errkey]
+ return errkey, errval
+
+ def clean_errors(self):
+ self.errors = OrderedDict()
+
+ def clean_wizard_errors(self, pagename=None):
+ if pagename is None: # pragma: no cover
+ pagename = getattr(self, 'prev_page', None)
+ if pagename is None: # pragma: no cover
+ return
+ #logger.debug('cleaning wizard errors for %s' % pagename)
+ self.wizard().set_validation_error(pagename, None)
+
+ def populateStepsTable(self):
+ # from examples,
+ # but I guess it's not needed to re-populate
+ # the whole table.
+ table = self.stepsTableWidget
+ table.setRowCount(len(self.steps))
+ columns = self.steps.columns
+ table.setColumnCount(len(columns))
+
+ for row, step in enumerate(self.steps):
+ item = QtGui.QTableWidgetItem(step.name)
+ item.setData(QtCore.Qt.UserRole,
+ long(id(step)))
+ table.setItem(row, columns.index('name'), item)
+ table.setItem(row, columns.index('done'),
+ QtGui.QTableWidgetItem(step.done))
+ self.resizeTable()
+ self.update()
+
+ def clearTable(self):
+ # ??? -- not sure what's the difference
+ #self.stepsTableWidget.clear()
+ self.stepsTableWidget.clearContents()
+
+ def resizeTable(self):
+ # resize first column to ~80%
+ table = self.stepsTableWidget
+ FIRST_COLUMN_PERCENT = 0.70
+ width = table.width()
+ #logger.debug('populate table. width=%s' % width)
+ table.horizontalHeader().resizeSection(0, width * FIRST_COLUMN_PERCENT)
+
+ def set_item_icon(self, img=ICON_CHECKMARK, current=True):
+ """
+ mark the last item
+ as done
+ """
+ # setting cell widget.
+ # see note on StepsTableWidget about plans to
+ # change this for a better solution.
+ if not hasattr(self, 'steps'):
+ return
+ index = len(self.steps)
+ table = self.stepsTableWidget
+ _index = index - 1 if current else index - 2
+ table.setCellWidget(
+ _index,
+ ProgressStep.DONE,
+ ImgWidget(img=img))
+ table.update()
+
+ def set_failed_icon(self):
+ self.set_item_icon(img=ICON_FAILED, current=True)
+
+ def set_checking_icon(self):
+ self.set_item_icon(img=ICON_WAITING, current=True)
+
+ def set_checked_icon(self, current=True):
+ self.set_item_icon(current=current)
+
+ def add_status_line(self, message):
+ """
+ adds a new status line
+ and mark the next-to-last item
+ as done
+ """
+ index = len(self.steps)
+ step = ProgressStep(message, False, index=index)
+ self.steps.addStep(step)
+ self.populateStepsTable()
+ self.set_checking_icon()
+ self.set_checked_icon(current=False)
+
+ # Sets/unsets done flag
+ # for isComplete checks
+
+ def set_done(self):
+ self.done = True
+ self.completeChanged.emit()
+
+ def set_undone(self):
+ self.done = False
+ self.completeChanged.emit()
+
+ def is_done(self):
+ return self.done
+
+ # convenience for going back and forth
+ # in the wizard pages.
+
+ def go_back(self):
+ self.wizard().back()
+
+ def go_next(self):
+ self.wizard().next()
+
+
+"""
+We will use one base class for the intermediate pages
+and another one for the in-page validations, both sharing the creation
+of the tablewidgets.
+The logic of this split comes from where I was trying to solve
+the ui update using signals, but now that it's working well with
+queues I could join them again.
+"""
+
+import Queue
+from functools import partial
+
+
+class InlineValidationPage(QtGui.QWizardPage, WithStepsMixIn):
+
+ def __init__(self, parent=None):
+ super(InlineValidationPage, self).__init__(parent)
+ self.setupStepsProcessingQueue()
+ self.done = False
+
+ # slot
+
+ @QtCore.pyqtSlot()
+ def showStepsFrame(self):
+ self.valFrame.show()
+ self.update()
+
+ # progress frame
+
+ def setupValidationFrame(self):
+ qframe = QtGui.QFrame
+ valFrame = qframe()
+ valFrame.setFrameStyle(qframe.NoFrame)
+ valframeLayout = QtGui.QVBoxLayout()
+ zeros = (0, 0, 0, 0)
+ valframeLayout.setContentsMargins(*zeros)
+
+ valframeLayout.addWidget(self.stepsTableWidget)
+ valFrame.setLayout(valframeLayout)
+ self.valFrame = valFrame
+
+
+class ValidationPage(QtGui.QWizardPage, WithStepsMixIn):
+ """
+ class to be used as an intermediate
+ between two pages in a wizard.
+ shows feedback to the user and goes back if errors,
+ goes forward if ok.
+ initializePage triggers a one shot timer
+ that calls do_checks.
+ Derived classes should implement
+ _do_checks and
+ _do_validation
+ """
+
+ # signals
+ stepChanged = QtCore.pyqtSignal([str, int])
+
+ def __init__(self, parent=None):
+ super(ValidationPage, self).__init__(parent)
+ self.setupSteps()
+ #self.connect_step_status()
+
+ layout = QtGui.QVBoxLayout()
+ self.progress = QtGui.QProgressBar(self)
+ layout.addWidget(self.progress)
+ layout.addWidget(self.stepsTableWidget)
+
+ self.setLayout(layout)
+ self.layout = layout
+
+ self.timer = QtCore.QTimer()
+ self.done = False
+
+ self.setupStepsProcessingQueue()
+
+ def isComplete(self):
+ return self.is_done()
+
+ ########################
+
+ def show_progress(self):
+ self.progress.show()
+ self.stepsTableWidget.show()
+
+ def hide_progress(self):
+ self.progress.hide()
+ self.stepsTableWidget.hide()
+
+ # pagewizard methods.
+ # if overriden, child classes should call super.
+
+ def initializePage(self):
+ self.clean_errors()
+ self.clean_wizard_errors()
+ self.steps.removeAllSteps()
+ self.clearTable()
+ self.resizeTable()
+ self.timer.singleShot(0, self.do_checks)
diff --git a/src/leap/gui/styles.py b/src/leap/gui/styles.py
new file mode 100644
index 00000000..b482922e
--- /dev/null
+++ b/src/leap/gui/styles.py
@@ -0,0 +1,16 @@
+GreenLineEdit = "QLabel {color: green; font-weight: bold}"
+ErrorLabelStyleSheet = """QLabel { color: red; font-weight: bold }"""
+ErrorLineEdit = """QLineEdit { border: 1px solid red; }"""
+
+
+# XXX this is bad.
+# and you should feel bad for it.
+# The original style has a sort of box color
+# white/beige left-top/right-bottom or something like
+# that.
+
+RegularLineEdit = """
+QLineEdit {
+ border: 1px solid black;
+}
+"""
diff --git a/src/leap/gui/tests/__init__.py b/src/leap/gui/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/gui/tests/__init__.py
diff --git a/src/leap/gui/tests/integration/fake_user_signup.py b/src/leap/gui/tests/integration/fake_user_signup.py
index 12f18966..78873749 100644
--- a/src/leap/gui/tests/integration/fake_user_signup.py
+++ b/src/leap/gui/tests/integration/fake_user_signup.py
@@ -12,6 +12,7 @@ curl -d login=python_test_user -d password_salt=54321\
from BaseHTTPServer import HTTPServer
from BaseHTTPServer import BaseHTTPRequestHandler
import cgi
+import json
import urlparse
HOST = "localhost"
@@ -19,12 +20,15 @@ PORT = 8000
LOGIN_ERROR = """{"errors":{"login":["has already been taken"]}}"""
+from leap.base.tests.test_providers import EXPECTED_DEFAULT_CONFIG
+
class request_handler(BaseHTTPRequestHandler):
responses = {
'/': ['ok\n'],
'/users.json': ['ok\n'],
- '/timeout': ['ok\n']
+ '/timeout': ['ok\n'],
+ '/provider.json': ['%s\n' % json.dumps(EXPECTED_DEFAULT_CONFIG)]
}
def do_GET(self):
diff --git a/src/leap/gui/tests/test_firstrun_login.py b/src/leap/gui/tests/test_firstrun_login.py
new file mode 100644
index 00000000..6c45b8ef
--- /dev/null
+++ b/src/leap/gui/tests/test_firstrun_login.py
@@ -0,0 +1,212 @@
+import sys
+import unittest
+
+import mock
+
+from leap.testing import qunittest
+#from leap.testing import pyqt
+
+from PyQt4 import QtGui
+#from PyQt4 import QtCore
+#import PyQt4.QtCore # some weirdness with mock module
+
+from PyQt4.QtTest import QTest
+from PyQt4.QtCore import Qt
+
+from leap.gui import firstrun
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+
+class TestPage(firstrun.login.LogInPage):
+ pass
+
+
+class LogInPageLogicTestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+ __name__ = "register user page logic tests"
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.page = TestPage(None)
+ self.page.wizard = mock.MagicMock()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.page = None
+
+ def test__do_checks(self):
+ eq = self.assertEqual
+
+ self.page.userNameLineEdit.setText('testuser@domain')
+ self.page.userPasswordLineEdit.setText('testpassword')
+
+ # fake register process
+ with mock.patch('leap.base.auth.LeapSRPRegister') as mockAuth:
+ mockSignup = mock.MagicMock()
+
+ reqMockup = mock.Mock()
+ # XXX should inject bad json to get error
+ reqMockup.content = '{"errors": null}'
+ mockSignup.register_user.return_value = (True, reqMockup)
+ mockAuth.return_value = mockSignup
+ checks = [x for x in self.page._do_checks()]
+
+ eq(len(checks), 4)
+ labels = [str(x) for (x, y), z in checks]
+ eq(labels, ['head_sentinel',
+ 'Resolving domain name',
+ 'Validating credentials',
+ 'end_sentinel'])
+ progress = [y for (x, y), z in checks]
+ eq(progress, [0, 20, 60, 100])
+
+ # normal run, ie, no exceptions
+
+ checkfuns = [z for (x, y), z in checks]
+ checkusername, resolvedomain, valcreds = checkfuns[:-1]
+
+ self.assertTrue(checkusername())
+ #self.mocknetchecker.check_name_resolution.assert_called_with(
+ #'test_provider1')
+
+ self.assertTrue(resolvedomain())
+ #self.mockpcertchecker.is_https_working.assert_called_with(
+ #"https://test_provider1", verify=True)
+
+ self.assertTrue(valcreds())
+
+ # XXX missing: inject failing exceptions
+ # XXX TODO make it break
+
+
+class RegisterUserPageUITestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+ __name__ = "Register User Page UI tests"
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+
+ self.pagename = "signup"
+ pages = OrderedDict((
+ (self.pagename, TestPage),
+ ('providersetupvalidation',
+ firstrun.connect.ConnectionPage)))
+ self.wizard = firstrun.wizard.FirstRunWizard(None, pages_dict=pages)
+ self.page = self.wizard.page(self.wizard.get_page_index(self.pagename))
+
+ self.page.do_checks = mock.Mock()
+
+ # wizard would do this for us
+ self.page.initializePage()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.wizard = None
+
+ # XXX refactor out
+ def fill_field(self, field, text):
+ """
+ fills a field (line edit) that is passed along
+ :param field: the qLineEdit
+ :param text: the text to be filled
+ :type field: QLineEdit widget
+ :type text: str
+ """
+ keyp = QTest.keyPress
+ field.setFocus(True)
+ for c in text:
+ keyp(field, c)
+ self.assertEqual(field.text(), text)
+
+ def del_field(self, field):
+ """
+ deletes entried text in
+ field line edit
+ :param field: the QLineEdit
+ :type field: QLineEdit widget
+ """
+ keyp = QTest.keyPress
+ for c in range(len(field.text())):
+ keyp(field, Qt.Key_Backspace)
+ self.assertEqual(field.text(), "")
+
+ def test_buttons_disabled_until_textentry(self):
+ # it's a commit button this time
+ nextbutton = self.wizard.button(QtGui.QWizard.CommitButton)
+
+ self.assertFalse(nextbutton.isEnabled())
+
+ f_username = self.page.userNameLineEdit
+ f_password = self.page.userPasswordLineEdit
+
+ self.fill_field(f_username, "testuser")
+ self.fill_field(f_password, "testpassword")
+
+ # commit should be enabled
+ # XXX Need a workaround here
+ # because the isComplete is not being evaluated...
+ # (no event loop running??)
+ #import ipdb;ipdb.set_trace()
+ #self.assertTrue(nextbutton.isEnabled())
+ self.assertTrue(self.page.isComplete())
+
+ self.del_field(f_username)
+ self.del_field(f_password)
+
+ # after rm fields commit button
+ # should be disabled again
+ #self.assertFalse(nextbutton.isEnabled())
+ self.assertFalse(self.page.isComplete())
+
+ def test_validate_page(self):
+ self.assertFalse(self.page.validatePage())
+ # XXX TODO MOAR CASES...
+ # add errors, False
+ # change done, False
+ # not done, do_checks called
+ # click confirm, True
+ # done and do_confirm, True
+
+ def test_next_id(self):
+ self.assertEqual(self.page.nextId(), 1)
+
+ def test_paint_event(self):
+ self.page.populateErrors = mock.Mock()
+ self.page.paintEvent(None)
+ self.page.populateErrors.assert_called_with()
+
+ def test_validation_ready(self):
+ f_username = self.page.userNameLineEdit
+ f_password = self.page.userPasswordLineEdit
+
+ self.fill_field(f_username, "testuser")
+ self.fill_field(f_password, "testpassword")
+
+ self.page.done = True
+ self.page.on_checks_validation_ready()
+ self.assertFalse(f_username.isEnabled())
+ self.assertFalse(f_password.isEnabled())
+
+ self.assertEqual(self.page.validationMsg.text(),
+ "Credentials validated.")
+ self.assertEqual(self.page.do_confirm_next, True)
+
+ def test_regex(self):
+ # XXX enter invalid username with key presses
+ # check text is not updated
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/tests/test_firstrun_providerselect.py b/src/leap/gui/tests/test_firstrun_providerselect.py
new file mode 100644
index 00000000..18d89010
--- /dev/null
+++ b/src/leap/gui/tests/test_firstrun_providerselect.py
@@ -0,0 +1,203 @@
+import sys
+import unittest
+
+import mock
+
+from leap.testing import qunittest
+#from leap.testing import pyqt
+
+from PyQt4 import QtGui
+#from PyQt4 import QtCore
+#import PyQt4.QtCore # some weirdness with mock module
+
+from PyQt4.QtTest import QTest
+from PyQt4.QtCore import Qt
+
+from leap.gui import firstrun
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+
+class TestPage(firstrun.providerselect.SelectProviderPage):
+ pass
+
+
+class SelectProviderPageLogicTestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.page = TestPage(None)
+ self.page.wizard = mock.MagicMock()
+
+ mocknetchecker = mock.Mock()
+ self.page.wizard().netchecker.return_value = mocknetchecker
+ self.mocknetchecker = mocknetchecker
+
+ mockpcertchecker = mock.Mock()
+ self.page.wizard().providercertchecker.return_value = mockpcertchecker
+ self.mockpcertchecker = mockpcertchecker
+
+ mockeipconfchecker = mock.Mock()
+ self.page.wizard().eipconfigchecker.return_value = mockeipconfchecker
+ self.mockeipconfchecker = mockeipconfchecker
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.page = None
+
+ def test__do_checks(self):
+ eq = self.assertEqual
+
+ self.page.providerNameEdit.setText('test_provider1')
+
+ checks = [x for x in self.page._do_checks()]
+ eq(len(checks), 5)
+ labels = [str(x) for (x, y), z in checks]
+ eq(labels, ['head_sentinel',
+ 'Checking if it is a valid provider',
+ 'Checking for a secure connection',
+ 'Getting info from the provider',
+ 'end_sentinel'])
+ progress = [y for (x, y), z in checks]
+ eq(progress, [0, 20, 40, 80, 100])
+
+ # normal run, ie, no exceptions
+
+ checkfuns = [z for (x, y), z in checks]
+ namecheck, httpscheck, fetchinfo = checkfuns[1:-1]
+
+ self.assertTrue(namecheck())
+ self.mocknetchecker.check_name_resolution.assert_called_with(
+ 'test_provider1')
+
+ self.assertTrue(httpscheck())
+ self.mockpcertchecker.is_https_working.assert_called_with(
+ "https://test_provider1", verify=True)
+
+ self.assertTrue(fetchinfo())
+ self.mockeipconfchecker.fetch_definition.assert_called_with(
+ domain="test_provider1")
+
+ # XXX missing: inject failing exceptions
+ # XXX TODO make it break
+
+
+class SelectProviderPageUITestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+ __name__ = "Select Provider Page UI tests"
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+
+ self.pagename = "providerselection"
+ pages = OrderedDict((
+ (self.pagename, TestPage),
+ ('providerinfo',
+ firstrun.providerinfo.ProviderInfoPage)))
+ self.wizard = firstrun.wizard.FirstRunWizard(None, pages_dict=pages)
+ self.page = self.wizard.page(self.wizard.get_page_index(self.pagename))
+
+ self.page.do_checks = mock.Mock()
+
+ # wizard would do this for us
+ self.page.initializePage()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.wizard = None
+
+ def fill_provider(self):
+ """
+ fills provider line edit
+ """
+ keyp = QTest.keyPress
+ pedit = self.page.providerNameEdit
+ pedit.setFocus(True)
+ for c in "testprovider":
+ keyp(pedit, c)
+ self.assertEqual(pedit.text(), "testprovider")
+
+ def del_provider(self):
+ """
+ deletes entried provider in
+ line edit
+ """
+ keyp = QTest.keyPress
+ pedit = self.page.providerNameEdit
+ for c in range(len("testprovider")):
+ keyp(pedit, Qt.Key_Backspace)
+ self.assertEqual(pedit.text(), "")
+
+ def test_buttons_disabled_until_textentry(self):
+ nextbutton = self.wizard.button(QtGui.QWizard.NextButton)
+ checkbutton = self.page.providerCheckButton
+
+ self.assertFalse(nextbutton.isEnabled())
+ self.assertFalse(checkbutton.isEnabled())
+
+ self.fill_provider()
+ # checkbutton should be enabled
+ self.assertTrue(checkbutton.isEnabled())
+ self.assertFalse(nextbutton.isEnabled())
+
+ self.del_provider()
+ # after rm provider checkbutton disabled again
+ self.assertFalse(checkbutton.isEnabled())
+ self.assertFalse(nextbutton.isEnabled())
+
+ def test_check_button_triggers_tests(self):
+ checkbutton = self.page.providerCheckButton
+ self.assertFalse(checkbutton.isEnabled())
+ self.assertFalse(self.page.do_checks.called)
+
+ self.fill_provider()
+
+ self.assertTrue(checkbutton.isEnabled())
+ mclick = QTest.mouseClick
+ # click!
+ mclick(checkbutton, Qt.LeftButton)
+ self.waitFor(seconds=0.1)
+ self.assertTrue(self.page.do_checks.called)
+
+ # XXX
+ # can play with different side_effects for do_checks mock...
+ # so we can see what happens with errors and so on
+
+ def test_page_completed_after_checks(self):
+ nextbutton = self.wizard.button(QtGui.QWizard.NextButton)
+ self.assertFalse(nextbutton.isEnabled())
+
+ self.assertFalse(self.page.isComplete())
+ self.fill_provider()
+ # simulate checks done
+ self.page.done = True
+ self.page.on_checks_validation_ready()
+ self.assertTrue(self.page.isComplete())
+ # cannot test for nexbutton enabled
+ # cause it's the the wizard loop
+ # that would do that I think
+
+ def test_validate_page(self):
+ self.assertTrue(self.page.validatePage())
+
+ def test_next_id(self):
+ self.assertEqual(self.page.nextId(), 1)
+
+ def test_paint_event(self):
+ self.page.populateErrors = mock.Mock()
+ self.page.paintEvent(None)
+ self.page.populateErrors.assert_called_with()
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/tests/test_firstrun_register.py b/src/leap/gui/tests/test_firstrun_register.py
new file mode 100644
index 00000000..9d62f808
--- /dev/null
+++ b/src/leap/gui/tests/test_firstrun_register.py
@@ -0,0 +1,244 @@
+import sys
+import unittest
+
+import mock
+
+from leap.testing import qunittest
+#from leap.testing import pyqt
+
+from PyQt4 import QtGui
+#from PyQt4 import QtCore
+#import PyQt4.QtCore # some weirdness with mock module
+
+from PyQt4.QtTest import QTest
+from PyQt4.QtCore import Qt
+
+from leap.gui import firstrun
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ # We must be in 2.6
+ from leap.util.dicts import OrderedDict
+
+
+class TestPage(firstrun.register.RegisterUserPage):
+
+ def field(self, field):
+ if field == "provider_domain":
+ return "testprovider"
+
+
+class RegisterUserPageLogicTestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+ __name__ = "register user page logic tests"
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.page = TestPage(None)
+ self.page.wizard = mock.MagicMock()
+
+ #mocknetchecker = mock.Mock()
+ #self.page.wizard().netchecker.return_value = mocknetchecker
+ #self.mocknetchecker = mocknetchecker
+#
+ #mockpcertchecker = mock.Mock()
+ #self.page.wizard().providercertchecker.return_value = mockpcertchecker
+ #self.mockpcertchecker = mockpcertchecker
+#
+ #mockeipconfchecker = mock.Mock()
+ #self.page.wizard().eipconfigchecker.return_value = mockeipconfchecker
+ #self.mockeipconfchecker = mockeipconfchecker
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.page = None
+
+ def test__do_checks(self):
+ eq = self.assertEqual
+
+ self.page.userNameLineEdit.setText('testuser')
+ self.page.userPasswordLineEdit.setText('testpassword')
+ self.page.userPassword2LineEdit.setText('testpassword')
+
+ # fake register process
+ with mock.patch('leap.base.auth.LeapSRPRegister') as mockAuth:
+ mockSignup = mock.MagicMock()
+
+ reqMockup = mock.Mock()
+ # XXX should inject bad json to get error
+ reqMockup.content = '{"errors": null}'
+ mockSignup.register_user.return_value = (True, reqMockup)
+ mockAuth.return_value = mockSignup
+ checks = [x for x in self.page._do_checks()]
+
+ eq(len(checks), 3)
+ labels = [str(x) for (x, y), z in checks]
+ eq(labels, ['head_sentinel',
+ 'Registering username',
+ 'end_sentinel'])
+ progress = [y for (x, y), z in checks]
+ eq(progress, [0, 40, 100])
+
+ # normal run, ie, no exceptions
+
+ checkfuns = [z for (x, y), z in checks]
+ passcheck, register = checkfuns[:-1]
+
+ self.assertTrue(passcheck())
+ #self.mocknetchecker.check_name_resolution.assert_called_with(
+ #'test_provider1')
+
+ self.assertTrue(register())
+ #self.mockpcertchecker.is_https_working.assert_called_with(
+ #"https://test_provider1", verify=True)
+
+ # XXX missing: inject failing exceptions
+ # XXX TODO make it break
+
+
+class RegisterUserPageUITestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+ __name__ = "Register User Page UI tests"
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+
+ self.pagename = "signup"
+ pages = OrderedDict((
+ (self.pagename, TestPage),
+ ('connect',
+ firstrun.connect.ConnectionPage)))
+ self.wizard = firstrun.wizard.FirstRunWizard(None, pages_dict=pages)
+ self.page = self.wizard.page(self.wizard.get_page_index(self.pagename))
+
+ self.page.do_checks = mock.Mock()
+
+ # wizard would do this for us
+ self.page.initializePage()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.wizard = None
+
+ def fill_field(self, field, text):
+ """
+ fills a field (line edit) that is passed along
+ :param field: the qLineEdit
+ :param text: the text to be filled
+ :type field: QLineEdit widget
+ :type text: str
+ """
+ keyp = QTest.keyPress
+ field.setFocus(True)
+ for c in text:
+ keyp(field, c)
+ self.assertEqual(field.text(), text)
+
+ def del_field(self, field):
+ """
+ deletes entried text in
+ field line edit
+ :param field: the QLineEdit
+ :type field: QLineEdit widget
+ """
+ keyp = QTest.keyPress
+ for c in range(len(field.text())):
+ keyp(field, Qt.Key_Backspace)
+ self.assertEqual(field.text(), "")
+
+ def test_buttons_disabled_until_textentry(self):
+ # it's a commit button this time
+ nextbutton = self.wizard.button(QtGui.QWizard.CommitButton)
+
+ self.assertFalse(nextbutton.isEnabled())
+
+ f_username = self.page.userNameLineEdit
+ f_password = self.page.userPasswordLineEdit
+ f_passwor2 = self.page.userPassword2LineEdit
+
+ self.fill_field(f_username, "testuser")
+ self.fill_field(f_password, "testpassword")
+ self.fill_field(f_passwor2, "testpassword")
+
+ # commit should be enabled
+ # XXX Need a workaround here
+ # because the isComplete is not being evaluated...
+ # (no event loop running??)
+ #import ipdb;ipdb.set_trace()
+ #self.assertTrue(nextbutton.isEnabled())
+ self.assertTrue(self.page.isComplete())
+
+ self.del_field(f_username)
+ self.del_field(f_password)
+ self.del_field(f_passwor2)
+
+ # after rm fields commit button
+ # should be disabled again
+ #self.assertFalse(nextbutton.isEnabled())
+ self.assertFalse(self.page.isComplete())
+
+ @unittest.skip
+ def test_check_button_triggers_tests(self):
+ checkbutton = self.page.providerCheckButton
+ self.assertFalse(checkbutton.isEnabled())
+ self.assertFalse(self.page.do_checks.called)
+
+ self.fill_provider()
+
+ self.assertTrue(checkbutton.isEnabled())
+ mclick = QTest.mouseClick
+ # click!
+ mclick(checkbutton, Qt.LeftButton)
+ self.waitFor(seconds=0.1)
+ self.assertTrue(self.page.do_checks.called)
+
+ # XXX
+ # can play with different side_effects for do_checks mock...
+ # so we can see what happens with errors and so on
+
+ def test_validate_page(self):
+ self.assertFalse(self.page.validatePage())
+ # XXX TODO MOAR CASES...
+ # add errors, False
+ # change done, False
+ # not done, do_checks called
+ # click confirm, True
+ # done and do_confirm, True
+
+ def test_next_id(self):
+ self.assertEqual(self.page.nextId(), 1)
+
+ def test_paint_event(self):
+ self.page.populateErrors = mock.Mock()
+ self.page.paintEvent(None)
+ self.page.populateErrors.assert_called_with()
+
+ def test_validation_ready(self):
+ f_username = self.page.userNameLineEdit
+ f_password = self.page.userPasswordLineEdit
+ f_passwor2 = self.page.userPassword2LineEdit
+
+ self.fill_field(f_username, "testuser")
+ self.fill_field(f_password, "testpassword")
+ self.fill_field(f_passwor2, "testpassword")
+
+ self.page.done = True
+ self.page.on_checks_validation_ready()
+ self.assertFalse(f_username.isEnabled())
+ self.assertFalse(f_password.isEnabled())
+ self.assertFalse(f_passwor2.isEnabled())
+
+ self.assertEqual(self.page.validationMsg.text(),
+ "Registration succeeded!")
+ self.assertEqual(self.page.do_confirm_next, True)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/tests/test_firstrun_wizard.py b/src/leap/gui/tests/test_firstrun_wizard.py
new file mode 100644
index 00000000..395604d3
--- /dev/null
+++ b/src/leap/gui/tests/test_firstrun_wizard.py
@@ -0,0 +1,137 @@
+import sys
+import unittest
+
+import mock
+
+from leap.testing import qunittest
+from leap.testing import pyqt
+
+from PyQt4 import QtGui
+#from PyQt4 import QtCore
+import PyQt4.QtCore # some weirdness with mock module
+
+from PyQt4.QtTest import QTest
+#from PyQt4.QtCore import Qt
+
+from leap.gui import firstrun
+
+
+class TestWizard(firstrun.wizard.FirstRunWizard):
+ pass
+
+
+PAGES_DICT = dict((
+ ('intro', firstrun.intro.IntroPage),
+ ('providerselection',
+ firstrun.providerselect.SelectProviderPage),
+ ('login', firstrun.login.LogInPage),
+ ('providerinfo', firstrun.providerinfo.ProviderInfoPage),
+ ('providersetupvalidation',
+ firstrun.providersetup.ProviderSetupValidationPage),
+ ('signup', firstrun.register.RegisterUserPage),
+ ('connect',
+ firstrun.connect.ConnectionPage),
+ ('lastpage', firstrun.last.LastPage)
+))
+
+
+mockQSettings = mock.MagicMock()
+mockQSettings().setValue.return_value = True
+
+#PyQt4.QtCore.QSettings = mockQSettings
+
+
+class FirstRunWizardTestCase(qunittest.TestCase):
+
+ # XXX can spy on signal connections
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.wizard = TestWizard(None)
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+ self.wizard = None
+
+ def test_defaults(self):
+ self.assertEqual(self.wizard.pages_dict, PAGES_DICT)
+
+ @mock.patch('PyQt4.QtCore.QSettings', mockQSettings)
+ def test_accept(self):
+ """
+ test the main accept method
+ that gets called when user has gone
+ thru all the wizard and click on finish button
+ """
+
+ self.wizard.success_cb = mock.Mock()
+ self.wizard.success_cb.return_value = True
+
+ # dummy values; we inject them in the field
+ # mocks (where wizard gets them) and then
+ # we check that they are passed to QSettings.setValue
+ field_returns = ["testuser", "1234", "testprovider", True]
+
+ def field_side_effects(*args):
+ return field_returns.pop(0)
+
+ self.wizard.field = mock.Mock(side_effect=field_side_effects)
+ self.wizard.get_random_str = mock.Mock()
+ RANDOMSTR = "thisisarandomstringTM"
+ self.wizard.get_random_str.return_value = RANDOMSTR
+
+ # mocked settings (see decorator on this method)
+ mqs = PyQt4.QtCore.QSettings
+
+ # go! call accept...
+ self.wizard.accept()
+
+ # did settings().setValue get called with the proper
+ # arguments?
+ call = mock.call
+ calls = [call("FirstRunWizardDone", True),
+ call("provider_domain", "testprovider"),
+ call("remember_user_and_pass", True),
+ call("username", "testuser@testprovider"),
+ call("testprovider_seed", RANDOMSTR)]
+ mqs().setValue.assert_has_calls(calls, any_order=True)
+
+ # assert success callback is success oh boy
+ self.wizard.success_cb.assert_called_with()
+
+ def test_random_str(self):
+ r = self.wizard.get_random_str(42)
+ self.assertTrue(len(r) == 42)
+
+ def test_page_index(self):
+ """
+ we test both the get_page_index function
+ and the correct ordering of names
+ """
+ # remember it's implemented as an ordered dict
+
+ pagenames = ('intro', 'providerselection', 'login', 'providerinfo',
+ 'providersetupvalidation', 'signup', 'connect',
+ 'lastpage')
+ eq = self.assertEqual
+ w = self.wizard
+ for index, name in enumerate(pagenames):
+ eq(w.get_page_index(name), index)
+
+ def test_validation_errors(self):
+ """
+ tests getters and setters for validation errors
+ """
+ page = "testpage"
+ eq = self.assertEqual
+ w = self.wizard
+ eq(w.get_validation_error(page), None)
+ w.set_validation_error(page, "error")
+ eq(w.get_validation_error(page), "error")
+ w.clean_validation_error(page)
+ eq(w.get_validation_error(page), None)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/test_mainwindow_rc.py b/src/leap/gui/tests/test_mainwindow_rc.py
index 88ae5854..67b9fae0 100644
--- a/src/leap/gui/test_mainwindow_rc.py
+++ b/src/leap/gui/tests/test_mainwindow_rc.py
@@ -1,8 +1,11 @@
import unittest
import hashlib
-import sip
-sip.setapi('QVariant', 2)
+try:
+ import sip
+ sip.setapi('QVariant', 2)
+except ValueError:
+ pass
from leap.gui import mainwindow_rc
@@ -23,4 +26,7 @@ class MainWindowResourcesTest(unittest.TestCase):
def test_mainwindow_resources_hash(self):
self.assertEqual(
hashlib.md5(mainwindow_rc.qt_resource_data).hexdigest(),
- 'd74eb99247b9d5cd2f00b2f695ca6b59')
+ '53e196f29061d8f08f112e5a2e64eb53')
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/tests/test_progress.py b/src/leap/gui/tests/test_progress.py
new file mode 100644
index 00000000..1f9f9e38
--- /dev/null
+++ b/src/leap/gui/tests/test_progress.py
@@ -0,0 +1,449 @@
+from collections import namedtuple
+import sys
+import unittest
+import Queue
+
+import mock
+
+from leap.testing import qunittest
+from leap.testing import pyqt
+
+from PyQt4 import QtGui
+from PyQt4 import QtCore
+from PyQt4.QtTest import QTest
+from PyQt4.QtCore import Qt
+
+from leap.gui import progress
+
+
+class ProgressStepTestCase(unittest.TestCase):
+
+ def test_step_attrs(self):
+ ps = progress.ProgressStep
+ step = ps('test', False, 1)
+ # instance
+ self.assertEqual(step.index, 1)
+ self.assertEqual(step.name, "test")
+ self.assertEqual(step.done, False)
+ step = ps('test2', True, 2)
+ self.assertEqual(step.index, 2)
+ self.assertEqual(step.name, "test2")
+ self.assertEqual(step.done, True)
+
+ # class methods and attrs
+ self.assertEqual(ps.columns(), ('name', 'done'))
+ self.assertEqual(ps.NAME, 0)
+ self.assertEqual(ps.DONE, 1)
+
+
+class ProgressStepContainerTestCase(unittest.TestCase):
+ def setUp(self):
+ self.psc = progress.ProgressStepContainer()
+
+ def addSteps(self, number):
+ Step = progress.ProgressStep
+ for n in range(number):
+ self.psc.addStep(Step("%s" % n, False, n))
+
+ def test_attrs(self):
+ self.assertEqual(self.psc.columns,
+ ('name', 'done'))
+
+ def test_add_steps(self):
+ Step = progress.ProgressStep
+ self.assertTrue(len(self.psc) == 0)
+ self.psc.addStep(Step('one', False, 0))
+ self.assertTrue(len(self.psc) == 1)
+ self.psc.addStep(Step('two', False, 1))
+ self.assertTrue(len(self.psc) == 2)
+
+ def test_del_all_steps(self):
+ self.assertTrue(len(self.psc) == 0)
+ self.addSteps(5)
+ self.assertTrue(len(self.psc) == 5)
+ self.psc.removeAllSteps()
+ self.assertTrue(len(self.psc) == 0)
+
+ def test_del_step(self):
+ Step = progress.ProgressStep
+ self.addSteps(5)
+ self.assertTrue(len(self.psc) == 5)
+ self.psc.removeStep(self.psc.step(4))
+ self.assertTrue(len(self.psc) == 4)
+ self.psc.removeStep(self.psc.step(4))
+ self.psc.removeStep(Step('none', False, 5))
+ self.psc.removeStep(self.psc.step(4))
+
+ def test_iter(self):
+ self.addSteps(10)
+ self.assertEqual(
+ [x.index for x in self.psc],
+ [x for x in range(10)])
+
+
+class StepsTableWidgetTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.stw = progress.StepsTableWidget()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+
+ def test_defaults(self):
+ self.assertTrue(isinstance(self.stw, QtGui.QTableWidget))
+ self.assertEqual(self.stw.focusPolicy(), 0)
+
+
+class TestWithStepsClass(QtGui.QWidget, progress.WithStepsMixIn):
+
+ def __init__(self, parent=None):
+ super(TestWithStepsClass, self).__init__(parent=parent)
+ self.setupStepsProcessingQueue()
+ self.statuses = []
+ self.current_page = "testpage"
+
+ def onStepStatusChanged(self, *args):
+ """
+ blank out this gui method
+ that will add status lines
+ """
+ self.statuses.append(args)
+
+
+class WithStepsMixInTestCase(qunittest.TestCase):
+
+ TIMER_WAIT = 2 * progress.WithStepsMixIn.STEPS_TIMER_MS / 1000.0
+
+ # XXX can spy on signal connections
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.stepy = TestWithStepsClass()
+ #self.connects = []
+ #pyqt.enableSignalDebugging(
+ #connectCall=lambda *args: self.connects.append(args))
+ #self.assertEqual(self.connects, [])
+ #self.stepy.stepscheck_timer.timeout.disconnect(
+ #self.stepy.processStepsQueue)
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+
+ def test_has_queue(self):
+ s = self.stepy
+ self.assertTrue(hasattr(s, 'steps_queue'))
+ self.assertTrue(isinstance(s.steps_queue, Queue.Queue))
+ self.assertTrue(isinstance(s.stepscheck_timer, QtCore.QTimer))
+
+ def test_do_checks_delegation(self):
+ s = self.stepy
+
+ _do_checks = mock.Mock()
+ _do_checks.return_value = (
+ (("test", 0), lambda: None),
+ (("test", 0), lambda: None))
+ s._do_checks = _do_checks
+ s.do_checks()
+ self.waitFor(seconds=self.TIMER_WAIT)
+ _do_checks.assert_called_with()
+ self.assertEqual(len(s.statuses), 2)
+
+ # test that a failed test interrupts the run
+
+ s.statuses = []
+ _do_checks = mock.Mock()
+ _do_checks.return_value = (
+ (("test", 0), lambda: None),
+ (("test", 0), lambda: False),
+ (("test", 0), lambda: None))
+ s._do_checks = _do_checks
+ s.do_checks()
+ self.waitFor(seconds=self.TIMER_WAIT)
+ _do_checks.assert_called_with()
+ self.assertEqual(len(s.statuses), 2)
+
+ def test_process_queue(self):
+ s = self.stepy
+ q = s.steps_queue
+ s.set_failed_icon = mock.MagicMock()
+ with self.assertRaises(AssertionError):
+ q.put('foo')
+ self.waitFor(seconds=self.TIMER_WAIT)
+ s.set_failed_icon.assert_called_with()
+ q.put("failed")
+ self.waitFor(seconds=self.TIMER_WAIT)
+ s.set_failed_icon.assert_called_with()
+
+ def test_on_checks_validation_ready_called(self):
+ s = self.stepy
+ s.on_checks_validation_ready = mock.MagicMock()
+
+ _do_checks = mock.Mock()
+ _do_checks.return_value = (
+ (("test", 0), lambda: None),)
+ s._do_checks = _do_checks
+ s.do_checks()
+
+ self.waitFor(seconds=self.TIMER_WAIT)
+ s.on_checks_validation_ready.assert_called_with()
+
+ def test_fail(self):
+ s = self.stepy
+
+ s.wizard = mock.Mock()
+ wizard = s.wizard.return_value
+ wizard.set_validation_error.return_value = True
+ s.completeChanged = mock.Mock()
+ s.completeChanged.emit.return_value = True
+
+ self.assertFalse(s.fail(err="foo"))
+ self.waitFor(seconds=self.TIMER_WAIT)
+ wizard.set_validation_error.assert_called_with('testpage', 'foo')
+ s.completeChanged.emit.assert_called_with()
+
+ # with no args
+ s.wizard = mock.Mock()
+ wizard = s.wizard.return_value
+ wizard.set_validation_error.return_value = True
+ s.completeChanged = mock.Mock()
+ s.completeChanged.emit.return_value = True
+
+ self.assertFalse(s.fail())
+ self.waitFor(seconds=self.TIMER_WAIT)
+ with self.assertRaises(AssertionError):
+ wizard.set_validation_error.assert_called_with()
+ s.completeChanged.emit.assert_called_with()
+
+ def test_done(self):
+ s = self.stepy
+ s.done = False
+
+ s.completeChanged = mock.Mock()
+ s.completeChanged.emit.return_value = True
+
+ self.assertFalse(s.is_done())
+ s.set_done()
+ self.assertTrue(s.is_done())
+ s.completeChanged.emit.assert_called_with()
+
+ s.completeChanged = mock.Mock()
+ s.completeChanged.emit.return_value = True
+ s.set_undone()
+ self.assertFalse(s.is_done())
+
+ def test_back_and_next(self):
+ s = self.stepy
+ s.wizard = mock.Mock()
+ wizard = s.wizard.return_value
+ wizard.back.return_value = True
+ wizard.next.return_value = True
+ s.go_back()
+ wizard.back.assert_called_with()
+ s.go_next()
+ wizard.next.assert_called_with()
+
+ def test_on_step_statuschanged_slot(self):
+ s = self.stepy
+ s.onStepStatusChanged = progress.WithStepsMixIn.onStepStatusChanged
+ s.add_status_line = mock.Mock()
+ s.set_checked_icon = mock.Mock()
+ s.progress = mock.Mock()
+ s.progress.setValue.return_value = True
+ s.progress.update.return_value = True
+
+ s.onStepStatusChanged(s, "end_sentinel")
+ s.set_checked_icon.assert_called_with()
+
+ s.onStepStatusChanged(s, "foo")
+ s.add_status_line.assert_called_with("foo")
+
+ s.onStepStatusChanged(s, "bar", 42)
+ s.progress.setValue.assert_called_with(42)
+ s.progress.update.assert_called_with()
+
+ def test_steps_and_errors(self):
+ s = self.stepy
+ s.setupSteps()
+ self.assertTrue(isinstance(s.steps, progress.ProgressStepContainer))
+ self.assertEqual(s.errors, {})
+ s.set_error('fooerror', 'barerror')
+ self.assertEqual(s.errors, {'fooerror': 'barerror'})
+ s.set_error('2', 42)
+ self.assertEqual(s.errors, {'fooerror': 'barerror', '2': 42})
+ fe = s.pop_first_error()
+ self.assertEqual(fe, ('fooerror', 'barerror'))
+ self.assertEqual(s.errors, {'2': 42})
+ s.clean_errors()
+ self.assertEqual(s.errors, {})
+
+ def test_launch_chechs_slot(self):
+ s = self.stepy
+ s.do_checks = mock.Mock()
+ s.launch_checks()
+ s.do_checks.assert_called_with()
+
+ def test_clean_wizard_errors(self):
+ s = self.stepy
+ s.wizard = mock.Mock()
+ wizard = s.wizard.return_value
+ wizard.set_validation_error.return_value = True
+ s.clean_wizard_errors(pagename="foopage")
+ wizard.set_validation_error.assert_called_with("foopage", None)
+
+ def test_clear_table(self):
+ s = self.stepy
+ s.stepsTableWidget = mock.Mock()
+ s.stepsTableWidget.clearContents.return_value = True
+ s.clearTable()
+ s.stepsTableWidget.clearContents.assert_called_with()
+
+ def test_populate_steps_table(self):
+ s = self.stepy
+ Step = namedtuple('Step', ['name', 'done'])
+
+ class Steps(object):
+ columns = ("name", "done")
+ _items = (Step('step1', False), Step('step2', False))
+
+ def __len__(self):
+ return 2
+
+ def __iter__(self):
+ for i in self._items:
+ yield i
+
+ s.steps = Steps()
+
+ s.stepsTableWidget = mock.Mock()
+ s.stepsTableWidget.setItem.return_value = True
+ s.resizeTable = mock.Mock()
+ s.update = mock.Mock()
+ s.populateStepsTable()
+ s.update.assert_called_with()
+ s.resizeTable.assert_called_with()
+
+ # assert stepsTableWidget.setItem called ...
+ # we do not want to get into the actual
+ # <QTableWidgetItem object at 0x92a565c>
+ call_list = s.stepsTableWidget.setItem.call_args_list
+ indexes = [(y, z) for y, z, xx in [x[0] for x in call_list]]
+ self.assertEqual(indexes,
+ [(0, 0), (0, 1), (1, 0), (1, 1)])
+
+ def test_add_status_line(self):
+ s = self.stepy
+ s.steps = progress.ProgressStepContainer()
+ s.stepsTableWidget = mock.Mock()
+ s.stepsTableWidget.width.return_value = 100
+ s.set_item = mock.Mock()
+ s.set_item_icon = mock.Mock()
+ s.add_status_line("new status")
+ s.set_item_icon.assert_called_with(current=False)
+
+ def test_set_item_icon(self):
+ s = self.stepy
+ s.steps = progress.ProgressStepContainer()
+ s.stepsTableWidget = mock.Mock()
+ s.stepsTableWidget.setCellWidget.return_value = True
+ s.stepsTableWidget.width.return_value = 100
+ #s.set_item = mock.Mock()
+ #s.set_item_icon = mock.Mock()
+ s.add_status_line("new status")
+ s.add_status_line("new 2 status")
+ s.add_status_line("new 3 status")
+ call_list = s.stepsTableWidget.setCellWidget.call_args_list
+ indexes = [(y, z) for y, z, xx in [x[0] for x in call_list]]
+ self.assertEqual(
+ indexes,
+ [(0, 1), (-1, 1), (1, 1), (0, 1), (2, 1), (1, 1)])
+
+
+class TestInlineValidationPage(progress.InlineValidationPage):
+ pass
+
+
+class InlineValidationPageTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.page = TestInlineValidationPage()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+
+ def test_defaults(self):
+ self.assertFalse(self.page.done)
+ # if setupProcessingQueue was called
+ self.assertTrue(isinstance(self.page.stepscheck_timer, QtCore.QTimer))
+ self.assertTrue(isinstance(self.page.steps_queue, Queue.Queue))
+
+ def test_validation_frame(self):
+ # test frame creation
+ self.page.stepsTableWidget = progress.StepsTableWidget(
+ parent=self.page)
+ self.page.setupValidationFrame()
+ self.assertTrue(isinstance(self.page.valFrame, QtGui.QFrame))
+
+ # test show steps calls frame.show
+ self.page.valFrame = mock.Mock()
+ self.page.valFrame.show.return_value = True
+ self.page.showStepsFrame()
+ self.page.valFrame.show.assert_called_with()
+
+
+class TestValidationPage(progress.ValidationPage):
+ pass
+
+
+class ValidationPageTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.app = QtGui.QApplication(sys.argv)
+ QtGui.qApp = self.app
+ self.page = TestValidationPage()
+
+ def tearDown(self):
+ QtGui.qApp = None
+ self.app = None
+
+ def test_defaults(self):
+ self.assertFalse(self.page.done)
+ # if setupProcessingQueue was called
+ self.assertTrue(isinstance(self.page.timer, QtCore.QTimer))
+ self.assertTrue(isinstance(self.page.stepscheck_timer, QtCore.QTimer))
+ self.assertTrue(isinstance(self.page.steps_queue, Queue.Queue))
+
+ def test_is_complete(self):
+ self.assertFalse(self.page.isComplete())
+ self.page.done = True
+ self.assertTrue(self.page.isComplete())
+ self.page.done = False
+ self.assertFalse(self.page.isComplete())
+
+ def test_show_hide_progress(self):
+ p = self.page
+ p.progress = mock.Mock()
+ p.progress.show.return_code = True
+ p.show_progress()
+ p.progress.show.assert_called_with()
+ p.progress.hide.return_code = True
+ p.hide_progress()
+ p.progress.hide.assert_called_with()
+
+ def test_initialize_page(self):
+ p = self.page
+ p.timer = mock.Mock()
+ p.timer.singleShot.return_code = True
+ p.initializePage()
+ p.timer.singleShot.assert_called_with(0, p.do_checks)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/tests/test_threads.py b/src/leap/gui/tests/test_threads.py
new file mode 100644
index 00000000..06c19606
--- /dev/null
+++ b/src/leap/gui/tests/test_threads.py
@@ -0,0 +1,27 @@
+import unittest
+
+import mock
+from leap.gui import threads
+
+
+class FunThreadTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.fun = mock.MagicMock()
+ self.fun.return_value = "foo"
+ self.t = threads.FunThread(fun=self.fun)
+
+ def test_thread(self):
+ self.t.begin()
+ self.t.wait()
+ self.fun.assert_called()
+ del self.t
+
+ def test_run(self):
+ # this is called by PyQt
+ self.t.run()
+ del self.t
+ self.fun.assert_called()
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/gui/threads.py b/src/leap/gui/threads.py
new file mode 100644
index 00000000..8aad8866
--- /dev/null
+++ b/src/leap/gui/threads.py
@@ -0,0 +1,21 @@
+from PyQt4 import QtCore
+
+
+class FunThread(QtCore.QThread):
+
+ def __init__(self, fun=None, parent=None):
+
+ QtCore.QThread.__init__(self, parent)
+ self.exiting = False
+ self.fun = fun
+
+ def __del__(self):
+ self.exiting = True
+ self.wait()
+
+ def run(self):
+ if self.fun:
+ self.fun()
+
+ def begin(self):
+ self.start()
diff --git a/src/leap/gui/utils.py b/src/leap/gui/utils.py
new file mode 100644
index 00000000..f91ac3ef
--- /dev/null
+++ b/src/leap/gui/utils.py
@@ -0,0 +1,34 @@
+"""
+utility functions to work with gui objects
+"""
+from PyQt4 import QtCore
+
+
+def layout_widgets(layout):
+ """
+ return a generator with all widgets in a layout
+ """
+ return (layout.itemAt(i) for i in range(layout.count()))
+
+
+DELAY_MSECS = 50
+
+
+def delay(obj, method_str=None, call_args=None):
+ """
+ Triggers a function or slot with a small delay.
+ this is a mainly a hack to get responsiveness in the ui
+ in cases in which the event loop freezes and the task
+ is not heavy enough to setup a processing queue.
+ """
+ if callable(obj) and not method_str:
+ fun = lambda: obj()
+
+ if method_str:
+ invoke = QtCore.QMetaObject.invokeMethod
+ if call_args:
+ fun = lambda: invoke(obj, method_str, call_args)
+ else:
+ fun = lambda: invoke(obj, method_str)
+
+ QtCore.QTimer().singleShot(DELAY_MSECS, fun)
diff --git a/src/leap/soledad/README b/src/leap/soledad/README
new file mode 100644
index 00000000..3bf62494
--- /dev/null
+++ b/src/leap/soledad/README
@@ -0,0 +1,37 @@
+Soledad -- Synchronization Of Locally Encrypted Data Among Devices
+==================================================================
+
+This software is under development.
+
+Dependencies
+------------
+
+Soledad depends on the following python libraries:
+
+ * u1db 0.1.4 [1]
+ * python-swiftclient 1.2.0 [2]
+ * python-gnupg 0.3.1 [3]
+ * CouchDB 0.8 [4]
+ * hmac 20101005 [5]
+
+[1] http://pypi.python.org/pypi/u1db/0.1.4
+[2] http://pypi.python.org/pypi/python-swiftclient/1.2.0
+[3] http://pypi.python.org/pypi/python-gnupg/0.3.1
+[4] http://pypi.python.org/pypi/CouchDB/0.8
+[5] http://pypi.python.org/pypi/hmac/20101005
+
+
+Tests
+-----
+
+Soledad's tests should be run with nose2, like this:
+
+ nose2 leap.soledad.tests
+
+Right now, there are 3 conditions that have to be met for all Soledad tests to
+pass without problems:
+
+ 1. Use nose2.
+ 2. Have an http CouchDB instance running on `localhost:5984`.
+ 3. Have sqlcipher configured (using LD_PRELOAD or LD_LIBRARY_CONFIG to point
+ to the place where libsqlite3.so.0 is located).
diff --git a/src/leap/soledad/__init__.py b/src/leap/soledad/__init__.py
new file mode 100644
index 00000000..c83627f0
--- /dev/null
+++ b/src/leap/soledad/__init__.py
@@ -0,0 +1,212 @@
+# License?
+
+"""A U1DB implementation for using Object Stores as its persistence layer."""
+
+import os
+import string
+import random
+import hmac
+from leap.soledad.backends import sqlcipher
+from leap.soledad.util import GPGWrapper
+import util
+
+
+class Soledad(object):
+
+ # paths
+ PREFIX = os.environ['HOME'] + '/.config/leap/soledad'
+ SECRET_PATH = PREFIX + '/secret.gpg'
+ GNUPG_HOME = PREFIX + '/gnupg'
+ LOCAL_DB_PATH = PREFIX + '/soledad.u1db'
+
+ # other configs
+ SECRET_LENGTH = 50
+
+ def __init__(self, user_email, gpghome=None):
+ self._user_email = user_email
+ if not os.path.isdir(self.PREFIX):
+ os.makedirs(self.PREFIX)
+ if not gpghome:
+ gpghome = self.GNUPG_HOME
+ self._gpg = util.GPGWrapper(gpghome=gpghome)
+ # load/generate OpenPGP keypair
+ if not self._has_openpgp_keypair():
+ self._gen_openpgp_keypair()
+ self._load_openpgp_keypair()
+ # load/generate secret
+ if not self._has_secret():
+ self._gen_secret()
+ self._load_secret()
+ # instantiate u1db
+ # TODO: verify if secret for sqlcipher should be the same as the one
+ # for symmetric encryption.
+ self._db = sqlcipher.open(self.LOCAL_DB_PATH, True, self._secret)
+
+ #-------------------------------------------------------------------------
+ # Management of secret for symmetric encryption
+ #-------------------------------------------------------------------------
+
+ def _has_secret(self):
+ """
+ Verify if secret for symmetric encryption exists on local encrypted
+ file.
+ """
+ # TODO: verify if file is a GPG-encrypted file and if we have the
+ # corresponding private key for decryption.
+ if os.path.isfile(self.SECRET_PATH):
+ return True
+ return False
+
+ def _load_secret(self):
+ """
+ Load secret for symmetric encryption from local encrypted file.
+ """
+ try:
+ with open(self.SECRET_PATH) as f:
+ self._secret = str(self._gpg.decrypt(f.read()))
+ except IOError as e:
+ raise IOError('Failed to open secret file %s.' % self.SECRET_PATH)
+
+ def _gen_secret(self):
+ """
+ Generate a secret for symmetric encryption and store in a local
+ encrypted file.
+ """
+ self._secret = ''.join(random.choice(string.ascii_uppercase +
+ string.digits) for x in
+ range(self.SECRET_LENGTH))
+ ciphertext = self._gpg.encrypt(self._secret, self._fingerprint,
+ self._fingerprint)
+ f = open(self.SECRET_PATH, 'w')
+ f.write(str(ciphertext))
+ f.close()
+
+ #-------------------------------------------------------------------------
+ # Management of OpenPGP keypair
+ #-------------------------------------------------------------------------
+
+ def _has_openpgp_keypair(self):
+ """
+ Verify if there exists an OpenPGP keypair for this user.
+ """
+ # TODO: verify if we have the corresponding private key.
+ try:
+ self._gpg.find_key(self._user_email)
+ return True
+ except LookupError:
+ return False
+
+ def _gen_openpgp_keypair(self):
+ """
+ Generate an OpenPGP keypair for this user.
+ """
+ params = self._gpg.gen_key_input(
+ key_type='RSA',
+ key_length=4096,
+ name_real=self._user_email,
+ name_email=self._user_email,
+ name_comment='Generated by LEAP Soledad.')
+ self._gpg.gen_key(params)
+
+ def _load_openpgp_keypair(self):
+ """
+ Find fingerprint for this user's OpenPGP keypair.
+ """
+ self._fingerprint = self._gpg.find_key(self._user_email)['fingerprint']
+
+ def publish_pubkey(self, keyserver):
+ """
+ Publish OpenPGP public key to a keyserver.
+ """
+ # TODO: this has to talk to LEAP's Nickserver.
+ pass
+
+ #-------------------------------------------------------------------------
+ # Data encryption and decryption
+ #-------------------------------------------------------------------------
+
+ def encrypt(self, data, sign=None, passphrase=None, symmetric=False):
+ """
+ Encrypt data.
+ """
+ return str(self._gpg.encrypt(data, self._fingerprint, sign=sign,
+ passphrase=passphrase,
+ symmetric=symmetric))
+
+ def encrypt_symmetric(self, doc_id, data, sign=None):
+ """
+ Encrypt data using symmetric secret.
+ """
+ h = hmac.new(self._secret, doc_id).hexdigest()
+ return self.encrypt(data, sign=sign, passphrase=h, symmetric=True)
+
+ def decrypt(self, data, passphrase=None, symmetric=False):
+ """
+ Decrypt data.
+ """
+ return str(self._gpg.decrypt(data, passphrase=passphrase))
+
+ def decrypt_symmetric(self, doc_id, data):
+ """
+ Decrypt data using symmetric secret.
+ """
+ h = hmac.new(self._secret, doc_id).hexdigest()
+ return self.decrypt(data, passphrase=h)
+
+ #-------------------------------------------------------------------------
+ # Document storage, retrieval and sync
+ #-------------------------------------------------------------------------
+
+ def put_doc(self, doc):
+ """
+ Update a document in the local encrypted database.
+ """
+ return self._db.put_doc(doc)
+
+ def delete_doc(self, doc):
+ """
+ Delete a document from the local encrypted database.
+ """
+ return self._db.delete_doc(doc)
+
+ def get_doc(self, doc_id, include_deleted=False):
+ """
+ Retrieve a document from the local encrypted database.
+ """
+ return self._db.get_doc(doc_id, include_deleted=include_deleted)
+
+ def get_docs(self, doc_ids, check_for_conflicts=True,
+ include_deleted=False):
+ """
+ Get the content for many documents.
+ """
+ return self._db.get_docs(doc_ids,
+ check_for_conflicts=check_for_conflicts,
+ include_deleted=include_deleted)
+
+ def create_doc(self, content, doc_id=None):
+ """
+ Create a new document in the local encrypted database.
+ """
+ return self._db.create_doc(content, doc_id=doc_id)
+
+ def get_doc_conflicts(self, doc_id):
+ """
+ Get the list of conflicts for the given document.
+ """
+ return self._db.get_doc_conflicts(doc_id)
+
+ def resolve_doc(self, doc, conflicted_doc_revs):
+ """
+ Mark a document as no longer conflicted.
+ """
+ return self._db.resolve_doc(doc, conflicted_doc_revs)
+
+ def sync(self, url):
+ """
+ Synchronize the local encrypted database with LEAP server.
+ """
+ # TODO: create authentication scheme for sync with server.
+ return self._db.sync(url, creds=None, autocreate=True, soledad=self)
+
+__all__ = ['util']
diff --git a/src/leap/soledad/backends/__init__.py b/src/leap/soledad/backends/__init__.py
new file mode 100644
index 00000000..72907f37
--- /dev/null
+++ b/src/leap/soledad/backends/__init__.py
@@ -0,0 +1,5 @@
+import objectstore
+
+
+__all__ = [
+ 'objectstore']
diff --git a/src/leap/soledad/backends/couch.py b/src/leap/soledad/backends/couch.py
new file mode 100644
index 00000000..c8dadfa8
--- /dev/null
+++ b/src/leap/soledad/backends/couch.py
@@ -0,0 +1,217 @@
+import uuid
+from base64 import b64encode, b64decode
+from u1db import errors
+from u1db.sync import LocalSyncTarget
+from u1db.backends.inmemory import InMemoryIndex
+from couchdb.client import Server, Document as CouchDocument
+from couchdb.http import ResourceNotFound
+from leap.soledad.backends.objectstore import ObjectStore
+from leap.soledad.backends.leap_backend import LeapDocument
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+
+class CouchDatabase(ObjectStore):
+ """A U1DB implementation that uses Couch as its persistence layer."""
+
+ def __init__(self, url, database, replica_uid=None, full_commit=True,
+ session=None):
+ """Create a new Couch data container."""
+ self._url = url
+ self._full_commit = full_commit
+ self._session = session
+ self._server = Server(url=self._url,
+ full_commit=self._full_commit,
+ session=self._session)
+ self._dbname = database
+ # this will ensure that transaction and sync logs exist and are
+ # up-to-date.
+ self.set_document_factory(LeapDocument)
+ try:
+ self._database = self._server[database]
+ except ResourceNotFound:
+ self._server.create(database)
+ self._database = self._server[database]
+ super(CouchDatabase, self).__init__(replica_uid=replica_uid)
+
+ #-------------------------------------------------------------------------
+ # methods from Database
+ #-------------------------------------------------------------------------
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """
+ Get just the document content, without fancy handling.
+ """
+ cdoc = self._database.get(doc_id)
+ if cdoc is None:
+ return None
+ has_conflicts = False
+ if check_for_conflicts:
+ has_conflicts = self._has_conflicts(doc_id)
+ doc = self._factory(
+ doc_id=doc_id,
+ rev=cdoc['u1db_rev'],
+ has_conflicts=has_conflicts)
+ contents = self._database.get_attachment(cdoc, 'u1db_json')
+ if contents:
+ doc.content = json.loads(contents.getvalue())
+ else:
+ doc.make_tombstone()
+ return doc
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ for doc_id in self._database:
+ if doc_id == self.U1DB_DATA_DOC_ID:
+ continue
+ doc = self._get_doc(doc_id, check_for_conflicts=True)
+ if doc.content is None and not include_deleted:
+ continue
+ results.append(doc)
+ return (generation, results)
+
+ def _put_doc(self, doc):
+ # prepare couch's Document
+ cdoc = CouchDocument()
+ cdoc['_id'] = doc.doc_id
+ # we have to guarantee that couch's _rev is cosistent
+ old_cdoc = self._database.get(doc.doc_id)
+ if old_cdoc is not None:
+ cdoc['_rev'] = old_cdoc['_rev']
+ # store u1db's rev
+ cdoc['u1db_rev'] = doc.rev
+ # save doc in db
+ self._database.save(cdoc)
+ # store u1db's content as json string
+ if not doc.is_tombstone():
+ self._database.put_attachment(cdoc, doc.get_json(),
+ filename='u1db_json')
+ else:
+ self._database.delete_attachment(cdoc, 'u1db_json')
+
+ def get_sync_target(self):
+ return CouchSyncTarget(self)
+
+ def create_index(self, index_name, *index_expressions):
+ if index_name in self._indexes:
+ if self._indexes[index_name]._definition == list(
+ index_expressions):
+ return
+ raise errors.IndexNameTakenError
+ index = InMemoryIndex(index_name, list(index_expressions))
+ for doc_id in self._database:
+ if doc_id == self.U1DB_DATA_DOC_ID:
+ continue
+ doc = self._get_doc(doc_id)
+ if doc.content is not None:
+ index.add_json(doc_id, doc.get_json())
+ self._indexes[index_name] = index
+ # save data in object store
+ self._set_u1db_data()
+
+ def close(self):
+ # TODO: fix this method so the connection is properly closed and
+ # test_close (+tearDown, which deletes the db) works without problems.
+ self._url = None
+ self._full_commit = None
+ self._session = None
+ #self._server = None
+ self._database = None
+ return True
+
+ def sync(self, url, creds=None, autocreate=True):
+ from u1db.sync import Synchronizer
+ return Synchronizer(self, CouchSyncTarget(url, creds=creds)).sync(
+ autocreate=autocreate)
+
+ #-------------------------------------------------------------------------
+ # methods from ObjectStore
+ #-------------------------------------------------------------------------
+
+ def _init_u1db_data(self):
+ if self._replica_uid is None:
+ self._replica_uid = uuid.uuid4().hex
+ doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID)
+ doc.content = {'transaction_log': [],
+ 'conflicts': b64encode(json.dumps({})),
+ 'other_generations': {},
+ 'indexes': b64encode(json.dumps({})),
+ 'replica_uid': self._replica_uid}
+ self._put_doc(doc)
+
+ def _get_u1db_data(self):
+ # retrieve u1db data from couch db
+ cdoc = self._database.get(self.U1DB_DATA_DOC_ID)
+ jsonstr = self._database.get_attachment(cdoc, 'u1db_json').getvalue()
+ content = json.loads(jsonstr)
+ # set u1db database info
+ #self._sync_log = content['sync_log']
+ self._transaction_log = content['transaction_log']
+ self._conflicts = json.loads(b64decode(content['conflicts']))
+ self._other_generations = content['other_generations']
+ self._indexes = self._load_indexes_from_json(
+ b64decode(content['indexes']))
+ self._replica_uid = content['replica_uid']
+ # save couch _rev
+ self._couch_rev = cdoc['_rev']
+
+ def _set_u1db_data(self):
+ doc = self._factory(doc_id=self.U1DB_DATA_DOC_ID)
+ doc.content = {
+ 'transaction_log': self._transaction_log,
+ # Here, the b64 encode ensures that document content
+ # does not cause strange behaviour in couchdb because
+ # of encoding.
+ 'conflicts': b64encode(json.dumps(self._conflicts)),
+ 'other_generations': self._other_generations,
+ 'indexes': b64encode(self._dump_indexes_as_json()),
+ 'replica_uid': self._replica_uid,
+ '_rev': self._couch_rev}
+ self._put_doc(doc)
+
+ #-------------------------------------------------------------------------
+ # Couch specific methods
+ #-------------------------------------------------------------------------
+
+ def delete_database(self):
+ del(self._server[self._dbname])
+
+ def _dump_indexes_as_json(self):
+ indexes = {}
+ for name, idx in self._indexes.iteritems():
+ indexes[name] = {}
+ for attr in ['name', 'definition', 'values']:
+ indexes[name][attr] = getattr(idx, '_' + attr)
+ return json.dumps(indexes)
+
+ def _load_indexes_from_json(self, indexes):
+ dict = {}
+ for name, idx_dict in json.loads(indexes).iteritems():
+ idx = InMemoryIndex(name, idx_dict['definition'])
+ idx._values = idx_dict['values']
+ dict[name] = idx
+ return dict
+
+
+class CouchSyncTarget(LocalSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
diff --git a/src/leap/soledad/backends/leap_backend.py b/src/leap/soledad/backends/leap_backend.py
new file mode 100644
index 00000000..ec26dca4
--- /dev/null
+++ b/src/leap/soledad/backends/leap_backend.py
@@ -0,0 +1,198 @@
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import Document
+from u1db.remote import utils
+from u1db.remote.http_target import HTTPSyncTarget
+from u1db.remote.http_database import HTTPDatabase
+from u1db.errors import BrokenSyncStream
+
+import uuid
+
+
+class NoDefaultKey(Exception):
+ pass
+
+
+class NoSoledadInstance(Exception):
+ pass
+
+
+class LeapDocument(Document):
+ """
+ LEAP Documents are standard u1db documents with cabability of returning an
+ encrypted version of the document json string as well as setting document
+ content based on an encrypted version of json string.
+ """
+
+ def __init__(self, doc_id=None, rev=None, json='{}', has_conflicts=False,
+ encrypted_json=None, soledad=None, syncable=True):
+ super(LeapDocument, self).__init__(doc_id, rev, json, has_conflicts)
+ self._soledad = soledad
+ self._syncable = syncable
+ if encrypted_json:
+ self.set_encrypted_json(encrypted_json)
+
+ def get_encrypted_json(self):
+ """
+ Returns document's json serialization encrypted with user's public key.
+ """
+ if not self._soledad:
+ raise NoSoledadInstance()
+ ciphertext = self._soledad.encrypt_symmetric(self.doc_id,
+ self.get_json())
+ return json.dumps({'_encrypted_json': ciphertext})
+
+ def set_encrypted_json(self, encrypted_json):
+ """
+ Set document's content based on encrypted version of json string.
+ """
+ if not self._soledad:
+ raise NoSoledadInstance()
+ ciphertext = json.loads(encrypted_json)['_encrypted_json']
+ plaintext = self._soledad.decrypt_symmetric(self.doc_id, ciphertext)
+ return self.set_json(plaintext)
+
+ def _get_syncable(self):
+ return self._syncable
+
+ def _set_syncable(self, syncable=True):
+ self._syncable = syncable
+
+ syncable = property(
+ _get_syncable,
+ _set_syncable,
+ doc="Determine if document should be synced with server."
+ )
+
+
+class LeapDatabase(HTTPDatabase):
+ """Implement the HTTP remote database API to a Leap server."""
+
+ def __init__(self, url, document_factory=None, creds=None, soledad=None):
+ super(LeapDatabase, self).__init__(url, creds=creds)
+ self._soledad = soledad
+ self._factory = LeapDocument
+
+ @staticmethod
+ def open_database(url, create):
+ db = LeapDatabase(url)
+ db.open(create)
+ return db
+
+ @staticmethod
+ def delete_database(url):
+ db = LeapDatabase(url)
+ db._delete()
+ db.close()
+
+ def _allocate_doc_id(self):
+ """Generate a unique identifier for this document."""
+ return 'D-' + uuid.uuid4().hex # 'D-' stands for document
+
+ def get_sync_target(self):
+ st = LeapSyncTarget(self._url.geturl())
+ st._creds = self._creds
+ return st
+
+ def create_doc_from_json(self, content, doc_id=None):
+ if doc_id is None:
+ doc_id = self._allocate_doc_id()
+ res, headers = self._request_json('PUT', ['doc', doc_id], {},
+ content, 'application/json')
+ new_doc = self._factory(doc_id, res['rev'], content,
+ soledad=self._soledad)
+ return new_doc
+
+
+class LeapSyncTarget(HTTPSyncTarget):
+
+ def __init__(self, url, creds=None, soledad=None):
+ super(LeapSyncTarget, self).__init__(url, creds)
+ self._soledad = soledad
+
+ def _parse_sync_stream(self, data, return_doc_cb, ensure_callback=None):
+ """
+ Does the same as parent's method but ensures incoming content will be
+ decrypted.
+ """
+ parts = data.splitlines() # one at a time
+ if not parts or parts[0] != '[':
+ raise BrokenSyncStream
+ data = parts[1:-1]
+ comma = False
+ if data:
+ line, comma = utils.check_and_strip_comma(data[0])
+ res = json.loads(line)
+ if ensure_callback and 'replica_uid' in res:
+ ensure_callback(res['replica_uid'])
+ for entry in data[1:]:
+ if not comma: # missing in between comma
+ raise BrokenSyncStream
+ line, comma = utils.check_and_strip_comma(entry)
+ entry = json.loads(line)
+ # decrypt after receiving from server.
+ doc = LeapDocument(entry['id'], entry['rev'],
+ encrypted_json=entry['content'],
+ soledad=self._soledad)
+ return_doc_cb(doc, entry['gen'], entry['trans_id'])
+ if parts[-1] != ']':
+ try:
+ partdic = json.loads(parts[-1])
+ except ValueError:
+ pass
+ else:
+ if isinstance(partdic, dict):
+ self._error(partdic)
+ raise BrokenSyncStream
+ if not data or comma: # no entries or bad extra comma
+ raise BrokenSyncStream
+ return res
+
+ def sync_exchange(self, docs_by_generations, source_replica_uid,
+ last_known_generation, last_known_trans_id,
+ return_doc_cb, ensure_callback=None):
+ """
+ Does the same as parent's method but encrypts content before syncing.
+ """
+ self._ensure_connection()
+ if self._trace_hook: # for tests
+ self._trace_hook('sync_exchange')
+ url = '%s/sync-from/%s' % (self._url.path, source_replica_uid)
+ self._conn.putrequest('POST', url)
+ self._conn.putheader('content-type', 'application/x-u1db-sync-stream')
+ for header_name, header_value in self._sign_request('POST', url, {}):
+ self._conn.putheader(header_name, header_value)
+ entries = ['[']
+ size = 1
+
+ def prepare(**dic):
+ entry = comma + '\r\n' + json.dumps(dic)
+ entries.append(entry)
+ return len(entry)
+
+ comma = ''
+ size += prepare(
+ last_known_generation=last_known_generation,
+ last_known_trans_id=last_known_trans_id,
+ ensure=ensure_callback is not None)
+ comma = ','
+ for doc, gen, trans_id in docs_by_generations:
+ if doc.syncable:
+ # encrypt before sending to server.
+ size += prepare(id=doc.doc_id, rev=doc.rev,
+ content=doc.get_encrypted_json(),
+ gen=gen, trans_id=trans_id)
+ entries.append('\r\n]')
+ size += len(entries[-1])
+ self._conn.putheader('content-length', str(size))
+ self._conn.endheaders()
+ for entry in entries:
+ self._conn.send(entry)
+ entries = None
+ data, _ = self._response()
+ res = self._parse_sync_stream(data, return_doc_cb, ensure_callback)
+ data = None
+ return res['new_generation'], res['new_transaction_id']
diff --git a/src/leap/soledad/backends/objectstore.py b/src/leap/soledad/backends/objectstore.py
new file mode 100644
index 00000000..588fc7a1
--- /dev/null
+++ b/src/leap/soledad/backends/objectstore.py
@@ -0,0 +1,109 @@
+from u1db.backends.inmemory import InMemoryDatabase
+from u1db import errors
+
+
+class ObjectStore(InMemoryDatabase):
+ """
+ A backend for storing u1db data in an object store.
+ """
+
+ def __init__(self, replica_uid=None):
+ super(ObjectStore, self).__init__(replica_uid)
+ # sync data in memory with data in object store
+ if not self._get_doc(self.U1DB_DATA_DOC_ID):
+ self._init_u1db_data()
+ self._get_u1db_data()
+
+ #-------------------------------------------------------------------------
+ # methods from Database
+ #-------------------------------------------------------------------------
+
+ def _set_replica_uid(self, replica_uid):
+ super(ObjectStore, self)._set_replica_uid(replica_uid)
+ self._set_u1db_data()
+
+ def _put_doc(self, doc):
+ raise NotImplementedError(self._put_doc)
+
+ def _get_doc(self, doc):
+ raise NotImplementedError(self._get_doc)
+
+ def get_all_docs(self, include_deleted=False):
+ raise NotImplementedError(self.get_all_docs)
+
+ def delete_doc(self, doc):
+ old_doc = self._get_doc(doc.doc_id, check_for_conflicts=True)
+ if old_doc is None:
+ raise errors.DocumentDoesNotExist
+ if old_doc.rev != doc.rev:
+ raise errors.RevisionConflict()
+ if old_doc.is_tombstone():
+ raise errors.DocumentAlreadyDeleted
+ if old_doc.has_conflicts:
+ raise errors.ConflictedDoc()
+ new_rev = self._allocate_doc_rev(doc.rev)
+ doc.rev = new_rev
+ doc.make_tombstone()
+ self._put_and_update_indexes(old_doc, doc)
+ return new_rev
+
+ # index-related methods
+
+ def create_index(self, index_name, *index_expressions):
+ raise NotImplementedError(self.create_index)
+
+ def delete_index(self, index_name):
+ super(ObjectStore, self).delete_index(index_name)
+ self._set_u1db_data()
+
+ def _replace_conflicts(self, doc, conflicts):
+ super(ObjectStore, self)._replace_conflicts(doc, conflicts)
+ self._set_u1db_data()
+
+ def _do_set_replica_gen_and_trans_id(self, other_replica_uid,
+ other_generation,
+ other_transaction_id):
+ super(ObjectStore, self)._do_set_replica_gen_and_trans_id(
+ other_replica_uid,
+ other_generation,
+ other_transaction_id)
+ self._set_u1db_data()
+
+ #-------------------------------------------------------------------------
+ # implemented methods from CommonBackend
+ #-------------------------------------------------------------------------
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ for index in self._indexes.itervalues():
+ if old_doc is not None and not old_doc.is_tombstone():
+ index.remove_json(old_doc.doc_id, old_doc.get_json())
+ if not doc.is_tombstone():
+ index.add_json(doc.doc_id, doc.get_json())
+ trans_id = self._allocate_transaction_id()
+ self._put_doc(doc)
+ self._transaction_log.append((doc.doc_id, trans_id))
+ self._set_u1db_data()
+
+ #-------------------------------------------------------------------------
+ # methods specific for object stores
+ #-------------------------------------------------------------------------
+
+ U1DB_DATA_DOC_ID = 'u1db_data'
+
+ def _get_u1db_data(self):
+ """
+ Fetch u1db configuration data from backend storage.
+ """
+ NotImplementedError(self._get_u1db_data)
+
+ def _set_u1db_data(self):
+ """
+ Save u1db configuration data on backend storage.
+ """
+ NotImplementedError(self._set_u1db_data)
+
+ def _init_u1db_data(self):
+ """
+ Initialize u1db configuration data on backend storage.
+ """
+ NotImplementedError(self._init_u1db_data)
diff --git a/src/leap/soledad/backends/openstack.py b/src/leap/soledad/backends/openstack.py
new file mode 100644
index 00000000..a9615736
--- /dev/null
+++ b/src/leap/soledad/backends/openstack.py
@@ -0,0 +1,98 @@
+# TODO: this backend is not tested yet.
+from u1db.remote.http_target import HTTPSyncTarget
+import swiftclient
+from soledad.backends.objectstore import ObjectStore
+
+
+class OpenStackDatabase(ObjectStore):
+ """A U1DB implementation that uses OpenStack as its persistence layer."""
+
+ def __init__(self, auth_url, user, auth_key, container):
+ """Create a new OpenStack data container."""
+ self._auth_url = auth_url
+ self._user = user
+ self._auth_key = auth_key
+ self._container = container
+ self._connection = swiftclient.Connection(self._auth_url, self._user,
+ self._auth_key)
+ self._get_auth()
+ # this will ensure transaction and sync logs exist and are up-to-date.
+ super(OpenStackDatabase, self).__init__()
+
+ #-------------------------------------------------------------------------
+ # implemented methods from Database
+ #-------------------------------------------------------------------------
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ """Get just the document content, without fancy handling.
+
+ Conflicts do not happen on server side, so there's no need to check
+ for them.
+ """
+ try:
+ response, contents = self._connection.get_object(self._container,
+ doc_id)
+ # TODO: change revision to be a dictionary element?
+ rev = response['x-object-meta-rev']
+ return self._factory(doc_id, rev, contents)
+ except swiftclient.ClientException:
+ return None
+
+ def get_all_docs(self, include_deleted=False):
+ """Get all documents from the database."""
+ generation = self._get_generation()
+ results = []
+ _, doc_ids = self._connection.get_container(self._container,
+ full_listing=True)
+ for doc_id in doc_ids:
+ doc = self._get_doc(doc_id)
+ if doc.content is None and not include_deleted:
+ continue
+ results.append(doc)
+ return (generation, results)
+
+ def _put_doc(self, doc, new_rev):
+ new_rev = self._allocate_doc_rev(doc.rev)
+ # TODO: change revision to be a dictionary element?
+ headers = {'X-Object-Meta-Rev': new_rev}
+ self._connection.put_object(self._container, doc_id, doc.get_json(),
+ headers=headers)
+
+ def get_sync_target(self):
+ return OpenStackSyncTarget(self)
+
+ def close(self):
+ raise NotImplementedError(self.close)
+
+ def sync(self, url, creds=None, autocreate=True):
+ from u1db.sync import Synchronizer
+ from u1db.remote.http_target import OpenStackSyncTarget
+ return Synchronizer(self, OpenStackSyncTarget(url, creds=creds)).sync(
+ autocreate=autocreate)
+
+ #-------------------------------------------------------------------------
+ # OpenStack specific methods
+ #-------------------------------------------------------------------------
+
+ def _get_auth(self):
+ self._url, self._auth_token = self._connection.get_auth()
+ return self._url, self.auth_token
+
+
+class OpenStackSyncTarget(HTTPSyncTarget):
+
+ def get_sync_info(self, source_replica_uid):
+ source_gen, source_trans_id = self._db._get_replica_gen_and_trans_id(
+ source_replica_uid)
+ my_gen, my_trans_id = self._db._get_generation_info()
+ return (
+ self._db._replica_uid, my_gen, my_trans_id, source_gen,
+ source_trans_id)
+
+ def record_sync_info(self, source_replica_uid, source_replica_generation,
+ source_replica_transaction_id):
+ if self._trace_hook:
+ self._trace_hook('record_sync_info')
+ self._db._set_replica_gen_and_trans_id(
+ source_replica_uid, source_replica_generation,
+ source_replica_transaction_id)
diff --git a/src/leap/soledad/backends/sqlcipher.py b/src/leap/soledad/backends/sqlcipher.py
new file mode 100644
index 00000000..6cebcf7d
--- /dev/null
+++ b/src/leap/soledad/backends/sqlcipher.py
@@ -0,0 +1,159 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""A U1DB implementation that uses SQLCipher as its persistence layer."""
+
+import os
+from sqlite3 import dbapi2, DatabaseError
+import time
+
+from u1db.backends.sqlite_backend import (
+ SQLiteDatabase,
+ SQLitePartialExpandDatabase,
+)
+from u1db import (
+ errors,
+)
+
+from leap.soledad.backends.leap_backend import LeapDocument
+
+
+def open(path, password, create=True, document_factory=None):
+ """Open a database at the given location.
+
+ Will raise u1db.errors.DatabaseDoesNotExist if create=False and the
+ database does not already exist.
+
+ :param path: The filesystem path for the database to open.
+ :param create: True/False, should the database be created if it doesn't
+ already exist?
+ :param document_factory: A function that will be called with the same
+ parameters as Document.__init__.
+ :return: An instance of Database.
+ """
+ return SQLCipherDatabase.open_database(
+ path, password, create=create, document_factory=document_factory)
+
+
+class DatabaseIsNotEncrypted(Exception):
+ """
+ Exception raised when trying to open non-encrypted databases.
+ """
+ pass
+
+
+class SQLCipherDatabase(SQLitePartialExpandDatabase):
+ """A U1DB implementation that uses SQLCipher as its persistence layer."""
+
+ _index_storage_value = 'expand referenced encrypted'
+
+ @classmethod
+ def set_pragma_key(cls, db_handle, key):
+ db_handle.cursor().execute("PRAGMA key = '%s'" % key)
+
+ def __init__(self, sqlite_file, password, document_factory=None):
+ """Create a new sqlcipher file."""
+ self._check_if_db_is_encrypted(sqlite_file)
+ self._db_handle = dbapi2.connect(sqlite_file)
+ SQLCipherDatabase.set_pragma_key(self._db_handle, password)
+ self._real_replica_uid = None
+ self._ensure_schema()
+ self._factory = document_factory or LeapDocument
+
+ def _check_if_db_is_encrypted(self, sqlite_file):
+ if not os.path.exists(sqlite_file):
+ return
+ else:
+ try:
+ # try to open an encrypted database with the regular u1db
+ # backend should raise a DatabaseError exception.
+ SQLitePartialExpandDatabase(sqlite_file)
+ raise DatabaseIsNotEncrypted()
+ except DatabaseError:
+ pass
+
+ @classmethod
+ def _open_database(cls, sqlite_file, password, document_factory=None):
+ if not os.path.isfile(sqlite_file):
+ raise errors.DatabaseDoesNotExist()
+ tries = 2
+ while True:
+ # Note: There seems to be a bug in sqlite 3.5.9 (with python2.6)
+ # where without re-opening the database on Windows, it
+ # doesn't see the transaction that was just committed
+ db_handle = dbapi2.connect(sqlite_file)
+ SQLCipherDatabase.set_pragma_key(db_handle, password)
+ c = db_handle.cursor()
+ v, err = cls._which_index_storage(c)
+ db_handle.close()
+ if v is not None:
+ break
+ # possibly another process is initializing it, wait for it to be
+ # done
+ if tries == 0:
+ raise err # go for the richest error?
+ tries -= 1
+ time.sleep(cls.WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL)
+ return SQLCipherDatabase._sqlite_registry[v](
+ sqlite_file, password, document_factory=document_factory)
+
+ @classmethod
+ def open_database(cls, sqlite_file, password, create, backend_cls=None,
+ document_factory=None):
+ try:
+ return cls._open_database(sqlite_file, password,
+ document_factory=document_factory)
+ except errors.DatabaseDoesNotExist:
+ if not create:
+ raise
+ if backend_cls is None:
+ # default is SQLCipherPartialExpandDatabase
+ backend_cls = SQLCipherDatabase
+ return backend_cls(sqlite_file, password,
+ document_factory=document_factory)
+
+ def sync(self, url, creds=None, autocreate=True, soledad=None):
+ """
+ Synchronize encrypted documents with remote replica exposed at url.
+ """
+ from u1db.sync import Synchronizer
+ from leap.soledad.backends.leap_backend import LeapSyncTarget
+ return Synchronizer(self, LeapSyncTarget(url, creds=creds),
+ soledad=self._soledad).sync(autocreate=autocreate)
+
+ def _extra_schema_init(self, c):
+ c.execute(
+ 'ALTER TABLE document '
+ 'ADD COLUMN syncable BOOL NOT NULL DEFAULT TRUE')
+
+ def _put_and_update_indexes(self, old_doc, doc):
+ super(SQLCipherDatabase, self)._put_and_update_indexes(old_doc, doc)
+ c = self._db_handle.cursor()
+ c.execute('UPDATE document SET syncable=? WHERE doc_id=?',
+ (doc.syncable, doc.doc_id))
+
+ def _get_doc(self, doc_id, check_for_conflicts=False):
+ doc = super(SQLCipherDatabase, self)._get_doc(doc_id,
+ check_for_conflicts)
+ if doc:
+ c = self._db_handle.cursor()
+ c.execute('SELECT syncable FROM document WHERE doc_id=?',
+ (doc.doc_id,))
+ doc.syncable = bool(c.fetchone()[0])
+ return doc
+
+
+SQLiteDatabase.register_implementation(SQLCipherDatabase)
diff --git a/src/leap/soledad/tests/__init__.py b/src/leap/soledad/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/soledad/tests/__init__.py
diff --git a/src/leap/soledad/tests/test_couch.py b/src/leap/soledad/tests/test_couch.py
new file mode 100644
index 00000000..6c3d7daf
--- /dev/null
+++ b/src/leap/soledad/tests/test_couch.py
@@ -0,0 +1,201 @@
+"""Test ObjectStore backend bits.
+
+For these tests to run, a couch server has to be running on (default) port
+5984.
+"""
+
+import copy
+from leap.soledad.backends import couch
+from leap.soledad.tests import u1db_tests as tests
+from leap.soledad.tests.u1db_tests import test_backends
+from leap.soledad.tests.u1db_tests import test_sync
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_common_backend`.
+#-----------------------------------------------------------------------------
+
+class TestCouchBackendImpl(tests.TestCase):
+
+ def test__allocate_doc_id(self):
+ db = couch.CouchDatabase('http://localhost:5984', 'u1db_tests')
+ doc_id1 = db._allocate_doc_id()
+ self.assertTrue(doc_id1.startswith('D-'))
+ self.assertEqual(34, len(doc_id1))
+ int(doc_id1[len('D-'):], 16)
+ self.assertNotEqual(doc_id1, db._allocate_doc_id())
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_backends`.
+#-----------------------------------------------------------------------------
+
+def make_couch_database_for_test(test, replica_uid):
+ return couch.CouchDatabase('http://localhost:5984', replica_uid,
+ replica_uid=replica_uid or 'test')
+
+
+def copy_couch_database_for_test(test, db):
+ new_db = couch.CouchDatabase('http://localhost:5984',
+ db._replica_uid + '_copy',
+ replica_uid=db._replica_uid or 'test')
+ gen, docs = db.get_all_docs(include_deleted=True)
+ for doc in docs:
+ new_db._put_doc(doc)
+ new_db._transaction_log = copy.deepcopy(db._transaction_log)
+ new_db._conflicts = copy.deepcopy(db._conflicts)
+ new_db._other_generations = copy.deepcopy(db._other_generations)
+ new_db._indexes = copy.deepcopy(db._indexes)
+ new_db._set_u1db_data()
+ return new_db
+
+
+COUCH_SCENARIOS = [
+ ('couch', {'make_database_for_test': make_couch_database_for_test,
+ 'copy_database_for_test': copy_couch_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test, }),
+]
+
+
+class CouchTests(test_backends.AllDatabaseTests):
+
+ scenarios = COUCH_SCENARIOS
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchTests, self).tearDown()
+
+
+class CouchDatabaseTests(test_backends.LocalDatabaseTests):
+
+ scenarios = COUCH_SCENARIOS
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchDatabaseTests, self).tearDown()
+
+
+class CouchValidateGenNTransIdTests(
+ test_backends.LocalDatabaseValidateGenNTransIdTests):
+
+ scenarios = COUCH_SCENARIOS
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchValidateGenNTransIdTests, self).tearDown()
+
+
+class CouchValidateSourceGenTests(
+ test_backends.LocalDatabaseValidateSourceGenTests):
+
+ scenarios = COUCH_SCENARIOS
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchValidateSourceGenTests, self).tearDown()
+
+
+class CouchWithConflictsTests(
+ test_backends.LocalDatabaseWithConflictsTests):
+
+ scenarios = COUCH_SCENARIOS
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchWithConflictsTests, self).tearDown()
+
+
+# Notice: the CouchDB backend is currently used for storing encrypted data in
+# the server, so indexing makes no sense. Thus, we ignore index testing for
+# now.
+
+class CouchIndexTests(test_backends.DatabaseIndexTests):
+
+ scenarios = COUCH_SCENARIOS
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchIndexTests, self).tearDown()
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_sync`.
+#-----------------------------------------------------------------------------
+
+target_scenarios = [
+ ('local', {'create_db_and_target': test_sync._make_local_db_and_target}), ]
+
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+
+class CouchDatabaseSyncTargetTests(test_sync.DatabaseSyncTargetTests):
+
+ scenarios = (tests.multiply_scenarios(COUCH_SCENARIOS, target_scenarios))
+
+ def tearDown(self):
+ self.db.delete_database()
+ super(CouchDatabaseSyncTargetTests, self).tearDown()
+
+ def test_sync_exchange_returns_many_new_docs(self):
+ # This test was replicated to allow dictionaries to be compared after
+ # JSON expansion (because one dictionary may have many different
+ # serialized representations).
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ self.assertEqual(2, new_gen)
+ self.assertEqual(
+ [(doc.doc_id, doc.rev, json.loads(simple_doc), 1),
+ (doc2.doc_id, doc2.rev, json.loads(nested_doc), 2)],
+ [c[:-3] + (json.loads(c[-3]), c[-2]) for c in self.other_changes])
+ if self.whitebox:
+ self.assertEqual(
+ self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs':
+ [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]})
+
+
+sync_scenarios = []
+for name, scenario in COUCH_SCENARIOS:
+ scenario = dict(scenario)
+ scenario['do_sync'] = test_sync.sync_via_synchronizer
+ sync_scenarios.append((name, scenario))
+ scenario = dict(scenario)
+
+
+class CouchDatabaseSyncTests(test_sync.DatabaseSyncTests):
+
+ scenarios = sync_scenarios
+
+ def setUp(self):
+ self.db = None
+ self.db1 = None
+ self.db2 = None
+ self.db3 = None
+ super(CouchDatabaseSyncTests, self).setUp()
+
+ def tearDown(self):
+ self.db and self.db.delete_database()
+ self.db1 and self.db1.delete_database()
+ self.db2 and self.db2.delete_database()
+ self.db3 and self.db3.delete_database()
+ db = self.create_database('test1_copy', 'source')
+ db.delete_database()
+ db = self.create_database('test2_copy', 'target')
+ db.delete_database()
+ db = self.create_database('test3', 'target')
+ db.delete_database()
+ super(CouchDatabaseSyncTests, self).tearDown()
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/tests/test_encrypted.py b/src/leap/soledad/tests/test_encrypted.py
new file mode 100644
index 00000000..af5f0fa4
--- /dev/null
+++ b/src/leap/soledad/tests/test_encrypted.py
@@ -0,0 +1,205 @@
+import unittest2 as unittest
+import os
+
+import u1db
+from leap.soledad import Soledad
+from leap.soledad.backends.leap_backend import LeapDocument
+
+
+class EncryptedSyncTestCase(unittest.TestCase):
+
+ PREFIX = "/var/tmp"
+ GNUPG_HOME = "%s/gnupg" % PREFIX
+ DB1_FILE = "%s/db1.u1db" % PREFIX
+ DB2_FILE = "%s/db2.u1db" % PREFIX
+ EMAIL = 'leap@leap.se'
+
+ def setUp(self):
+ self.db1 = u1db.open(self.DB1_FILE, create=True,
+ document_factory=LeapDocument)
+ self.db2 = u1db.open(self.DB2_FILE, create=True,
+ document_factory=LeapDocument)
+ self.soledad = Soledad(self.EMAIL, gpghome=self.GNUPG_HOME)
+ self.soledad._gpg.import_keys(PUBLIC_KEY)
+ self.soledad._gpg.import_keys(PRIVATE_KEY)
+
+ def tearDown(self):
+ os.unlink(self.DB1_FILE)
+ os.unlink(self.DB2_FILE)
+
+ def test_get_set_encrypted(self):
+ doc1 = LeapDocument(soledad=self.soledad)
+ doc1.content = {'key': 'val'}
+ doc2 = LeapDocument(doc_id=doc1.doc_id,
+ encrypted_json=doc1.get_encrypted_json(),
+ soledad=self.soledad)
+ res1 = doc1.get_json()
+ res2 = doc2.get_json()
+ self.assertEqual(res1, res2, 'incorrect document encryption')
+
+
+# Key material for testing
+KEY_FINGERPRINT = "E36E738D69173C13D709E44F2F455E2824D18DDF"
+PUBLIC_KEY = """
+-----BEGIN PGP PUBLIC KEY BLOCK-----
+Version: GnuPG v1.4.10 (GNU/Linux)
+
+mQINBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz
+iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO
+zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx
+irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT
+huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs
+d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g
+wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb
+hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv
+U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H
+T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i
+Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB
+tBxMZWFwIFRlc3QgS2V5IDxsZWFwQGxlYXAuc2U+iQI3BBMBCAAhBQJQvfnZAhsD
+BQsJCAcDBRUKCQgLBRYCAwEAAh4BAheAAAoJEC9FXigk0Y3fT7EQAKH3IuRniOpb
+T/DDIgwwjz3oxB/W0DDMyPXowlhSOuM0rgGfntBpBb3boezEXwL86NPQxNGGruF5
+hkmecSiuPSvOmQlqlS95NGQp6hNG0YaKColh+Q5NTspFXCAkFch9oqUje0LdxfSP
+QfV9UpeEvGyPmk1I9EJV/YDmZ4+Djge1d7qhVZInz4Rx1NrSyF/Tc2EC0VpjQFsU
+Y9Kb2YBBR7ivG6DBc8ty0jJXi7B4WjkFcUEJviQpMF2dCLdonCehYs1PqsN1N7j+
+eFjQd+hqVMJgYuSGKjvuAEfClM6MQw7+FmFwMyLgK/Ew/DttHEDCri77SPSkOGSI
+txCzhTg6798f6mJr7WcXmHX1w1Vcib5FfZ8vTDFVhz/XgAgArdhPo9V6/1dgSSiB
+KPQ/spsco6u5imdOhckERE0lnAYvVT6KE81TKuhF/b23u7x+Wdew6kK0EQhYA7wy
+7LmlaNXc7rMBQJ9Z60CJ4JDtatBWZ0kNrt2VfdDHVdqBTOpl0CraNUjWE5YMDasr
+K2dF5IX8D3uuYtpZnxqg0KzyLg0tzL0tvOL1C2iudgZUISZNPKbS0z0v+afuAAnx
+2pTC3uezbh2Jt8SWTLhll4i0P4Ps5kZ6HQUO56O+/Z1cWovX+mQekYFmERySDR9n
+3k1uAwLilJmRmepGmvYbB8HloV8HqwgguQINBFC9+dkBEAC0I/xn1uborMgDvBtf
+H0sEhwnXBC849/32zic6udB6/3Efk9nzbSpL3FSOuXITZsZgCHPkKarnoQ2ztMcS
+sh1ke1C5gQGms75UVmM/nS+2YI4vY8OX/GC/on2vUyncqdH+bR6xH5hx4NbWpfTs
+iQHmz5C6zzS/kuabGdZyKRaZHt23WQ7JX/4zpjqbC99DjHcP9BSk7tJ8wI4bkMYD
+uFVQdT9O6HwyKGYwUU4sAQRAj7XCTGvVbT0dpgJwH4RmrEtJoHAx4Whg8mJ710E0
+GCmzf2jqkNuOw76ivgk27Kge+Hw00jmJjQhHY0yVbiaoJwcRrPKzaSjEVNgrpgP3
+lXPRGQArgESsIOTeVVHQ8fhK2YtTeCY9rIiO+L0OX2xo9HK7hfHZZWL6rqymXdyS
+fhzh/f6IPyHFWnvj7Brl7DR8heMikygcJqv+ed2yx7iLyCUJ10g12I48+aEj1aLe
+dP7lna32iY8/Z0SHQLNH6PXO9SlPcq2aFUgKqE75A/0FMk7CunzU1OWr2ZtTLNO1
+WT/13LfOhhuEq9jTyTosn0WxBjJKq18lnhzCXlaw6EAtbA7CUwsD3CTPR56aAXFK
+3I7KXOVAqggrvMe5Tpdg5drfYpI8hZovL5aAgb+7Y5ta10TcJdUhS5K3kFAWe/td
+U0cmWUMDP1UMSQ5Jg6JIQVWhSwARAQABiQIfBBgBCAAJBQJQvfnZAhsMAAoJEC9F
+Xigk0Y3fRwsP/i0ElYCyxeLpWJTwo1iCLkMKz2yX1lFVa9nT1BVTPOQwr/IAc5OX
+NdtbJ14fUsKL5pWgW8OmrXtwZm1y4euI1RPWWubG01ouzwnGzv26UcuHeqC5orZj
+cOnKtL40y8VGMm8LoicVkRJH8blPORCnaLjdOtmA3rx/v2EXrJpSa3AhOy0ZSRXk
+ZSrK68AVNwamHRoBSYyo0AtaXnkPX4+tmO8X8BPfj125IljubvwZPIW9VWR9UqCE
+VPfDR1XKegVb6VStIywF7kmrknM1C5qUY28rdZYWgKorw01hBGV4jTW0cqde3N51
+XT1jnIAa+NoXUM9uQoGYMiwrL7vNsLlyyiW5ayDyV92H/rIuiqhFgbJsHTlsm7I8
+oGheR784BagAA1NIKD1qEO9T6Kz9lzlDaeWS5AUKeXrb7ZJLI1TTCIZx5/DxjLqM
+Tt/RFBpVo9geZQrvLUqLAMwdaUvDXC2c6DaCPXTh65oCZj/hqzlJHH+RoTWWzKI+
+BjXxgUWF9EmZUBrg68DSmI+9wuDFsjZ51BcqvJwxyfxtTaWhdoYqH/UQS+D1FP3/
+diZHHlzwVwPICzM9ooNTgbrcDzyxRkIVqsVwBq7EtzcvgYUyX53yG25Giy6YQaQ2
+ZtQ/VymwFL3XdUWV6B/hU4PVAFvO3qlOtdJ6TpE+nEWgcWjCv5g7RjXX
+=MuOY
+-----END PGP PUBLIC KEY BLOCK-----
+"""
+PRIVATE_KEY = """
+-----BEGIN PGP PRIVATE KEY BLOCK-----
+Version: GnuPG v1.4.10 (GNU/Linux)
+
+lQcYBFC9+dkBEADNRfwV23TWEoGc/x0wWH1P7PlXt8MnC2Z1kKaKKmfnglVrpOiz
+iLWoiU58sfZ0L5vHkzXHXCBf6Eiy/EtUIvdiWAn+yASJ1mk5jZTBKO/WMAHD8wTO
+zpMsFmWyg3xc4DkmFa9KQ5EVU0o/nqPeyQxNMQN7px5pPwrJtJFmPxnxm+aDkPYx
+irDmz/4DeDNqXliazGJKw7efqBdlwTHkl9Akw2gwy178pmsKwHHEMOBOFFvX61AT
+huKqHYmlCGSliwbrJppTG7jc1/ls3itrK+CWTg4txREkSpEVmfcASvw/ZqLbjgfs
+d/INMwXnR9U81O8+7LT6yw/ca4ppcFoJD7/XJbkRiML6+bJ4Dakiy6i727BzV17g
+wI1zqNvm5rAhtALKfACha6YO43aJzairO4II1wxVHvRDHZn2IuKDDephQ3Ii7/vb
+hUOf6XCSmchkAcpKXUOvbxm1yfB1LRa64mMc2RcZxf4mW7KQkulBsdV5QG2276lv
+U2UUy2IutXcGP5nXC+f6sJJGJeEToKJ57yiO/VWJFjKN8SvP+7AYsQSqINUuEf6H
+T5gCPCraGMkTUTPXrREvu7NOohU78q6zZNaL3GW8ai7eSeANSuQ8Vzffx7Wd8Y7i
+Pw9sYj0SMFs1UgjbuL6pO5ueHh+qyumbtAq2K0Bci0kqOcU4E9fNtdiovQARAQAB
+AA/+JHtlL39G1wsH9R6UEfUQJGXR9MiIiwZoKcnRB2o8+DS+OLjg0JOh8XehtuCs
+E/8oGQKtQqa5bEIstX7IZoYmYFiUQi9LOzIblmp2vxOm+HKkxa4JszWci2/ZmC3t
+KtaA4adl9XVnshoQ7pijuCMUKB3naBEOAxd8s9d/JeReGIYkJErdrnVfNk5N71Ds
+FmH5Ll3XtEDvgBUQP3nkA6QFjpsaB94FHjL3gDwum/cxzj6pCglcvHOzEhfY0Ddb
+J967FozQTaf2JW3O+w3LOqtcKWpq87B7+O61tVidQPSSuzPjCtFF0D2LC9R/Hpky
+KTMQ6CaKja4MPhjwywd4QPcHGYSqjMpflvJqi+kYIt8psUK/YswWjnr3r4fbuqVY
+VhtiHvnBHQjz135lUqWvEz4hM3Xpnxydx7aRlv5NlevK8+YIO5oFbWbGNTWsPZI5
+jpoFBpSsnR1Q5tnvtNHauvoWV+XN2qAOBTG+/nEbDYH6Ak3aaE9jrpTdYh0CotYF
+q7csANsDy3JvkAzeU6WnYpsHHaAjqOGyiZGsLej1UcXPFMosE/aUo4WQhiS8Zx2c
+zOVKOi/X5vQ2GdNT9Qolz8AriwzsvFR+bxPzyd8V6ALwDsoXvwEYinYBKK8j0OPv
+OOihSR6HVsuP9NUZNU9ewiGzte/+/r6pNXHvR7wTQ8EWLcEIAN6Zyrb0bHZTIlxt
+VWur/Ht2mIZrBaO50qmM5RD3T5oXzWXi/pjLrIpBMfeZR9DWfwQwjYzwqi7pxtYx
+nJvbMuY505rfnMoYxb4J+cpRXV8MS7Dr1vjjLVUC9KiwSbM3gg6emfd2yuA93ihv
+Pe3mffzLIiQa4mRE3wtGcioC43nWuV2K2e1KjxeFg07JhrezA/1Cak505ab/tmvP
+4YmjR5c44+yL/YcQ3HdFgs4mV+nVbptRXvRcPpolJsgxPccGNdvHhsoR4gwXMS3F
+RRPD2z6x8xeN73Q4KH3bm01swQdwFBZbWVfmUGLxvN7leCdfs9+iFJyqHiCIB6Iv
+mQfp8F0IAOwSo8JhWN+V1dwML4EkIrM8wUb4yecNLkyR6TpPH/qXx4PxVMC+vy6x
+sCtjeHIwKE+9vqnlhd5zOYh7qYXEJtYwdeDDmDbL8oks1LFfd+FyAuZXY33DLwn0
+cRYsr2OEZmaajqUB3NVmj3H4uJBN9+paFHyFSXrH68K1Fk2o3n+RSf2EiX+eICwI
+L6rqoF5sSVUghBWdNegV7qfy4anwTQwrIMGjgU5S6PKW0Dr/3iO5z3qQpGPAj5OW
+ATqPWkDICLbObPxD5cJlyyNE2wCA9VVc6/1d6w4EVwSq9h3/WTpATEreXXxTGptd
+LNiTA1nmakBYNO2Iyo3djhaqBdWjk+EIAKtVEnJH9FAVwWOvaj1RoZMA5DnDMo7e
+SnhrCXl8AL7Z1WInEaybasTJXn1uQ8xY52Ua4b8cbuEKRKzw/70NesFRoMLYoHTO
+dyeszvhoDHberpGRTciVmpMu7Hyi33rM31K9epA4ib6QbbCHnxkWOZB+Bhgj1hJ8
+xb4RBYWiWpAYcg0+DAC3w9gfxQhtUlZPIbmbrBmrVkO2GVGUj8kH6k4UV6kUHEGY
+HQWQR0HcbKcXW81ZXCCD0l7ROuEWQtTe5Jw7dJ4/QFuqZnPutXVRNOZqpl6eRShw
+7X2/a29VXBpmHA95a88rSQsL+qm7Fb3prqRmuMCtrUZgFz7HLSTuUMR867QcTGVh
+cCBUZXN0IEtleSA8bGVhcEBsZWFwLnNlPokCNwQTAQgAIQUCUL352QIbAwULCQgH
+AwUVCgkICwUWAgMBAAIeAQIXgAAKCRAvRV4oJNGN30+xEACh9yLkZ4jqW0/wwyIM
+MI896MQf1tAwzMj16MJYUjrjNK4Bn57QaQW926HsxF8C/OjT0MTRhq7heYZJnnEo
+rj0rzpkJapUveTRkKeoTRtGGigqJYfkOTU7KRVwgJBXIfaKlI3tC3cX0j0H1fVKX
+hLxsj5pNSPRCVf2A5mePg44HtXe6oVWSJ8+EcdTa0shf03NhAtFaY0BbFGPSm9mA
+QUe4rxugwXPLctIyV4uweFo5BXFBCb4kKTBdnQi3aJwnoWLNT6rDdTe4/nhY0Hfo
+alTCYGLkhio77gBHwpTOjEMO/hZhcDMi4CvxMPw7bRxAwq4u+0j0pDhkiLcQs4U4
+Ou/fH+pia+1nF5h19cNVXIm+RX2fL0wxVYc/14AIAK3YT6PVev9XYEkogSj0P7Kb
+HKOruYpnToXJBERNJZwGL1U+ihPNUyroRf29t7u8flnXsOpCtBEIWAO8Muy5pWjV
+3O6zAUCfWetAieCQ7WrQVmdJDa7dlX3Qx1XagUzqZdAq2jVI1hOWDA2rKytnReSF
+/A97rmLaWZ8aoNCs8i4NLcy9Lbzi9QtornYGVCEmTTym0tM9L/mn7gAJ8dqUwt7n
+s24dibfElky4ZZeItD+D7OZGeh0FDuejvv2dXFqL1/pkHpGBZhEckg0fZ95NbgMC
+4pSZkZnqRpr2GwfB5aFfB6sIIJ0HGARQvfnZARAAtCP8Z9bm6KzIA7wbXx9LBIcJ
+1wQvOPf99s4nOrnQev9xH5PZ820qS9xUjrlyE2bGYAhz5Cmq56ENs7THErIdZHtQ
+uYEBprO+VFZjP50vtmCOL2PDl/xgv6J9r1Mp3KnR/m0esR+YceDW1qX07IkB5s+Q
+us80v5LmmxnWcikWmR7dt1kOyV/+M6Y6mwvfQ4x3D/QUpO7SfMCOG5DGA7hVUHU/
+Tuh8MihmMFFOLAEEQI+1wkxr1W09HaYCcB+EZqxLSaBwMeFoYPJie9dBNBgps39o
+6pDbjsO+or4JNuyoHvh8NNI5iY0IR2NMlW4mqCcHEazys2koxFTYK6YD95Vz0RkA
+K4BErCDk3lVR0PH4StmLU3gmPayIjvi9Dl9saPRyu4Xx2WVi+q6spl3ckn4c4f3+
+iD8hxVp74+wa5ew0fIXjIpMoHCar/nndsse4i8glCddINdiOPPmhI9Wi3nT+5Z2t
+9omPP2dEh0CzR+j1zvUpT3KtmhVICqhO+QP9BTJOwrp81NTlq9mbUyzTtVk/9dy3
+zoYbhKvY08k6LJ9FsQYySqtfJZ4cwl5WsOhALWwOwlMLA9wkz0eemgFxStyOylzl
+QKoIK7zHuU6XYOXa32KSPIWaLy+WgIG/u2ObWtdE3CXVIUuSt5BQFnv7XVNHJllD
+Az9VDEkOSYOiSEFVoUsAEQEAAQAP/1AagnZQZyzHDEgw4QELAspYHCWLXE5aZInX
+wTUJhK31IgIXNn9bJ0hFiSpQR2xeMs9oYtRuPOu0P8oOFMn4/z374fkjZy8QVY3e
+PlL+3EUeqYtkMwlGNmVw5a/NbNuNfm5Darb7pEfbYd1gPcni4MAYw7R2SG/57GbC
+9gucvspHIfOSfBNLBthDzmK8xEKe1yD2eimfc2T7IRYb6hmkYfeds5GsqvGI6mwI
+85h4uUHWRc5JOlhVM6yX8hSWx0L60Z3DZLChmc8maWnFXd7C8eQ6P1azJJbW71Ih
+7CoK0XW4LE82vlQurSRFgTwfl7wFYszW2bOzCuhHDDtYnwH86Nsu0DC78ZVRnvxn
+E8Ke/AJgrdhIOo4UAyR+aZD2+2mKd7/waOUTUrUtTzc7i8N3YXGi/EIaNReBXaq+
+ZNOp24BlFzRp+FCF/pptDW9HjPdiV09x0DgICmeZS4Gq/4vFFIahWctg52NGebT0
+Idxngjj+xDtLaZlLQoOz0n5ByjO/Wi0ANmMv1sMKCHhGvdaSws2/PbMR2r4caj8m
+KXpIgdinM/wUzHJ5pZyF2U/qejsRj8Kw8KH/tfX4JCLhiaP/mgeTuWGDHeZQERAT
+xPmRFHaLP9/ZhvGNh6okIYtrKjWTLGoXvKLHcrKNisBLSq+P2WeFrlme1vjvJMo/
+jPwLT5o9CADQmcbKZ+QQ1ZM9v99iDZol7SAMZX43JC019sx6GK0u6xouJBcLfeB4
+OXacTgmSYdTa9RM9fbfVpti01tJ84LV2SyL/VJq/enJF4XQPSynT/tFTn1PAor6o
+tEAAd8fjKdJ6LnD5wb92SPHfQfXqI84rFEO8rUNIE/1ErT6DYifDzVCbfD2KZdoF
+cOSp7TpD77sY1bs74ocBX5ejKtd+aH99D78bJSMM4pSDZsIEwnomkBHTziubPwJb
+OwnATy0LmSMAWOw5rKbsh5nfwCiUTM20xp0t5JeXd+wPVWbpWqI2EnkCEN+RJr9i
+7dp/ymDQ+Yt5wrsN3NwoyiexPOG91WQVCADdErHsnglVZZq9Z8Wx7KwecGCUurJ2
+H6lKudv5YOxPnAzqZS5HbpZd/nRTMZh2rdXCr5m2YOuewyYjvM757AkmUpM09zJX
+MQ1S67/UX2y8/74TcRF97Ncx9HeELs92innBRXoFitnNguvcO6Esx4BTe1OdU6qR
+ER3zAmVf22Le9ciXbu24DN4mleOH+OmBx7X2PqJSYW9GAMTsRB081R6EWKH7romQ
+waxFrZ4DJzZ9ltyosEJn5F32StyLrFxpcrdLUoEaclZCv2qka7sZvi0EvovDVEBU
+e10jOx9AOwf8Gj2ufhquQ6qgVYCzbP+YrodtkFrXRS3IsljIchj1M2ffB/0bfoUs
+rtER9pLvYzCjBPg8IfGLw0o754Qbhh/ReplCRTusP/fQMybvCvfxreS3oyEriu/G
+GufRomjewZ8EMHDIgUsLcYo2UHZsfF7tcazgxMGmMvazp4r8vpgrvW/8fIN/6Adu
+tF+WjWDTvJLFJCe6O+BFJOWrssNrrra1zGtLC1s8s+Wfpe+bGPL5zpHeebGTwH1U
+22eqgJArlEKxrfarz7W5+uHZJHSjF/K9ZvunLGD0n9GOPMpji3UO3zeM8IYoWn7E
+/EWK1XbjnssNemeeTZ+sDh+qrD7BOi+vCX1IyBxbfqnQfJZvmcPWpruy1UsO+aIC
+0GY8Jr3OL69dDQ21jueJAh8EGAEIAAkFAlC9+dkCGwwACgkQL0VeKCTRjd9HCw/+
+LQSVgLLF4ulYlPCjWIIuQwrPbJfWUVVr2dPUFVM85DCv8gBzk5c121snXh9Swovm
+laBbw6ate3BmbXLh64jVE9Za5sbTWi7PCcbO/bpRy4d6oLmitmNw6cq0vjTLxUYy
+bwuiJxWREkfxuU85EKdouN062YDevH+/YResmlJrcCE7LRlJFeRlKsrrwBU3BqYd
+GgFJjKjQC1peeQ9fj62Y7xfwE9+PXbkiWO5u/Bk8hb1VZH1SoIRU98NHVcp6BVvp
+VK0jLAXuSauSczULmpRjbyt1lhaAqivDTWEEZXiNNbRyp17c3nVdPWOcgBr42hdQ
+z25CgZgyLCsvu82wuXLKJblrIPJX3Yf+si6KqEWBsmwdOWybsjygaF5HvzgFqAAD
+U0goPWoQ71PorP2XOUNp5ZLkBQp5etvtkksjVNMIhnHn8PGMuoxO39EUGlWj2B5l
+Cu8tSosAzB1pS8NcLZzoNoI9dOHrmgJmP+GrOUkcf5GhNZbMoj4GNfGBRYX0SZlQ
+GuDrwNKYj73C4MWyNnnUFyq8nDHJ/G1NpaF2hiof9RBL4PUU/f92JkceXPBXA8gL
+Mz2ig1OButwPPLFGQhWqxXAGrsS3Ny+BhTJfnfIbbkaLLphBpDZm1D9XKbAUvdd1
+RZXoH+FTg9UAW87eqU610npOkT6cRaBxaMK/mDtGNdc=
+=JTFu
+-----END PGP PRIVATE KEY BLOCK-----
+"""
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/src/leap/soledad/tests/test_leap_backend.py b/src/leap/soledad/tests/test_leap_backend.py
new file mode 100644
index 00000000..c19ca666
--- /dev/null
+++ b/src/leap/soledad/tests/test_leap_backend.py
@@ -0,0 +1,379 @@
+"""Test ObjectStore backend bits.
+
+For these tests to run, a leap server has to be running on (default) port
+5984.
+"""
+
+import os
+import unittest2 as unittest
+import u1db
+from leap.soledad import Soledad
+from leap.soledad.backends import leap_backend
+from leap.soledad.tests import u1db_tests as tests
+from leap.soledad.tests.u1db_tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+)
+from leap.soledad.tests.u1db_tests import test_backends
+from leap.soledad.tests.u1db_tests import test_http_database
+from leap.soledad.tests.u1db_tests import test_http_client
+from leap.soledad.tests.u1db_tests import test_document
+from leap.soledad.tests.u1db_tests import test_remote_sync_target
+from leap.soledad.tests.u1db_tests import test_https
+from leap.soledad.tests.test_encrypted import (
+ PUBLIC_KEY,
+ PRIVATE_KEY,
+)
+
+
+#-----------------------------------------------------------------------------
+# The EncryptedSyncTest is used with multiple inheritance to guarantee that we
+# have a working Soledad instance in each test.
+#-----------------------------------------------------------------------------
+
+class SoledadTest(unittest.TestCase):
+
+ PREFIX = "/var/tmp"
+ GNUPG_HOME = "%s/gnupg" % PREFIX
+ DB1_FILE = "%s/db1.u1db" % PREFIX
+ DB2_FILE = "%s/db2.u1db" % PREFIX
+ EMAIL = 'leap@leap.se'
+
+ def setUp(self):
+ super(SoledadTest, self).setUp()
+ self._db1 = u1db.open(self.DB1_FILE, create=True,
+ document_factory=leap_backend.LeapDocument)
+ self._db2 = u1db.open(self.DB2_FILE, create=True,
+ document_factory=leap_backend.LeapDocument)
+ self._soledad = Soledad(self.EMAIL, gpghome=self.GNUPG_HOME)
+ self._soledad._gpg.import_keys(PUBLIC_KEY)
+ self._soledad._gpg.import_keys(PRIVATE_KEY)
+
+ def tearDown(self):
+ super(SoledadTest, self).tearDown()
+ os.unlink(self.DB1_FILE)
+ os.unlink(self.DB2_FILE)
+ #rmtree(self.GNUPG_HOME)
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_common_backend`.
+#-----------------------------------------------------------------------------
+
+class TestLeapBackendImpl(tests.TestCase):
+
+ def test__allocate_doc_id(self):
+ db = leap_backend.LeapDatabase('test')
+ doc_id1 = db._allocate_doc_id()
+ self.assertTrue(doc_id1.startswith('D-'))
+ self.assertEqual(34, len(doc_id1))
+ int(doc_id1[len('D-'):], 16)
+ self.assertNotEqual(doc_id1, db._allocate_doc_id())
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_backends`.
+#-----------------------------------------------------------------------------
+
+def make_leap_database_for_test(test, replica_uid, path='test'):
+ test.startServer()
+ test.request_state._create_database(replica_uid)
+ return leap_backend.LeapDatabase(test.getURL(path))
+
+
+def copy_leap_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ return test.request_state._copy_database(db)
+
+
+def make_oauth_leap_database_for_test(test, replica_uid):
+ http_db = make_leap_database_for_test(test, replica_uid, '~/test')
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+def make_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ return leap_backend.LeapDocument(
+ doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+def make_leap_document_for_test(test, doc_id, rev, content,
+ has_conflicts=False):
+ return leap_backend.LeapDocument(
+ doc_id, rev, content, has_conflicts=has_conflicts,
+ soledad=test._soledad)
+
+
+def make_leap_encrypted_document_for_test(test, doc_id, rev, encrypted_content,
+ has_conflicts=False):
+ return leap_backend.LeapDocument(
+ doc_id, rev, encrypted_json=encrypted_content,
+ has_conflicts=has_conflicts,
+ soledad=test._soledad)
+
+
+LEAP_SCENARIOS = [
+ ('http', {'make_database_for_test': make_leap_database_for_test,
+ 'copy_database_for_test': copy_leap_database_for_test,
+ 'make_document_for_test': make_leap_document_for_test,
+ 'make_app_with_state': make_http_app}),
+]
+
+
+class LeapTests(test_backends.AllDatabaseTests, SoledadTest):
+
+ scenarios = LEAP_SCENARIOS
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_http_database`.
+#-----------------------------------------------------------------------------
+
+class TestLeapDatabaseSimpleOperations(
+ test_http_database.TestHTTPDatabaseSimpleOperations):
+
+ def setUp(self):
+ super(test_http_database.TestHTTPDatabaseSimpleOperations,
+ self).setUp()
+ self.db = leap_backend.LeapDatabase('dbase')
+ self.db._conn = object() # crash if used
+ self.got = None
+ self.response_val = None
+
+ def _request(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ def _request_json(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ self.db._request = _request
+ self.db._request_json = _request_json
+
+ def test_get_sync_target(self):
+ st = self.db.get_sync_target()
+ self.assertIsInstance(st, leap_backend.LeapSyncTarget)
+ self.assertEqual(st._url, self.db._url)
+
+
+class TestLeapDatabaseCtrWithCreds(
+ test_http_database.TestHTTPDatabaseCtrWithCreds):
+ pass
+
+
+class TestLeapDatabaseIntegration(
+ test_http_database.TestHTTPDatabaseIntegration):
+
+ def test_non_existing_db(self):
+ db = leap_backend.LeapDatabase(self.getURL('not-there'))
+ self.assertRaises(u1db.errors.DatabaseDoesNotExist, db.get_doc, 'doc1')
+
+ def test__ensure(self):
+ db = leap_backend.LeapDatabase(self.getURL('new'))
+ db._ensure()
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test__delete(self):
+ self.request_state._create_database('db0')
+ db = leap_backend.LeapDatabase(self.getURL('db0'))
+ db._delete()
+ self.assertRaises(u1db.errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_open_database_existing(self):
+ self.request_state._create_database('db0')
+ db = leap_backend.LeapDatabase.open_database(self.getURL('db0'),
+ create=False)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_open_database_non_existing(self):
+ self.assertRaises(u1db.errors.DatabaseDoesNotExist,
+ leap_backend.LeapDatabase.open_database,
+ self.getURL('not-there'),
+ create=False)
+
+ def test_open_database_create(self):
+ db = leap_backend.LeapDatabase.open_database(self.getURL('new'),
+ create=True)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_delete_database_existing(self):
+ self.request_state._create_database('db0')
+ leap_backend.LeapDatabase.delete_database(self.getURL('db0'))
+ self.assertRaises(u1db.errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_doc_ids_needing_quoting(self):
+ db0 = self.request_state._create_database('db0')
+ db = leap_backend.LeapDatabase.open_database(self.getURL('db0'),
+ create=False)
+ doc = leap_backend.LeapDocument('%fff', None, '{}')
+ db.put_doc(doc)
+ self.assertGetDoc(db0, '%fff', doc.rev, '{}', False)
+ self.assertGetDoc(db, '%fff', doc.rev, '{}', False)
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_http_client`.
+#-----------------------------------------------------------------------------
+
+class TestLeapClientBase(test_http_client.TestHTTPClientBase):
+ pass
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_document`.
+#-----------------------------------------------------------------------------
+
+class TestLeapDocument(test_document.TestDocument, SoledadTest):
+
+ scenarios = ([(
+ 'leap', {'make_document_for_test': make_leap_document_for_test})])
+
+
+class TestLeapPyDocument(test_document.TestPyDocument, SoledadTest):
+
+ scenarios = ([(
+ 'leap', {'make_document_for_test': make_leap_document_for_test})])
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_remote_sync_target`.
+#-----------------------------------------------------------------------------
+
+class TestLeapSyncTargetBasics(
+ test_remote_sync_target.TestHTTPSyncTargetBasics):
+
+ def test_parse_url(self):
+ remote_target = leap_backend.LeapSyncTarget('http://127.0.0.1:12345/')
+ self.assertEqual('http', remote_target._url.scheme)
+ self.assertEqual('127.0.0.1', remote_target._url.hostname)
+ self.assertEqual(12345, remote_target._url.port)
+ self.assertEqual('/', remote_target._url.path)
+
+
+class TestLeapParsingSyncStream(test_remote_sync_target.TestParsingSyncStream):
+
+ def test_wrong_start(self):
+ tgt = leap_backend.LeapSyncTarget("http://foo/foo")
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "{}\r\n]", None)
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "\r\n{}\r\n]", None)
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "", None)
+
+ def test_wrong_end(self):
+ tgt = leap_backend.LeapSyncTarget("http://foo/foo")
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{}", None)
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n", None)
+
+ def test_missing_comma(self):
+ tgt = leap_backend.LeapSyncTarget("http://foo/foo")
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{}\r\n{"id": "i", "rev": "r", '
+ '"content": "c", "gen": 3}\r\n]', None)
+
+ def test_no_entries(self):
+ tgt = leap_backend.LeapSyncTarget("http://foo/foo")
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n]", None)
+
+ def test_extra_comma(self):
+ tgt = leap_backend.LeapSyncTarget("http://foo/foo")
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{},\r\n]", None)
+
+ self.assertRaises(leap_backend.NoSoledadInstance,
+ tgt._parse_sync_stream,
+ '[\r\n{},\r\n{"id": "i", "rev": "r", '
+ '"content": "{}", "gen": 3, "trans_id": "T-sid"}'
+ ',\r\n]',
+ lambda doc, gen, trans_id: None)
+
+ def test_error_in_stream(self):
+ tgt = leap_backend.LeapSyncTarget("http://foo/foo")
+
+ self.assertRaises(u1db.errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"new_generation": 0},'
+ '\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(u1db.errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(u1db.errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "?"}\r\n', None)
+
+
+def leap_sync_target(test, path):
+ return leap_backend.LeapSyncTarget(test.getURL(path))
+
+
+def oauth_leap_sync_target(test, path):
+ st = leap_sync_target(test, '~/' + path)
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestRemoteSyncTargets(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('http', {'make_app_with_state': make_http_app,
+ 'make_document_for_test': make_leap_document_for_test,
+ 'sync_target': leap_sync_target}),
+ ('oauth_http', {'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': make_leap_document_for_test,
+ 'sync_target': oauth_leap_sync_target}),
+ ]
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_https`.
+#-----------------------------------------------------------------------------
+
+def oauth_https_sync_target(test, host, path):
+ _, port = test.server.server_address
+ st = leap_backend.LeapSyncTarget('https://%s:%d/~/%s' % (host, port, path))
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestLeapSyncTargetHttpsSupport(test_https.TestHttpSyncTargetHttpsSupport,
+ SoledadTest):
+
+ scenarios = [
+ ('oauth_https', {'server_def': test_https.https_server_def,
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': make_leap_document_for_test,
+ 'sync_target': oauth_https_sync_target,
+ }), ]
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/tests/test_logs.py b/src/leap/soledad/tests/test_logs.py
new file mode 100644
index 00000000..3dfeff75
--- /dev/null
+++ b/src/leap/soledad/tests/test_logs.py
@@ -0,0 +1,96 @@
+import unittest2 as unittest
+from leap.soledad.backends.objectstore import (
+ TransactionLog,
+ SyncLog,
+ ConflictLog
+)
+
+
+class LogTestCase(unittest.TestCase):
+
+ def test_transaction_log(self):
+ data = [
+ (2, "doc_3", "tran_3"),
+ (3, "doc_2", "tran_2"),
+ (1, "doc_1", "tran_1")
+ ]
+ log = TransactionLog()
+ log.log = data
+ self.assertEqual(log.get_generation(), 3, 'error getting generation')
+ self.assertEqual(log.get_generation_info(), (3, 'tran_2'),
+ 'error getting generation info')
+ self.assertEqual(log.get_trans_id_for_gen(1), 'tran_1',
+ 'error getting trans_id for gen')
+ self.assertEqual(log.get_trans_id_for_gen(2), 'tran_3',
+ 'error getting trans_id for gen')
+ self.assertEqual(log.get_trans_id_for_gen(3), 'tran_2',
+ 'error getting trans_id for gen')
+
+ def test_sync_log(self):
+ data = [
+ ("replica_3", 3, "tran_3"),
+ ("replica_2", 2, "tran_2"),
+ ("replica_1", 1, "tran_1")
+ ]
+ log = SyncLog()
+ log.log = data
+ # test getting
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'),
+ (3, 'tran_3'),
+ 'error getting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'),
+ (2, 'tran_2'),
+ 'error getting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'),
+ (1, 'tran_1'),
+ 'error getting replica gen and trans id')
+ # test setting
+ log.set_replica_gen_and_trans_id('replica_1', 2, 'tran_12')
+ self.assertEqual(len(log._data), 3, 'error in log size after setting')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_1'),
+ (2, 'tran_12'),
+ 'error setting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_2'),
+ (2, 'tran_2'),
+ 'error setting replica gen and trans id')
+ self.assertEqual(log.get_replica_gen_and_trans_id('replica_3'),
+ (3, 'tran_3'),
+ 'error setting replica gen and trans id')
+
+ def test_whats_changed(self):
+ data = [
+ (1, "doc_1", "tran_1"),
+ (2, "doc_2", "tran_2"),
+ (3, "doc_3", "tran_3")
+ ]
+ log = TransactionLog()
+ log.log = data
+ self.assertEqual(
+ log.whats_changed(3),
+ (3, "tran_3", []),
+ 'error getting whats changed.')
+ self.assertEqual(
+ log.whats_changed(2),
+ (3, "tran_3", [("doc_3", 3, "tran_3")]),
+ 'error getting whats changed.')
+ self.assertEqual(
+ log.whats_changed(1),
+ (3, "tran_3", [("doc_2", 2, "tran_2"), ("doc_3", 3, "tran_3")]),
+ 'error getting whats changed.')
+
+ def test_conflict_log(self):
+ # TODO: include tests for `get_conflicts` and `has_conflicts`.
+ data = [('1', 'my:1', 'irrelevant'),
+ ('2', 'my:1', 'irrelevant'),
+ ('3', 'my:1', 'irrelevant')]
+ log = ConflictLog(None)
+ log.log = data
+ log.delete_conflicts([('1', 'my:1'), ('2', 'my:1')])
+ self.assertEqual(
+ log.log,
+ [('3', 'my:1', 'irrelevant')],
+ 'error deleting conflicts.')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/src/leap/soledad/tests/test_sqlcipher.py b/src/leap/soledad/tests/test_sqlcipher.py
new file mode 100644
index 00000000..a3ab35b6
--- /dev/null
+++ b/src/leap/soledad/tests/test_sqlcipher.py
@@ -0,0 +1,374 @@
+"""Test sqlcipher backend internals."""
+
+import os
+import time
+from sqlite3 import dbapi2, DatabaseError
+import unittest2 as unittest
+from StringIO import StringIO
+import threading
+
+# u1db stuff.
+from u1db import (
+ errors,
+ query_parser,
+)
+from u1db.backends.sqlite_backend import SQLitePartialExpandDatabase
+
+# soledad stuff.
+from leap.soledad.backends.sqlcipher import (
+ SQLCipherDatabase,
+ DatabaseIsNotEncrypted,
+)
+from leap.soledad.backends.sqlcipher import open as u1db_open
+from leap.soledad.backends.leap_backend import LeapDocument
+
+# u1db tests stuff.
+from leap.soledad.tests import u1db_tests as tests
+from leap.soledad.tests.u1db_tests import test_sqlite_backend
+from leap.soledad.tests.u1db_tests import test_backends
+from leap.soledad.tests.u1db_tests import test_open
+
+PASSWORD = '123456'
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_common_backend`.
+#-----------------------------------------------------------------------------
+
+class TestSQLCipherBackendImpl(tests.TestCase):
+
+ def test__allocate_doc_id(self):
+ db = SQLCipherDatabase(':memory:', PASSWORD)
+ doc_id1 = db._allocate_doc_id()
+ self.assertTrue(doc_id1.startswith('D-'))
+ self.assertEqual(34, len(doc_id1))
+ int(doc_id1[len('D-'):], 16)
+ self.assertNotEqual(doc_id1, db._allocate_doc_id())
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_backends`.
+#-----------------------------------------------------------------------------
+
+def make_sqlcipher_database_for_test(test, replica_uid):
+ db = SQLCipherDatabase(':memory:', PASSWORD)
+ db._set_replica_uid(replica_uid)
+ return db
+
+
+def copy_sqlcipher_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = SQLCipherDatabase(':memory:', PASSWORD)
+ tmpfile = StringIO()
+ for line in db._db_handle.iterdump():
+ if not 'sqlite_sequence' in line: # work around bug in iterdump
+ tmpfile.write('%s\n' % line)
+ tmpfile.seek(0)
+ new_db._db_handle = dbapi2.connect(':memory:')
+ new_db._db_handle.cursor().executescript(tmpfile.read())
+ new_db._db_handle.commit()
+ new_db._set_replica_uid(db._replica_uid)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ return LeapDocument(doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+SQLCIPHER_SCENARIOS = [
+ ('sqlcipher', {'make_database_for_test': make_sqlcipher_database_for_test,
+ 'copy_database_for_test': copy_sqlcipher_database_for_test,
+ 'make_document_for_test': make_document_for_test, }),
+]
+
+
+class SQLCipherTests(test_backends.AllDatabaseTests):
+ scenarios = SQLCIPHER_SCENARIOS
+
+
+class SQLCipherDatabaseTests(test_backends.LocalDatabaseTests):
+ scenarios = SQLCIPHER_SCENARIOS
+
+
+class SQLCipherValidateGenNTransIdTests(
+ test_backends.LocalDatabaseValidateGenNTransIdTests):
+ scenarios = SQLCIPHER_SCENARIOS
+
+
+class SQLCipherValidateSourceGenTests(
+ test_backends.LocalDatabaseValidateSourceGenTests):
+ scenarios = SQLCIPHER_SCENARIOS
+
+
+class SQLCipherWithConflictsTests(
+ test_backends.LocalDatabaseWithConflictsTests):
+ scenarios = SQLCIPHER_SCENARIOS
+
+
+class SQLCipherIndexTests(test_backends.DatabaseIndexTests):
+ scenarios = SQLCIPHER_SCENARIOS
+
+
+load_tests = tests.load_with_scenarios
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_sqlite_backend`.
+#-----------------------------------------------------------------------------
+
+class TestSQLCipherDatabase(test_sqlite_backend.TestSQLiteDatabase):
+
+ def test_atomic_initialize(self):
+ tmpdir = self.createTempDir()
+ dbname = os.path.join(tmpdir, 'atomic.db')
+
+ t2 = None # will be a thread
+
+ class SQLCipherDatabaseTesting(SQLitePartialExpandDatabase):
+ _index_storage_value = "testing"
+
+ def __init__(self, dbname, ntry):
+ self._try = ntry
+ self._is_initialized_invocations = 0
+ super(SQLCipherDatabaseTesting, self).__init__(dbname)
+
+ def _is_initialized(self, c):
+ res = super(SQLCipherDatabaseTesting, self)._is_initialized(c)
+ if self._try == 1:
+ self._is_initialized_invocations += 1
+ if self._is_initialized_invocations == 2:
+ t2.start()
+ # hard to do better and have a generic test
+ time.sleep(0.05)
+ return res
+
+ outcome2 = []
+
+ def second_try():
+ try:
+ db2 = SQLCipherDatabaseTesting(dbname, 2)
+ except Exception, e:
+ outcome2.append(e)
+ else:
+ outcome2.append(db2)
+
+ t2 = threading.Thread(target=second_try)
+ db1 = SQLCipherDatabaseTesting(dbname, 1)
+ t2.join()
+
+ self.assertIsInstance(outcome2[0], SQLCipherDatabaseTesting)
+ db2 = outcome2[0]
+ self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor()))
+
+
+class TestAlternativeDocument(LeapDocument):
+ """A (not very) alternative implementation of Document."""
+
+
+class TestSQLCipherPartialExpandDatabase(
+ test_sqlite_backend.TestSQLitePartialExpandDatabase):
+
+ # The following tests had to be cloned from u1db because they all
+ # instantiate the backend directly, so we need to change that in order to
+ # our backend be instantiated in place.
+
+ def setUp(self):
+ super(test_sqlite_backend.TestSQLitePartialExpandDatabase,
+ self).setUp()
+ self.db = SQLCipherDatabase(':memory:', PASSWORD)
+ self.db._set_replica_uid('test')
+
+ def test_default_replica_uid(self):
+ self.db = SQLCipherDatabase(':memory:', PASSWORD)
+ self.assertIsNot(None, self.db._replica_uid)
+ self.assertEqual(32, len(self.db._replica_uid))
+ int(self.db._replica_uid, 16)
+
+ def test__parse_index(self):
+ self.db = SQLCipherDatabase(':memory:', PASSWORD)
+ g = self.db._parse_index_definition('fieldname')
+ self.assertIsInstance(g, query_parser.ExtractField)
+ self.assertEqual(['fieldname'], g.field)
+
+ def test__update_indexes(self):
+ self.db = SQLCipherDatabase(':memory:', PASSWORD)
+ g = self.db._parse_index_definition('fieldname')
+ c = self.db._get_sqlite_handle().cursor()
+ self.db._update_indexes('doc-id', {'fieldname': 'val'},
+ [('fieldname', g)], c)
+ c.execute('SELECT doc_id, field_name, value FROM document_fields')
+ self.assertEqual([('doc-id', 'fieldname', 'val')],
+ c.fetchall())
+
+ def test__set_replica_uid(self):
+ # Start from scratch, so that replica_uid isn't set.
+ self.db = SQLCipherDatabase(':memory:', PASSWORD)
+ self.assertIsNot(None, self.db._real_replica_uid)
+ self.assertIsNot(None, self.db._replica_uid)
+ self.db._set_replica_uid('foo')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'")
+ self.assertEqual(('foo',), c.fetchone())
+ self.assertEqual('foo', self.db._real_replica_uid)
+ self.assertEqual('foo', self.db._replica_uid)
+ self.db._close_sqlite_handle()
+ self.assertEqual('foo', self.db._replica_uid)
+
+ def test__open_database(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ SQLCipherDatabase(path, PASSWORD)
+ db2 = SQLCipherDatabase._open_database(path, PASSWORD)
+ self.assertIsInstance(db2, SQLCipherDatabase)
+
+ def test__open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ SQLCipherDatabase(path, PASSWORD)
+ db2 = SQLCipherDatabase._open_database(
+ path, PASSWORD,
+ document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test_open_database_existing(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ SQLCipherDatabase(path, PASSWORD)
+ db2 = SQLCipherDatabase.open_database(path, PASSWORD, create=False)
+ self.assertIsInstance(db2, SQLCipherDatabase)
+
+ def test_open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ SQLCipherDatabase(path, PASSWORD)
+ db2 = SQLCipherDatabase.open_database(
+ path, PASSWORD, create=False,
+ document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test_create_database_initializes_schema(self):
+ # This test had to be cloned because our implementation of SQLCipher
+ # backend is referenced with an index_storage_value that includes the
+ # word "encrypted". See u1db's sqlite_backend and our
+ # sqlcipher_backend for reference.
+ raw_db = self.db._get_sqlite_handle()
+ c = raw_db.cursor()
+ c.execute("SELECT * FROM u1db_config")
+ config = dict([(r[0], r[1]) for r in c.fetchall()])
+ self.assertEqual({'sql_schema': '0', 'replica_uid': 'test',
+ 'index_storage': 'expand referenced encrypted'},
+ config)
+
+ def test_store_syncable(self):
+ doc = self.db.create_doc_from_json(tests.simple_doc)
+ # assert that docs are syncable by default
+ self.assertEqual(True, doc.syncable)
+ # assert that we can store syncable = False
+ doc.syncable = False
+ self.db.put_doc(doc)
+ self.assertEqual(False, self.db.get_doc(doc.doc_id).syncable)
+ # assert that we can store syncable = True
+ doc.syncable = True
+ self.db.put_doc(doc)
+ self.assertEqual(True, self.db.get_doc(doc.doc_id).syncable)
+
+
+#-----------------------------------------------------------------------------
+# The following tests come from `u1db.tests.test_open`.
+#-----------------------------------------------------------------------------
+
+class SQLCipherOpen(test_open.TestU1DBOpen):
+
+ def test_open_no_create(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ u1db_open, self.db_path,
+ password=PASSWORD,
+ create=False)
+ self.assertFalse(os.path.exists(self.db_path))
+
+ def test_open_create(self):
+ db = u1db_open(self.db_path, password=PASSWORD, create=True)
+ self.addCleanup(db.close)
+ self.assertTrue(os.path.exists(self.db_path))
+ self.assertIsInstance(db, SQLCipherDatabase)
+
+ def test_open_with_factory(self):
+ db = u1db_open(self.db_path, password=PASSWORD, create=True,
+ document_factory=TestAlternativeDocument)
+ self.addCleanup(db.close)
+ self.assertEqual(TestAlternativeDocument, db._factory)
+
+ def test_open_existing(self):
+ db = SQLCipherDatabase(self.db_path, PASSWORD)
+ self.addCleanup(db.close)
+ doc = db.create_doc_from_json(tests.simple_doc)
+ # Even though create=True, we shouldn't wipe the db
+ db2 = u1db_open(self.db_path, password=PASSWORD, create=True)
+ self.addCleanup(db2.close)
+ doc2 = db2.get_doc(doc.doc_id)
+ self.assertEqual(doc, doc2)
+
+ def test_open_existing_no_create(self):
+ db = SQLCipherDatabase(self.db_path, PASSWORD)
+ self.addCleanup(db.close)
+ db2 = u1db_open(self.db_path, password=PASSWORD, create=False)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, SQLCipherDatabase)
+
+
+#-----------------------------------------------------------------------------
+# Tests for actual encryption of the database
+#-----------------------------------------------------------------------------
+
+class SQLCipherEncryptionTest(unittest.TestCase):
+
+ DB_FILE = '/tmp/test.db'
+
+ def delete_dbfiles(self):
+ for dbfile in [self.DB_FILE]:
+ if os.path.exists(dbfile):
+ os.unlink(dbfile)
+
+ def setUp(self):
+ self.delete_dbfiles()
+
+ def tearDown(self):
+ self.delete_dbfiles()
+
+ def test_try_to_open_encrypted_db_with_sqlite_backend(self):
+ db = SQLCipherDatabase(self.DB_FILE, PASSWORD)
+ doc = db.create_doc_from_json(tests.simple_doc)
+ db.close()
+ try:
+ # trying to open an encrypted database with the regular u1db
+ # backend should raise a DatabaseError exception.
+ SQLitePartialExpandDatabase(self.DB_FILE,
+ document_factory=LeapDocument)
+ raise DatabaseIsNotEncrypted()
+ except DatabaseError:
+ # at this point we know that the regular U1DB sqlcipher backend
+ # did not succeed on opening the database, so it was indeed
+ # encrypted.
+ db = SQLCipherDatabase(self.DB_FILE, PASSWORD)
+ doc = db.get_doc(doc.doc_id)
+ self.assertEqual(tests.simple_doc, doc.get_json(),
+ 'decrypted content mismatch')
+
+ def test_try_to_open_raw_db_with_sqlcipher_backend(self):
+ db = SQLitePartialExpandDatabase(self.DB_FILE,
+ document_factory=LeapDocument)
+ db.create_doc_from_json(tests.simple_doc)
+ db.close()
+ try:
+ # trying to open the a non-encrypted database with sqlcipher
+ # backend should raise a DatabaseIsNotEncrypted exception.
+ SQLCipherDatabase(self.DB_FILE, PASSWORD)
+ raise DatabaseError("SQLCipher backend should not be able to open "
+ "non-encrypted dbs.")
+ except DatabaseIsNotEncrypted:
+ pass
diff --git a/src/leap/soledad/tests/u1db_tests/README b/src/leap/soledad/tests/u1db_tests/README
new file mode 100644
index 00000000..605f01fa
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/README
@@ -0,0 +1,34 @@
+General info
+------------
+
+Test files in this directory are derived from u1db-0.1.4 tests. The main
+difference is that:
+
+ (1) they include the test infrastructure packed with soledad; and
+ (2) they do not include c_backend_wrapper testing.
+
+Dependencies
+------------
+
+u1db tests depend on the following python packages:
+
+ nose2
+ unittest2
+ mercurial
+ hgtools
+ testtools
+ discover
+ oauth
+ testscenarios
+ dirspec
+ paste
+ routes
+ simplejson
+ cython
+
+Running tests
+-------------
+
+Use nose2 to run tests:
+
+ nose2 leap.soledad.tests.u1db_tests
diff --git a/src/leap/soledad/tests/u1db_tests/__init__.py b/src/leap/soledad/tests/u1db_tests/__init__.py
new file mode 100644
index 00000000..27aa4d79
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/__init__.py
@@ -0,0 +1,421 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test infrastructure for U1DB"""
+
+import copy
+import shutil
+import socket
+import tempfile
+import threading
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from wsgiref import simple_server
+
+from oauth import oauth
+from sqlite3 import dbapi2
+from StringIO import StringIO
+
+import testscenarios
+import testtools
+
+from u1db import (
+ errors,
+ Document,
+)
+from u1db.backends import (
+ inmemory,
+ sqlite_backend,
+)
+from u1db.remote import (
+ server_state,
+)
+
+
+class TestCase(testtools.TestCase):
+
+ def createTempDir(self, prefix='u1db-tmp-'):
+ """Create a temporary directory to do some work in.
+
+ This directory will be scheduled for cleanup when the test ends.
+ """
+ tempdir = tempfile.mkdtemp(prefix=prefix)
+ self.addCleanup(shutil.rmtree, tempdir)
+ return tempdir
+
+ def make_document(self, doc_id, doc_rev, content, has_conflicts=False):
+ return self.make_document_for_test(
+ self, doc_id, doc_rev, content, has_conflicts)
+
+ def make_document_for_test(self, test, doc_id, doc_rev, content,
+ has_conflicts):
+ return make_document_for_test(
+ test, doc_id, doc_rev, content, has_conflicts)
+
+ def assertGetDoc(self, db, doc_id, doc_rev, content, has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id))
+
+ def assertGetDocIncludeDeleted(self, db, doc_id, doc_rev, content,
+ has_conflicts):
+ """Assert that the document in the database looks correct."""
+ exp_doc = self.make_document(doc_id, doc_rev, content,
+ has_conflicts=has_conflicts)
+ self.assertEqual(exp_doc, db.get_doc(doc_id, include_deleted=True))
+
+ def assertGetDocConflicts(self, db, doc_id, conflicts):
+ """Assert what conflicts are stored for a given doc_id.
+
+ :param conflicts: A list of (doc_rev, content) pairs.
+ The first item must match the first item returned from the
+ database, however the rest can be returned in any order.
+ """
+ if conflicts:
+ conflicts = [(rev,
+ (json.loads(cont) if isinstance(cont, basestring)
+ else cont)) for (rev, cont) in conflicts]
+ conflicts = conflicts[:1] + sorted(conflicts[1:])
+ actual = db.get_doc_conflicts(doc_id)
+ if actual:
+ actual = [
+ (doc.rev, (json.loads(doc.get_json())
+ if doc.get_json() is not None else None))
+ for doc in actual]
+ actual = actual[:1] + sorted(actual[1:])
+ self.assertEqual(conflicts, actual)
+
+
+def multiply_scenarios(a_scenarios, b_scenarios):
+ """Create the cross-product of scenarios."""
+
+ all_scenarios = []
+ for a_name, a_attrs in a_scenarios:
+ for b_name, b_attrs in b_scenarios:
+ name = '%s,%s' % (a_name, b_name)
+ attrs = dict(a_attrs)
+ attrs.update(b_attrs)
+ all_scenarios.append((name, attrs))
+ return all_scenarios
+
+
+simple_doc = '{"key": "value"}'
+nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
+
+
+def make_memory_database_for_test(test, replica_uid):
+ return inmemory.InMemoryDatabase(replica_uid)
+
+
+def copy_memory_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = inmemory.InMemoryDatabase(db._replica_uid)
+ new_db._transaction_log = db._transaction_log[:]
+ new_db._docs = copy.deepcopy(db._docs)
+ new_db._conflicts = copy.deepcopy(db._conflicts)
+ new_db._indexes = copy.deepcopy(db._indexes)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_sqlite_partial_expanded_for_test(test, replica_uid):
+ db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ db._set_replica_uid(replica_uid)
+ return db
+
+
+def copy_sqlite_partial_expanded_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ new_db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ tmpfile = StringIO()
+ for line in db._db_handle.iterdump():
+ if not 'sqlite_sequence' in line: # work around bug in iterdump
+ tmpfile.write('%s\n' % line)
+ tmpfile.seek(0)
+ new_db._db_handle = dbapi2.connect(':memory:')
+ new_db._db_handle.cursor().executescript(tmpfile.read())
+ new_db._db_handle.commit()
+ new_db._set_replica_uid(db._replica_uid)
+ new_db._factory = db._factory
+ return new_db
+
+
+def make_document_for_test(test, doc_id, rev, content, has_conflicts=False):
+ return Document(doc_id, rev, content, has_conflicts=has_conflicts)
+
+
+LOCAL_DATABASES_SCENARIOS = [
+ ('mem', {'make_database_for_test': make_memory_database_for_test,
+ 'copy_database_for_test': copy_memory_database_for_test,
+ 'make_document_for_test': make_document_for_test}),
+ ('sql', {'make_database_for_test':
+ make_sqlite_partial_expanded_for_test,
+ 'copy_database_for_test':
+ copy_sqlite_partial_expanded_for_test,
+ 'make_document_for_test': make_document_for_test}),
+]
+
+
+class DatabaseBaseTests(TestCase):
+
+ accept_fixed_trans_id = False # set to True assertTransactionLog
+ # is happy with all trans ids = ''
+
+ scenarios = LOCAL_DATABASES_SCENARIOS
+
+ def create_database(self, replica_uid):
+ return self.make_database_for_test(self, replica_uid)
+
+ def copy_database(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ return self.copy_database_for_test(self, db)
+
+ def setUp(self):
+ super(DatabaseBaseTests, self).setUp()
+ self.db = self.create_database('test')
+
+ def tearDown(self):
+ # TODO: Add close_database parameterization
+ # self.close_database(self.db)
+ super(DatabaseBaseTests, self).tearDown()
+
+ def assertTransactionLog(self, doc_ids, db):
+ """Assert that the given docs are in the transaction log."""
+ log = db._get_transaction_log()
+ just_ids = []
+ seen_transactions = set()
+ for doc_id, transaction_id in log:
+ just_ids.append(doc_id)
+ self.assertIsNot(None, transaction_id,
+ "Transaction id should not be None")
+ if transaction_id == '' and self.accept_fixed_trans_id:
+ continue
+ self.assertNotEqual('', transaction_id,
+ "Transaction id should be a unique string")
+ self.assertTrue(transaction_id.startswith('T-'))
+ self.assertNotIn(transaction_id, seen_transactions)
+ seen_transactions.add(transaction_id)
+ self.assertEqual(doc_ids, just_ids)
+
+ def getLastTransId(self, db):
+ """Return the transaction id for the last database update."""
+ return self.db._get_transaction_log()[-1][-1]
+
+
+class ServerStateForTests(server_state.ServerState):
+ """Used in the test suite, so we don't have to touch disk, etc."""
+
+ def __init__(self):
+ super(ServerStateForTests, self).__init__()
+ self._dbs = {}
+
+ def open_database(self, path):
+ try:
+ return self._dbs[path]
+ except KeyError:
+ raise errors.DatabaseDoesNotExist
+
+ def check_database(self, path):
+ # cares only about the possible exception
+ self.open_database(path)
+
+ def ensure_database(self, path):
+ try:
+ db = self.open_database(path)
+ except errors.DatabaseDoesNotExist:
+ db = self._create_database(path)
+ return db, db._replica_uid
+
+ def _copy_database(self, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ new_db = copy_memory_database_for_test(None, db)
+ path = db._replica_uid
+ while path in self._dbs:
+ path += 'copy'
+ self._dbs[path] = new_db
+ return new_db
+
+ def _create_database(self, path):
+ db = inmemory.InMemoryDatabase(path)
+ self._dbs[path] = db
+ return db
+
+ def delete_database(self, path):
+ del self._dbs[path]
+
+
+class ResponderForTests(object):
+ """Responder for tests."""
+ _started = False
+ sent_response = False
+ status = None
+
+ def start_response(self, status='success', **kwargs):
+ self._started = True
+ self.status = status
+ self.kwargs = kwargs
+
+ def send_response(self, status='success', **kwargs):
+ self.start_response(status, **kwargs)
+ self.finish_response()
+
+ def finish_response(self):
+ self.sent_response = True
+
+
+class TestCaseWithServer(TestCase):
+
+ @staticmethod
+ def server_def():
+ # hook point
+ # should return (ServerClass, "shutdown method name", "url_scheme")
+ class _RequestHandler(simple_server.WSGIRequestHandler):
+ def log_request(*args):
+ pass # suppress
+
+ def make_server(host_port, application):
+ assert application, "forgot to override make_app(_with_state)?"
+ srv = simple_server.WSGIServer(host_port, _RequestHandler)
+ # patch the value in if it's None
+ if getattr(application, 'base_url', 1) is None:
+ application.base_url = "http://%s:%s" % srv.server_address
+ srv.set_app(application)
+ return srv
+
+ return make_server, "shutdown", "http"
+
+ @staticmethod
+ def make_app_with_state(state):
+ # hook point
+ return None
+
+ def make_app(self):
+ # potential hook point
+ self.request_state = ServerStateForTests()
+ return self.make_app_with_state(self.request_state)
+
+ def setUp(self):
+ super(TestCaseWithServer, self).setUp()
+ self.server = self.server_thread = None
+
+ @property
+ def url_scheme(self):
+ return self.server_def()[-1]
+
+ def startServer(self):
+ server_def = self.server_def()
+ server_class, shutdown_meth, _ = server_def
+ application = self.make_app()
+ self.server = server_class(('127.0.0.1', 0), application)
+ self.server_thread = threading.Thread(target=self.server.serve_forever,
+ kwargs=dict(poll_interval=0.01))
+ self.server_thread.start()
+ self.addCleanup(self.server_thread.join)
+ self.addCleanup(getattr(self.server, shutdown_meth))
+
+ def getURL(self, path=None):
+ host, port = self.server.server_address
+ if path is None:
+ path = ''
+ return '%s://%s:%s/%s' % (self.url_scheme, host, port, path)
+
+
+def socket_pair():
+ """Return a pair of TCP sockets connected to each other.
+
+ Unlike socket.socketpair, this should work on Windows.
+ """
+ sock_pair = getattr(socket, 'socket_pair', None)
+ if sock_pair:
+ return sock_pair(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(1)
+ client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client_sock.connect(listen_sock.getsockname())
+ server_sock, addr = listen_sock.accept()
+ listen_sock.close()
+ return server_sock, client_sock
+
+
+# OAuth related testing
+
+consumer1 = oauth.OAuthConsumer('K1', 'S1')
+token1 = oauth.OAuthToken('kkkk1', 'XYZ')
+consumer2 = oauth.OAuthConsumer('K2', 'S2')
+token2 = oauth.OAuthToken('kkkk2', 'ZYX')
+token3 = oauth.OAuthToken('kkkk3', 'ZYX')
+
+
+class TestingOAuthDataStore(oauth.OAuthDataStore):
+ """In memory predefined OAuthDataStore for testing."""
+
+ consumers = {
+ consumer1.key: consumer1,
+ consumer2.key: consumer2,
+ }
+
+ tokens = {
+ token1.key: token1,
+ token2.key: token2
+ }
+
+ def lookup_consumer(self, key):
+ return self.consumers.get(key)
+
+ def lookup_token(self, token_type, token_token):
+ return self.tokens.get(token_token)
+
+ def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
+ return None
+
+testingOAuthStore = TestingOAuthDataStore()
+
+sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
+sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
+
+
+def load_with_scenarios(loader, standard_tests, pattern):
+ """Load the tests in a given module.
+
+ This just applies testscenarios.generate_scenarios to all the tests that
+ are present. We do it at load time rather than at run time, because it
+ plays nicer with various tools.
+ """
+ suite = loader.suiteClass()
+ suite.addTests(testscenarios.generate_scenarios(standard_tests))
+ return suite
diff --git a/src/leap/soledad/tests/u1db_tests/test_backends.py b/src/leap/soledad/tests/u1db_tests/test_backends.py
new file mode 100644
index 00000000..a53b01ba
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_backends.py
@@ -0,0 +1,1907 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The backend class for U1DB. This deals with hiding storage details."""
+
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+from u1db import (
+ DocumentBase,
+ errors,
+ vectorclock,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+from leap.soledad.tests.u1db_tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+)
+
+from u1db.remote import (
+ http_database,
+)
+
+
+def make_http_database_for_test(test, replica_uid, path='test'):
+ test.startServer()
+ test.request_state._create_database(replica_uid)
+ return http_database.HTTPDatabase(test.getURL(path))
+
+
+def copy_http_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ return test.request_state._copy_database(db)
+
+
+def make_oauth_http_database_for_test(test, replica_uid):
+ http_db = make_http_database_for_test(test, replica_uid, '~/test')
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+def copy_oauth_http_database_for_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR
+ # HOUSE.
+ http_db = test.request_state._copy_database(db)
+ http_db.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return http_db
+
+
+class TestAlternativeDocument(DocumentBase):
+ """A (not very) alternative implementation of Document."""
+
+
+class AllDatabaseTests(tests.DatabaseBaseTests, tests.TestCaseWithServer):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS + [
+ ('http', {'make_database_for_test': make_http_database_for_test,
+ 'copy_database_for_test': copy_http_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_http_app}),
+ ('oauth_http', {'make_database_for_test':
+ make_oauth_http_database_for_test,
+ 'copy_database_for_test':
+ copy_oauth_http_database_for_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_oauth_http_app})
+ ]
+
+ def test_close(self):
+ self.db.close()
+
+ def test_create_doc_allocating_doc_id(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertNotEqual(None, doc.doc_id)
+ self.assertNotEqual(None, doc.rev)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_create_doc_different_ids_same_db(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertNotEqual(doc1.doc_id, doc2.doc_id)
+
+ def test_create_doc_with_id(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my-id')
+ self.assertEqual('my-id', doc.doc_id)
+ self.assertNotEqual(None, doc.rev)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_create_doc_existing_id(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ new_content = '{"something": "else"}'
+ self.assertRaises(
+ errors.RevisionConflict, self.db.create_doc_from_json,
+ new_content, doc.doc_id)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+
+ def test_put_doc_creating_initial(self):
+ doc = self.make_document('my_doc_id', None, simple_doc)
+ new_rev = self.db.put_doc(doc)
+ self.assertIsNot(None, new_rev)
+ self.assertGetDoc(self.db, 'my_doc_id', new_rev, simple_doc, False)
+
+ def test_put_doc_space_in_id(self):
+ doc = self.make_document('my doc id', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_update(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ orig_rev = doc.rev
+ doc.set_json('{"updated": "stuff"}')
+ new_rev = self.db.put_doc(doc)
+ self.assertNotEqual(new_rev, orig_rev)
+ self.assertGetDoc(self.db, 'my_doc_id', new_rev,
+ '{"updated": "stuff"}', False)
+ self.assertEqual(doc.rev, new_rev)
+
+ def test_put_non_ascii_key(self):
+ content = json.dumps({u'key\xe5': u'val'})
+ doc = self.db.create_doc_from_json(content, doc_id='my_doc')
+ self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False)
+
+ def test_put_non_ascii_value(self):
+ content = json.dumps({'key': u'\xe5'})
+ doc = self.db.create_doc_from_json(content, doc_id='my_doc')
+ self.assertGetDoc(self.db, 'my_doc', doc.rev, content, False)
+
+ def test_put_doc_refuses_no_id(self):
+ doc = self.make_document(None, None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+ doc = self.make_document("", None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_refuses_slashes(self):
+ doc = self.make_document('a/b', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+ doc = self.make_document(r'\b', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_doc_url_quoting_is_fine(self):
+ doc_id = "%2F%2Ffoo%2Fbar"
+ doc = self.make_document(doc_id, None, simple_doc)
+ new_rev = self.db.put_doc(doc)
+ self.assertGetDoc(self.db, doc_id, new_rev, simple_doc, False)
+
+ def test_put_doc_refuses_non_existing_old_rev(self):
+ doc = self.make_document('doc-id', 'test:4', simple_doc)
+ self.assertRaises(errors.RevisionConflict, self.db.put_doc, doc)
+
+ def test_put_doc_refuses_non_ascii_doc_id(self):
+ doc = self.make_document('d\xc3\xa5c-id', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_put_fails_with_bad_old_rev(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ old_rev = doc.rev
+ bad_doc = self.make_document(doc.doc_id, 'other:1',
+ '{"something": "else"}')
+ self.assertRaises(errors.RevisionConflict, self.db.put_doc, bad_doc)
+ self.assertGetDoc(self.db, 'my_doc_id', old_rev, simple_doc, False)
+
+ def test_create_succeeds_after_delete(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True)
+ deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev)
+ new_doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.assertGetDoc(self.db, 'my_doc_id', new_doc.rev, simple_doc, False)
+ new_vc = vectorclock.VectorClockRev(new_doc.rev)
+ self.assertTrue(
+ new_vc.is_newer(deleted_vc),
+ "%s does not supersede %s" % (new_doc.rev, deleted_doc.rev))
+
+ def test_put_succeeds_after_delete(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ deleted_doc = self.db.get_doc('my_doc_id', include_deleted=True)
+ deleted_vc = vectorclock.VectorClockRev(deleted_doc.rev)
+ doc2 = self.make_document('my_doc_id', None, simple_doc)
+ self.db.put_doc(doc2)
+ self.assertGetDoc(self.db, 'my_doc_id', doc2.rev, simple_doc, False)
+ new_vc = vectorclock.VectorClockRev(doc2.rev)
+ self.assertTrue(
+ new_vc.is_newer(deleted_vc),
+ "%s does not supersede %s" % (doc2.rev, deleted_doc.rev))
+
+ def test_get_doc_after_put(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.assertGetDoc(self.db, 'my_doc_id', doc.rev, simple_doc, False)
+
+ def test_get_doc_nonexisting(self):
+ self.assertIs(None, self.db.get_doc('non-existing'))
+
+ def test_get_doc_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ self.assertIs(None, self.db.get_doc('my_doc_id'))
+
+ def test_get_doc_include_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc, doc_id='my_doc_id')
+ self.db.delete_doc(doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+
+ def test_get_docs(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual([doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+
+ def test_get_docs_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc1)
+ self.assertEqual([doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+
+ def test_get_docs_include_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc1)
+ self.assertEqual(
+ [doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id],
+ include_deleted=True)))
+
+ def test_get_docs_request_ordered(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual([doc1, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id])))
+ self.assertEqual([doc2, doc1],
+ list(self.db.get_docs([doc2.doc_id, doc1.doc_id])))
+
+ def test_get_docs_empty_list(self):
+ self.assertEqual([], list(self.db.get_docs([])))
+
+ def test_handles_nested_content(self):
+ doc = self.db.create_doc_from_json(nested_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False)
+
+ def test_handles_doc_with_null(self):
+ doc = self.db.create_doc_from_json('{"key": null}')
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, '{"key": null}', False)
+
+ def test_delete_doc(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ orig_rev = doc.rev
+ self.db.delete_doc(doc)
+ self.assertNotEqual(orig_rev, doc.rev)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+ self.assertIs(None, self.db.get_doc(doc.doc_id))
+
+ def test_delete_doc_non_existent(self):
+ doc = self.make_document('non-existing', 'other:1', simple_doc)
+ self.assertRaises(errors.DocumentDoesNotExist, self.db.delete_doc, doc)
+
+ def test_delete_doc_already_deleted(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertRaises(errors.DocumentAlreadyDeleted,
+ self.db.delete_doc, doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+
+ def test_delete_doc_bad_rev(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+ doc2 = self.make_document(doc1.doc_id, 'other:1', simple_doc)
+ self.assertRaises(errors.RevisionConflict, self.db.delete_doc, doc2)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+
+ def test_delete_doc_sets_content_to_None(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertIs(None, doc.get_json())
+
+ def test_delete_doc_rev_supersedes(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc.set_json(nested_doc)
+ self.db.put_doc(doc)
+ doc.set_json('{"fishy": "content"}')
+ self.db.put_doc(doc)
+ old_rev = doc.rev
+ self.db.delete_doc(doc)
+ cur_vc = vectorclock.VectorClockRev(old_rev)
+ deleted_vc = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(deleted_vc.is_newer(cur_vc),
+ "%s does not supersede %s" % (doc.rev, old_rev))
+
+ def test_delete_then_put(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, doc.rev, None, False)
+ doc.set_json(nested_doc)
+ self.db.put_doc(doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, nested_doc, False)
+
+
+class DocumentSizeTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS
+
+ def test_put_doc_refuses_oversized_documents(self):
+ self.db.set_document_size_limit(1)
+ doc = self.make_document('doc-id', None, simple_doc)
+ self.assertRaises(errors.DocumentTooBig, self.db.put_doc, doc)
+
+ def test_create_doc_refuses_oversized_documents(self):
+ self.db.set_document_size_limit(1)
+ self.assertRaises(
+ errors.DocumentTooBig, self.db.create_doc_from_json, simple_doc,
+ doc_id='my_doc_id')
+
+ def test_set_document_size_limit_zero(self):
+ self.db.set_document_size_limit(0)
+ self.assertEqual(0, self.db.document_size_limit)
+
+ def test_set_document_size_limit(self):
+ self.db.set_document_size_limit(1000000)
+ self.assertEqual(1000000, self.db.document_size_limit)
+
+
+class LocalDatabaseTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS
+
+ def test_create_doc_different_ids_diff_db(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ db2 = self.create_database('other-uid')
+ doc2 = db2.create_doc_from_json(simple_doc)
+ self.assertNotEqual(doc1.doc_id, doc2.doc_id)
+
+ def test_put_doc_refuses_slashes_picky(self):
+ doc = self.make_document('/a', None, simple_doc)
+ self.assertRaises(errors.InvalidDocId, self.db.put_doc, doc)
+
+ def test_get_all_docs_empty(self):
+ self.assertEqual([], list(self.db.get_all_docs()[1]))
+
+ def test_get_all_docs(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(
+ sorted([doc1, doc2]), sorted(list(self.db.get_all_docs()[1])))
+
+ def test_get_all_docs_exclude_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc2)
+ self.assertEqual([doc1], list(self.db.get_all_docs()[1]))
+
+ def test_get_all_docs_include_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.db.delete_doc(doc2)
+ self.assertEqual(
+ sorted([doc1, doc2]),
+ sorted(list(self.db.get_all_docs(include_deleted=True)[1])))
+
+ def test_get_all_docs_generation(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(2, self.db.get_all_docs()[0])
+
+ def test_simple_put_doc_if_newer(self):
+ doc = self.make_document('my-doc-id', 'test:1', simple_doc)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('inserted', 1), state_at_gen)
+ self.assertGetDoc(self.db, 'my-doc-id', 'test:1', simple_doc, False)
+
+ def test_simple_put_doc_if_newer_deleted(self):
+ self.db.create_doc_from_json('{}', doc_id='my-doc-id')
+ doc = self.make_document('my-doc-id', 'test:2', None)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('inserted', 2), state_at_gen)
+ self.assertGetDocIncludeDeleted(
+ self.db, 'my-doc-id', 'test:2', None, False)
+
+ def test_put_doc_if_newer_already_superseded(self):
+ orig_doc = '{"new": "doc"}'
+ doc1 = self.db.create_doc_from_json(orig_doc)
+ doc1_rev1 = doc1.rev
+ doc1.set_json(simple_doc)
+ self.db.put_doc(doc1)
+ doc1_rev2 = doc1.rev
+ # Nothing is inserted, because the document is already superseded
+ doc = self.make_document(doc1.doc_id, doc1_rev1, orig_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('superseded', state)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1_rev2, simple_doc, False)
+
+ def test_put_doc_if_newer_autoresolve(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ rev = doc1.rev
+ doc = self.make_document(doc1.doc_id, "whatever:1", doc1.get_json())
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('superseded', state)
+ doc2 = self.db.get_doc(doc1.doc_id)
+ v2 = vectorclock.VectorClockRev(doc2.rev)
+ self.assertTrue(v2.is_newer(vectorclock.VectorClockRev("whatever:1")))
+ self.assertTrue(v2.is_newer(vectorclock.VectorClockRev(rev)))
+ # strictly newer locally
+ self.assertTrue(rev not in doc2.rev)
+
+ def test_put_doc_if_newer_already_converged(self):
+ orig_doc = '{"new": "doc"}'
+ doc1 = self.db.create_doc_from_json(orig_doc)
+ state_at_gen = self.db._put_doc_if_newer(
+ doc1, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(('converged', 1), state_at_gen)
+
+ def test_put_doc_if_newer_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ # Nothing is inserted, the document id is returned as would-conflict
+ alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ state, _ = self.db._put_doc_if_newer(
+ alt_doc, save_conflict=False, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('conflicted', state)
+ # The database wasn't altered
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+
+ def test_put_doc_if_newer_newer_generation(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.make_document('doc_id', 'other:2', simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='other', replica_gen=2,
+ replica_trans_id='T-irrelevant')
+ self.assertEqual('inserted', state)
+
+ def test_put_doc_if_newer_same_generation_same_txid(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.make_document(doc.doc_id, 'other:1', simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc, save_conflict=False, replica_uid='other', replica_gen=1,
+ replica_trans_id='T-sid')
+ self.assertEqual('converged', state)
+
+ def test_put_doc_if_newer_wrong_transaction_id(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc = self.make_document('doc_id', 'other:1', simple_doc)
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db._put_doc_if_newer, doc, save_conflict=False,
+ replica_uid='other', replica_gen=1, replica_trans_id='T-sad')
+
+ def test_put_doc_if_newer_old_generation_older_doc(self):
+ orig_doc = '{"new": "doc"}'
+ doc = self.db.create_doc_from_json(orig_doc)
+ doc_rev1 = doc.rev
+ doc.set_json(simple_doc)
+ self.db.put_doc(doc)
+ self.db._set_replica_gen_and_trans_id('other', 3, 'T-sid')
+ older_doc = self.make_document(doc.doc_id, doc_rev1, simple_doc)
+ state, _ = self.db._put_doc_if_newer(
+ older_doc, save_conflict=False, replica_uid='other', replica_gen=8,
+ replica_trans_id='T-irrelevant')
+ self.assertEqual('superseded', state)
+
+ def test_put_doc_if_newer_old_generation_newer_doc(self):
+ self.db._set_replica_gen_and_trans_id('other', 5, 'T-sid')
+ doc = self.make_document('doc_id', 'other:1', simple_doc)
+ self.assertRaises(
+ errors.InvalidGeneration,
+ self.db._put_doc_if_newer, doc, save_conflict=False,
+ replica_uid='other', replica_gen=1, replica_trans_id='T-sad')
+
+ def test_put_doc_if_newer_replica_uid(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1',
+ nested_doc)
+ self.assertEqual('inserted',
+ self.db._put_doc_if_newer(
+ doc2,
+ save_conflict=False,
+ replica_uid='other',
+ replica_gen=2,
+ replica_trans_id='T-id2')[0])
+ self.assertEqual((2, 'T-id2'), self.db._get_replica_gen_and_trans_id(
+ 'other'))
+ # Compare to the old rev, should be superseded
+ doc2 = self.make_document(doc1.doc_id, doc1.rev, nested_doc)
+ self.assertEqual('superseded',
+ self.db._put_doc_if_newer(
+ doc2,
+ save_conflict=False,
+ replica_uid='other',
+ replica_gen=3,
+ replica_trans_id='T-id3')[0])
+ self.assertEqual(
+ (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other'))
+ # A conflict that isn't saved still records the sync gen, because we
+ # don't need to see it again
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|fourth:1',
+ '{}')
+ self.assertEqual('conflicted',
+ self.db._put_doc_if_newer(
+ doc2,
+ save_conflict=False,
+ replica_uid='other',
+ replica_gen=4,
+ replica_trans_id='T-id4')[0])
+ self.assertEqual(
+ (4, 'T-id4'), self.db._get_replica_gen_and_trans_id('other'))
+
+ def test__get_replica_gen_and_trans_id(self):
+ self.assertEqual(
+ (0, ''), self.db._get_replica_gen_and_trans_id('other-db'))
+ self.db._set_replica_gen_and_trans_id('other-db', 2, 'T-transaction')
+ self.assertEqual(
+ (2, 'T-transaction'),
+ self.db._get_replica_gen_and_trans_id('other-db'))
+
+ def test_put_updates_transaction_log(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ doc.set_json('{"something": "else"}')
+ self.db.put_doc(doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed())
+
+ def test_delete_updates_transaction_log(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ db_gen, _, _ = self.db.whats_changed()
+ self.db.delete_doc(doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed(db_gen))
+
+ def test_whats_changed_initial_database(self):
+ self.assertEqual((0, '', []), self.db.whats_changed())
+
+ def test_whats_changed_returns_one_id_for_multiple_changes(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc.set_json('{"new": "contents"}')
+ self.db.put_doc(doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((2, last_trans_id, [(doc.doc_id, 2, last_trans_id)]),
+ self.db.whats_changed())
+ self.assertEqual((2, last_trans_id, []), self.db.whats_changed(2))
+
+ def test_whats_changed_returns_last_edits_ascending(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc.set_json('{"new": "contents"}')
+ self.db.delete_doc(doc1)
+ delete_trans_id = self.getLastTransId(self.db)
+ self.db.put_doc(doc)
+ put_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((4, put_trans_id,
+ [(doc1.doc_id, 3, delete_trans_id),
+ (doc.doc_id, 4, put_trans_id)]),
+ self.db.whats_changed())
+
+ def test_whats_changed_doesnt_include_old_gen(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual((3, last_trans_id, [(doc2.doc_id, 3, last_trans_id)]),
+ self.db.whats_changed(2))
+
+
+class LocalDatabaseValidateGenNTransIdTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS
+
+ def test_validate_gen_and_trans_id(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, trans_id = self.db._get_generation_info()
+ self.db.validate_gen_and_trans_id(gen, trans_id)
+
+ def test_validate_gen_and_trans_id_invalid_txid(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, _ = self.db._get_generation_info()
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db.validate_gen_and_trans_id, gen, 'wrong')
+
+ def test_validate_gen_and_trans_id_invalid_gen(self):
+ self.db.create_doc_from_json(simple_doc)
+ gen, trans_id = self.db._get_generation_info()
+ self.assertRaises(
+ errors.InvalidGeneration,
+ self.db.validate_gen_and_trans_id, gen + 1, trans_id)
+
+
+class LocalDatabaseValidateSourceGenTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS
+
+ def test_validate_source_gen_and_trans_id_same(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.db._validate_source('other', 1, 'T-sid')
+
+ def test_validate_source_gen_newer(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.db._validate_source('other', 2, 'T-whatevs')
+
+ def test_validate_source_wrong_txid(self):
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-sid')
+ self.assertRaises(
+ errors.InvalidTransactionId,
+ self.db._validate_source, 'other', 1, 'T-sad')
+
+
+class LocalDatabaseWithConflictsTests(tests.DatabaseBaseTests):
+ # test supporting/functionality around storing conflicts
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS
+
+ def test_get_docs_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual([doc2], list(self.db.get_docs([doc1.doc_id])))
+
+ def test_get_docs_conflicts_ignored(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ alt_doc = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ no_conflict_doc = self.make_document(doc1.doc_id, 'alternate:1',
+ nested_doc)
+ self.assertEqual([no_conflict_doc, doc2],
+ list(self.db.get_docs([doc1.doc_id, doc2.doc_id],
+ check_for_conflicts=False)))
+
+ def test_get_doc_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual([alt_doc, doc],
+ self.db.get_doc_conflicts(doc.doc_id))
+
+ def test_get_all_docs_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ _, docs = self.db.get_all_docs()
+ self.assertTrue(list(docs)[0].has_conflicts)
+
+ def test_get_doc_conflicts_unconflicted(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([], self.db.get_doc_conflicts(doc.doc_id))
+
+ def test_get_doc_conflicts_no_such_id(self):
+ self.assertEqual([], self.db.get_doc_conflicts('doc-id'))
+
+ def test_resolve_doc(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ alt_doc = self.make_document(doc.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc.doc_id,
+ [('alternate:1', nested_doc),
+ (doc.rev, simple_doc)])
+ orig_rev = doc.rev
+ self.db.resolve_doc(doc, [alt_doc.rev, doc.rev])
+ self.assertNotEqual(orig_rev, doc.rev)
+ self.assertFalse(doc.has_conflicts)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertGetDocConflicts(self.db, doc.doc_id, [])
+
+ def test_resolve_doc_picks_biggest_vcr(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, simple_doc)])
+ orig_doc1_rev = doc1.rev
+ self.db.resolve_doc(doc1, [doc2.rev, doc1.rev])
+ self.assertFalse(doc1.has_conflicts)
+ self.assertNotEqual(orig_doc1_rev, doc1.rev)
+ self.assertGetDoc(self.db, doc1.doc_id, doc1.rev, simple_doc, False)
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ vcr_1 = vectorclock.VectorClockRev(orig_doc1_rev)
+ vcr_2 = vectorclock.VectorClockRev(doc2.rev)
+ vcr_new = vectorclock.VectorClockRev(doc1.rev)
+ self.assertTrue(vcr_new.is_newer(vcr_1))
+ self.assertTrue(vcr_new.is_newer(vcr_2))
+
+ def test_resolve_doc_partial_not_winning(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, simple_doc)])
+ content3 = '{"key": "valin3"}'
+ doc3 = self.make_document(doc1.doc_id, 'third:1', content3)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+ self.db.resolve_doc(doc1, [doc2.rev, doc1.rev])
+ self.assertTrue(doc1.has_conflicts)
+ self.assertGetDoc(self.db, doc1.doc_id, doc3.rev, content3, True)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc)])
+
+ def test_resolve_doc_partial_winning(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ content3 = '{"key": "valin3"}'
+ doc3 = self.make_document(doc1.doc_id, 'third:1', content3)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc3.rev, content3),
+ (doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+ self.db.resolve_doc(doc1, [doc3.rev, doc1.rev])
+ self.assertTrue(doc1.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc1.rev, simple_doc),
+ (doc2.rev, nested_doc)])
+
+ def test_resolve_doc_with_delete_conflict(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, None)])
+ self.db.resolve_doc(doc2, [doc1.rev, doc2.rev])
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, False)
+
+ def test_resolve_doc_with_delete_to_delete(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [(doc2.rev, nested_doc),
+ (doc1.rev, None)])
+ self.db.resolve_doc(doc1, [doc1.rev, doc2.rev])
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ self.assertGetDocIncludeDeleted(
+ self.db, doc1.doc_id, doc1.rev, None, False)
+
+ def test_put_doc_if_newer_save_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ # Document is inserted as a conflict
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ state, _ = self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual('conflicted', state)
+ # The database was updated
+ self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, nested_doc, True)
+
+ def test_force_doc_conflict_supersedes_properly(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', '{"b": 1}')
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ doc3 = self.make_document(doc1.doc_id, 'altalt:1', '{"c": 1}')
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='bar')
+ doc22 = self.make_document(doc1.doc_id, 'alternate:2', '{"b": 2}')
+ self.db._put_doc_if_newer(
+ doc22, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='zed')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:2', doc22.get_json()),
+ ('altalt:1', doc3.get_json()),
+ (doc1.rev, simple_doc)])
+
+ def test_put_doc_if_newer_save_conflict_was_deleted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc1)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertTrue(doc2.has_conflicts)
+ self.assertGetDoc(
+ self.db, doc1.doc_id, 'alternate:1', nested_doc, True)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:1', nested_doc),
+ (doc1.rev, None)])
+
+ def test_put_doc_if_newer_propagates_full_resolution(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ resolved_vcr = vectorclock.VectorClockRev(doc1.rev)
+ vcr_2 = vectorclock.VectorClockRev(doc2.rev)
+ resolved_vcr.maximize(vcr_2)
+ resolved_vcr.increment('alternate')
+ doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(),
+ '{"good": 1}')
+ state, _ = self.db._put_doc_if_newer(
+ doc_resolved, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual('inserted', state)
+ self.assertFalse(doc_resolved.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id, [])
+ doc3 = self.db.get_doc(doc1.doc_id)
+ self.assertFalse(doc3.has_conflicts)
+
+ def test_put_doc_if_newer_propagates_partial_resolution(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'altalt:1', '{}')
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ doc3 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc3, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:1', nested_doc),
+ ('test:1', simple_doc),
+ ('altalt:1', '{}')])
+ resolved_vcr = vectorclock.VectorClockRev(doc1.rev)
+ vcr_3 = vectorclock.VectorClockRev(doc3.rev)
+ resolved_vcr.maximize(vcr_3)
+ resolved_vcr.increment('alternate')
+ doc_resolved = self.make_document(doc1.doc_id, resolved_vcr.as_str(),
+ '{"good": 1}')
+ state, _ = self.db._put_doc_if_newer(
+ doc_resolved, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual('inserted', state)
+ self.assertTrue(doc_resolved.has_conflicts)
+ doc4 = self.db.get_doc(doc1.doc_id)
+ self.assertTrue(doc4.has_conflicts)
+ self.assertGetDocConflicts(self.db, doc1.doc_id,
+ [('alternate:2|test:1', '{"good": 1}'),
+ ('altalt:1', '{}')])
+
+ def test_put_doc_if_newer_replica_uid(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db._set_replica_gen_and_trans_id('other', 1, 'T-id')
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|other:1',
+ nested_doc)
+ self.db._put_doc_if_newer(doc2, save_conflict=True,
+ replica_uid='other', replica_gen=2,
+ replica_trans_id='T-id2')
+ # Conflict vs the current update
+ doc2 = self.make_document(doc1.doc_id, doc1.rev + '|third:3',
+ '{}')
+ self.assertEqual('conflicted',
+ self.db._put_doc_if_newer(
+ doc2,
+ save_conflict=True,
+ replica_uid='other',
+ replica_gen=3,
+ replica_trans_id='T-id3')[0])
+ self.assertEqual(
+ (3, 'T-id3'), self.db._get_replica_gen_and_trans_id('other'))
+
+ def test_put_doc_if_newer_autoresolve_2(self):
+ # this is an ordering variant of _3, but that already works
+ # adding the test explicitly to catch the regression easily
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', "{}")
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1',
+ '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:2|other:1', "{}")
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'inserted')
+ self.assertFalse(self.db.get_doc(doc_a1.doc_id).has_conflicts)
+
+ def test_put_doc_if_newer_autoresolve_3(self):
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', "{}")
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', "{}")
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'superseded')
+ doc = self.db.get_doc(doc_a1.doc_id, True)
+ self.assertFalse(doc.has_conflicts)
+ rev = vectorclock.VectorClockRev(doc.rev)
+ rev_a3 = vectorclock.VectorClockRev('test:3')
+ rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1')
+ self.assertTrue(rev.is_newer(rev_a3))
+ self.assertTrue('test:4' in doc.rev) # locally increased
+ self.assertTrue(rev.is_newer(rev_a1b1))
+
+ def test_put_doc_if_newer_autoresolve_4(self):
+ doc_a1 = self.db.create_doc_from_json(simple_doc)
+ doc_a1b1 = self.make_document(doc_a1.doc_id, 'test:1|other:1', None)
+ doc_a2 = self.make_document(doc_a1.doc_id, 'test:2', '{"a":"42"}')
+ doc_a3 = self.make_document(doc_a1.doc_id, 'test:3', None)
+ state, _ = self.db._put_doc_if_newer(
+ doc_a1b1, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertEqual(state, 'inserted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a2, save_conflict=True, replica_uid='r', replica_gen=2,
+ replica_trans_id='foo2')
+ self.assertEqual(state, 'conflicted')
+ state, _ = self.db._put_doc_if_newer(
+ doc_a3, save_conflict=True, replica_uid='r', replica_gen=3,
+ replica_trans_id='foo3')
+ self.assertEqual(state, 'superseded')
+ doc = self.db.get_doc(doc_a1.doc_id, True)
+ self.assertFalse(doc.has_conflicts)
+ rev = vectorclock.VectorClockRev(doc.rev)
+ rev_a3 = vectorclock.VectorClockRev('test:3')
+ rev_a1b1 = vectorclock.VectorClockRev('test:1|other:1')
+ self.assertTrue(rev.is_newer(rev_a3))
+ self.assertTrue('test:4' in doc.rev) # locally increased
+ self.assertTrue(rev.is_newer(rev_a1b1))
+
+ def test_put_refuses_to_update_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ content2 = '{"key": "altval"}'
+ doc2 = self.make_document(doc1.doc_id, 'altrev:1', content2)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDoc(self.db, doc1.doc_id, doc2.rev, content2, True)
+ content3 = '{"key": "local"}'
+ doc2.set_json(content3)
+ self.assertRaises(errors.ConflictedDoc, self.db.put_doc, doc2)
+
+ def test_delete_refuses_for_conflicted(self):
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'altrev:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, nested_doc, True)
+ self.assertRaises(errors.ConflictedDoc, self.db.delete_doc, doc2)
+
+
+class DatabaseIndexTests(tests.DatabaseBaseTests):
+
+ scenarios = tests.LOCAL_DATABASES_SCENARIOS
+
+ def assertParseError(self, definition):
+ self.db.create_doc_from_json(nested_doc)
+ self.assertRaises(
+ errors.IndexDefinitionParseError, self.db.create_index, 'idx',
+ definition)
+
+ def assertIndexCreatable(self, definition):
+ name = "idx"
+ self.db.create_doc_from_json(nested_doc)
+ self.db.create_index(name, definition)
+ self.assertEqual(
+ [(name, [definition])], self.db.list_indexes())
+
+ def test_create_index(self):
+ self.db.create_index('test-idx', 'name')
+ self.assertEqual([('test-idx', ['name'])],
+ self.db.list_indexes())
+
+ def test_create_index_on_non_ascii_field_name(self):
+ doc = self.db.create_doc_from_json(json.dumps({u'\xe5': 'value'}))
+ self.db.create_index('test-idx', u'\xe5')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_list_indexes_with_non_ascii_field_names(self):
+ self.db.create_index('test-idx', u'\xe5')
+ self.assertEqual(
+ [('test-idx', [u'\xe5'])], self.db.list_indexes())
+
+ def test_create_index_evaluates_it(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_wildcard_matches_unicode_value(self):
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_retrieve_unicode_value_from_index(self):
+ doc = self.db.create_doc_from_json(json.dumps({"key": u"valu\xe5"}))
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', u"valu\xe5"))
+
+ def test_create_index_fails_if_name_taken(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertRaises(errors.IndexNameTakenError,
+ self.db.create_index,
+ 'test-idx', 'stuff')
+
+ def test_create_index_does_not_fail_if_name_taken_with_same_index(self):
+ self.db.create_index('test-idx', 'key')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+
+ def test_create_index_does_not_duplicate_indexed_fields(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.db.delete_index('test-idx')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(1, len(self.db.get_from_index('test-idx', 'value')))
+
+ def test_delete_index_does_not_remove_fields_from_other_indexes(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.db.create_index('test-idx2', 'key')
+ self.db.delete_index('test-idx')
+ self.assertEqual(1, len(self.db.get_from_index('test-idx2', 'value')))
+
+ def test_create_index_after_deleting_document(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc2)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_delete_index(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([('test-idx', ['key'])], self.db.list_indexes())
+ self.db.delete_index('test-idx')
+ self.assertEqual([], self.db.list_indexes())
+
+ def test_create_adds_to_index(self):
+ self.db.create_index('test-idx', 'key')
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_unmatched(self):
+ self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_from_index('test-idx', 'novalue'))
+
+ def test_create_index_multiple_exact_matches(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'value')))
+
+ def test_get_from_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_multi(self):
+ content = '{"key": "value", "key2": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2'))
+
+ def test_get_from_index_multi_list(self):
+ doc = self.db.create_doc_from_json(
+ '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-1'))
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-2'))
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'value', 'value2-3'))
+ self.assertEqual(
+ [('value', 'value2-1'), ('value', 'value2-2'),
+ ('value', 'value2-3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_from_index_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key', 'key2')
+ alt_doc = self.make_document(
+ doc.doc_id, 'alternate:1',
+ '{"key": "value", "key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ docs = self.db.get_from_index('test-idx', 'value', 'value2-1')
+ self.assertTrue(docs[0].has_conflicts)
+
+ def test_get_index_keys_multi_list_list(self):
+ self.db.create_doc_from_json(
+ '{"key": "value1-1 value1-2 value1-3", '
+ '"key2": ["value2-1", "value2-2", "value2-3"]}')
+ self.db.create_index('test-idx', 'split_words(key)', 'key2')
+ self.assertEqual(
+ [(u'value1-1', u'value2-1'), (u'value1-1', u'value2-2'),
+ (u'value1-1', u'value2-3'), (u'value1-2', u'value2-1'),
+ (u'value1-2', u'value2-2'), (u'value1-2', u'value2-3'),
+ (u'value1-3', u'value2-1'), (u'value1-3', u'value2-2'),
+ (u'value1-3', u'value2-3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_from_index_multi_ordered(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2, doc1],
+ self.db.get_from_index('test-idx', 'v*', '*'))
+
+ def test_get_range_from_index_start_end(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ self.db.create_doc_from_json('{"key": "value4"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', 'value2', 'value3'))
+
+ def test_get_range_from_index_start(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value4"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc2, doc1, doc3],
+ self.db.get_range_from_index('test-idx', 'value2'))
+
+ def test_get_range_from_index_sees_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ alt_doc = self.make_document(
+ doc.doc_id, 'alternate:1', '{"key": "valuedepalue"}')
+ self.db._put_doc_if_newer(
+ alt_doc, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ docs = self.db.get_range_from_index('test-idx', 'a')
+ self.assertTrue(docs[0].has_conflicts)
+
+ def test_get_range_from_index_end(self):
+ self.db.create_doc_from_json('{"key": "value3"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value2"}')
+ self.db.create_doc_from_json('{"key": "value4"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc4, doc2],
+ self.db.get_range_from_index('test-idx', None, 'value2'))
+
+ def test_get_wildcard_range_from_index_start(self):
+ doc1 = self.db.create_doc_from_json('{"key": "value4"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value23"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value22"}')
+ self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc3, doc4, doc2, doc1],
+ self.db.get_range_from_index('test-idx', 'value2*'))
+
+ def test_get_wildcard_range_from_index_end(self):
+ self.db.create_doc_from_json('{"key": "value4"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value23"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value2"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value22"}')
+ doc5 = self.db.create_doc_from_json('{"key": "value1"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc5, doc3, doc4, doc2],
+ self.db.get_range_from_index('test-idx', None, 'value2*'))
+
+ def test_get_wildcard_range_from_index_start_end(self):
+ self.db.create_doc_from_json('{"key": "a"}')
+ self.db.create_doc_from_json('{"key": "boo3"}')
+ doc3 = self.db.create_doc_from_json('{"key": "catalyst"}')
+ doc4 = self.db.create_doc_from_json('{"key": "whaever"}')
+ self.db.create_doc_from_json('{"key": "zerg"}')
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ [doc3, doc4],
+ self.db.get_range_from_index('test-idx', 'cat*', 'zap*'))
+
+ def test_get_range_from_index_multi_column_start_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', ('value2', 'value2'), ('value2', 'value3')))
+
+ def test_get_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ self.db.create_doc_from_json('{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', 'value3')))
+
+ def test_get_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value3"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', None, ('value2', 'value3')))
+
+ def test_get_wildcard_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc3, doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', 'value2*')))
+
+ def test_get_wildcard_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index(
+ 'test-idx', None, ('value2', 'value2*')))
+
+ def test_get_glob_range_from_index_multi_column_start(self):
+ doc1 = self.db.create_doc_from_json(
+ '{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value2"}')
+ self.db.create_doc_from_json('{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc2, doc1],
+ self.db.get_range_from_index('test-idx', ('value2', '*')))
+
+ def test_get_glob_range_from_index_multi_column_end(self):
+ self.db.create_doc_from_json('{"key": "value3", "key2": "value4"}')
+ doc2 = self.db.create_doc_from_json(
+ '{"key": "value2", "key2": "value23"}')
+ doc3 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value2"}')
+ doc4 = self.db.create_doc_from_json(
+ '{"key": "value1", "key2": "value1"}')
+ self.db.create_index('test-idx', 'key', 'key2')
+ self.assertEqual(
+ [doc4, doc3, doc2],
+ self.db.get_range_from_index('test-idx', None, ('value2', '*')))
+
+ def test_get_range_from_index_illegal_wildcard_order(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', ('*', 'v2'))
+
+ def test_get_range_from_index_illegal_glob_after_wildcard(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', ('*', 'v*'))
+
+ def test_get_range_from_index_illegal_wildcard_order_end(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', None, ('*', 'v2'))
+
+ def test_get_range_from_index_illegal_glob_after_wildcard_end(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_range_from_index, 'test-idx', None, ('*', 'v*'))
+
+ def test_get_from_index_fails_if_no_index(self):
+ self.assertRaises(
+ errors.IndexDoesNotExist, self.db.get_from_index, 'foo')
+
+ def test_get_index_keys_fails_if_no_index(self):
+ self.assertRaises(errors.IndexDoesNotExist,
+ self.db.get_index_keys,
+ 'foo')
+
+ def test_get_index_keys_works_if_no_docs(self):
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_index_keys('test-idx'))
+
+ def test_put_updates_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ new_content = '{"key": "altval"}'
+ doc.set_json(new_content)
+ self.db.put_doc(doc)
+ self.assertEqual([], self.db.get_from_index('test-idx', 'value'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'altval'))
+
+ def test_delete_updates_index(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'value')))
+ self.db.delete_doc(doc)
+ self.assertEqual([doc2], self.db.get_from_index('test-idx', 'value'))
+
+ def test_get_from_index_illegal_number_of_entries(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidValueForIndex, self.db.get_from_index, 'test-idx')
+ self.assertRaises(
+ errors.InvalidValueForIndex,
+ self.db.get_from_index, 'test-idx', 'v1')
+ self.assertRaises(
+ errors.InvalidValueForIndex,
+ self.db.get_from_index, 'test-idx', 'v1', 'v2', 'v3')
+
+ def test_get_from_index_illegal_wildcard_order(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', '*', 'v2')
+
+ def test_get_from_index_illegal_glob_after_wildcard(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', '*', 'v*')
+
+ def test_get_all_from_index(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ # This one should not be in the index
+ self.db.create_doc_from_json('{"no": "key"}')
+ diff_value_doc = '{"key": "diff value"}'
+ doc4 = self.db.create_doc_from_json(diff_value_doc)
+ # This is essentially a 'prefix' match, but we match every entry.
+ self.assertEqual(
+ sorted([doc1, doc2, doc4]),
+ sorted(self.db.get_from_index('test-idx', '*')))
+
+ def test_get_all_from_index_ordered(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json('{"key": "value x"}')
+ doc2 = self.db.create_doc_from_json('{"key": "value b"}')
+ doc3 = self.db.create_doc_from_json('{"key": "value a"}')
+ doc4 = self.db.create_doc_from_json('{"key": "value m"}')
+ # This is essentially a 'prefix' match, but we match every entry.
+ self.assertEqual(
+ [doc3, doc2, doc4, doc1], self.db.get_from_index('test-idx', '*'))
+
+ def test_put_updates_when_adding_key(self):
+ doc = self.db.create_doc_from_json("{}")
+ self.db.create_index('test-idx', 'key')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+ doc.set_json(simple_doc)
+ self.db.put_doc(doc)
+ self.assertEqual([doc], self.db.get_from_index('test-idx', '*'))
+
+ def test_get_from_index_empty_string(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ content2 = '{"key": ""}'
+ doc2 = self.db.create_doc_from_json(content2)
+ self.assertEqual([doc2], self.db.get_from_index('test-idx', ''))
+ # Empty string matches the wildcard.
+ self.assertEqual(
+ sorted([doc1, doc2]),
+ sorted(self.db.get_from_index('test-idx', '*')))
+
+ def test_get_from_index_not_null(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.db.create_doc_from_json('{"key": null}')
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', '*'))
+
+ def test_get_partial_from_index(self):
+ content1 = '{"k1": "v1", "k2": "v2"}'
+ content2 = '{"k1": "v1", "k2": "x2"}'
+ content3 = '{"k1": "v1", "k2": "y2"}'
+ # doc4 has a different k1 value, so it doesn't match the prefix.
+ content4 = '{"k1": "NN", "k2": "v2"}'
+ doc1 = self.db.create_doc_from_json(content1)
+ doc2 = self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertEqual(
+ sorted([doc1, doc2, doc3]),
+ sorted(self.db.get_from_index('test-idx', "v1", "*")))
+
+ def test_get_glob_match(self):
+ # Note: the exact glob syntax is probably subject to change
+ content1 = '{"k1": "v1", "k2": "v1"}'
+ content2 = '{"k1": "v1", "k2": "v2"}'
+ content3 = '{"k1": "v1", "k2": "v3"}'
+ # doc4 has a different k2 prefix value, so it doesn't match
+ content4 = '{"k1": "v1", "k2": "ZZ"}'
+ self.db.create_index('test-idx', 'k1', 'k2')
+ doc1 = self.db.create_doc_from_json(content1)
+ doc2 = self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.assertEqual(
+ sorted([doc1, doc2, doc3]),
+ sorted(self.db.get_from_index('test-idx', "v1", "v*")))
+
+ def test_nested_index(self):
+ doc = self.db.create_doc_from_json(nested_doc)
+ self.db.create_index('test-idx', 'sub.doc')
+ self.assertEqual(
+ [doc], self.db.get_from_index('test-idx', 'underneath'))
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertEqual(
+ sorted([doc, doc2]),
+ sorted(self.db.get_from_index('test-idx', 'underneath')))
+
+ def test_nested_nonexistent(self):
+ self.db.create_doc_from_json(nested_doc)
+ # sub exists, but sub.foo does not:
+ self.db.create_index('test-idx', 'sub.foo')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+
+ def test_nested_nonexistent2(self):
+ self.db.create_doc_from_json(nested_doc)
+ self.db.create_index('test-idx', 'sub.foo.bar.baz.qux.fnord')
+ self.assertEqual([], self.db.get_from_index('test-idx', '*'))
+
+ def test_nested_traverses_lists(self):
+ # subpath finds dicts in list
+ doc = self.db.create_doc_from_json(
+ '{"foo": [{"zap": "bar"}, {"zap": "baz"}]}')
+ # subpath only finds dicts in list
+ self.db.create_doc_from_json('{"foo": ["zap", "baz"]}')
+ self.db.create_index('test-idx', 'foo.zap')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'bar'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'baz'))
+
+ def test_nested_list_traversal(self):
+ # subpath finds dicts in list
+ doc = self.db.create_doc_from_json(
+ '{"foo": [{"zap": [{"qux": "fnord"}, {"qux": "zombo"}]},'
+ '{"zap": "baz"}]}')
+ # subpath only finds dicts in list
+ self.db.create_index('test-idx', 'foo.zap.qux')
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'fnord'))
+ self.assertEqual([doc], self.db.get_from_index('test-idx', 'zombo'))
+
+ def test_index_list1(self):
+ self.db.create_index("index", "name")
+ content = '{"name": ["foo", "bar"]}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_list2(self):
+ self.db.create_index("index", "name")
+ content = '{"name": ["foo", "bar"]}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_case_sensitive(self):
+ self.db.create_index('test-idx', 'key')
+ doc1 = self.db.create_doc_from_json(simple_doc)
+ self.assertEqual([], self.db.get_from_index('test-idx', 'V*'))
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', 'v*'))
+
+ def test_get_from_index_illegal_glob_before_value(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', 'v*', 'v2')
+
+ def test_get_from_index_illegal_glob_after_glob(self):
+ self.db.create_index('test-idx', 'k1', 'k2')
+ self.assertRaises(
+ errors.InvalidGlobbing,
+ self.db.get_from_index, 'test-idx', 'v*', 'v*')
+
+ def test_get_from_index_with_sql_wildcards(self):
+ self.db.create_index('test-idx', 'key')
+ content1 = '{"key": "va%lue"}'
+ content2 = '{"key": "value"}'
+ content3 = '{"key": "va_lue"}'
+ doc1 = self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ doc3 = self.db.create_doc_from_json(content3)
+ # The '%' in the search should be treated literally, not as a sql
+ # globbing character.
+ self.assertEqual([doc1], self.db.get_from_index('test-idx', 'va%*'))
+ # Same for '_'
+ self.assertEqual([doc3], self.db.get_from_index('test-idx', 'va_*'))
+
+ def test_get_from_index_with_lower(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "Foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_lower_matches_same_case(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_lower_doesnt_match_different_case(self):
+ self.db.create_index("index", "lower(name)")
+ content = '{"name": "Foo"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "Foo")
+ self.assertEqual([], rows)
+
+ def test_index_lower_doesnt_match_other_index(self):
+ self.db.create_index("index", "lower(name)")
+ self.db.create_index("other_index", "name")
+ content = '{"name": "Foo"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "Foo")
+ self.assertEqual(0, len(rows))
+
+ def test_index_split_words_match_first(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_match_second(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_match_both(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo foo"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_double_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_leading_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": " foo bar"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "foo")
+ self.assertEqual([doc], rows)
+
+ def test_index_split_words_trailing_space(self):
+ self.db.create_index("index", "split_words(name)")
+ content = '{"name": "foo bar "}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "bar")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_number(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 12}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "00012")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_number_bigger_than_padding(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 123456}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "123456")
+ self.assertEqual([doc], rows)
+
+ def test_number_mapping_ignores_non_numbers(self):
+ self.db.create_index("index", "number(foo, 5)")
+ content = '{"foo": 56}'
+ doc1 = self.db.create_doc_from_json(content)
+ content = '{"foo": "this is not a maigret painting"}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "*")
+ self.assertEqual([doc1], rows)
+
+ def test_get_from_index_with_bool(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": true}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "1")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_bool_false(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": false}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "0")
+ self.assertEqual([doc], rows)
+
+ def test_get_from_index_with_non_bool(self):
+ self.db.create_index("index", "bool(foo)")
+ content = '{"foo": 42}'
+ self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "*")
+ self.assertEqual([], rows)
+
+ def test_get_from_index_with_combine(self):
+ self.db.create_index("index", "combine(foo, bar)")
+ content = '{"foo": "value1", "bar": "value2"}'
+ doc = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "value1")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "value2")
+ self.assertEqual([doc], rows)
+
+ def test_get_complex_combine(self):
+ self.db.create_index(
+ "index", "combine(number(foo, 5), lower(bar), split_words(baz))")
+ content = '{"foo": 12, "bar": "ALLCAPS", "baz": "qux nox"}'
+ doc = self.db.create_doc_from_json(content)
+ content = '{"foo": "not a number", "bar": "something"}'
+ doc2 = self.db.create_doc_from_json(content)
+ rows = self.db.get_from_index("index", "00012")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "allcaps")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "nox")
+ self.assertEqual([doc], rows)
+ rows = self.db.get_from_index("index", "something")
+ self.assertEqual([doc2], rows)
+
+ def test_get_index_keys_from_index(self):
+ self.db.create_index('test-idx', 'key')
+ content1 = '{"key": "value1"}'
+ content2 = '{"key": "value2"}'
+ content3 = '{"key": "value2"}'
+ self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ self.db.create_doc_from_json(content3)
+ self.assertEqual(
+ [('value1',), ('value2',)],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_get_index_keys_from_multicolumn_index(self):
+ self.db.create_index('test-idx', 'key1', 'key2')
+ content1 = '{"key1": "value1", "key2": "val2-1"}'
+ content2 = '{"key1": "value2", "key2": "val2-2"}'
+ content3 = '{"key1": "value2", "key2": "val2-2"}'
+ content4 = '{"key1": "value2", "key2": "val3"}'
+ self.db.create_doc_from_json(content1)
+ self.db.create_doc_from_json(content2)
+ self.db.create_doc_from_json(content3)
+ self.db.create_doc_from_json(content4)
+ self.assertEqual([
+ ('value1', 'val2-1'),
+ ('value2', 'val2-2'),
+ ('value2', 'val3')],
+ sorted(self.db.get_index_keys('test-idx')))
+
+ def test_empty_expr(self):
+ self.assertParseError('')
+
+ def test_nested_unknown_operation(self):
+ self.assertParseError('unknown_operation(field1)')
+
+ def test_parse_missing_close_paren(self):
+ self.assertParseError("lower(a")
+
+ def test_parse_trailing_close_paren(self):
+ self.assertParseError("lower(ab))")
+
+ def test_parse_trailing_chars(self):
+ self.assertParseError("lower(ab)adsf")
+
+ def test_parse_empty_op(self):
+ self.assertParseError("(ab)")
+
+ def test_parse_top_level_commas(self):
+ self.assertParseError("a, b")
+
+ def test_invalid_field_name(self):
+ self.assertParseError("a.")
+
+ def test_invalid_inner_field_name(self):
+ self.assertParseError("lower(a.)")
+
+ def test_gobbledigook(self):
+ self.assertParseError("(@#@cc @#!*DFJSXV(()jccd")
+
+ def test_leading_space(self):
+ self.assertIndexCreatable(" lower(a)")
+
+ def test_trailing_space(self):
+ self.assertIndexCreatable("lower(a) ")
+
+ def test_spaces_before_open_paren(self):
+ self.assertIndexCreatable("lower (a)")
+
+ def test_spaces_after_open_paren(self):
+ self.assertIndexCreatable("lower( a)")
+
+ def test_spaces_before_close_paren(self):
+ self.assertIndexCreatable("lower(a )")
+
+ def test_spaces_before_comma(self):
+ self.assertIndexCreatable("combine(a , b , c)")
+
+ def test_spaces_after_comma(self):
+ self.assertIndexCreatable("combine(a, b, c)")
+
+ def test_all_together_now(self):
+ self.assertParseError(' (a) ')
+
+ def test_all_together_now2(self):
+ self.assertParseError('combine(lower(x)x,foo)')
+
+
+class PythonBackendTests(tests.DatabaseBaseTests):
+
+ def setUp(self):
+ super(PythonBackendTests, self).setUp()
+ self.simple_doc = json.loads(simple_doc)
+
+ def test_create_doc_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id')
+ self.assertTrue(isinstance(doc, TestAlternativeDocument))
+
+ def test_get_doc_after_put_with_factory(self):
+ doc = self.db.create_doc(self.simple_doc, doc_id='my_doc_id')
+ self.db.set_document_factory(TestAlternativeDocument)
+ result = self.db.get_doc('my_doc_id')
+ self.assertTrue(isinstance(result, TestAlternativeDocument))
+ self.assertEqual(doc.doc_id, result.doc_id)
+ self.assertEqual(doc.rev, result.rev)
+ self.assertEqual(doc.get_json(), result.get_json())
+ self.assertEqual(False, result.has_conflicts)
+
+ def test_get_doc_nonexisting_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.assertIs(None, self.db.get_doc('non-existing'))
+
+ def test_get_all_docs_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.db.create_doc(self.simple_doc)
+ self.assertTrue(isinstance(
+ list(self.db.get_all_docs()[1])[0], TestAlternativeDocument))
+
+ def test_get_docs_conflicted_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ doc1 = self.db.create_doc(self.simple_doc)
+ doc2 = self.make_document(doc1.doc_id, 'alternate:1', nested_doc)
+ self.db._put_doc_if_newer(
+ doc2, save_conflict=True, replica_uid='r', replica_gen=1,
+ replica_trans_id='foo')
+ self.assertTrue(
+ isinstance(
+ list(self.db.get_docs([doc1.doc_id]))[0],
+ TestAlternativeDocument))
+
+ def test_get_from_index_with_factory(self):
+ self.db.set_document_factory(TestAlternativeDocument)
+ self.db.create_doc(self.simple_doc)
+ self.db.create_index('test-idx', 'key')
+ self.assertTrue(
+ isinstance(
+ self.db.get_from_index('test-idx', 'value')[0],
+ TestAlternativeDocument))
+
+ def test_sync_exchange_updates_indexes(self):
+ doc = self.db.create_doc(self.simple_doc)
+ self.db.create_index('test-idx', 'key')
+ new_content = '{"key": "altval"}'
+ other_rev = 'test:1|z:2'
+ st = self.db.get_sync_target()
+
+ def ignore(doc_id, doc_rev, doc):
+ pass
+
+ doc_other = self.make_document(doc.doc_id, other_rev, new_content)
+ docs_by_gen = [(doc_other, 10, 'T-sid')]
+ st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=ignore)
+ self.assertGetDoc(self.db, doc.doc_id, other_rev, new_content, False)
+ self.assertEqual(
+ [doc_other], self.db.get_from_index('test-idx', 'altval'))
+ self.assertEqual([], self.db.get_from_index('test-idx', 'value'))
+
+
+# Use a custom loader to apply the scenarios at load time.
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/tests/u1db_tests/test_document.py b/src/leap/soledad/tests/u1db_tests/test_document.py
new file mode 100644
index 00000000..e706e1a9
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_document.py
@@ -0,0 +1,150 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+
+from u1db import errors
+
+from leap.soledad.tests import u1db_tests as tests
+
+
+class TestDocument(tests.TestCase):
+
+ scenarios = ([(
+ 'py', {'make_document_for_test': tests.make_document_for_test})]) # +
+ #tests.C_DATABASE_SCENARIOS)
+
+ def test_create_doc(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ self.assertEqual('doc-id', doc.doc_id)
+ self.assertEqual('uid:1', doc.rev)
+ self.assertEqual(tests.simple_doc, doc.get_json())
+ self.assertFalse(doc.has_conflicts)
+
+ def test__repr__(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc)
+ self.assertEqual(
+ '%s(doc-id, uid:1, \'{"key": "value"}\')'
+ % (doc.__class__.__name__,),
+ repr(doc))
+
+ def test__repr__conflicted(self):
+ doc = self.make_document('doc-id', 'uid:1', tests.simple_doc,
+ has_conflicts=True)
+ self.assertEqual(
+ '%s(doc-id, uid:1, conflicted, \'{"key": "value"}\')'
+ % (doc.__class__.__name__,),
+ repr(doc))
+
+ def test__lt__(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('b', 'b', '{}')
+ self.assertTrue(doc_a < doc_b)
+ self.assertTrue(doc_b > doc_a)
+ doc_aa = self.make_document('a', 'a', '{}')
+ self.assertTrue(doc_aa < doc_a)
+
+ def test__eq__(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('a', 'b', '{}')
+ self.assertTrue(doc_a == doc_b)
+ doc_b = self.make_document('a', 'b', '{}', has_conflicts=True)
+ self.assertFalse(doc_a == doc_b)
+
+ def test_non_json_dict(self):
+ self.assertRaises(
+ errors.InvalidJSON, self.make_document, 'id', 'uid:1',
+ '"not a json dictionary"')
+
+ def test_non_json(self):
+ self.assertRaises(
+ errors.InvalidJSON, self.make_document, 'id', 'uid:1',
+ 'not a json dictionary')
+
+ def test_get_size(self):
+ doc_a = self.make_document('a', 'b', '{"some": "content"}')
+ self.assertEqual(
+ len('a' + 'b' + '{"some": "content"}'), doc_a.get_size())
+
+ def test_get_size_empty_document(self):
+ doc_a = self.make_document('a', 'b', None)
+ self.assertEqual(len('a' + 'b'), doc_a.get_size())
+
+
+class TestPyDocument(tests.TestCase):
+
+ scenarios = ([(
+ 'py', {'make_document_for_test': tests.make_document_for_test})])
+
+ def test_get_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertEqual({"content": ""}, doc.content)
+ doc.set_json('{"content": "new"}')
+ self.assertEqual({"content": "new"}, doc.content)
+
+ def test_set_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ doc.content = {"content": "new"}
+ self.assertEqual('{"content": "new"}', doc.get_json())
+
+ def test_set_bad_content(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(
+ errors.InvalidContent, setattr, doc, 'content',
+ '{"content": "new"}')
+
+ def test_is_tombstone(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ self.assertFalse(doc_a.is_tombstone())
+ doc_a.set_json(None)
+ self.assertTrue(doc_a.is_tombstone())
+
+ def test_make_tombstone(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ self.assertFalse(doc_a.is_tombstone())
+ doc_a.make_tombstone()
+ self.assertTrue(doc_a.is_tombstone())
+
+ def test_same_content_as(self):
+ doc_a = self.make_document('a', 'b', '{}')
+ doc_b = self.make_document('d', 'e', '{}')
+ self.assertTrue(doc_a.same_content_as(doc_b))
+ doc_b = self.make_document('p', 'q', '{}', has_conflicts=True)
+ self.assertTrue(doc_a.same_content_as(doc_b))
+ doc_b.content['key'] = 'value'
+ self.assertFalse(doc_a.same_content_as(doc_b))
+
+ def test_same_content_as_json_order(self):
+ doc_a = self.make_document(
+ 'a', 'b', '{"key1": "val1", "key2": "val2"}')
+ doc_b = self.make_document(
+ 'c', 'd', '{"key2": "val2", "key1": "val1"}')
+ self.assertTrue(doc_a.same_content_as(doc_b))
+
+ def test_set_json(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ doc.set_json('{"content": "new"}')
+ self.assertEqual('{"content": "new"}', doc.get_json())
+
+ def test_set_json_non_dict(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(errors.InvalidJSON, doc.set_json, '"is not a dict"')
+
+ def test_set_json_error(self):
+ doc = self.make_document('id', 'rev', '{"content":""}')
+ self.assertRaises(errors.InvalidJSON, doc.set_json, 'is not json')
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/tests/u1db_tests/test_http_app.py b/src/leap/soledad/tests/u1db_tests/test_http_app.py
new file mode 100644
index 00000000..e0729aa2
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_http_app.py
@@ -0,0 +1,1135 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test the WSGI app."""
+
+import paste.fixture
+import sys
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import StringIO
+
+from u1db import (
+ __version__ as _u1db_version,
+ errors,
+ sync,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.remote import (
+ http_app,
+ http_errors,
+)
+
+
+class TestFencedReader(tests.TestCase):
+
+ def test_init(self):
+ reader = http_app._FencedReader(StringIO.StringIO(""), 25, 100)
+ self.assertEqual(25, reader.remaining)
+
+ def test_read_chunk(self):
+ inp = StringIO.StringIO("abcdef")
+ reader = http_app._FencedReader(inp, 5, 10)
+ data = reader.read_chunk(2)
+ self.assertEqual("ab", data)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(3, reader.remaining)
+
+ def test_read_chunk_remaining(self):
+ inp = StringIO.StringIO("abcdef")
+ reader = http_app._FencedReader(inp, 4, 10)
+ data = reader.read_chunk(9999)
+ self.assertEqual("abcd", data)
+ self.assertEqual(4, inp.tell())
+ self.assertEqual(0, reader.remaining)
+
+ def test_read_chunk_nothing_left(self):
+ inp = StringIO.StringIO("abc")
+ reader = http_app._FencedReader(inp, 2, 10)
+ reader.read_chunk(2)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(0, reader.remaining)
+ data = reader.read_chunk(2)
+ self.assertEqual("", data)
+ self.assertEqual(2, inp.tell())
+ self.assertEqual(0, reader.remaining)
+
+ def test_read_chunk_kept(self):
+ inp = StringIO.StringIO("abcde")
+ reader = http_app._FencedReader(inp, 4, 10)
+ reader._kept = "xyz"
+ data = reader.read_chunk(2) # atmost ignored
+ self.assertEqual("xyz", data)
+ self.assertEqual(0, inp.tell())
+ self.assertEqual(4, reader.remaining)
+ self.assertIsNone(reader._kept)
+
+ def test_getline(self):
+ inp = StringIO.StringIO("abc\r\nde")
+ reader = http_app._FencedReader(inp, 6, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abc\r\n", line)
+ self.assertEqual("d", reader._kept)
+
+ def test_getline_exact(self):
+ inp = StringIO.StringIO("abcd\r\nef")
+ reader = http_app._FencedReader(inp, 6, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abcd\r\n", line)
+ self.assertIs(None, reader._kept)
+
+ def test_getline_no_newline(self):
+ inp = StringIO.StringIO("abcd")
+ reader = http_app._FencedReader(inp, 4, 10)
+ reader.MAXCHUNK = 6
+ line = reader.getline()
+ self.assertEqual("abcd", line)
+
+ def test_getline_many_chunks(self):
+ inp = StringIO.StringIO("abcde\r\nf")
+ reader = http_app._FencedReader(inp, 8, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("abcde\r\n", line)
+ self.assertEqual("f", reader._kept)
+ line = reader.getline()
+ self.assertEqual("f", line)
+
+ def test_getline_empty(self):
+ inp = StringIO.StringIO("")
+ reader = http_app._FencedReader(inp, 0, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("", line)
+ line = reader.getline()
+ self.assertEqual("", line)
+
+ def test_getline_just_newline(self):
+ inp = StringIO.StringIO("\r\n")
+ reader = http_app._FencedReader(inp, 2, 10)
+ reader.MAXCHUNK = 4
+ line = reader.getline()
+ self.assertEqual("\r\n", line)
+ line = reader.getline()
+ self.assertEqual("", line)
+
+ def test_getline_too_large(self):
+ inp = StringIO.StringIO("x" * 50)
+ reader = http_app._FencedReader(inp, 50, 25)
+ reader.MAXCHUNK = 4
+ self.assertRaises(http_app.BadRequest, reader.getline)
+
+ def test_getline_too_large_complete(self):
+ inp = StringIO.StringIO("x" * 25 + "\r\n")
+ reader = http_app._FencedReader(inp, 50, 25)
+ reader.MAXCHUNK = 4
+ self.assertRaises(http_app.BadRequest, reader.getline)
+
+
+class TestHTTPMethodDecorator(tests.TestCase):
+
+ def test_args(self):
+ @http_app.http_method()
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x", "b": "y"}, None)
+ self.assertEqual(("self", "x", "y"), res)
+
+ def test_args_missing(self):
+ @http_app.http_method()
+ def f(self, a, b):
+ return a, b
+ self.assertRaises(http_app.BadRequest, f, "self", {"a": "x"}, None)
+
+ def test_args_unexpected(self):
+ @http_app.http_method()
+ def f(self, a):
+ return a
+ self.assertRaises(http_app.BadRequest, f, "self",
+ {"a": "x", "c": "z"}, None)
+
+ def test_args_default(self):
+ @http_app.http_method()
+ def f(self, a, b="z"):
+ return a, b
+ res = f("self", {"a": "x"}, None)
+ self.assertEqual(("x", "z"), res)
+
+ def test_args_conversion(self):
+ @http_app.http_method(b=int)
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x", "b": "2"}, None)
+ self.assertEqual(("self", "x", 2), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self",
+ {"a": "x", "b": "foo"}, None)
+
+ def test_args_conversion_with_default(self):
+ @http_app.http_method(b=str)
+ def f(self, a, b=None):
+ return self, a, b
+ res = f("self", {"a": "x"}, None)
+ self.assertEqual(("self", "x", None), res)
+
+ def test_args_content(self):
+ @http_app.http_method()
+ def f(self, a, content):
+ return a, content
+ res = f(self, {"a": "x"}, "CONTENT")
+ self.assertEqual(("x", "CONTENT"), res)
+
+ def test_args_content_as_args(self):
+ @http_app.http_method(b=int, content_as_args=True)
+ def f(self, a, b):
+ return self, a, b
+ res = f("self", {"a": "x"}, '{"b": "2"}')
+ self.assertEqual(("self", "x", 2), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self", {}, 'not-json')
+
+ def test_args_content_no_query(self):
+ @http_app.http_method(no_query=True,
+ content_as_args=True)
+ def f(self, a='a', b='b'):
+ return a, b
+ res = f("self", {}, '{"b": "y"}')
+ self.assertEqual(('a', 'y'), res)
+
+ self.assertRaises(http_app.BadRequest, f, "self", {'a': 'x'},
+ '{"b": "y"}')
+
+
+class TestResource(object):
+
+ @http_app.http_method()
+ def get(self, a, b):
+ self.args = dict(a=a, b=b)
+ return 'Get'
+
+ @http_app.http_method()
+ def put(self, a, content):
+ self.args = dict(a=a)
+ self.content = content
+ return 'Put'
+
+ @http_app.http_method(content_as_args=True)
+ def put_args(self, a, b):
+ self.args = dict(a=a, b=b)
+ self.order = ['a']
+ self.entries = []
+
+ @http_app.http_method()
+ def put_stream_entry(self, content):
+ self.entries.append(content)
+ self.order.append('s')
+
+ def put_end(self):
+ self.order.append('e')
+ return "Put/end"
+
+
+class parameters:
+ max_request_size = 200000
+ max_entry_size = 100000
+
+
+class TestHTTPInvocationByMethodWithBody(tests.TestCase):
+
+ def test_get(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'GET'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Get', res)
+ self.assertEqual({'a': '1', 'b': '2'}, resource.args)
+
+ def test_put_json(self):
+ resource = TestResource()
+ body = '{"body": true}'
+ environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Put', res)
+ self.assertEqual({'a': '1'}, resource.args)
+ self.assertEqual('{"body": true}', resource.content)
+
+ def test_put_sync_stream(self):
+ resource = TestResource()
+ body = (
+ '[\r\n'
+ '{"b": 2},\r\n' # args
+ '{"entry": "x"},\r\n' # stream entry
+ '{"entry": "y"}\r\n' # stream entry
+ ']'
+ )
+ environ = {'QUERY_STRING': 'a=1', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-sync-stream'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ res = invoke()
+ self.assertEqual('Put/end', res)
+ self.assertEqual({'a': '1', 'b': 2}, resource.args)
+ self.assertEqual(
+ ['{"entry": "x"}', '{"entry": "y"}'], resource.entries)
+ self.assertEqual(['a', 's', 's', 'e'], resource.order)
+
+ def _put_sync_stream(self, body):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=1&b=2', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-sync-stream'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ invoke()
+
+ def test_put_sync_stream_wrong_start(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "{}\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "\r\n{}\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "")
+
+ def test_put_sync_stream_wrong_end(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}\r\n]\r\n...")
+
+ def test_put_sync_stream_missing_comma(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{}\r\n{}\r\n]")
+
+ def test_put_sync_stream_extra_comma(self):
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{},\r\n]")
+
+ self.assertRaises(http_app.BadRequest,
+ self._put_sync_stream, "[\r\n{},\r\n{},\r\n]")
+
+ def test_bad_request_decode_failure(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': 'a=\xff', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_content_type(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'text/plain'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_content_length_too_large(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '10000',
+ 'CONTENT_TYPE': 'text/plain'}
+
+ resource.max_request_size = 5000
+ resource.max_entry_size = sys.maxint # we don't get to use this
+
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_no_content_length(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('a'),
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_invalid_content_length(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('abc'),
+ 'CONTENT_LENGTH': '1unk',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_empty_body(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO(''),
+ 'CONTENT_LENGTH': '0',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_get_like(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'DELETE'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_put_like(self):
+ resource = TestResource()
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'PUT',
+ 'wsgi.input': StringIO.StringIO('{}'),
+ 'CONTENT_LENGTH': '2',
+ 'CONTENT_TYPE': 'application/json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+ def test_bad_request_unsupported_method_put_like_multi_json(self):
+ resource = TestResource()
+ body = '{}\r\n{}\r\n'
+ environ = {'QUERY_STRING': '', 'REQUEST_METHOD': 'POST',
+ 'wsgi.input': StringIO.StringIO(body),
+ 'CONTENT_LENGTH': str(len(body)),
+ 'CONTENT_TYPE': 'application/x-u1db-multi-json'}
+ invoke = http_app.HTTPInvocationByMethodWithBody(resource, environ,
+ parameters)
+ self.assertRaises(http_app.BadRequest, invoke)
+
+
+class TestHTTPResponder(tests.TestCase):
+
+ def start_response(self, status, headers):
+ self.status = status
+ self.headers = dict(headers)
+ self.response_body = []
+
+ def write(data):
+ self.response_body.append(data)
+
+ return write
+
+ def test_send_response_content_w_headers(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_content('foo', headers={'x-a': '1'})
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/json',
+ 'cache-control': 'no-cache',
+ 'x-a': '1', 'content-length': '3'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual(['foo'], responder.content)
+
+ def test_send_response_json(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_json(value='success')
+ self.assertEqual('200 OK', self.status)
+ expected_body = '{"value": "success"}\r\n'
+ self.assertEqual({'content-type': 'application/json',
+ 'content-length': str(len(expected_body)),
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual([expected_body], responder.content)
+
+ def test_send_response_json_status_fail(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.send_response_json(400)
+ self.assertEqual('400 Bad Request', self.status)
+ expected_body = '{}\r\n'
+ self.assertEqual({'content-type': 'application/json',
+ 'content-length': str(len(expected_body)),
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual([], self.response_body)
+ self.assertEqual([expected_body], responder.content)
+
+ def test_start_finish_response_status_fail(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.start_response(404, {'error': 'not found'})
+ responder.finish_response()
+ self.assertEqual('404 Not Found', self.status)
+ self.assertEqual({'content-type': 'application/json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['{"error": "not found"}\r\n'], self.response_body)
+ self.assertEqual([], responder.content)
+
+ def test_send_stream_entry(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.content_type = "application/x-u1db-multi-json"
+ responder.start_response(200)
+ responder.start_stream()
+ responder.stream_entry({'entry': 1})
+ responder.stream_entry({'entry': 2})
+ responder.end_stream()
+ responder.finish_response()
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/x-u1db-multi-json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['[',
+ '\r\n', '{"entry": 1}',
+ ',\r\n', '{"entry": 2}',
+ '\r\n]\r\n'], self.response_body)
+ self.assertEqual([], responder.content)
+
+ def test_send_stream_w_error(self):
+ responder = http_app.HTTPResponder(self.start_response)
+ responder.content_type = "application/x-u1db-multi-json"
+ responder.start_response(200)
+ responder.start_stream()
+ responder.stream_entry({'entry': 1})
+ responder.send_response_json(503, error="unavailable")
+ self.assertEqual('200 OK', self.status)
+ self.assertEqual({'content-type': 'application/x-u1db-multi-json',
+ 'cache-control': 'no-cache'}, self.headers)
+ self.assertEqual(['[',
+ '\r\n', '{"entry": 1}'], self.response_body)
+ self.assertEqual([',\r\n', '{"error": "unavailable"}\r\n'],
+ responder.content)
+
+
+class TestHTTPApp(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPApp, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.http_app = http_app.HTTPApp(self.state)
+ self.app = paste.fixture.TestApp(self.http_app)
+ self.db0 = self.state._create_database('db0')
+
+ def test_bad_request_broken(self):
+ resp = self.app.put('/db0/doc/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/foo'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_bad_request_dispatch(self):
+ resp = self.app.put('/db0/foo/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_version(self):
+ resp = self.app.get('/')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"version": _u1db_version}, json.loads(resp.body))
+
+ def test_create_database(self):
+ resp = self.app.put('/db1', params='{}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+
+ resp = self.app.put('/db1', params='{}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+
+ def test_delete_database(self):
+ resp = self.app.delete('/db0')
+ self.assertEqual(200, resp.status)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.state.check_database, 'db0')
+
+ def test_get_database(self):
+ resp = self.app.get('/db0')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({}, json.loads(resp.body))
+
+ def test_valid_database_names(self):
+ resp = self.app.get('/a-database', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/db1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/0', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/0-0', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ resp = self.app.get('/org.future', expect_errors=True)
+ self.assertEqual(404, resp.status)
+
+ def test_invalid_database_names(self):
+ resp = self.app.get('/.a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ resp = self.app.get('/-a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ resp = self.app.get('/_a', expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_put_doc_create(self):
+ resp = self.app.put('/db0/doc/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'})
+ doc = self.db0.get_doc('doc1')
+ self.assertEqual(201, resp.status) # created
+ self.assertEqual('{"x": 1}', doc.get_json())
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_put_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev,
+ params='{"x": 2}',
+ headers={'content-type': 'application/json'})
+ doc = self.db0.get_doc('doc1')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('{"x": 2}', doc.get_json())
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_put_doc_too_large(self):
+ self.http_app.max_request_size = 15000
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.put('/db0/doc/doc1?old_rev=%s' % doc.rev,
+ params='{"%s": 2}' % ('z' * 16000),
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_delete_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.delete('/db0/doc/doc1?old_rev=%s' % doc.rev)
+ doc = self.db0.get_doc('doc1', include_deleted=True)
+ self.assertEqual(None, doc.content)
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'rev': doc.rev}, json.loads(resp.body))
+
+ def test_get_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ resp = self.app.get('/db0/doc/%s' % doc.doc_id)
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual('{"x": 1}', resp.body)
+ self.assertEqual(doc.rev, resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_non_existing(self):
+ resp = self.app.get('/db0/doc/not-there', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "document does not exist"}, json.loads(resp.body))
+ self.assertEqual('', resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_deleted(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get('/db0/doc/doc1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DocumentDoesNotExist.wire_description},
+ json.loads(resp.body))
+
+ def test_get_doc_deleted_explicit_exclude(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get(
+ '/db0/doc/doc1?include_deleted=false', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DocumentDoesNotExist.wire_description},
+ json.loads(resp.body))
+
+ def test_get_deleted_doc(self):
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.db0.delete_doc(doc)
+ resp = self.app.get(
+ '/db0/doc/doc1?include_deleted=true', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": errors.DOCUMENT_DELETED}, json.loads(resp.body))
+ self.assertEqual(doc.rev, resp.header('x-u1db-rev'))
+ self.assertEqual('false', resp.header('x-u1db-has-conflicts'))
+
+ def test_get_doc_non_existing_dabase(self):
+ resp = self.app.get('/not-there/doc/doc1', expect_errors=True)
+ self.assertEqual(404, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "database does not exist"}, json.loads(resp.body))
+
+ def test_get_docs(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False},
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_missing_doc_ids(self):
+ resp = self.app.get('/db0/docs', expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "missing document ids"}, json.loads(resp.body))
+
+ def test_get_docs_empty_doc_ids(self):
+ resp = self.app.get('/db0/docs?doc_ids=', expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(
+ {"error": "missing document ids"}, json.loads(resp.body))
+
+ def test_get_docs_percent(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc%1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc%1",
+ "has_conflicts": False},
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_deleted(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ self.db0.delete_doc(doc2)
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_docs_include_deleted(self):
+ doc1 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ doc2 = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc2')
+ self.db0.delete_doc(doc2)
+ ids = ','.join([doc1.doc_id, doc2.doc_id])
+ resp = self.app.get('/db0/docs?doc_ids=%s&include_deleted=true' % ids)
+ self.assertEqual(200, resp.status)
+ self.assertEqual(
+ 'application/json', resp.header('content-type'))
+ expected = [
+ {"content": '{"x": 1}', "doc_rev": "db0:1", "doc_id": "doc1",
+ "has_conflicts": False},
+ {"content": None, "doc_rev": "db0:2", "doc_id": "doc2",
+ "has_conflicts": False}]
+ self.assertEqual(expected, json.loads(resp.body))
+
+ def test_get_sync_info(self):
+ self.db0._set_replica_gen_and_trans_id('other-id', 1, 'T-transid')
+ resp = self.app.get('/db0/sync-from/other-id')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual(dict(target_replica_uid='db0',
+ target_replica_generation=0,
+ target_replica_transaction_id='',
+ source_replica_uid='other-id',
+ source_replica_generation=1,
+ source_transaction_id='T-transid'),
+ json.loads(resp.body))
+
+ def test_record_sync_info(self):
+ resp = self.app.put('/db0/sync-from/other-id',
+ params='{"generation": 2, "transaction_id": '
+ '"T-transid"}',
+ headers={'content-type': 'application/json'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({'ok': True}, json.loads(resp.body))
+ self.assertEqual(
+ (2, 'T-transid'),
+ self.db0._get_replica_gen_and_trans_id('other-id'))
+
+ def test_sync_exchange_send(self):
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'},
+ 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content':
+ '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'}
+ }
+
+ gens = []
+ _do_set_replica_gen_and_trans_id = \
+ self.db0._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness(other_uid, other_gen, other_trans_id):
+ gens.append((other_uid, other_gen))
+ _do_set_replica_gen_and_trans_id(
+ other_uid, other_gen, other_trans_id)
+ self.assertGetDoc(self.db0, entries[other_gen]['id'],
+ entries[other_gen]['rev'],
+ entries[other_gen]['content'], False)
+
+ self.patch(
+ self.db0, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness)
+
+ args = dict(last_known_generation=0)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s,\r\n" % json.dumps(entries[10]) +
+ "%s\r\n" % json.dumps(entries[11]) +
+ "]\r\n")
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ bits = resp.body.split('\r\n')
+ self.assertEqual('[', bits[0])
+ last_trans_id = self.db0._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id},
+ json.loads(bits[1]))
+ self.assertEqual(']', bits[2])
+ self.assertEqual('', bits[3])
+ self.assertEqual([('replica', 10), ('replica', 11)], gens)
+
+ def test_sync_exchange_send_ensure(self):
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "here"}', 'gen': 10, 'trans_id': 'T-sid'},
+ 11: {'id': 'doc-here2', 'rev': 'replica:1', 'content':
+ '{"value": "here2"}', 'gen': 11, 'trans_id': 'T-sed'}
+ }
+
+ args = dict(last_known_generation=0, ensure=True)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s,\r\n" % json.dumps(entries[10]) +
+ "%s\r\n" % json.dumps(entries[11]) +
+ "]\r\n")
+ resp = self.app.post('/dbnew/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ bits = resp.body.split('\r\n')
+ self.assertEqual('[', bits[0])
+ dbnew = self.state.open_database("dbnew")
+ last_trans_id = dbnew._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id,
+ 'replica_uid': dbnew._replica_uid},
+ json.loads(bits[1]))
+ self.assertEqual(']', bits[2])
+ self.assertEqual('', bits[3])
+
+ def test_sync_exchange_send_entry_too_large(self):
+ self.patch(http_app.SyncResource, 'max_request_size', 20000)
+ self.patch(http_app.SyncResource, 'max_entry_size', 10000)
+ entries = {
+ 10: {'id': 'doc-here', 'rev': 'replica:1', 'content':
+ '{"value": "%s"}' % ('H' * 11000), 'gen': 10},
+ }
+ args = dict(last_known_generation=0)
+ body = ("[\r\n" +
+ "%s,\r\n" % json.dumps(args) +
+ "%s\r\n" % json.dumps(entries[10]) +
+ "]\r\n")
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+
+ def test_sync_exchange_receive(self):
+ doc = self.db0.create_doc_from_json('{"value": "there"}')
+ doc2 = self.db0.create_doc_from_json('{"value": "there2"}')
+ args = dict(last_known_generation=0)
+ body = "[\r\n%s\r\n]" % json.dumps(args)
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ parts = resp.body.splitlines()
+ self.assertEqual(5, len(parts))
+ self.assertEqual('[', parts[0])
+ last_trans_id = self.db0._get_transaction_log()[-1][1]
+ self.assertEqual({'new_generation': 2,
+ 'new_transaction_id': last_trans_id},
+ json.loads(parts[1].rstrip(",")))
+ part2 = json.loads(parts[2].rstrip(","))
+ self.assertTrue(part2['trans_id'].startswith('T-'))
+ self.assertEqual('{"value": "there"}', part2['content'])
+ self.assertEqual(doc.rev, part2['rev'])
+ self.assertEqual(doc.doc_id, part2['id'])
+ self.assertEqual(1, part2['gen'])
+ part3 = json.loads(parts[3].rstrip(","))
+ self.assertTrue(part3['trans_id'].startswith('T-'))
+ self.assertEqual('{"value": "there2"}', part3['content'])
+ self.assertEqual(doc2.rev, part3['rev'])
+ self.assertEqual(doc2.doc_id, part3['id'])
+ self.assertEqual(2, part3['gen'])
+ self.assertEqual(']', parts[4])
+
+ def test_sync_exchange_error_in_stream(self):
+ args = dict(last_known_generation=0)
+ body = "[\r\n%s\r\n]" % json.dumps(args)
+
+ def boom(self, return_doc_cb):
+ raise errors.Unavailable
+
+ self.patch(sync.SyncExchange, 'return_docs',
+ boom)
+ resp = self.app.post('/db0/sync-from/replica',
+ params=body,
+ headers={'content-type':
+ 'application/x-u1db-sync-stream'})
+ self.assertEqual(200, resp.status)
+ self.assertEqual('application/x-u1db-sync-stream',
+ resp.header('content-type'))
+ parts = resp.body.splitlines()
+ self.assertEqual(3, len(parts))
+ self.assertEqual('[', parts[0])
+ self.assertEqual({'new_generation': 0, 'new_transaction_id': ''},
+ json.loads(parts[1].rstrip(",")))
+ self.assertEqual({'error': 'unavailable'}, json.loads(parts[2]))
+
+
+class TestRequestHooks(tests.TestCase):
+
+ def setUp(self):
+ super(TestRequestHooks, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.http_app = http_app.HTTPApp(self.state)
+ self.app = paste.fixture.TestApp(self.http_app)
+ self.db0 = self.state._create_database('db0')
+
+ def test_begin_and_done(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def done(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('done')
+
+ self.http_app.request_begin = begin
+ self.http_app.request_done = done
+
+ doc = self.db0.create_doc_from_json('{"x": 1}', doc_id='doc1')
+ self.app.get('/db0/doc/%s' % doc.doc_id)
+
+ self.assertEqual(['begin', 'done'], calls)
+
+ def test_bad_request(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def bad_request(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('bad-request')
+
+ self.http_app.request_begin = begin
+ self.http_app.request_bad_request = bad_request
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ resp = self.app.put('/db0/foo/doc1', params='{"x": 1}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(400, resp.status)
+ self.assertEqual(['begin', 'bad-request'], calls)
+
+
+class TestHTTPErrors(tests.TestCase):
+
+ def test_wire_description_to_status(self):
+ self.assertNotIn("error", http_errors.wire_description_to_status)
+
+
+class TestHTTPAppErrorHandling(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPAppErrorHandling, self).setUp()
+ self.exc = None
+ self.state = tests.ServerStateForTests()
+
+ class ErroringResource(object):
+
+ def post(_, args, content):
+ raise self.exc
+
+ def lookup_resource(environ, responder):
+ return ErroringResource()
+
+ self.http_app = http_app.HTTPApp(self.state)
+ self.http_app._lookup_resource = lookup_resource
+ self.app = paste.fixture.TestApp(self.http_app)
+
+ def test_RevisionConflict_etc(self):
+ self.exc = errors.RevisionConflict()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(409, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "revision conflict"},
+ json.loads(resp.body))
+
+ def test_Unavailable(self):
+ self.exc = errors.Unavailable
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(503, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "unavailable"},
+ json.loads(resp.body))
+
+ def test_generic_u1db_errors(self):
+ self.exc = errors.U1DBError()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(500, resp.status)
+ self.assertEqual('application/json', resp.header('content-type'))
+ self.assertEqual({"error": "error"},
+ json.loads(resp.body))
+
+ def test_generic_u1db_errors_hooks(self):
+ calls = []
+
+ def begin(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append('begin')
+
+ def u1db_error(environ, exc):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append(('error', exc))
+
+ self.http_app.request_begin = begin
+ self.http_app.request_u1db_error = u1db_error
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ self.exc = errors.U1DBError()
+ resp = self.app.post('/req', params='{}',
+ headers={'content-type': 'application/json'},
+ expect_errors=True)
+ self.assertEqual(500, resp.status)
+ self.assertEqual(['begin', ('error', self.exc)], calls)
+
+ def test_failure(self):
+ class Failure(Exception):
+ pass
+ self.exc = Failure()
+ self.assertRaises(Failure, self.app.post, '/req', params='{}',
+ headers={'content-type': 'application/json'})
+
+ def test_failure_hooks(self):
+ class Failure(Exception):
+ pass
+ calls = []
+
+ def begin(environ):
+ calls.append('begin')
+
+ def failed(environ):
+ self.assertTrue('PATH_INFO' in environ)
+ calls.append(('failed', sys.exc_info()))
+
+ self.http_app.request_begin = begin
+ self.http_app.request_failed = failed
+ # shouldn't be called
+ self.http_app.request_done = lambda env: 1 / 0
+
+ self.exc = Failure()
+ self.assertRaises(Failure, self.app.post, '/req', params='{}',
+ headers={'content-type': 'application/json'})
+
+ self.assertEqual(2, len(calls))
+ self.assertEqual('begin', calls[0])
+ marker, (exc_type, exc, tb) = calls[1]
+ self.assertEqual('failed', marker)
+ self.assertEqual(self.exc, exc)
+
+
+class TestPluggableSyncExchange(tests.TestCase):
+
+ def setUp(self):
+ super(TestPluggableSyncExchange, self).setUp()
+ self.state = tests.ServerStateForTests()
+ self.state.ensure_database('foo')
+
+ def test_plugging(self):
+
+ class MySyncExchange(object):
+ def __init__(self, db, source_replica_uid, last_known_generation):
+ pass
+
+ class MySyncResource(http_app.SyncResource):
+ sync_exchange_class = MySyncExchange
+
+ sync_res = MySyncResource('foo', 'src', self.state, None)
+ sync_res.post_args(
+ {'last_known_generation': 0, 'last_known_trans_id': None}, '{}')
+ self.assertIsInstance(sync_res.sync_exch, MySyncExchange)
diff --git a/src/leap/soledad/tests/u1db_tests/test_http_client.py b/src/leap/soledad/tests/u1db_tests/test_http_client.py
new file mode 100644
index 00000000..42e98461
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_http_client.py
@@ -0,0 +1,363 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for HTTPDatabase"""
+
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ errors,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.remote import (
+ http_client,
+)
+
+
+class TestEncoder(tests.TestCase):
+
+ def test_encode_string(self):
+ self.assertEqual("foo", http_client._encode_query_parameter("foo"))
+
+ def test_encode_true(self):
+ self.assertEqual("true", http_client._encode_query_parameter(True))
+
+ def test_encode_false(self):
+ self.assertEqual("false", http_client._encode_query_parameter(False))
+
+
+class TestHTTPClientBase(tests.TestCaseWithServer):
+
+ def setUp(self):
+ super(TestHTTPClientBase, self).setUp()
+ self.errors = 0
+
+ def app(self, environ, start_response):
+ if environ['PATH_INFO'].endswith('echo'):
+ start_response("200 OK", [('Content-Type', 'application/json')])
+ ret = {}
+ for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
+ ret[name] = environ[name]
+ if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
+ ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
+ content_length = int(environ['CONTENT_LENGTH'])
+ ret['body'] = environ['wsgi.input'].read(content_length)
+ return [json.dumps(ret)]
+ elif environ['PATH_INFO'].endswith('error_then_accept'):
+ if self.errors >= 3:
+ start_response(
+ "200 OK", [('Content-Type', 'application/json')])
+ ret = {}
+ for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
+ ret[name] = environ[name]
+ if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
+ ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
+ content_length = int(environ['CONTENT_LENGTH'])
+ ret['body'] = '{"oki": "doki"}'
+ return [json.dumps(ret)]
+ self.errors += 1
+ content_length = int(environ['CONTENT_LENGTH'])
+ error = json.loads(
+ environ['wsgi.input'].read(content_length))
+ response = error['response']
+ # In debug mode, wsgiref has an assertion that the status parameter
+ # is a 'str' object. However error['status'] returns a unicode
+ # object.
+ status = str(error['status'])
+ if isinstance(response, unicode):
+ response = str(response)
+ if isinstance(response, str):
+ start_response(status, [('Content-Type', 'text/plain')])
+ return [str(response)]
+ else:
+ start_response(status, [('Content-Type', 'application/json')])
+ return [json.dumps(response)]
+ elif environ['PATH_INFO'].endswith('error'):
+ self.errors += 1
+ content_length = int(environ['CONTENT_LENGTH'])
+ error = json.loads(
+ environ['wsgi.input'].read(content_length))
+ response = error['response']
+ # In debug mode, wsgiref has an assertion that the status parameter
+ # is a 'str' object. However error['status'] returns a unicode
+ # object.
+ status = str(error['status'])
+ if isinstance(response, unicode):
+ response = str(response)
+ if isinstance(response, str):
+ start_response(status, [('Content-Type', 'text/plain')])
+ return [str(response)]
+ else:
+ start_response(status, [('Content-Type', 'application/json')])
+ return [json.dumps(response)]
+ elif '/oauth' in environ['PATH_INFO']:
+ base_url = self.getURL('').rstrip('/')
+ oauth_req = oauth.OAuthRequest.from_request(
+ http_method=environ['REQUEST_METHOD'],
+ http_url=base_url + environ['PATH_INFO'],
+ headers={'Authorization': environ['HTTP_AUTHORIZATION']},
+ query_string=environ['QUERY_STRING']
+ )
+ oauth_server = oauth.OAuthServer(tests.testingOAuthStore)
+ oauth_server.add_signature_method(tests.sign_meth_HMAC_SHA1)
+ try:
+ consumer, token, params = oauth_server.verify_request(
+ oauth_req)
+ except oauth.OAuthError, e:
+ start_response("401 Unauthorized",
+ [('Content-Type', 'application/json')])
+ return [json.dumps({"error": "unauthorized",
+ "message": e.message})]
+ start_response("200 OK", [('Content-Type', 'application/json')])
+ return [json.dumps([environ['PATH_INFO'], token.key, params])]
+
+ def make_app(self):
+ return self.app
+
+ def getClient(self, **kwds):
+ self.startServer()
+ return http_client.HTTPClientBase(self.getURL('dbase'), **kwds)
+
+ def test_construct(self):
+ self.startServer()
+ url = self.getURL()
+ cli = http_client.HTTPClientBase(url)
+ self.assertEqual(url, cli._url.geturl())
+ self.assertIs(None, cli._conn)
+
+ def test_parse_url(self):
+ cli = http_client.HTTPClientBase(
+ '%s://127.0.0.1:12345/' % self.url_scheme)
+ self.assertEqual(self.url_scheme, cli._url.scheme)
+ self.assertEqual('127.0.0.1', cli._url.hostname)
+ self.assertEqual(12345, cli._url.port)
+ self.assertEqual('/', cli._url.path)
+
+ def test__ensure_connection(self):
+ cli = self.getClient()
+ self.assertIs(None, cli._conn)
+ cli._ensure_connection()
+ self.assertIsNot(None, cli._conn)
+ conn = cli._conn
+ cli._ensure_connection()
+ self.assertIs(conn, cli._conn)
+
+ def test_close(self):
+ cli = self.getClient()
+ cli._ensure_connection()
+ cli.close()
+ self.assertIs(None, cli._conn)
+
+ def test__request(self):
+ cli = self.getClient()
+ res, headers = cli._request('PUT', ['echo'], {}, {})
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': '',
+ 'body': '{}',
+ 'REQUEST_METHOD': 'PUT'}, json.loads(res))
+
+ res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1})
+ self.assertEqual({'PATH_INFO': '/dbase/doc/echo',
+ 'QUERY_STRING': 'a=1',
+ 'REQUEST_METHOD': 'GET'}, json.loads(res))
+
+ res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1})
+ self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo',
+ 'QUERY_STRING': 'a=1',
+ 'REQUEST_METHOD': 'GET'}, json.loads(res))
+
+ res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body',
+ 'application/x-test')
+ self.assertEqual({'CONTENT_TYPE': 'application/x-test',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': 'b=2',
+ 'body': 'Body',
+ 'REQUEST_METHOD': 'POST'}, json.loads(res))
+
+ def test__request_json(self):
+ cli = self.getClient()
+ res, headers = cli._request_json(
+ 'POST', ['echo'], {'b': 2}, {'a': 'x'})
+ self.assertEqual('application/json', headers['content-type'])
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/echo',
+ 'QUERY_STRING': 'b=2',
+ 'body': '{"a": "x"}',
+ 'REQUEST_METHOD': 'POST'}, res)
+
+ def test_unspecified_http_error(self):
+ cli = self.getClient()
+ self.assertRaises(errors.HTTPError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "500 Internal Error",
+ 'response': "Crash."})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "500 Internal Error",
+ 'response': "Fail."})
+ except errors.HTTPError, e:
+ pass
+
+ self.assertEqual(500, e.status)
+ self.assertEqual("Fail.", e.message)
+ self.assertTrue("content-type" in e.headers)
+
+ def test_revision_conflict(self):
+ cli = self.getClient()
+ self.assertRaises(errors.RevisionConflict,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "409 Conflict",
+ 'response': {"error": "revision conflict"}})
+
+ def test_unavailable_proper(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ self.assertRaises(errors.Unavailable,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "503 Service Unavailable",
+ 'response': {"error": "unavailable"}})
+ self.assertEqual(5, self.errors)
+
+ def test_unavailable_then_available(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ res, headers = cli._request_json(
+ 'POST', ['error_then_accept'], {'b': 2},
+ {'status': "503 Service Unavailable",
+ 'response': {"error": "unavailable"}})
+ self.assertEqual('application/json', headers['content-type'])
+ self.assertEqual({'CONTENT_TYPE': 'application/json',
+ 'PATH_INFO': '/dbase/error_then_accept',
+ 'QUERY_STRING': 'b=2',
+ 'body': '{"oki": "doki"}',
+ 'REQUEST_METHOD': 'POST'}, res)
+ self.assertEqual(3, self.errors)
+
+ def test_unavailable_random_source(self):
+ cli = self.getClient()
+ cli._delays = (0, 0, 0, 0, 0)
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "503 Service Unavailable",
+ 'response': "random unavailable."})
+ except errors.Unavailable, e:
+ pass
+
+ self.assertEqual(503, e.status)
+ self.assertEqual("random unavailable.", e.message)
+ self.assertTrue("content-type" in e.headers)
+ self.assertEqual(5, self.errors)
+
+ def test_document_too_big(self):
+ cli = self.getClient()
+ self.assertRaises(errors.DocumentTooBig,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "document too big"}})
+
+ def test_user_quota_exceeded(self):
+ cli = self.getClient()
+ self.assertRaises(errors.UserQuotaExceeded,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "user quota exceeded"}})
+
+ def test_user_needs_subscription(self):
+ cli = self.getClient()
+ self.assertRaises(errors.SubscriptionNeeded,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "403 Forbidden",
+ 'response': {"error": "user needs subscription"}})
+
+ def test_generic_u1db_error(self):
+ cli = self.getClient()
+ self.assertRaises(errors.U1DBError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': {"error": "error"}})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': {"error": "error"}})
+ except errors.U1DBError, e:
+ pass
+ self.assertIs(e.__class__, errors.U1DBError)
+
+ def test_unspecified_bad_request(self):
+ cli = self.getClient()
+ self.assertRaises(errors.HTTPError,
+ cli._request_json, 'POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': "<Bad Request>"})
+ try:
+ cli._request_json('POST', ['error'], {},
+ {'status': "400 Bad Request",
+ 'response': "<Bad Request>"})
+ except errors.HTTPError, e:
+ pass
+
+ self.assertEqual(400, e.status)
+ self.assertEqual("<Bad Request>", e.message)
+ self.assertTrue("content-type" in e.headers)
+
+ def test_oauth(self):
+ cli = self.getClient()
+ cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res))
+
+ # oauth does its own internal quoting
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth', 'foo bar'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth/foo bar', tests.token1.key, params],
+ json.loads(res))
+
+ def test_oauth_ctr_creds(self):
+ cli = self.getClient(creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret,
+ }})
+ params = {'x': u'\xf0', 'y': "foo"}
+ res, headers = cli._request('GET', ['doc', 'oauth'], params)
+ self.assertEqual(
+ ['/dbase/doc/oauth', tests.token1.key, params], json.loads(res))
+
+ def test_unknown_creds(self):
+ self.assertRaises(errors.UnknownAuthMethod,
+ self.getClient, creds={'foo': {}})
+ self.assertRaises(errors.UnknownAuthMethod,
+ self.getClient, creds={})
+
+ def test_oauth_Unauthorized(self):
+ cli = self.getClient()
+ cli.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, "WRONG")
+ params = {'y': 'foo'}
+ self.assertRaises(errors.Unauthorized, cli._request, 'GET',
+ ['doc', 'oauth'], params)
diff --git a/src/leap/soledad/tests/u1db_tests/test_http_database.py b/src/leap/soledad/tests/u1db_tests/test_http_database.py
new file mode 100644
index 00000000..f21e6da1
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_http_database.py
@@ -0,0 +1,260 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for HTTPDatabase"""
+
+import inspect
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+
+from u1db import (
+ errors,
+ Document,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.remote import (
+ http_database,
+ http_target,
+)
+from leap.soledad.tests.u1db_tests.test_remote_sync_target import (
+ make_http_app,
+)
+
+
+class TestHTTPDatabaseSimpleOperations(tests.TestCase):
+
+ def setUp(self):
+ super(TestHTTPDatabaseSimpleOperations, self).setUp()
+ self.db = http_database.HTTPDatabase('dbase')
+ self.db._conn = object() # crash if used
+ self.got = None
+ self.response_val = None
+
+ def _request(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ def _request_json(method, url_parts, params=None, body=None,
+ content_type=None):
+ self.got = method, url_parts, params, body, content_type
+ if isinstance(self.response_val, Exception):
+ raise self.response_val
+ return self.response_val
+
+ self.db._request = _request
+ self.db._request_json = _request_json
+
+ def test__sanity_same_signature(self):
+ my_request_sig = inspect.getargspec(self.db._request)
+ my_request_sig = (['self'] + my_request_sig[0],) + my_request_sig[1:]
+ self.assertEqual(
+ my_request_sig,
+ inspect.getargspec(http_database.HTTPDatabase._request))
+ my_request_json_sig = inspect.getargspec(self.db._request_json)
+ my_request_json_sig = ((['self'] + my_request_json_sig[0],) +
+ my_request_json_sig[1:])
+ self.assertEqual(
+ my_request_json_sig,
+ inspect.getargspec(http_database.HTTPDatabase._request_json))
+
+ def test__ensure(self):
+ self.response_val = {'ok': True}, {}
+ self.db._ensure()
+ self.assertEqual(('PUT', [], {}, {}, None), self.got)
+
+ def test__delete(self):
+ self.response_val = {'ok': True}, {}
+ self.db._delete()
+ self.assertEqual(('DELETE', [], {}, {}, None), self.got)
+
+ def test__check(self):
+ self.response_val = {}, {}
+ res = self.db._check()
+ self.assertEqual({}, res)
+ self.assertEqual(('GET', [], None, None, None), self.got)
+
+ def test_put_doc(self):
+ self.response_val = {'rev': 'doc-rev'}, {}
+ doc = Document('doc-id', None, '{"v": 1}')
+ res = self.db.put_doc(doc)
+ self.assertEqual('doc-rev', res)
+ self.assertEqual('doc-rev', doc.rev)
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {},
+ '{"v": 1}', 'application/json'), self.got)
+
+ self.response_val = {'rev': 'doc-rev-2'}, {}
+ doc.content = {"v": 2}
+ res = self.db.put_doc(doc)
+ self.assertEqual('doc-rev-2', res)
+ self.assertEqual('doc-rev-2', doc.rev)
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {'old_rev': 'doc-rev'},
+ '{"v": 2}', 'application/json'), self.got)
+
+ def test_get_doc(self):
+ self.response_val = '{"v": 2}', {'x-u1db-rev': 'doc-rev',
+ 'x-u1db-has-conflicts': 'false'}
+ self.assertGetDoc(self.db, 'doc-id', 'doc-rev', '{"v": 2}', False)
+ self.assertEqual(
+ ('GET', ['doc', 'doc-id'], {'include_deleted': False}, None, None),
+ self.got)
+
+ def test_get_doc_non_existing(self):
+ self.response_val = errors.DocumentDoesNotExist()
+ self.assertIs(None, self.db.get_doc('not-there'))
+ self.assertEqual(
+ ('GET', ['doc', 'not-there'], {'include_deleted': False}, None,
+ None), self.got)
+
+ def test_get_doc_deleted(self):
+ self.response_val = errors.DocumentDoesNotExist()
+ self.assertIs(None, self.db.get_doc('deleted'))
+ self.assertEqual(
+ ('GET', ['doc', 'deleted'], {'include_deleted': False}, None,
+ None), self.got)
+
+ def test_get_doc_deleted_include_deleted(self):
+ self.response_val = errors.HTTPError(404,
+ json.dumps(
+ {"error": errors.DOCUMENT_DELETED}
+ ),
+ {'x-u1db-rev': 'doc-rev-gone',
+ 'x-u1db-has-conflicts': 'false'})
+ doc = self.db.get_doc('deleted', include_deleted=True)
+ self.assertEqual('deleted', doc.doc_id)
+ self.assertEqual('doc-rev-gone', doc.rev)
+ self.assertIs(None, doc.content)
+ self.assertEqual(
+ ('GET', ['doc', 'deleted'], {'include_deleted': True}, None, None),
+ self.got)
+
+ def test_get_doc_pass_through_errors(self):
+ self.response_val = errors.HTTPError(500, 'Crash.')
+ self.assertRaises(errors.HTTPError,
+ self.db.get_doc, 'something-something')
+
+ def test_create_doc_with_id(self):
+ self.response_val = {'rev': 'doc-rev'}, {}
+ new_doc = self.db.create_doc_from_json('{"v": 1}', doc_id='doc-id')
+ self.assertEqual('doc-rev', new_doc.rev)
+ self.assertEqual('doc-id', new_doc.doc_id)
+ self.assertEqual('{"v": 1}', new_doc.get_json())
+ self.assertEqual(('PUT', ['doc', 'doc-id'], {},
+ '{"v": 1}', 'application/json'), self.got)
+
+ def test_create_doc_without_id(self):
+ self.response_val = {'rev': 'doc-rev-2'}, {}
+ new_doc = self.db.create_doc_from_json('{"v": 3}')
+ self.assertEqual('D-', new_doc.doc_id[:2])
+ self.assertEqual('doc-rev-2', new_doc.rev)
+ self.assertEqual('{"v": 3}', new_doc.get_json())
+ self.assertEqual(('PUT', ['doc', new_doc.doc_id], {},
+ '{"v": 3}', 'application/json'), self.got)
+
+ def test_delete_doc(self):
+ self.response_val = {'rev': 'doc-rev-gone'}, {}
+ doc = Document('doc-id', 'doc-rev', None)
+ self.db.delete_doc(doc)
+ self.assertEqual('doc-rev-gone', doc.rev)
+ self.assertEqual(('DELETE', ['doc', 'doc-id'], {'old_rev': 'doc-rev'},
+ None, None), self.got)
+
+ def test_get_sync_target(self):
+ st = self.db.get_sync_target()
+ self.assertIsInstance(st, http_target.HTTPSyncTarget)
+ self.assertEqual(st._url, self.db._url)
+
+ def test_get_sync_target_inherits_oauth_credentials(self):
+ self.db.set_oauth_credentials(tests.consumer1.key,
+ tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ st = self.db.get_sync_target()
+ self.assertEqual(self.db._creds, st._creds)
+
+
+class TestHTTPDatabaseCtrWithCreds(tests.TestCase):
+
+ def test_ctr_with_creds(self):
+ db1 = http_database.HTTPDatabase('http://dbs/db', creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret
+ }})
+ self.assertIn('oauth', db1._creds)
+
+
+class TestHTTPDatabaseIntegration(tests.TestCaseWithServer):
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestHTTPDatabaseIntegration, self).setUp()
+ self.startServer()
+
+ def test_non_existing_db(self):
+ db = http_database.HTTPDatabase(self.getURL('not-there'))
+ self.assertRaises(errors.DatabaseDoesNotExist, db.get_doc, 'doc1')
+
+ def test__ensure(self):
+ db = http_database.HTTPDatabase(self.getURL('new'))
+ db._ensure()
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test__delete(self):
+ self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase(self.getURL('db0'))
+ db._delete()
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_open_database_existing(self):
+ self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase.open_database(self.getURL('db0'),
+ create=False)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_open_database_non_existing(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ http_database.HTTPDatabase.open_database,
+ self.getURL('not-there'),
+ create=False)
+
+ def test_open_database_create(self):
+ db = http_database.HTTPDatabase.open_database(self.getURL('new'),
+ create=True)
+ self.assertIs(None, db.get_doc('doc1'))
+
+ def test_delete_database_existing(self):
+ self.request_state._create_database('db0')
+ http_database.HTTPDatabase.delete_database(self.getURL('db0'))
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ self.request_state.check_database, 'db0')
+
+ def test_doc_ids_needing_quoting(self):
+ db0 = self.request_state._create_database('db0')
+ db = http_database.HTTPDatabase.open_database(self.getURL('db0'),
+ create=False)
+ doc = Document('%fff', None, '{}')
+ db.put_doc(doc)
+ self.assertGetDoc(db0, '%fff', doc.rev, '{}', False)
+ self.assertGetDoc(db, '%fff', doc.rev, '{}', False)
diff --git a/src/leap/soledad/tests/u1db_tests/test_https.py b/src/leap/soledad/tests/u1db_tests/test_https.py
new file mode 100644
index 00000000..3f8797d8
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_https.py
@@ -0,0 +1,117 @@
+"""Test support for client-side https support."""
+
+import os
+import ssl
+import sys
+
+from paste import httpserver
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.remote import (
+ http_client,
+ http_target,
+)
+
+from leap.soledad.tests.u1db_tests.test_remote_sync_target import (
+ make_oauth_http_app,
+)
+
+
+def https_server_def():
+ def make_server(host_port, application):
+ from OpenSSL import SSL
+ cert_file = os.path.join(os.path.dirname(__file__), 'testing-certs',
+ 'testing.cert')
+ key_file = os.path.join(os.path.dirname(__file__), 'testing-certs',
+ 'testing.key')
+ ssl_context = SSL.Context(SSL.SSLv23_METHOD)
+ ssl_context.use_privatekey_file(key_file)
+ ssl_context.use_certificate_chain_file(cert_file)
+ srv = httpserver.WSGIServerBase(application, host_port,
+ httpserver.WSGIHandler,
+ ssl_context=ssl_context
+ )
+
+ def shutdown_request(req):
+ req.shutdown()
+ srv.close_request(req)
+
+ srv.shutdown_request = shutdown_request
+ application.base_url = "https://localhost:%s" % srv.server_address[1]
+ return srv
+ return make_server, "shutdown", "https"
+
+
+def oauth_https_sync_target(test, host, path):
+ _, port = test.server.server_address
+ st = http_target.HTTPSyncTarget('https://%s:%d/~/%s' % (host, port, path))
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestHttpSyncTargetHttpsSupport(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('oauth_https', {'server_def': https_server_def,
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test':
+ tests.make_document_for_test,
+ 'sync_target': oauth_https_sync_target
+ }),
+ ]
+
+ def setUp(self):
+ try:
+ import OpenSSL # noqa
+ except ImportError:
+ self.skipTest("Requires pyOpenSSL")
+ self.cacert_pem = os.path.join(os.path.dirname(__file__),
+ 'testing-certs', 'cacert.pem')
+ super(TestHttpSyncTargetHttpsSupport, self).setUp()
+
+ def getSyncTarget(self, host, path=None):
+ if self.server is None:
+ self.startServer()
+ return self.sync_target(self, host, path)
+
+ def test_working(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ self.patch(http_client, 'CA_CERTS', self.cacert_pem)
+ remote_target = self.getSyncTarget('localhost', 'test')
+ remote_target.record_sync_info('other-id', 2, 'T-id')
+ self.assertEqual(
+ (2, 'T-id'), db._get_replica_gen_and_trans_id('other-id'))
+
+ def test_cannot_verify_cert(self):
+ if not sys.platform.startswith('linux'):
+ self.skipTest(
+ "XXX certificate verification happens on linux only for now")
+ self.startServer()
+ # don't print expected traceback server-side
+ self.server.handle_error = lambda req, cli_addr: None
+ self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('localhost', 'test')
+ try:
+ remote_target.record_sync_info('other-id', 2, 'T-id')
+ except ssl.SSLError, e:
+ self.assertIn("certificate verify failed", str(e))
+ else:
+ self.fail("certificate verification should have failed.")
+
+ def test_host_mismatch(self):
+ if not sys.platform.startswith('linux'):
+ self.skipTest(
+ "XXX certificate verification happens on linux only for now")
+ self.startServer()
+ self.request_state._create_database('test')
+ self.patch(http_client, 'CA_CERTS', self.cacert_pem)
+ remote_target = self.getSyncTarget('127.0.0.1', 'test')
+ self.assertRaises(
+ http_client.CertificateError, remote_target.record_sync_info,
+ 'other-id', 2, 'T-id')
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/tests/u1db_tests/test_open.py b/src/leap/soledad/tests/u1db_tests/test_open.py
new file mode 100644
index 00000000..0ff307e8
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_open.py
@@ -0,0 +1,69 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test u1db.open"""
+
+import os
+
+from u1db import (
+ errors,
+ open as u1db_open,
+)
+from leap.soledad.tests import u1db_tests as tests
+from u1db.backends import sqlite_backend
+from leap.soledad.tests.u1db_tests.test_backends import TestAlternativeDocument
+
+
+class TestU1DBOpen(tests.TestCase):
+
+ def setUp(self):
+ super(TestU1DBOpen, self).setUp()
+ tmpdir = self.createTempDir()
+ self.db_path = tmpdir + '/test.db'
+
+ def test_open_no_create(self):
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ u1db_open, self.db_path, create=False)
+ self.assertFalse(os.path.exists(self.db_path))
+
+ def test_open_create(self):
+ db = u1db_open(self.db_path, create=True)
+ self.addCleanup(db.close)
+ self.assertTrue(os.path.exists(self.db_path))
+ self.assertIsInstance(db, sqlite_backend.SQLiteDatabase)
+
+ def test_open_with_factory(self):
+ db = u1db_open(self.db_path, create=True,
+ document_factory=TestAlternativeDocument)
+ self.addCleanup(db.close)
+ self.assertEqual(TestAlternativeDocument, db._factory)
+
+ def test_open_existing(self):
+ db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path)
+ self.addCleanup(db.close)
+ doc = db.create_doc_from_json(tests.simple_doc)
+ # Even though create=True, we shouldn't wipe the db
+ db2 = u1db_open(self.db_path, create=True)
+ self.addCleanup(db2.close)
+ doc2 = db2.get_doc(doc.doc_id)
+ self.assertEqual(doc, doc2)
+
+ def test_open_existing_no_create(self):
+ db = sqlite_backend.SQLitePartialExpandDatabase(self.db_path)
+ self.addCleanup(db.close)
+ db2 = u1db_open(self.db_path, create=False)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
diff --git a/src/leap/soledad/tests/u1db_tests/test_remote_sync_target.py b/src/leap/soledad/tests/u1db_tests/test_remote_sync_target.py
new file mode 100644
index 00000000..66d404d2
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_remote_sync_target.py
@@ -0,0 +1,317 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for the remote sync targets"""
+
+import cStringIO
+
+from u1db import (
+ errors,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.remote import (
+ http_app,
+ http_target,
+ oauth_middleware,
+)
+
+
+class TestHTTPSyncTargetBasics(tests.TestCase):
+
+ def test_parse_url(self):
+ remote_target = http_target.HTTPSyncTarget('http://127.0.0.1:12345/')
+ self.assertEqual('http', remote_target._url.scheme)
+ self.assertEqual('127.0.0.1', remote_target._url.hostname)
+ self.assertEqual(12345, remote_target._url.port)
+ self.assertEqual('/', remote_target._url.path)
+
+
+class TestParsingSyncStream(tests.TestCase):
+
+ def test_wrong_start(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "{}\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "\r\n{}\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "", None)
+
+ def test_wrong_end(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{}", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n", None)
+
+ def test_missing_comma(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{}\r\n{"id": "i", "rev": "r", '
+ '"content": "c", "gen": 3}\r\n]', None)
+
+ def test_no_entries(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n]", None)
+
+ def test_extra_comma(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream, "[\r\n{},\r\n]", None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{},\r\n{"id": "i", "rev": "r", '
+ '"content": "{}", "gen": 3, "trans_id": "T-sid"}'
+ ',\r\n]',
+ lambda doc, gen, trans_id: None)
+
+ def test_error_in_stream(self):
+ tgt = http_target.HTTPSyncTarget("http://foo/foo")
+
+ self.assertRaises(errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"new_generation": 0},'
+ '\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(errors.Unavailable,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "unavailable"}\r\n', None)
+
+ self.assertRaises(errors.BrokenSyncStream,
+ tgt._parse_sync_stream,
+ '[\r\n{"error": "?"}\r\n', None)
+
+
+def make_http_app(state):
+ return http_app.HTTPApp(state)
+
+
+def http_sync_target(test, path):
+ return http_target.HTTPSyncTarget(test.getURL(path))
+
+
+def make_oauth_http_app(state):
+ app = http_app.HTTPApp(state)
+ application = oauth_middleware.OAuthMiddleware(app, None, prefix='/~/')
+ application.get_oauth_data_store = lambda: tests.testingOAuthStore
+ return application
+
+
+def oauth_http_sync_target(test, path):
+ st = http_sync_target(test, '~/' + path)
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return st
+
+
+class TestRemoteSyncTargets(tests.TestCaseWithServer):
+
+ scenarios = [
+ ('http', {'make_app_with_state': make_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': http_sync_target}),
+ ('oauth_http', {'make_app_with_state': make_oauth_http_app,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'sync_target': oauth_http_sync_target}),
+ ]
+
+ def getSyncTarget(self, path=None):
+ if self.server is None:
+ self.startServer()
+ return self.sync_target(self, path)
+
+ def test_get_sync_info(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ db._set_replica_gen_and_trans_id('other-id', 1, 'T-transid')
+ remote_target = self.getSyncTarget('test')
+ self.assertEqual(('test', 0, '', 1, 'T-transid'),
+ remote_target.get_sync_info('other-id'))
+
+ def test_record_sync_info(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('test')
+ remote_target.record_sync_info('other-id', 2, 'T-transid')
+ self.assertEqual(
+ (2, 'T-transid'), db._get_replica_gen_and_trans_id('other-id'))
+
+ def test_sync_exchange_send(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ remote_target = self.getSyncTarget('test')
+ other_docs = []
+
+ def receive_doc(doc):
+ other_docs.append((doc.doc_id, doc.rev, doc.get_json()))
+
+ doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc, 10, 'T-sid')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc)
+ self.assertEqual(1, new_gen)
+ self.assertGetDoc(
+ db, 'doc-here', 'replica:1', '{"value": "here"}', False)
+
+ def test_sync_exchange_send_failure_and_retry_scenario(self):
+ self.startServer()
+
+ def blackhole_getstderr(inst):
+ return cStringIO.StringIO()
+
+ self.patch(self.server.RequestHandlerClass, 'get_stderr',
+ blackhole_getstderr)
+ db = self.request_state._create_database('test')
+ _put_doc_if_newer = db._put_doc_if_newer
+ trigger_ids = ['doc-here2']
+
+ def bomb_put_doc_if_newer(doc, save_conflict,
+ replica_uid=None, replica_gen=None,
+ replica_trans_id=None):
+ if doc.doc_id in trigger_ids:
+ raise Exception
+ return _put_doc_if_newer(doc, save_conflict=save_conflict,
+ replica_uid=replica_uid,
+ replica_gen=replica_gen,
+ replica_trans_id=replica_trans_id)
+ self.patch(db, '_put_doc_if_newer', bomb_put_doc_if_newer)
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ doc1 = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ doc2 = self.make_document('doc-here2', 'replica:1',
+ '{"value": "here2"}')
+ self.assertRaises(
+ errors.HTTPError,
+ remote_target.sync_exchange,
+ [(doc1, 10, 'T-sid'), (doc2, 11, 'T-sud')],
+ 'replica', last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertGetDoc(db, 'doc-here', 'replica:1', '{"value": "here"}',
+ False)
+ self.assertEqual(
+ (10, 'T-sid'), db._get_replica_gen_and_trans_id('replica'))
+ self.assertEqual([], other_changes)
+ # retry
+ trigger_ids = []
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc2, 11, 'T-sud')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc)
+ self.assertGetDoc(db, 'doc-here2', 'replica:1', '{"value": "here2"}',
+ False)
+ self.assertEqual(
+ (11, 'T-sud'), db._get_replica_gen_and_trans_id('replica'))
+ self.assertEqual(2, new_gen)
+ # bounced back to us
+ self.assertEqual(
+ ('doc-here', 'replica:1', '{"value": "here"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_in_stream_error(self):
+ self.startServer()
+
+ def blackhole_getstderr(inst):
+ return cStringIO.StringIO()
+
+ self.patch(self.server.RequestHandlerClass, 'get_stderr',
+ blackhole_getstderr)
+ db = self.request_state._create_database('test')
+ doc = db.create_doc_from_json('{"value": "there"}')
+
+ def bomb_get_docs(doc_ids, check_for_conflicts=None,
+ include_deleted=False):
+ yield doc
+ # delayed failure case
+ raise errors.Unavailable
+
+ self.patch(db, 'get_docs', bomb_get_docs)
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ self.assertRaises(
+ errors.Unavailable, remote_target.sync_exchange, [], 'replica',
+ last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, '{"value": "there"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_receive(self):
+ self.startServer()
+ db = self.request_state._create_database('test')
+ doc = db.create_doc_from_json('{"value": "there"}')
+ remote_target = self.getSyncTarget('test')
+ other_changes = []
+
+ def receive_doc(doc, gen, trans_id):
+ other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ new_gen, trans_id = remote_target.sync_exchange(
+ [], 'replica', last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=receive_doc)
+ self.assertEqual(1, new_gen)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, '{"value": "there"}', 1),
+ other_changes[0][:-1])
+
+ def test_sync_exchange_send_ensure_callback(self):
+ self.startServer()
+ remote_target = self.getSyncTarget('test')
+ other_docs = []
+ replica_uid_box = []
+
+ def receive_doc(doc):
+ other_docs.append((doc.doc_id, doc.rev, doc.get_json()))
+
+ def ensure_cb(replica_uid):
+ replica_uid_box.append(replica_uid)
+
+ doc = self.make_document('doc-here', 'replica:1', '{"value": "here"}')
+ new_gen, trans_id = remote_target.sync_exchange(
+ [(doc, 10, 'T-sid')], 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=receive_doc,
+ ensure_callback=ensure_cb)
+ self.assertEqual(1, new_gen)
+ db = self.request_state.open_database('test')
+ self.assertEqual(1, len(replica_uid_box))
+ self.assertEqual(db._replica_uid, replica_uid_box[0])
+ self.assertGetDoc(
+ db, 'doc-here', 'replica:1', '{"value": "here"}', False)
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/tests/u1db_tests/test_sqlite_backend.py b/src/leap/soledad/tests/u1db_tests/test_sqlite_backend.py
new file mode 100644
index 00000000..2003da03
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_sqlite_backend.py
@@ -0,0 +1,494 @@
+# Copyright 2011 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Test sqlite backend internals."""
+
+import os
+import time
+import threading
+
+from sqlite3 import dbapi2
+
+from u1db import (
+ errors,
+ query_parser,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.backends import sqlite_backend
+from leap.soledad.tests.u1db_tests.test_backends import TestAlternativeDocument
+
+
+simple_doc = '{"key": "value"}'
+nested_doc = '{"key": "value", "sub": {"doc": "underneath"}}'
+
+
+class TestSQLiteDatabase(tests.TestCase):
+
+ def test_atomic_initialize(self):
+ tmpdir = self.createTempDir()
+ dbname = os.path.join(tmpdir, 'atomic.db')
+
+ t2 = None # will be a thread
+
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ _index_storage_value = "testing"
+
+ def __init__(self, dbname, ntry):
+ self._try = ntry
+ self._is_initialized_invocations = 0
+ super(SQLiteDatabaseTesting, self).__init__(dbname)
+
+ def _is_initialized(self, c):
+ res = super(SQLiteDatabaseTesting, self)._is_initialized(c)
+ if self._try == 1:
+ self._is_initialized_invocations += 1
+ if self._is_initialized_invocations == 2:
+ t2.start()
+ # hard to do better and have a generic test
+ time.sleep(0.05)
+ return res
+
+ outcome2 = []
+
+ def second_try():
+ try:
+ db2 = SQLiteDatabaseTesting(dbname, 2)
+ except Exception, e:
+ outcome2.append(e)
+ else:
+ outcome2.append(db2)
+
+ t2 = threading.Thread(target=second_try)
+ db1 = SQLiteDatabaseTesting(dbname, 1)
+ t2.join()
+
+ self.assertIsInstance(outcome2[0], SQLiteDatabaseTesting)
+ db2 = outcome2[0]
+ self.assertTrue(db2._is_initialized(db1._get_sqlite_handle().cursor()))
+
+
+class TestSQLitePartialExpandDatabase(tests.TestCase):
+
+ def setUp(self):
+ super(TestSQLitePartialExpandDatabase, self).setUp()
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.db._set_replica_uid('test')
+
+ def test_create_database(self):
+ raw_db = self.db._get_sqlite_handle()
+ self.assertNotEqual(None, raw_db)
+
+ def test_default_replica_uid(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.assertIsNot(None, self.db._replica_uid)
+ self.assertEqual(32, len(self.db._replica_uid))
+ int(self.db._replica_uid, 16)
+
+ def test__close_sqlite_handle(self):
+ raw_db = self.db._get_sqlite_handle()
+ self.db._close_sqlite_handle()
+ self.assertRaises(dbapi2.ProgrammingError,
+ raw_db.cursor)
+
+ def test_create_database_initializes_schema(self):
+ raw_db = self.db._get_sqlite_handle()
+ c = raw_db.cursor()
+ c.execute("SELECT * FROM u1db_config")
+ config = dict([(r[0], r[1]) for r in c.fetchall()])
+ self.assertEqual({'sql_schema': '0', 'replica_uid': 'test',
+ 'index_storage': 'expand referenced'}, config)
+
+ # These tables must exist, though we don't care what is in them yet
+ c.execute("SELECT * FROM transaction_log")
+ c.execute("SELECT * FROM document")
+ c.execute("SELECT * FROM document_fields")
+ c.execute("SELECT * FROM sync_log")
+ c.execute("SELECT * FROM conflicts")
+ c.execute("SELECT * FROM index_definitions")
+
+ def test__parse_index(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ g = self.db._parse_index_definition('fieldname')
+ self.assertIsInstance(g, query_parser.ExtractField)
+ self.assertEqual(['fieldname'], g.field)
+
+ def test__update_indexes(self):
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ g = self.db._parse_index_definition('fieldname')
+ c = self.db._get_sqlite_handle().cursor()
+ self.db._update_indexes('doc-id', {'fieldname': 'val'},
+ [('fieldname', g)], c)
+ c.execute('SELECT doc_id, field_name, value FROM document_fields')
+ self.assertEqual([('doc-id', 'fieldname', 'val')],
+ c.fetchall())
+
+ def test__set_replica_uid(self):
+ # Start from scratch, so that replica_uid isn't set.
+ self.db = sqlite_backend.SQLitePartialExpandDatabase(':memory:')
+ self.assertIsNot(None, self.db._real_replica_uid)
+ self.assertIsNot(None, self.db._replica_uid)
+ self.db._set_replica_uid('foo')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT value FROM u1db_config WHERE name='replica_uid'")
+ self.assertEqual(('foo',), c.fetchone())
+ self.assertEqual('foo', self.db._real_replica_uid)
+ self.assertEqual('foo', self.db._replica_uid)
+ self.db._close_sqlite_handle()
+ self.assertEqual('foo', self.db._replica_uid)
+
+ def test__get_generation(self):
+ self.assertEqual(0, self.db._get_generation())
+
+ def test__get_generation_info(self):
+ self.assertEqual((0, ''), self.db._get_generation_info())
+
+ def test_create_index(self):
+ self.db.create_index('test-idx', "key")
+ self.assertEqual([('test-idx', ["key"])], self.db.list_indexes())
+
+ def test_create_index_multiple_fields(self):
+ self.db.create_index('test-idx', "key", "key2")
+ self.assertEqual([('test-idx', ["key", "key2"])],
+ self.db.list_indexes())
+
+ def test__get_index_definition(self):
+ self.db.create_index('test-idx', "key", "key2")
+ # TODO: How would you test that an index is getting used for an SQL
+ # request?
+ self.assertEqual(["key", "key2"],
+ self.db._get_index_definition('test-idx'))
+
+ def test_list_index_mixed(self):
+ # Make sure that we properly order the output
+ c = self.db._get_sqlite_handle().cursor()
+ # We intentionally insert the data in weird ordering, to make sure the
+ # query still gets it back correctly.
+ c.executemany("INSERT INTO index_definitions VALUES (?, ?, ?)",
+ [('idx-1', 0, 'key10'),
+ ('idx-2', 2, 'key22'),
+ ('idx-1', 1, 'key11'),
+ ('idx-2', 0, 'key20'),
+ ('idx-2', 1, 'key21')])
+ self.assertEqual([('idx-1', ['key10', 'key11']),
+ ('idx-2', ['key20', 'key21', 'key22'])],
+ self.db.list_indexes())
+
+ def test_no_indexes_no_document_fields(self):
+ self.db.create_doc_from_json(
+ '{"key1": "val1", "key2": "val2"}')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([], c.fetchall())
+
+ def test_create_extracts_fields(self):
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ doc2 = self.db.create_doc_from_json('{"key1": "valx", "key2": "valy"}')
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([], c.fetchall())
+ self.db.create_index('test', 'key1', 'key2')
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual(sorted(
+ [(doc1.doc_id, "key1", "val1"),
+ (doc1.doc_id, "key2", "val2"),
+ (doc2.doc_id, "key1", "valx"),
+ (doc2.doc_id, "key2", "valy"), ]), sorted(c.fetchall()))
+
+ def test_put_updates_fields(self):
+ self.db.create_index('test', 'key1', 'key2')
+ doc1 = self.db.create_doc_from_json(
+ '{"key1": "val1", "key2": "val2"}')
+ doc1.content = {"key1": "val1", "key2": "valy"}
+ self.db.put_doc(doc1)
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, "key1", "val1"),
+ (doc1.doc_id, "key2", "valy"), ], c.fetchall())
+
+ def test_put_updates_nested_fields(self):
+ self.db.create_index('test', 'key', 'sub.doc')
+ doc1 = self.db.create_doc_from_json(nested_doc)
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, "key", "value"),
+ (doc1.doc_id, "sub.doc", "underneath"), ],
+ c.fetchall())
+
+ def test__ensure_schema_rollback(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/rollback.db'
+
+ class SQLitePartialExpandDbTesting(
+ sqlite_backend.SQLitePartialExpandDatabase):
+
+ def _set_replica_uid_in_transaction(self, uid):
+ super(SQLitePartialExpandDbTesting,
+ self)._set_replica_uid_in_transaction(uid)
+ if fail:
+ raise Exception()
+
+ db = SQLitePartialExpandDbTesting.__new__(SQLitePartialExpandDbTesting)
+ db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
+ fail = True
+ self.assertRaises(Exception, db._ensure_schema)
+ fail = False
+ db._initialize(db._db_handle.cursor())
+
+ def test__open_database(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase._open_database(path)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test__open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/test.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase._open_database(
+ path, document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test__open_database_non_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase._open_database, path)
+
+ def test__open_database_during_init(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/initialised.db'
+ db = sqlite_backend.SQLitePartialExpandDatabase.__new__(
+ sqlite_backend.SQLitePartialExpandDatabase)
+ db._db_handle = dbapi2.connect(path) # db is there but not yet init-ed
+ self.addCleanup(db.close)
+ observed = []
+
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
+
+ @classmethod
+ def _which_index_storage(cls, c):
+ res = super(SQLiteDatabaseTesting, cls)._which_index_storage(c)
+ db._ensure_schema() # init db
+ observed.append(res[0])
+ return res
+
+ db2 = SQLiteDatabaseTesting._open_database(path)
+ self.addCleanup(db2.close)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+ self.assertEqual(
+ [None,
+ sqlite_backend.SQLitePartialExpandDatabase._index_storage_value],
+ observed)
+
+ def test__open_database_invalid(self):
+ class SQLiteDatabaseTesting(sqlite_backend.SQLiteDatabase):
+ WAIT_FOR_PARALLEL_INIT_HALF_INTERVAL = 0.1
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path1 = temp_dir + '/invalid1.db'
+ with open(path1, 'wb') as f:
+ f.write("")
+ self.assertRaises(dbapi2.OperationalError,
+ SQLiteDatabaseTesting._open_database, path1)
+ with open(path1, 'wb') as f:
+ f.write("invalid")
+ self.assertRaises(dbapi2.DatabaseError,
+ SQLiteDatabaseTesting._open_database, path1)
+
+ def test_open_database_existing(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_open_database_with_factory(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/existing.sqlite'
+ sqlite_backend.SQLitePartialExpandDatabase(path)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(
+ path, create=False, document_factory=TestAlternativeDocument)
+ self.assertEqual(TestAlternativeDocument, db2._factory)
+
+ def test_open_database_create(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/new.sqlite'
+ sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db2 = sqlite_backend.SQLiteDatabase.open_database(path, create=False)
+ self.assertIsInstance(db2, sqlite_backend.SQLitePartialExpandDatabase)
+
+ def test_open_database_non_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.open_database, path,
+ create=False)
+
+ def test_delete_database_existent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/new.sqlite'
+ db = sqlite_backend.SQLiteDatabase.open_database(path, create=True)
+ db.close()
+ sqlite_backend.SQLiteDatabase.delete_database(path)
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.open_database, path,
+ create=False)
+
+ def test_delete_database_nonexistent(self):
+ temp_dir = self.createTempDir(prefix='u1db-test-')
+ path = temp_dir + '/non-existent.sqlite'
+ self.assertRaises(errors.DatabaseDoesNotExist,
+ sqlite_backend.SQLiteDatabase.delete_database, path)
+
+ def test__get_indexed_fields(self):
+ self.db.create_index('idx1', 'a', 'b')
+ self.assertEqual(set(['a', 'b']), self.db._get_indexed_fields())
+ self.db.create_index('idx2', 'b', 'c')
+ self.assertEqual(set(['a', 'b', 'c']), self.db._get_indexed_fields())
+
+ def test_indexed_fields_expanded(self):
+ self.db.create_index('idx1', 'key1')
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ self.assertEqual(set(['key1']), self.db._get_indexed_fields())
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall())
+
+ def test_create_index_updates_fields(self):
+ doc1 = self.db.create_doc_from_json('{"key1": "val1", "key2": "val2"}')
+ self.db.create_index('idx1', 'key1')
+ self.assertEqual(set(['key1']), self.db._get_indexed_fields())
+ c = self.db._get_sqlite_handle().cursor()
+ c.execute("SELECT doc_id, field_name, value FROM document_fields"
+ " ORDER BY doc_id, field_name, value")
+ self.assertEqual([(doc1.doc_id, 'key1', 'val1')], c.fetchall())
+
+ def assertFormatQueryEquals(self, exp_statement, exp_args, definition,
+ values):
+ statement, args = self.db._format_query(definition, values)
+ self.assertEqual(exp_statement, statement)
+ self.assertEqual(exp_args, args)
+
+ def test__format_query(self):
+ self.assertFormatQueryEquals(
+ "SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM "
+ "document d, document_fields d0 LEFT OUTER JOIN conflicts c ON "
+ "c.doc_id = d.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name "
+ "= ? AND d0.value = ? GROUP BY d.doc_id, d.doc_rev, d.content "
+ "ORDER BY d0.value;", ["key1", "a"],
+ ["key1"], ["a"])
+
+ def test__format_query2(self):
+ self.assertFormatQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value = ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value = ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ["key1", "a", "key2", "b", "key3", "c"],
+ ["key1", "key2", "key3"], ["a", "b", "c"])
+
+ def test__format_query_wildcard(self):
+ self.assertFormatQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value = ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value GLOB ? AND d.doc_id = d2.doc_id AND d2.field_name = ? '
+ 'AND d2.value NOT NULL GROUP BY d.doc_id, d.doc_rev, d.content '
+ 'ORDER BY d0.value, d1.value, d2.value;',
+ ["key1", "a", "key2", "b*", "key3"], ["key1", "key2", "key3"],
+ ["a", "b*", "*"])
+
+ def assertFormatRangeQueryEquals(self, exp_statement, exp_args, definition,
+ start_value, end_value):
+ statement, args = self.db._format_range_query(
+ definition, start_value, end_value)
+ self.assertEqual(exp_statement, statement)
+ self.assertEqual(exp_args, args)
+
+ def test__format_range_query(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value >= ? AND d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c', 'key1', 'p', 'key2', 'q',
+ 'key3', 'r'],
+ ["key1", "key2", "key3"], ["a", "b", "c"], ["p", "q", "r"])
+
+ def test__format_range_query_no_start(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value <= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value <= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c'],
+ ["key1", "key2", "key3"], None, ["a", "b", "c"])
+
+ def test__format_range_query_no_end(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value >= ? GROUP BY d.doc_id, d.doc_rev, d.content ORDER BY '
+ 'd0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'c'],
+ ["key1", "key2", "key3"], ["a", "b", "c"], None)
+
+ def test__format_range_query_wildcard(self):
+ self.assertFormatRangeQueryEquals(
+ 'SELECT d.doc_id, d.doc_rev, d.content, count(c.doc_rev) FROM '
+ 'document d, document_fields d0, document_fields d1, '
+ 'document_fields d2 LEFT OUTER JOIN conflicts c ON c.doc_id = '
+ 'd.doc_id WHERE d.doc_id = d0.doc_id AND d0.field_name = ? AND '
+ 'd0.value >= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? AND '
+ 'd1.value >= ? AND d.doc_id = d2.doc_id AND d2.field_name = ? AND '
+ 'd2.value NOT NULL AND d.doc_id = d0.doc_id AND d0.field_name = ? '
+ 'AND d0.value <= ? AND d.doc_id = d1.doc_id AND d1.field_name = ? '
+ 'AND (d1.value < ? OR d1.value GLOB ?) AND d.doc_id = d2.doc_id '
+ 'AND d2.field_name = ? AND d2.value NOT NULL GROUP BY d.doc_id, '
+ 'd.doc_rev, d.content ORDER BY d0.value, d1.value, d2.value;',
+ ['key1', 'a', 'key2', 'b', 'key3', 'key1', 'p', 'key2', 'q', 'q*',
+ 'key3'],
+ ["key1", "key2", "key3"], ["a", "b*", "*"], ["p", "q*", "*"])
diff --git a/src/leap/soledad/tests/u1db_tests/test_sync.py b/src/leap/soledad/tests/u1db_tests/test_sync.py
new file mode 100644
index 00000000..96aa2736
--- /dev/null
+++ b/src/leap/soledad/tests/u1db_tests/test_sync.py
@@ -0,0 +1,1242 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""The Synchronization class for U1DB."""
+
+import os
+from wsgiref import simple_server
+
+from u1db import (
+ errors,
+ sync,
+ vectorclock,
+ SyncTarget,
+)
+
+from leap.soledad.tests import u1db_tests as tests
+
+from u1db.backends import (
+ inmemory,
+)
+from u1db.remote import (
+ http_target,
+)
+
+from leap.soledad.tests.u1db_tests.test_remote_sync_target import (
+ make_http_app,
+ make_oauth_http_app,
+)
+
+simple_doc = tests.simple_doc
+nested_doc = tests.nested_doc
+
+
+def _make_local_db_and_target(test):
+ db = test.create_database('test')
+ st = db.get_sync_target()
+ return db, st
+
+
+def _make_local_db_and_http_target(test, path='test'):
+ test.startServer()
+ db = test.request_state._create_database(os.path.basename(path))
+ st = http_target.HTTPSyncTarget.connect(test.getURL(path))
+ return db, st
+
+
+def _make_local_db_and_oauth_http_target(test):
+ db, st = _make_local_db_and_http_target(test, '~/test')
+ st.set_oauth_credentials(tests.consumer1.key, tests.consumer1.secret,
+ tests.token1.key, tests.token1.secret)
+ return db, st
+
+
+target_scenarios = [
+ ('local', {'create_db_and_target': _make_local_db_and_target}),
+ ('http', {'create_db_and_target': _make_local_db_and_http_target,
+ 'make_app_with_state': make_http_app}),
+ ('oauth_http', {'create_db_and_target':
+ _make_local_db_and_oauth_http_target,
+ 'make_app_with_state': make_oauth_http_app}),
+]
+
+
+class DatabaseSyncTargetTests(tests.DatabaseBaseTests,
+ tests.TestCaseWithServer):
+
+ scenarios = (tests.multiply_scenarios(tests.DatabaseBaseTests.scenarios,
+ target_scenarios))
+ #+ c_db_scenarios)
+ # whitebox true means self.db is the actual local db object
+ # against which the sync is performed
+ whitebox = True
+
+ def setUp(self):
+ super(DatabaseSyncTargetTests, self).setUp()
+ self.db, self.st = self.create_db_and_target(self)
+ self.other_changes = []
+
+ def tearDown(self):
+ # We delete them explicitly, so that connections are cleanly closed
+ del self.st
+ self.db.close()
+ del self.db
+ super(DatabaseSyncTargetTests, self).tearDown()
+
+ def receive_doc(self, doc, gen, trans_id):
+ self.other_changes.append(
+ (doc.doc_id, doc.rev, doc.get_json(), gen, trans_id))
+
+ def set_trace_hook(self, callback, shallow=False):
+ setter = (self.st._set_trace_hook if not shallow else
+ self.st._set_trace_hook_shallow)
+ try:
+ setter(callback)
+ except NotImplementedError:
+ self.skipTest("%s does not implement _set_trace_hook"
+ % (self.st.__class__.__name__,))
+
+ def test_get_sync_target(self):
+ self.assertIsNot(None, self.st)
+
+ def test_get_sync_info(self):
+ self.assertEqual(
+ ('test', 0, '', 0, ''), self.st.get_sync_info('other'))
+
+ def test_create_doc_updates_sync_info(self):
+ self.assertEqual(
+ ('test', 0, '', 0, ''), self.st.get_sync_info('other'))
+ self.db.create_doc_from_json(simple_doc)
+ self.assertEqual(1, self.st.get_sync_info('other')[1])
+
+ def test_record_sync_info(self):
+ self.st.record_sync_info('replica', 10, 'T-transid')
+ self.assertEqual(
+ ('test', 0, '', 10, 'T-transid'), self.st.get_sync_info('replica'))
+
+ def test_sync_exchange(self):
+ docs_by_gen = [
+ (self.make_document('doc-id', 'replica:1', simple_doc), 10,
+ 'T-sid')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False)
+ self.assertTransactionLog(['doc-id'], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 1, last_trans_id),
+ (self.other_changes, new_gen, last_trans_id))
+ self.assertEqual(10, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_deleted(self):
+ doc = self.db.create_doc_from_json('{}')
+ edit_rev = 'replica:1|' + doc.rev
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, edit_rev, None), 10, 'T-sid')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDocIncludeDeleted(
+ self.db, doc.doc_id, edit_rev, None, False)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 2, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(10, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_push_many(self):
+ docs_by_gen = [
+ (self.make_document('doc-id', 'replica:1', simple_doc), 10, 'T-1'),
+ (self.make_document('doc-id2', 'replica:1', nested_doc), 11,
+ 'T-2')]
+ new_gen, trans_id = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, 'doc-id', 'replica:1', simple_doc, False)
+ self.assertGetDoc(self.db, 'doc-id2', 'replica:1', nested_doc, False)
+ self.assertTransactionLog(['doc-id', 'doc-id2'], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 2, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(11, self.st.get_sync_info('replica')[3])
+
+ def test_sync_exchange_refuses_conflicts(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'replica:1', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1])
+ self.assertEqual(1, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_ignores_convergence(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ gen, txid = self.db._get_generation_info()
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, doc.rev, simple_doc), 10, 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'replica', last_known_generation=gen,
+ last_known_trans_id=txid, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(([], 1), (self.other_changes, new_gen))
+
+ def test_sync_exchange_returns_new_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, simple_doc, 1), self.other_changes[0][:-1])
+ self.assertEqual(1, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 1, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_returns_deleted_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.db.delete_doc(doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ self.assertEqual(
+ (doc.doc_id, doc.rev, None, 2), self.other_changes[0][:-1])
+ self.assertEqual(2, new_gen)
+ if self.whitebox:
+ self.assertEqual(self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs': [(doc.doc_id, doc.rev)]})
+
+ def test_sync_exchange_returns_many_new_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ doc2 = self.db.create_doc_from_json(nested_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ new_gen, _ = self.st.sync_exchange(
+ [], 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc2.doc_id], self.db)
+ self.assertEqual(2, new_gen)
+ self.assertEqual(
+ [(doc.doc_id, doc.rev, simple_doc, 1),
+ (doc2.doc_id, doc2.rev, nested_doc, 2)],
+ [c[:-1] for c in self.other_changes])
+ if self.whitebox:
+ self.assertEqual(
+ self.db._last_exchange_log['return'],
+ {'last_gen': 2, 'docs':
+ [(doc.doc_id, doc.rev), (doc2.doc_id, doc2.rev)]})
+
+ def test_sync_exchange_getting_newer_docs(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertTransactionLog([doc.doc_id, doc.doc_id], self.db)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_with_concurrent_updates_of_synced_doc(self):
+ expected = []
+
+ def before_whatschanged_cb(state):
+ if state != 'before whats_changed':
+ return
+ cont = '{"key": "cuncurrent"}'
+ conc_rev = self.db.put_doc(
+ self.make_document(doc.doc_id, 'test:1|z:2', cont))
+ expected.append((doc.doc_id, conc_rev, cont, 3))
+
+ self.set_trace_hook(before_whatschanged_cb)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(expected, [c[:-1] for c in self.other_changes])
+ self.assertEqual(3, new_gen)
+
+ def test_sync_exchange_with_concurrent_updates(self):
+
+ def after_whatschanged_cb(state):
+ if state != 'after whats_changed':
+ return
+ self.db.create_doc_from_json('{"new": "doc"}')
+
+ self.set_trace_hook(after_whatschanged_cb)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ new_doc = '{"key": "altval"}'
+ docs_by_gen = [
+ (self.make_document(doc.doc_id, 'test:1|z:2', new_doc), 10,
+ 'T-sid')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_converged_handling(self):
+ doc = self.db.create_doc_from_json(simple_doc)
+ docs_by_gen = [
+ (self.make_document('new', 'other:1', '{}'), 4, 'T-foo'),
+ (self.make_document(doc.doc_id, doc.rev, doc.get_json()), 5,
+ 'T-bar')]
+ new_gen, _ = self.st.sync_exchange(
+ docs_by_gen, 'other-replica', last_known_generation=0,
+ last_known_trans_id=None, return_doc_cb=self.receive_doc)
+ self.assertEqual(([], 2), (self.other_changes, new_gen))
+
+ def test_sync_exchange_detect_incomplete_exchange(self):
+ def before_get_docs_explode(state):
+ if state != 'before get_docs':
+ return
+ raise errors.U1DBError("fail")
+ self.set_trace_hook(before_get_docs_explode)
+ # suppress traceback printing in the wsgiref server
+ self.patch(simple_server.ServerHandler,
+ 'log_exception', lambda h, exc_info: None)
+ doc = self.db.create_doc_from_json(simple_doc)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ self.assertRaises(
+ (errors.U1DBError, errors.BrokenSyncStream),
+ self.st.sync_exchange, [], 'other-replica',
+ last_known_generation=0, last_known_trans_id=None,
+ return_doc_cb=self.receive_doc)
+
+ def test_sync_exchange_doc_ids(self):
+ sync_exchange_doc_ids = getattr(self.st, 'sync_exchange_doc_ids', None)
+ if sync_exchange_doc_ids is None:
+ self.skipTest("sync_exchange_doc_ids not implemented")
+ db2 = self.create_database('test2')
+ doc = db2.create_doc_from_json(simple_doc)
+ new_gen, trans_id = sync_exchange_doc_ids(
+ db2, [(doc.doc_id, 10, 'T-sid')], 0, None,
+ return_doc_cb=self.receive_doc)
+ self.assertGetDoc(self.db, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertTransactionLog([doc.doc_id], self.db)
+ last_trans_id = self.getLastTransId(self.db)
+ self.assertEqual(([], 1, last_trans_id),
+ (self.other_changes, new_gen, trans_id))
+ self.assertEqual(10, self.st.get_sync_info(db2._replica_uid)[3])
+
+ def test__set_trace_hook(self):
+ called = []
+
+ def cb(state):
+ called.append(state)
+
+ self.set_trace_hook(cb)
+ self.st.sync_exchange([], 'replica', 0, None, self.receive_doc)
+ self.st.record_sync_info('replica', 0, 'T-sid')
+ self.assertEqual(['before whats_changed',
+ 'after whats_changed',
+ 'before get_docs',
+ 'record_sync_info',
+ ],
+ called)
+
+ def test__set_trace_hook_shallow(self):
+ if (self.st._set_trace_hook_shallow == self.st._set_trace_hook
+ or
+ self.st._set_trace_hook_shallow.im_func ==
+ SyncTarget._set_trace_hook_shallow.im_func):
+ # shallow same as full
+ expected = ['before whats_changed',
+ 'after whats_changed',
+ 'before get_docs',
+ 'record_sync_info',
+ ]
+ else:
+ expected = ['sync_exchange', 'record_sync_info']
+
+ called = []
+
+ def cb(state):
+ called.append(state)
+
+ self.set_trace_hook(cb, shallow=True)
+ self.st.sync_exchange([], 'replica', 0, None, self.receive_doc)
+ self.st.record_sync_info('replica', 0, 'T-sid')
+ self.assertEqual(expected, called)
+
+
+def sync_via_synchronizer(test, db_source, db_target, trace_hook=None,
+ trace_hook_shallow=None):
+ target = db_target.get_sync_target()
+ trace_hook = trace_hook or trace_hook_shallow
+ if trace_hook:
+ target._set_trace_hook(trace_hook)
+ return sync.Synchronizer(db_source, target).sync()
+
+
+sync_scenarios = []
+for name, scenario in tests.LOCAL_DATABASES_SCENARIOS:
+ scenario = dict(scenario)
+ scenario['do_sync'] = sync_via_synchronizer
+ sync_scenarios.append((name, scenario))
+ scenario = dict(scenario)
+
+
+def make_database_for_http_test(test, replica_uid):
+ if test.server is None:
+ test.startServer()
+ db = test.request_state._create_database(replica_uid)
+ try:
+ http_at = test._http_at
+ except AttributeError:
+ http_at = test._http_at = {}
+ http_at[db] = replica_uid
+ return db
+
+
+def copy_database_for_http_test(test, db):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES IS
+ # THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST THAT WE
+ # CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS RATHER THAN
+ # CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND NINJA TO YOUR HOUSE.
+ if test.server is None:
+ test.startServer()
+ new_db = test.request_state._copy_database(db)
+ try:
+ http_at = test._http_at
+ except AttributeError:
+ http_at = test._http_at = {}
+ path = db._replica_uid
+ while path in http_at.values():
+ path += 'copy'
+ http_at[new_db] = path
+ return new_db
+
+
+def sync_via_synchronizer_and_http(test, db_source, db_target,
+ trace_hook=None, trace_hook_shallow=None):
+ if trace_hook:
+ test.skipTest("full trace hook unsupported over http")
+ path = test._http_at[db_target]
+ target = http_target.HTTPSyncTarget.connect(test.getURL(path))
+ if trace_hook_shallow:
+ target._set_trace_hook_shallow(trace_hook_shallow)
+ return sync.Synchronizer(db_source, target).sync()
+
+
+sync_scenarios.append(('pyhttp', {
+ 'make_database_for_test': make_database_for_http_test,
+ 'copy_database_for_test': copy_database_for_http_test,
+ 'make_document_for_test': tests.make_document_for_test,
+ 'make_app_with_state': make_http_app,
+ 'do_sync': sync_via_synchronizer_and_http
+}))
+
+
+class DatabaseSyncTests(tests.DatabaseBaseTests,
+ tests.TestCaseWithServer):
+
+ scenarios = sync_scenarios
+ do_sync = None # set by scenarios
+
+ def create_database(self, replica_uid, sync_role=None):
+ if replica_uid == 'test' and sync_role is None:
+ # created up the chain by base class but unused
+ return None
+ db = self.create_database_for_role(replica_uid, sync_role)
+ if sync_role:
+ self._use_tracking[db] = (replica_uid, sync_role)
+ return db
+
+ def create_database_for_role(self, replica_uid, sync_role):
+ # hook point for reuse
+ return super(DatabaseSyncTests, self).create_database(replica_uid)
+
+ def copy_database(self, db, sync_role=None):
+ # DO NOT COPY OR REUSE THIS CODE OUTSIDE TESTS: COPYING U1DB DATABASES
+ # IS THE WRONG THING TO DO, THE ONLY REASON WE DO SO HERE IS TO TEST
+ # THAT WE CORRECTLY DETECT IT HAPPENING SO THAT WE CAN RAISE ERRORS
+ # RATHER THAN CORRUPT USER DATA. USE SYNC INSTEAD, OR WE WILL SEND
+ # NINJA TO YOUR HOUSE.
+ db_copy = super(DatabaseSyncTests, self).copy_database(db)
+ name, orig_sync_role = self._use_tracking[db]
+ self._use_tracking[db_copy] = (name + '(copy)', sync_role
+ or orig_sync_role)
+ return db_copy
+
+ def sync(self, db_from, db_to, trace_hook=None,
+ trace_hook_shallow=None):
+ from_name, from_sync_role = self._use_tracking[db_from]
+ to_name, to_sync_role = self._use_tracking[db_to]
+ if from_sync_role not in ('source', 'both'):
+ raise Exception("%s marked for %s use but used as source" %
+ (from_name, from_sync_role))
+ if to_sync_role not in ('target', 'both'):
+ raise Exception("%s marked for %s use but used as target" %
+ (to_name, to_sync_role))
+ return self.do_sync(self, db_from, db_to, trace_hook,
+ trace_hook_shallow)
+
+ def setUp(self):
+ self._use_tracking = {}
+ super(DatabaseSyncTests, self).setUp()
+
+ def assertLastExchangeLog(self, db, expected):
+ log = getattr(db, '_last_exchange_log', None)
+ if log is None:
+ return
+ self.assertEqual(expected, log)
+
+ def test_sync_tracks_db_generation_of_other(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.assertEqual(0, self.sync(self.db1, self.db2))
+ self.assertEqual(
+ (0, ''), self.db1._get_replica_gen_and_trans_id('test2'))
+ self.assertEqual(
+ (0, ''), self.db2._get_replica_gen_and_trans_id('test1'))
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [], 'last_known_gen': 0},
+ 'return':
+ {'docs': [], 'last_gen': 0}})
+
+ def test_sync_autoresolves(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ rev1 = doc1.rev
+ doc2 = self.db2.create_doc_from_json(simple_doc, doc_id='doc')
+ rev2 = doc2.rev
+ self.sync(self.db1, self.db2)
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual(doc.rev, self.db2.get_doc('doc').rev)
+ v = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev1)))
+ self.assertTrue(v.is_newer(vectorclock.VectorClockRev(rev2)))
+
+ def test_sync_autoresolves_moar(self):
+ # here we test that when a database that has a conflicted document is
+ # the source of a sync, and the target database has a revision of the
+ # conflicted document that is newer than the source database's, and
+ # that target's database's document's content is the same as the
+ # source's document's conflict's, the source's document's conflict gets
+ # autoresolved, and the source's document's revision bumped.
+ #
+ # idea is as follows:
+ # A B
+ # a1 -
+ # `------->
+ # a1 a1
+ # v v
+ # a2 a1b1
+ # `------->
+ # a1b1+a2 a1b1
+ # v
+ # a1b1+a2 a1b2 (a1b2 has same content as a2)
+ # `------->
+ # a3b2 a1b2 (autoresolved)
+ # `------->
+ # a3b2 a3b2
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db2)
+ # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertTrue(doc.has_conflicts)
+ # set db2 to have a doc of {} (same as db1 before the conflict)
+ doc = self.db2.get_doc('doc')
+ doc.set_json('{}')
+ self.db2.put_doc(doc)
+ rev2 = doc.rev
+ # sync it across
+ self.sync(self.db1, self.db2)
+ # tadaa!
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ vec1 = vectorclock.VectorClockRev(rev1)
+ vec2 = vectorclock.VectorClockRev(rev2)
+ vec3 = vectorclock.VectorClockRev(doc.rev)
+ self.assertTrue(vec3.is_newer(vec1))
+ self.assertTrue(vec3.is_newer(vec2))
+ # because the conflict is on the source, sync it another time
+ self.sync(self.db1, self.db2)
+ # make sure db2 now has the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_autoresolves_moar_backwards(self):
+ # here we test that when a database that has a conflicted document is
+ # the target of a sync, and the source database has a revision of the
+ # conflicted document that is newer than the target database's, and
+ # that source's database's document's content is the same as the
+ # target's document's conflict's, the target's document's conflict gets
+ # autoresolved, and the document's revision bumped.
+ #
+ # idea is as follows:
+ # A B
+ # a1 -
+ # `------->
+ # a1 a1
+ # v v
+ # a2 a1b1
+ # `------->
+ # a1b1+a2 a1b1
+ # v
+ # a1b1+a2 a1b2 (a1b2 has same content as a2)
+ # <-------'
+ # a3b2 a3b2 (autoresolved and propagated)
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ for db, content in [(self.db1, '{}'), (self.db2, '{"hi": 42}')]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db2)
+ # db1 and db2 now both have a doc of {hi:42}, but db1 has a conflict
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertTrue(doc.has_conflicts)
+ revc = self.db1.get_doc_conflicts('doc')[-1].rev
+ # set db2 to have a doc of {} (same as db1 before the conflict)
+ doc = self.db2.get_doc('doc')
+ doc.set_json('{}')
+ self.db2.put_doc(doc)
+ rev2 = doc.rev
+ # sync it across
+ self.sync(self.db2, self.db1)
+ # tadaa!
+ doc = self.db1.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ vec1 = vectorclock.VectorClockRev(rev1)
+ vec2 = vectorclock.VectorClockRev(rev2)
+ vec3 = vectorclock.VectorClockRev(doc.rev)
+ vecc = vectorclock.VectorClockRev(revc)
+ self.assertTrue(vec3.is_newer(vec1))
+ self.assertTrue(vec3.is_newer(vec2))
+ self.assertTrue(vec3.is_newer(vecc))
+ # make sure db2 now has the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_autoresolves_moar_backwards_three(self):
+ # same as autoresolves_moar_backwards, but with three databases (note
+ # all the syncs go in the same direction -- this is a more natural
+ # scenario):
+ #
+ # A B C
+ # a1 - -
+ # `------->
+ # a1 a1 -
+ # `------->
+ # a1 a1 a1
+ # v v
+ # a2 a1b1 a1
+ # `------------------->
+ # a2 a1b1 a2
+ # `------->
+ # a2+a1b1 a2
+ # v
+ # a2 a2+a1b1 a2c1 (same as a1b1)
+ # `------------------->
+ # a2c1 a2+a1b1 a2c1
+ # `------->
+ # a2b2c1 a2b2c1 a2c1
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ self.db3 = self.create_database('test3', 'target')
+ self.db1.create_doc_from_json(simple_doc, doc_id='doc')
+ self.sync(self.db1, self.db2)
+ self.sync(self.db2, self.db3)
+ for db, content in [(self.db2, '{"hi": 42}'),
+ (self.db1, '{}'),
+ ]:
+ doc = db.get_doc('doc')
+ doc.set_json(content)
+ db.put_doc(doc)
+ self.sync(self.db1, self.db3)
+ self.sync(self.db2, self.db3)
+ # db2 and db3 now both have a doc of {}, but db2 has a
+ # conflict
+ doc = self.db2.get_doc('doc')
+ self.assertTrue(doc.has_conflicts)
+ revc = self.db2.get_doc_conflicts('doc')[-1].rev
+ self.assertEqual('{}', doc.get_json())
+ self.assertEqual(self.db3.get_doc('doc').get_json(), doc.get_json())
+ self.assertEqual(self.db3.get_doc('doc').rev, doc.rev)
+ # set db3 to have a doc of {hi:42} (same as db2 before the conflict)
+ doc = self.db3.get_doc('doc')
+ doc.set_json('{"hi": 42}')
+ self.db3.put_doc(doc)
+ rev3 = doc.rev
+ # sync it across to db1
+ self.sync(self.db1, self.db3)
+ # db1 now has hi:42, with a rev that is newer than db2's doc
+ doc = self.db1.get_doc('doc')
+ rev1 = doc.rev
+ self.assertFalse(doc.has_conflicts)
+ self.assertEqual('{"hi": 42}', doc.get_json())
+ VCR = vectorclock.VectorClockRev
+ self.assertTrue(VCR(rev1).is_newer(VCR(self.db2.get_doc('doc').rev)))
+ # so sync it to db2
+ self.sync(self.db1, self.db2)
+ # tadaa!
+ doc = self.db2.get_doc('doc')
+ self.assertFalse(doc.has_conflicts)
+ # db2's revision of the document is strictly newer than db1's before
+ # the sync, and db3's before that sync way back when
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(rev1)))
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(rev3)))
+ self.assertTrue(VCR(doc.rev).is_newer(VCR(revc)))
+ # make sure both dbs now have the exact same thing
+ self.assertEqual(self.db1.get_doc('doc'), self.db2.get_doc('doc'))
+
+ def test_sync_puts_changes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ self.assertEqual(1, self.sync(self.db1, self.db2))
+ self.assertGetDoc(self.db2, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0])
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [(doc.doc_id, doc.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1,
+ 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 1}})
+
+ def test_sync_pulls_changes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db2.create_doc_from_json(simple_doc)
+ self.db1.create_index('test-idx', 'key')
+ self.assertEqual(0, self.sync(self.db1, self.db2))
+ self.assertGetDoc(self.db1, doc.doc_id, doc.rev, simple_doc, False)
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ self.assertEqual(1, self.db2._get_replica_gen_and_trans_id('test1')[0])
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [], 'last_known_gen': 0},
+ 'return':
+ {'docs': [(doc.doc_id, doc.rev)],
+ 'last_gen': 1}})
+ self.assertEqual([doc], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_pulling_doesnt_update_other_if_changed(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db2.create_doc_from_json(simple_doc)
+ # After the local side has sent its list of docs, before we start
+ # receiving the "targets" response, we update the local database with a
+ # new record.
+ # When we finish synchronizing, we can notice that something locally
+ # was updated, and we cannot tell c2 our new updated generation
+
+ def before_get_docs(state):
+ if state != 'before get_docs':
+ return
+ self.db1.create_doc_from_json(simple_doc)
+
+ self.assertEqual(0, self.sync(self.db1, self.db2,
+ trace_hook=before_get_docs))
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [], 'last_known_gen': 0},
+ 'return':
+ {'docs': [(doc.doc_id, doc.rev)],
+ 'last_gen': 1}})
+ self.assertEqual(1, self.db1._get_replica_gen_and_trans_id('test2')[0])
+ # c2 should not have gotten a '_record_sync_info' call, because the
+ # local database had been updated more than just by the messages
+ # returned from c2.
+ self.assertEqual(
+ (0, ''), self.db2._get_replica_gen_and_trans_id('test1'))
+
+ def test_sync_doesnt_update_other_if_nothing_pulled(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(simple_doc)
+
+ def no_record_sync_info(state):
+ if state != 'record_sync_info':
+ return
+ self.fail('SyncTarget.record_sync_info was called')
+ self.assertEqual(1, self.sync(self.db1, self.db2,
+ trace_hook_shallow=no_record_sync_info))
+ self.assertEqual(
+ 1,
+ self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid)[0])
+
+ def test_sync_ignores_convergence(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ self.db3 = self.create_database('test3', 'target')
+ self.assertEqual(1, self.sync(self.db1, self.db3))
+ self.assertEqual(0, self.sync(self.db2, self.db3))
+ self.assertEqual(1, self.sync(self.db1, self.db2))
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [(doc.doc_id, doc.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return': {'docs': [], 'last_gen': 1}})
+
+ def test_sync_ignores_superseded(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ doc_rev1 = doc.rev
+ self.db3 = self.create_database('test3', 'target')
+ self.sync(self.db1, self.db3)
+ self.sync(self.db2, self.db3)
+ new_content = '{"key": "altval"}'
+ doc.set_json(new_content)
+ self.db1.put_doc(doc)
+ doc_rev2 = doc.rev
+ self.sync(self.db2, self.db1)
+ self.assertLastExchangeLog(self.db1,
+ {'receive':
+ {'docs': [(doc.doc_id, doc_rev1)],
+ 'source_uid': 'test2',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return':
+ {'docs': [(doc.doc_id, doc_rev2)],
+ 'last_gen': 2}})
+ self.assertGetDoc(self.db1, doc.doc_id, doc_rev2, new_content, False)
+
+ def test_sync_sees_remote_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ doc1_rev = doc1.rev
+ self.db1.create_index('test-idx', 'key')
+ new_doc = '{"key": "altval"}'
+ doc2 = self.db2.create_doc_from_json(new_doc, doc_id=doc_id)
+ doc2_rev = doc2.rev
+ self.assertTransactionLog([doc1.doc_id], self.db1)
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [(doc_id, doc1_rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 1, 'last_known_gen': 0},
+ 'return':
+ {'docs': [(doc_id, doc2_rev)],
+ 'last_gen': 1}})
+ self.assertTransactionLog([doc_id, doc_id], self.db1)
+ self.assertGetDoc(self.db1, doc_id, doc2_rev, new_doc, True)
+ self.assertGetDoc(self.db2, doc_id, doc2_rev, new_doc, False)
+ from_idx = self.db1.get_from_index('test-idx', 'altval')[0]
+ self.assertEqual(doc2.doc_id, from_idx.doc_id)
+ self.assertEqual(doc2.rev, from_idx.rev)
+ self.assertTrue(from_idx.has_conflicts)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_sees_remote_delete_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ doc2 = self.make_document(doc1.doc_id, doc1.rev, doc1.get_json())
+ new_doc = '{"key": "altval"}'
+ doc1.set_json(new_doc)
+ self.db1.put_doc(doc1)
+ self.db2.delete_doc(doc2)
+ self.assertTransactionLog([doc_id, doc_id], self.db1)
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [(doc_id, doc1.rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 2, 'last_known_gen': 1},
+ 'return': {'docs': [(doc_id, doc2.rev)],
+ 'last_gen': 2}})
+ self.assertTransactionLog([doc_id, doc_id, doc_id], self.db1)
+ self.assertGetDocIncludeDeleted(self.db1, doc_id, doc2.rev, None, True)
+ self.assertGetDocIncludeDeleted(
+ self.db2, doc_id, doc2.rev, None, False)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+
+ def test_sync_local_race_conflicted(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ doc = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc.doc_id
+ doc1_rev = doc.rev
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ content1 = '{"key": "localval"}'
+ content2 = '{"key": "altval"}'
+ doc.set_json(content2)
+ self.db2.put_doc(doc)
+ doc2_rev2 = doc.rev
+ triggered = []
+
+ def after_whatschanged(state):
+ if state != 'after whats_changed':
+ return
+ triggered.append(True)
+ doc = self.make_document(doc_id, doc1_rev, content1)
+ self.db1.put_doc(doc)
+
+ self.sync(self.db1, self.db2, trace_hook=after_whatschanged)
+ self.assertEqual([True], triggered)
+ self.assertGetDoc(self.db1, doc_id, doc2_rev2, content2, True)
+ from_idx = self.db1.get_from_index('test-idx', 'altval')[0]
+ self.assertEqual(doc.doc_id, from_idx.doc_id)
+ self.assertEqual(doc.rev, from_idx.rev)
+ self.assertTrue(from_idx.has_conflicts)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'localval'))
+
+ def test_sync_propagates_deletes(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'both')
+ doc1 = self.db1.create_doc_from_json(simple_doc)
+ doc_id = doc1.doc_id
+ self.db1.create_index('test-idx', 'key')
+ self.sync(self.db1, self.db2)
+ self.db2.create_index('test-idx', 'key')
+ self.db3 = self.create_database('test3', 'target')
+ self.sync(self.db1, self.db3)
+ self.db1.delete_doc(doc1)
+ deleted_rev = doc1.rev
+ self.sync(self.db1, self.db2)
+ self.assertLastExchangeLog(self.db2,
+ {'receive':
+ {'docs': [(doc_id, deleted_rev)],
+ 'source_uid': 'test1',
+ 'source_gen': 2, 'last_known_gen': 1},
+ 'return': {'docs': [], 'last_gen': 2}})
+ self.assertGetDocIncludeDeleted(
+ self.db1, doc_id, deleted_rev, None, False)
+ self.assertGetDocIncludeDeleted(
+ self.db2, doc_id, deleted_rev, None, False)
+ self.assertEqual([], self.db1.get_from_index('test-idx', 'value'))
+ self.assertEqual([], self.db2.get_from_index('test-idx', 'value'))
+ self.sync(self.db2, self.db3)
+ self.assertLastExchangeLog(self.db3,
+ {'receive':
+ {'docs': [(doc_id, deleted_rev)],
+ 'source_uid': 'test2',
+ 'source_gen': 2,
+ 'last_known_gen': 0},
+ 'return':
+ {'docs': [], 'last_gen': 2}})
+ self.assertGetDocIncludeDeleted(
+ self.db3, doc_id, deleted_rev, None, False)
+
+ def test_sync_propagates_resolution(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'both')
+ doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc')
+ db3 = self.create_database('test3', 'both')
+ self.sync(self.db2, self.db1)
+ self.assertEqual(
+ self.db1._get_generation_info(),
+ self.db2._get_replica_gen_and_trans_id(self.db1._replica_uid))
+ self.assertEqual(
+ self.db2._get_generation_info(),
+ self.db1._get_replica_gen_and_trans_id(self.db2._replica_uid))
+ self.sync(db3, self.db1)
+ # update on 2
+ doc2 = self.make_document('the-doc', doc1.rev, '{"a": 2}')
+ self.db2.put_doc(doc2)
+ self.sync(self.db2, db3)
+ self.assertEqual(db3.get_doc('the-doc').rev, doc2.rev)
+ # update on 1
+ doc1.set_json('{"a": 3}')
+ self.db1.put_doc(doc1)
+ # conflicts
+ self.sync(self.db2, self.db1)
+ self.sync(db3, self.db1)
+ self.assertTrue(self.db2.get_doc('the-doc').has_conflicts)
+ self.assertTrue(db3.get_doc('the-doc').has_conflicts)
+ # resolve
+ conflicts = self.db2.get_doc_conflicts('the-doc')
+ doc4 = self.make_document('the-doc', None, '{"a": 4}')
+ revs = [doc.rev for doc in conflicts]
+ self.db2.resolve_doc(doc4, revs)
+ doc2 = self.db2.get_doc('the-doc')
+ self.assertEqual(doc4.get_json(), doc2.get_json())
+ self.assertFalse(doc2.has_conflicts)
+ self.sync(self.db2, db3)
+ doc3 = db3.get_doc('the-doc')
+ self.assertEqual(doc4.get_json(), doc3.get_json())
+ self.assertFalse(doc3.has_conflicts)
+
+ def test_sync_supersedes_conflicts(self):
+ self.db1 = self.create_database('test1', 'both')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.create_database('test3', 'both')
+ doc1 = self.db1.create_doc_from_json('{"a": 1}', doc_id='the-doc')
+ self.db2.create_doc_from_json('{"b": 1}', doc_id='the-doc')
+ db3.create_doc_from_json('{"c": 1}', doc_id='the-doc')
+ self.sync(db3, self.db1)
+ self.assertEqual(
+ self.db1._get_generation_info(),
+ db3._get_replica_gen_and_trans_id(self.db1._replica_uid))
+ self.assertEqual(
+ db3._get_generation_info(),
+ self.db1._get_replica_gen_and_trans_id(db3._replica_uid))
+ self.sync(db3, self.db2)
+ self.assertEqual(
+ self.db2._get_generation_info(),
+ db3._get_replica_gen_and_trans_id(self.db2._replica_uid))
+ self.assertEqual(
+ db3._get_generation_info(),
+ self.db2._get_replica_gen_and_trans_id(db3._replica_uid))
+ self.assertEqual(3, len(db3.get_doc_conflicts('the-doc')))
+ doc1.set_json('{"a": 2}')
+ self.db1.put_doc(doc1)
+ self.sync(db3, self.db1)
+ # original doc1 should have been removed from conflicts
+ self.assertEqual(3, len(db3.get_doc_conflicts('the-doc')))
+
+ def test_sync_stops_after_get_sync_info(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc)
+ self.sync(self.db1, self.db2)
+
+ def put_hook(state):
+ self.fail("Tracehook triggered for %s" % (state,))
+
+ self.sync(self.db1, self.db2, trace_hook_shallow=put_hook)
+
+ def test_sync_detects_rollback_in_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1')
+ self.sync(self.db1, self.db2)
+ db1_copy = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidGeneration, self.sync, db1_copy, self.db2)
+
+ def test_sync_detects_rollback_in_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ db2_copy = self.copy_database(self.db2)
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidGeneration, self.sync, self.db1, db2_copy)
+
+ def test_sync_detects_diverged_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ db3.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, db3, self.db2)
+
+ def test_sync_detects_diverged_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ db3 = self.copy_database(self.db2)
+ db3.create_doc_from_json(tests.nested_doc, doc_id="divergent")
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, self.db1, db3)
+
+ def test_sync_detects_rollback_and_divergence_in_source(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc1')
+ self.sync(self.db1, self.db2)
+ db1_copy = self.copy_database(self.db1)
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.sync(self.db1, self.db2)
+ db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ db1_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, db1_copy, self.db2)
+
+ def test_sync_detects_rollback_and_divergence_in_target(self):
+ self.db1 = self.create_database('test1', 'source')
+ self.db2 = self.create_database('test2', 'target')
+ self.db1.create_doc_from_json(tests.simple_doc, doc_id="divergent")
+ self.sync(self.db1, self.db2)
+ db2_copy = self.copy_database(self.db2)
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ self.db2.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.sync(self.db1, self.db2)
+ db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc2')
+ db2_copy.create_doc_from_json(tests.simple_doc, doc_id='doc3')
+ self.assertRaises(
+ errors.InvalidTransactionId, self.sync, self.db1, db2_copy)
+
+
+class TestDbSync(tests.TestCaseWithServer):
+ """Test db.sync remote sync shortcut"""
+
+ scenarios = [
+ ('py-http', {
+ 'make_app_with_state': make_http_app,
+ 'make_database_for_test': tests.make_memory_database_for_test,
+ }),
+ ('py-oauth-http', {
+ 'make_app_with_state': make_oauth_http_app,
+ 'make_database_for_test': tests.make_memory_database_for_test,
+ 'oauth': True
+ }),
+ ]
+
+ oauth = False
+
+ def do_sync(self, target_name):
+ if self.oauth:
+ path = '~/' + target_name
+ extra = dict(creds={'oauth': {
+ 'consumer_key': tests.consumer1.key,
+ 'consumer_secret': tests.consumer1.secret,
+ 'token_key': tests.token1.key,
+ 'token_secret': tests.token1.secret,
+ }})
+ else:
+ path = target_name
+ extra = {}
+ target_url = self.getURL(path)
+ return self.db.sync(target_url, **extra)
+
+ def setUp(self):
+ super(TestDbSync, self).setUp()
+ self.startServer()
+ self.db = self.make_database_for_test(self, 'test1')
+ self.db2 = self.request_state._create_database('test2.db')
+
+ def test_db_sync(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ doc2 = self.db2.create_doc_from_json(tests.nested_doc)
+ local_gen_before_sync = self.do_sync('test2.db')
+ gen, _, changes = self.db.whats_changed(local_gen_before_sync)
+ self.assertEqual(1, len(changes))
+ self.assertEqual(doc2.doc_id, changes[0][0])
+ self.assertEqual(1, gen - local_gen_before_sync)
+ self.assertGetDoc(self.db2, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ self.assertGetDoc(self.db, doc2.doc_id, doc2.rev, tests.nested_doc,
+ False)
+
+ def test_db_sync_autocreate(self):
+ doc1 = self.db.create_doc_from_json(tests.simple_doc)
+ local_gen_before_sync = self.do_sync('test3.db')
+ gen, _, changes = self.db.whats_changed(local_gen_before_sync)
+ self.assertEqual(0, gen - local_gen_before_sync)
+ db3 = self.request_state.open_database('test3.db')
+ gen, _, changes = db3.whats_changed()
+ self.assertEqual(1, len(changes))
+ self.assertEqual(doc1.doc_id, changes[0][0])
+ self.assertGetDoc(db3, doc1.doc_id, doc1.rev, tests.simple_doc,
+ False)
+ t_gen, _ = self.db._get_replica_gen_and_trans_id('test3.db')
+ s_gen, _ = db3._get_replica_gen_and_trans_id('test1')
+ self.assertEqual(1, t_gen)
+ self.assertEqual(1, s_gen)
+
+
+class TestRemoteSyncIntegration(tests.TestCaseWithServer):
+ """Integration tests for the most common sync scenario local -> remote"""
+
+ make_app_with_state = staticmethod(make_http_app)
+
+ def setUp(self):
+ super(TestRemoteSyncIntegration, self).setUp()
+ self.startServer()
+ self.db1 = inmemory.InMemoryDatabase('test1')
+ self.db2 = self.request_state._create_database('test2')
+
+ def test_sync_tracks_generations_incrementally(self):
+ doc11 = self.db1.create_doc_from_json('{"a": 1}')
+ doc12 = self.db1.create_doc_from_json('{"a": 2}')
+ doc21 = self.db2.create_doc_from_json('{"b": 1}')
+ doc22 = self.db2.create_doc_from_json('{"b": 2}')
+ #sanity
+ self.assertEqual(2, len(self.db1._get_transaction_log()))
+ self.assertEqual(2, len(self.db2._get_transaction_log()))
+ progress1 = []
+ progress2 = []
+ _do_set_replica_gen_and_trans_id = \
+ self.db1._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness1(other_uid, other_gen, trans_id):
+ progress1.append((other_uid, other_gen,
+ [d for d, t in
+ self.db1._get_transaction_log()[2:]]))
+ _do_set_replica_gen_and_trans_id(other_uid, other_gen, trans_id)
+ self.patch(self.db1, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness1)
+ _do_set_replica_gen_and_trans_id2 = \
+ self.db2._do_set_replica_gen_and_trans_id
+
+ def set_sync_generation_witness2(other_uid, other_gen, trans_id):
+ progress2.append((other_uid, other_gen,
+ [d for d, t in
+ self.db2._get_transaction_log()[2:]]))
+ _do_set_replica_gen_and_trans_id2(other_uid, other_gen, trans_id)
+ self.patch(self.db2, '_do_set_replica_gen_and_trans_id',
+ set_sync_generation_witness2)
+
+ db2_url = self.getURL('test2')
+ self.db1.sync(db2_url)
+
+ self.assertEqual([('test2', 1, [doc21.doc_id]),
+ ('test2', 2, [doc21.doc_id, doc22.doc_id]),
+ ('test2', 4, [doc21.doc_id, doc22.doc_id])],
+ progress1)
+ self.assertEqual([('test1', 1, [doc11.doc_id]),
+ ('test1', 2, [doc11.doc_id, doc12.doc_id]),
+ ('test1', 4, [doc11.doc_id, doc12.doc_id])],
+ progress2)
+
+
+load_tests = tests.load_with_scenarios
diff --git a/src/leap/soledad/util.py b/src/leap/soledad/util.py
new file mode 100644
index 00000000..4bc4d2c9
--- /dev/null
+++ b/src/leap/soledad/util.py
@@ -0,0 +1,55 @@
+import os
+import gnupg
+import re
+
+
+class GPGWrapper(gnupg.GPG):
+ """
+ This is a temporary class for handling GPG requests, and should be
+ replaced by a more general class used throughout the project.
+ """
+
+ GNUPG_HOME = os.environ['HOME'] + "/.config/leap/gnupg"
+ GNUPG_BINARY = "/usr/bin/gpg" # this has to be changed based on OS
+
+ def __init__(self, gpghome=GNUPG_HOME, gpgbinary=GNUPG_BINARY):
+ super(GPGWrapper, self).__init__(gnupghome=gpghome,
+ gpgbinary=gpgbinary)
+
+ def find_key(self, email):
+ """
+ Find user's key based on their email.
+ """
+ for key in self.list_keys():
+ for uid in key['uids']:
+ if re.search(email, uid):
+ return key
+ raise LookupError("GnuPG public key for %s not found!" % email)
+
+ def encrypt(self, data, recipient, sign=None, always_trust=True,
+ passphrase=None, symmetric=False):
+ # TODO: devise a way so we don't need to "always trust".
+ return super(GPGWrapper, self).encrypt(data, recipient, sign=sign,
+ always_trust=always_trust,
+ passphrase=passphrase,
+ symmetric=symmetric)
+
+ def decrypt(self, data, always_trust=True, passphrase=None):
+ # TODO: devise a way so we don't need to "always trust".
+ return super(GPGWrapper, self).decrypt(data,
+ always_trust=always_trust,
+ passphrase=passphrase)
+
+ def send_keys(self, keyserver, *keyids):
+ """
+ Send keys to a keyserver
+ """
+ result = self.result_map['list'](self)
+ gnupg.logger.debug('send_keys: %r', keyids)
+ data = gnupg._make_binary_stream("", self.encoding)
+ args = ['--keyserver', keyserver, '--send-keys']
+ args.extend(keyids)
+ self._handle_io(args, data, result, binary=True)
+ gnupg.logger.debug('send_keys result: %r', result.__dict__)
+ data.close()
+ return result
diff --git a/src/leap/testing/pyqt.py b/src/leap/testing/pyqt.py
new file mode 100644
index 00000000..6edaf059
--- /dev/null
+++ b/src/leap/testing/pyqt.py
@@ -0,0 +1,52 @@
+from PyQt4 import QtCore
+
+_oldConnect = QtCore.QObject.connect
+_oldDisconnect = QtCore.QObject.disconnect
+_oldEmit = QtCore.QObject.emit
+
+
+def _wrapConnect(callableObject):
+ """
+ Returns a wrapped call to the old version of QtCore.QObject.connect
+ """
+ @staticmethod
+ def call(*args):
+ callableObject(*args)
+ _oldConnect(*args)
+ return call
+
+
+def _wrapDisconnect(callableObject):
+ """
+ Returns a wrapped call to the old version of QtCore.QObject.disconnect
+ """
+ @staticmethod
+ def call(*args):
+ callableObject(*args)
+ _oldDisconnect(*args)
+ return call
+
+
+def enableSignalDebugging(**kwargs):
+ """
+ Call this to enable Qt Signal debugging. This will trap all
+ connect, and disconnect calls.
+ """
+
+ f = lambda *args: None
+ connectCall = kwargs.get('connectCall', f)
+ disconnectCall = kwargs.get('disconnectCall', f)
+ emitCall = kwargs.get('emitCall', f)
+
+ def printIt(msg):
+ def call(*args):
+ print msg, args
+ return call
+ QtCore.QObject.connect = _wrapConnect(connectCall)
+ QtCore.QObject.disconnect = _wrapDisconnect(disconnectCall)
+
+ def new_emit(self, *args):
+ emitCall(self, *args)
+ _oldEmit(self, *args)
+
+ QtCore.QObject.emit = new_emit
diff --git a/src/leap/testing/qunittest.py b/src/leap/testing/qunittest.py
new file mode 100644
index 00000000..b89ccec3
--- /dev/null
+++ b/src/leap/testing/qunittest.py
@@ -0,0 +1,302 @@
+# -*- coding: utf-8 -*-
+
+# **qunittest** is an standard Python `unittest` enhancement for PyQt4,
+# allowing
+# you to test asynchronous code using standard synchronous testing facility.
+#
+# The source for `qunittest` is available on [GitHub][gh], and released under
+# the MIT license.
+#
+# Slightly modified by The Leap Project.
+
+### Prerequisites
+
+# Import unittest2 or unittest
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+# ... and some standard Python libraries
+import sys
+import functools
+import contextlib
+import re
+
+# ... and several PyQt classes
+from PyQt4.QtCore import QTimer
+from PyQt4.QtTest import QTest
+from PyQt4 import QtGui
+
+### The code
+
+
+# Override standard main method, by invoking it inside PyQt event loop
+
+def main(*args, **kwargs):
+ qapplication = QtGui.QApplication(sys.argv)
+
+ QTimer.singleShot(0, unittest.main(*args, **kwargs))
+ qapplication.exec_()
+
+"""
+This main substitute does not integrate with unittest.
+
+Note about mixing the event loop and unittests:
+
+Unittest will fail if we keep more than one reference to a QApplication.
+(pyqt expects to be and only one).
+So, for the things that need a QApplication to exist, do something like:
+
+ self.app = QApplication()
+ QtGui.qApp = self.app
+
+in the class setUp, and::
+
+ QtGui.qApp = None
+ self.app = None
+
+in the class tearDown.
+
+For some explanation about this, see
+ http://stuvel.eu/blog/127/multiple-instances-of-qapplication-in-one-process
+and
+ http://www.riverbankcomputing.com/pipermail/pyqt/2010-September/027705.html
+"""
+
+
+# Helper returning the name of a given signal
+
+def _signal_name(signal):
+ s = repr(signal)
+ name_re = "signal (\w+) of (\w+)"
+ match = re.search(name_re, s, re.I)
+ if not match:
+ return "??"
+ return "%s#%s" % (match.group(2), match.group(1))
+
+
+class _SignalConnector(object):
+ """ Encapsulates signal assertion testing """
+ def __init__(self, test, signal, callable_):
+ self.test = test
+ self.callable_ = callable_
+ self.called_with = None
+ self.emited = False
+ self.signal = signal
+ self._asserted = False
+
+ signal.connect(self.on_signal_emited)
+
+ # Store given parameters and mark signal as `emited`
+ def on_signal_emited(self, *args, **kwargs):
+ self.called_with = (args, kwargs)
+ self.emited = True
+
+ def assertEmission(self):
+ # Assert once wheter signal was emited or not
+ was_asserted = self._asserted
+ self._asserted = True
+
+ if not was_asserted:
+ if not self.emited:
+ self.test.fail(
+ "signal %s not emited" % (_signal_name(self.signal)))
+
+ # Call given callable is necessary
+ if self.callable_:
+ args, kwargs = self.called_with
+ self.callable_(*args, **kwargs)
+
+ def __enter__(self):
+ # Assert emission when context is entered
+ self.assertEmission()
+ return self.called_with
+
+ def __exit__(self, *_):
+ return False
+
+### Unit Testing
+
+# `qunittest` does not force much abould how test should look - it just adds
+# several helpers for asynchronous code testing.
+#
+# Common test case may look like this:
+#
+# import qunittest
+# from calculator import Calculator
+#
+# class TestCalculator(qunittest.TestCase):
+# def setUp(self):
+# self.calc = Calculator()
+#
+# def test_should_add_two_numbers_synchronously(self):
+# # given
+# a, b = 2, 3
+#
+# # when
+# r = self.calc.add(a, b)
+#
+# # then
+# self.assertEqual(5, r)
+#
+# def test_should_calculate_factorial_in_background(self):
+# # given
+#
+# # when
+# self.calc.factorial(20)
+#
+# # then
+# self.assertEmited(self.calc.done) with (args, kwargs):
+# self.assertEqual([2432902008176640000], args)
+#
+# if __name__ == "__main__":
+# main()
+#
+# Test can be run by typing:
+#
+# python test_calculator.py
+#
+# Automatic test discovery is not supported now, because testing PyQt needs
+# an instance of `QApplication` and its `exec_` method is blocking.
+#
+
+
+### TestCase class
+
+class TestCase(unittest.TestCase):
+ """
+ Extends standard `unittest.TestCase` with several PyQt4 testing features
+ useful for asynchronous testing.
+ """
+ def __init__(self, *args, **kwargs):
+ super(TestCase, self).__init__(*args, **kwargs)
+
+ self._clearSignalConnectors()
+ self._succeeded = False
+ self.addCleanup(self._clearSignalConnectors)
+ self.tearDown = self._decorateTearDown(self.tearDown)
+
+ ### Protected methods
+
+ def _clearSignalConnectors(self):
+ self._connectedSignals = []
+
+ def _decorateTearDown(self, tearDown):
+ @functools.wraps(tearDown)
+ def decorator():
+ self._ensureEmitedSignals()
+ return tearDown()
+ return decorator
+
+ def _ensureEmitedSignals(self):
+ """
+ Checks if signals were acually emited. Raises AssertionError if no.
+ """
+ # TODO: add information about line
+ for signal in self._connectedSignals:
+ signal.assertEmission()
+
+ ### Assertions
+
+ def assertEmited(self, signal, callable_=None, timeout=1):
+ """
+ Asserts if given `signal` was emited. Waits 1 second by default,
+ before asserts signal emission.
+
+ If `callable_` is given, it should be a function which takes two
+ arguments: `args` and `kwargs`. It will be called after blocking
+ operation or when assertion about signal emission is made and
+ signal was emited.
+
+ When timeout is not `False`, method call is blocking, and ends
+ after `timeout` seconds. After that time, it validates wether
+ signal was emited.
+
+ When timeout is `False`, method is non blocking, and test should wait
+ for signals afterwards. Otherwise, at the end of the test, all
+ signal emissions are checked if appeared.
+
+ Function returns context, which yields to list of parameters given
+ to signal. It can be useful for testing given parameters. Following
+ code:
+
+ with self.assertEmited(widget.signal) as (args, kwargs):
+ self.assertEqual(1, len(args))
+ self.assertEqual("Hello World!", args[0])
+
+ will wait 1 second and test for correct parameters, is signal was
+ emtied.
+
+ Note that code:
+
+ with self.assertEmited(widget.signal, timeout=False) as (a, k):
+ # Will not be invoked
+
+ will always fail since signal cannot be emited in the time of its
+ connection - code inside the context will not be invoked at all.
+ """
+
+ connector = _SignalConnector(self, signal, callable_)
+ self._connectedSignals.append(connector)
+ if timeout:
+ self.waitFor(timeout)
+ connector.assertEmission()
+
+ return connector
+
+ ### Helper methods
+
+ @contextlib.contextmanager
+ def invokeAfter(self, seconds, callable_=None):
+ """
+ Waits given amount of time and executes the context.
+
+ If `callable_` is given, executes it, instead of context.
+ """
+ self.waitFor(seconds)
+ if callable_:
+ callable_()
+ else:
+ yield
+
+ def waitFor(self, seconds):
+ """
+ Waits given amount of time.
+
+ self.widget.loadImage(url)
+ self.waitFor(seconds=10)
+ """
+ QTest.qWait(seconds * 1000)
+
+ def succeed(self, bool_=True):
+ """ Marks test as suceeded for next `failAfter()` invocation. """
+ self._succeeded = self._succeeded or bool_
+
+ def failAfter(self, seconds, message=None):
+ """
+ Waits given amount of time, and fails the test if `succeed(bool)`
+ is not called - in most common case, `succeed(bool)` should be called
+ asynchronously (in signal handler):
+
+ self.widget.signal.connect(lambda: self.succeed())
+ self.failAfter(1, "signal not emited?")
+
+ After invocation, test is no longer consider as succeeded.
+ """
+ self.waitFor(seconds)
+ if not self._succeeded:
+ self.fail(message)
+
+ self._succeeded = False
+
+### Credits
+#
+# * **Who is responsible:** [Dawid Fatyga][df]
+# * **Source:** [GitHub][gh]
+# * **Doc. generator:** [rocco][ro]
+#
+# [gh]: https://www.github.com/dejw/qunittest
+# [df]: https://github.com/dejw
+# [ro]: http://rtomayko.github.com/rocco/
+#
diff --git a/src/leap/util/__init__.py b/src/leap/util/__init__.py
index e69de29b..a70a9a8b 100644
--- a/src/leap/util/__init__.py
+++ b/src/leap/util/__init__.py
@@ -0,0 +1,9 @@
+import logging
+logger = logging.getLogger(__name__)
+
+try:
+ import pygeoip
+ HAS_GEOIP = True
+except ImportError:
+ logger.debug('PyGeoIP not found. Disabled Geo support.')
+ HAS_GEOIP = False
diff --git a/src/leap/util/certs.py b/src/leap/util/certs.py
new file mode 100644
index 00000000..f0f790e9
--- /dev/null
+++ b/src/leap/util/certs.py
@@ -0,0 +1,18 @@
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def get_mac_cabundle():
+ # hackaround bundle error
+ # XXX this needs a better fix!
+ f = os.path.split(__file__)[0]
+ sep = os.path.sep
+ f_ = sep.join(f.split(sep)[:-2])
+ verify = os.path.join(f_, 'cacert.pem')
+ #logger.error('VERIFY PATH = %s' % verify)
+ exists = os.path.isfile(verify)
+ #logger.error('do exist? %s', exists)
+ if exists:
+ return verify
diff --git a/src/leap/util/dicts.py b/src/leap/util/dicts.py
new file mode 100644
index 00000000..001ca96b
--- /dev/null
+++ b/src/leap/util/dicts.py
@@ -0,0 +1,268 @@
+# Backport of OrderedDict() class that runs
+# on Python 2.4, 2.5, 2.6, 2.7 and pypy.
+# Passes Python2.7's test suite and incorporates all the latest updates.
+
+try:
+ from thread import get_ident as _get_ident
+except ImportError:
+ from dummy_thread import get_ident as _get_ident
+
+try:
+ from _abcoll import KeysView, ValuesView, ItemsView
+except ImportError:
+ pass
+
+
+class OrderedDict(dict):
+ 'Dictionary that remembers insertion order'
+ # An inherited dict maps keys to values.
+ # The inherited dict provides __getitem__, __len__, __contains__, and get.
+ # The remaining methods are order-aware.
+ # Big-O running times for all methods are the same as for regular
+ # dictionaries.
+
+ # The internal self.__map dictionary maps keys to links in a doubly
+ # linked list.
+ # The circular doubly linked list starts and ends with a sentinel element.
+ # The sentinel element never gets deleted (this simplifies the algorithm).
+ # Each link is stored as a list of length three: [PREV, NEXT, KEY].
+
+ def __init__(self, *args, **kwds):
+ '''Initialize an ordered dictionary. Signature is the same as for
+ regular dictionaries, but keyword arguments are not recommended
+ because their insertion order is arbitrary.
+
+ '''
+ if len(args) > 1:
+ raise TypeError('expected at most 1 arguments, got %d' % len(args))
+ try:
+ self.__root
+ except AttributeError:
+ self.__root = root = [] # sentinel node
+ root[:] = [root, root, None]
+ self.__map = {}
+ self.__update(*args, **kwds)
+
+ def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
+ 'od.__setitem__(i, y) <==> od[i]=y'
+ # Setting a new item creates a new link which goes at the end
+ # of the linked list, and the inherited dictionary is updated
+ # with the new key/value pair.
+ if key not in self:
+ root = self.__root
+ last = root[0]
+ last[1] = root[0] = self.__map[key] = [last, root, key]
+ dict_setitem(self, key, value)
+
+ def __delitem__(self, key, dict_delitem=dict.__delitem__):
+ 'od.__delitem__(y) <==> del od[y]'
+ # Deleting an existing item uses self.__map to find the link which is
+ # then removed by updating the links in the predecessor and successor
+ # nodes.
+ dict_delitem(self, key)
+ link_prev, link_next, key = self.__map.pop(key)
+ link_prev[1] = link_next
+ link_next[0] = link_prev
+
+ def __iter__(self):
+ 'od.__iter__() <==> iter(od)'
+ root = self.__root
+ curr = root[1]
+ while curr is not root:
+ yield curr[2]
+ curr = curr[1]
+
+ def __reversed__(self):
+ 'od.__reversed__() <==> reversed(od)'
+ root = self.__root
+ curr = root[0]
+ while curr is not root:
+ yield curr[2]
+ curr = curr[0]
+
+ def clear(self):
+ 'od.clear() -> None. Remove all items from od.'
+ try:
+ for node in self.__map.itervalues():
+ del node[:]
+ root = self.__root
+ root[:] = [root, root, None]
+ self.__map.clear()
+ except AttributeError:
+ pass
+ dict.clear(self)
+
+ def popitem(self, last=True):
+ '''od.popitem() -> (k, v), return and remove a (key, value) pair.
+ Pairs are returned in LIFO order if last is true or FIFO order if
+ false.
+ '''
+ if not self:
+ raise KeyError('dictionary is empty')
+ root = self.__root
+ if last:
+ link = root[0]
+ link_prev = link[0]
+ link_prev[1] = root
+ root[0] = link_prev
+ else:
+ link = root[1]
+ link_next = link[1]
+ root[1] = link_next
+ link_next[0] = root
+ key = link[2]
+ del self.__map[key]
+ value = dict.pop(self, key)
+ return key, value
+
+ # -- the following methods do not depend on the internal structure --
+
+ def keys(self):
+ 'od.keys() -> list of keys in od'
+ return list(self)
+
+ def values(self):
+ 'od.values() -> list of values in od'
+ return [self[key] for key in self]
+
+ def items(self):
+ 'od.items() -> list of (key, value) pairs in od'
+ return [(key, self[key]) for key in self]
+
+ def iterkeys(self):
+ 'od.iterkeys() -> an iterator over the keys in od'
+ return iter(self)
+
+ def itervalues(self):
+ 'od.itervalues -> an iterator over the values in od'
+ for k in self:
+ yield self[k]
+
+ def iteritems(self):
+ 'od.iteritems -> an iterator over the (key, value) items in od'
+ for k in self:
+ yield (k, self[k])
+
+ def update(*args, **kwds):
+ '''od.update(E, **F) -> None. Update od from dict/iterable E and F.
+
+ If E is a dict instance, does: for k in E: od[k] = E[k]
+ If E has a .keys() method, does: for k in E.keys():
+ od[k] = E[k]
+ Or if E is an iterable of items, does: for k, v in E: od[k] = v
+ In either case, this is followed by: for k, v in F.items():
+ od[k] = v
+ '''
+
+ if len(args) > 2:
+ raise TypeError('update() takes at most 2 positional '
+ 'arguments (%d given)' % (len(args),))
+ elif not args:
+ raise TypeError('update() takes at least 1 argument (0 given)')
+ self = args[0]
+ # Make progressively weaker assumptions about "other"
+ other = ()
+ if len(args) == 2:
+ other = args[1]
+ if isinstance(other, dict):
+ for key in other:
+ self[key] = other[key]
+ elif hasattr(other, 'keys'):
+ for key in other.keys():
+ self[key] = other[key]
+ else:
+ for key, value in other:
+ self[key] = value
+ for key, value in kwds.items():
+ self[key] = value
+
+ __update = update # let subclasses override update
+ # without breaking __init__
+
+ __marker = object()
+
+ def pop(self, key, default=__marker):
+ '''od.pop(k[,d]) -> v
+ remove specified key and return the corresponding value.
+ If key is not found, d is returned if given,
+ otherwise KeyError is raised.
+
+ '''
+ if key in self:
+ result = self[key]
+ del self[key]
+ return result
+ if default is self.__marker:
+ raise KeyError(key)
+ return default
+
+ def setdefault(self, key, default=None):
+ 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
+ if key in self:
+ return self[key]
+ self[key] = default
+ return default
+
+ def __repr__(self, _repr_running={}):
+ 'od.__repr__() <==> repr(od)'
+ call_key = id(self), _get_ident()
+ if call_key in _repr_running:
+ return '...'
+ _repr_running[call_key] = 1
+ try:
+ if not self:
+ return '%s()' % (self.__class__.__name__,)
+ return '%s(%r)' % (self.__class__.__name__, self.items())
+ finally:
+ del _repr_running[call_key]
+
+ def __reduce__(self):
+ 'Return state information for pickling'
+ items = [[k, self[k]] for k in self]
+ inst_dict = vars(self).copy()
+ for k in vars(OrderedDict()):
+ inst_dict.pop(k, None)
+ if inst_dict:
+ return (self.__class__, (items,), inst_dict)
+ return self.__class__, (items,)
+
+ def copy(self):
+ 'od.copy() -> a shallow copy of od'
+ return self.__class__(self)
+
+ @classmethod
+ def fromkeys(cls, iterable, value=None):
+ '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
+ and values equal to v (which defaults to None).
+
+ '''
+ d = cls()
+ for key in iterable:
+ d[key] = value
+ return d
+
+ def __eq__(self, other):
+ '''od.__eq__(y) <==> od==y.
+ Comparison to another OD is order-sensitive
+ while comparison to a regular mapping is order-insensitive.
+ '''
+ if isinstance(other, OrderedDict):
+ return len(self) == len(other) and self.items() == other.items()
+ return dict.__eq__(self, other)
+
+ def __ne__(self, other):
+ return not self == other
+
+ # -- the following methods are only used in Python 2.7 --
+
+ def viewkeys(self):
+ "od.viewkeys() -> a set-like object providing a view on od's keys"
+ return KeysView(self)
+
+ def viewvalues(self):
+ "od.viewvalues() -> an object providing a view on od's values"
+ return ValuesView(self)
+
+ def viewitems(self):
+ "od.viewitems() -> a set-like object providing a view on od's items"
+ return ItemsView(self)
diff --git a/src/leap/util/fileutil.py b/src/leap/util/fileutil.py
index aef4cfe0..820ffe46 100644
--- a/src/leap/util/fileutil.py
+++ b/src/leap/util/fileutil.py
@@ -93,6 +93,11 @@ def mkdir_p(path):
raise
+def mkdir_f(path):
+ folder, fname = os.path.split(path)
+ mkdir_p(folder)
+
+
def check_and_fix_urw_only(_file):
"""
test for 600 mode and try
diff --git a/src/leap/util/geo.py b/src/leap/util/geo.py
new file mode 100644
index 00000000..54b29596
--- /dev/null
+++ b/src/leap/util/geo.py
@@ -0,0 +1,32 @@
+"""
+experimental geo support.
+not yet a feature.
+in debian, we rely on the (optional) geoip-database
+"""
+import os
+import platform
+
+from leap.util import HAS_GEOIP
+
+GEOIP = None
+
+if HAS_GEOIP:
+ import pygeoip # we know we can :)
+
+ GEOIP_PATH = None
+
+ if platform.system() == "Linux":
+ PATH = "/usr/share/GeoIP/GeoIP.dat"
+ if os.path.isfile(PATH):
+ GEOIP_PATH = PATH
+ GEOIP = pygeoip.GeoIP(GEOIP_PATH, pygeoip.MEMORY_CACHE)
+
+
+def get_country_name(ip):
+ if not GEOIP:
+ return
+ try:
+ country = GEOIP.country_name_by_addr(ip)
+ except pygeoip.GeoIPError:
+ country = None
+ return country if country else "-"
diff --git a/src/leap/util/leap_argparse.py b/src/leap/util/leap_argparse.py
index 2f996a31..5b0775cc 100644
--- a/src/leap/util/leap_argparse.py
+++ b/src/leap/util/leap_argparse.py
@@ -37,5 +37,5 @@ Launches main LEAP Client""", epilog=epilog)
def init_leapc_args():
parser = build_parser()
- opts = parser.parse_args()
+ opts, unknown = parser.parse_known_args()
return parser, opts
diff --git a/src/leap/util/misc.py b/src/leap/util/misc.py
new file mode 100644
index 00000000..d869a1ba
--- /dev/null
+++ b/src/leap/util/misc.py
@@ -0,0 +1,37 @@
+"""
+misc utils
+"""
+import psutil
+
+from leap.base.constants import OPENVPN_BIN
+
+
+class ImproperlyConfigured(Exception):
+ """
+ """
+
+
+def null_check(value, value_name):
+ try:
+ assert value is not None
+ except AssertionError:
+ raise ImproperlyConfigured(
+ "%s parameter cannot be None" % value_name)
+
+
+def get_openvpn_pids():
+ # binary name might change
+
+ openvpn_pids = []
+ for p in psutil.process_iter():
+ try:
+ # XXX Not exact!
+ # Will give false positives.
+ # we should check that cmdline BEGINS
+ # with openvpn or with our wrapper
+ # (pkexec / osascript / whatever)
+ if OPENVPN_BIN in ' '.join(p.cmdline):
+ openvpn_pids.append(p.pid)
+ except psutil.error.AccessDenied:
+ pass
+ return openvpn_pids
diff --git a/src/leap/util/tests/test_translations.py b/src/leap/util/tests/test_translations.py
new file mode 100644
index 00000000..794daeba
--- /dev/null
+++ b/src/leap/util/tests/test_translations.py
@@ -0,0 +1,22 @@
+import unittest
+
+from leap.util import translations
+
+
+class TrasnlationsTestCase(unittest.TestCase):
+ """
+ tests for translation functions and classes
+ """
+
+ def setUp(self):
+ self.trClass = translations.LEAPTranslatable
+
+ def test_trasnlatable(self):
+ tr = self.trClass({"en": "house", "es": "casa"})
+ eq = self.assertEqual
+ eq(tr.tr(to="es"), "casa")
+ eq(tr.tr(to="en"), "house")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/util/translations.py b/src/leap/util/translations.py
new file mode 100644
index 00000000..f55c8fba
--- /dev/null
+++ b/src/leap/util/translations.py
@@ -0,0 +1,82 @@
+import inspect
+import logging
+
+from PyQt4.QtCore import QCoreApplication
+from PyQt4.QtCore import QLocale
+
+logger = logging.getLogger(__name__)
+
+"""
+here I could not do all that I wanted.
+the context is not getting passed to the xml file.
+Looks like pylupdate4 is somehow a hack that does not
+parse too well the python ast.
+I guess we could generate the xml for ourselves as a last recourse.
+"""
+
+# XXX BIG NOTE:
+# RESIST the temptation to get the translate function
+# more compact, or have the Context argument passed as a variable
+# Its name HAS to be explicit due to how the pylupdate parser
+# works.
+
+
+qtTranslate = QCoreApplication.translate
+
+
+def translate(*args, **kwargs):
+ """
+ our magic function.
+ translate(Context, text, comment)
+ """
+ if len(args) == 1:
+ obj = args[0]
+ if isinstance(obj, LEAPTranslatable) and hasattr(obj, 'tr'):
+ return obj.tr()
+
+ klsname = None
+ try:
+ # get class value from instance
+ # using live object inspection
+ prev_frame = inspect.stack()[1][0]
+ locals_ = inspect.getargvalues(prev_frame).locals
+ self = locals_.get('self')
+ if self:
+
+ # Trying to get the class name
+ # but this is useless, the parser
+ # has already got the context.
+ klsname = self.__class__.__name__
+ #print 'KLSNAME -- ', klsname
+ except:
+ logger.error('error getting stack frame')
+
+ if klsname and len(args) == 1:
+ nargs = (klsname,) + args
+ return qtTranslate(*nargs)
+
+ else:
+ return qtTranslate(*args)
+
+
+class LEAPTranslatable(dict):
+ """
+ An extended dict that implements a .tr method
+ so it can be translated on the fly by our
+ magic translate method
+ """
+
+ try:
+ locale = str(QLocale.system().name()).split('_')[0]
+ except:
+ logger.warning("could not get system locale!")
+ print "could not get system locale!"
+ locale = "en"
+
+ def tr(self, to=None):
+ if not to:
+ to = self.locale
+ _tr = self.get(to, None)
+ if not _tr:
+ _tr = self.get("en", None)
+ return _tr
diff --git a/src/leap/util/web.py b/src/leap/util/web.py
new file mode 100644
index 00000000..15de0561
--- /dev/null
+++ b/src/leap/util/web.py
@@ -0,0 +1,40 @@
+"""
+web related utilities
+"""
+
+
+class UsageError(Exception):
+ """ """
+
+
+def get_https_domain_and_port(full_domain):
+ """
+ returns a tuple with domain and port
+ from a full_domain string that can
+ contain a colon
+ """
+ full_domain = unicode(full_domain)
+ if full_domain is None:
+ return None, None
+
+ https_sch = "https://"
+ http_sch = "http://"
+
+ if full_domain.startswith(https_sch):
+ full_domain = full_domain.lstrip(https_sch)
+ elif full_domain.startswith(http_sch):
+ raise UsageError(
+ "cannot be called with a domain "
+ "that begins with 'http://'")
+
+ domain_split = full_domain.split(':')
+ _len = len(domain_split)
+ if _len == 1:
+ domain, port = full_domain, 443
+ elif _len == 2:
+ domain, port = domain_split
+ else:
+ raise UsageError(
+ "must be called with one only parameter"
+ "in the form domain[:port]")
+ return domain, port