summaryrefslogtreecommitdiff
path: root/src/leap/base
diff options
context:
space:
mode:
authorkali <kali@leap.se>2013-02-15 09:31:51 +0900
committerkali <kali@leap.se>2013-02-15 09:31:51 +0900
commit9cea9c8a34343f8792d65b96f93ae22bd8685878 (patch)
tree9f512367b1d47ced5614702a00f3ff0a8fe746d7 /src/leap/base
parent7159734ec6c0b76fc7f3737134cd22fdaaaa7d58 (diff)
parent1032e07a50c8bb265ff9bd31b3bb00e83ddb451e (diff)
Merge branch 'release/v0.2.0'
Conflicts: README.txt
Diffstat (limited to 'src/leap/base')
-rw-r--r--src/leap/base/__init__.py0
-rw-r--r--src/leap/base/auth.py355
-rw-r--r--src/leap/base/authentication.py11
-rw-r--r--src/leap/base/checks.py213
-rw-r--r--src/leap/base/config.py348
-rw-r--r--src/leap/base/connection.py115
-rw-r--r--src/leap/base/constants.py42
-rw-r--r--src/leap/base/exceptions.py97
-rw-r--r--src/leap/base/jsonschema.py791
-rw-r--r--src/leap/base/network.py107
-rw-r--r--src/leap/base/pluggableconfig.py462
-rw-r--r--src/leap/base/providers.py29
-rw-r--r--src/leap/base/specs.py62
-rw-r--r--src/leap/base/tests/__init__.py0
-rw-r--r--src/leap/base/tests/test_auth.py58
-rw-r--r--src/leap/base/tests/test_checks.py177
-rw-r--r--src/leap/base/tests/test_config.py247
-rw-r--r--src/leap/base/tests/test_providers.py148
-rw-r--r--src/leap/base/tests/test_validation.py93
19 files changed, 3355 insertions, 0 deletions
diff --git a/src/leap/base/__init__.py b/src/leap/base/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/base/__init__.py
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py
new file mode 100644
index 00000000..c2d3f424
--- /dev/null
+++ b/src/leap/base/auth.py
@@ -0,0 +1,355 @@
+import binascii
+import json
+import logging
+#import urlparse
+
+import requests
+import srp
+
+from PyQt4 import QtCore
+
+from leap.base import constants as baseconstants
+from leap.crypto import leapkeyring
+from leap.util.misc import null_check
+from leap.util.web import get_https_domain_and_port
+
+logger = logging.getLogger(__name__)
+
+SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5)
+
+"""
+Registration and authentication classes for the
+SRP auth mechanism used in the leap platform.
+
+We're using the srp library which uses a c-based implementation
+of the protocol if the c extension is available, and a python-based
+one if not.
+"""
+
+
+class SRPAuthenticationError(Exception):
+ """
+ exception raised
+ for authentication errors
+ """
+
+
+safe_unhexlify = lambda x: binascii.unhexlify(x) \
+ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
+
+
+class LeapSRPRegister(object):
+
+ def __init__(self,
+ schema="https",
+ provider=None,
+ verify=True,
+ register_path="1/users",
+ method="POST",
+ fetcher=requests,
+ srp=srp,
+ hashfun=srp.SHA256,
+ ng_constant=srp.NG_1024):
+
+ null_check(provider, "provider")
+
+ self.schema = schema
+
+ domain, port = get_https_domain_and_port(provider)
+ self.provider = domain
+ self.port = port
+
+ self.verify = verify
+ self.register_path = register_path
+ self.method = method
+ self.fetcher = fetcher
+ self.srp = srp
+ self.HASHFUN = hashfun
+ self.NG = ng_constant
+
+ self.init_session()
+
+ def init_session(self):
+ self.session = self.fetcher.session()
+
+ def get_registration_uri(self):
+ # XXX assert is https!
+ # use urlparse
+ if self.port:
+ uri = "%s://%s:%s/%s" % (
+ self.schema,
+ self.provider,
+ self.port,
+ self.register_path)
+ else:
+ uri = "%s://%s/%s" % (
+ self.schema,
+ self.provider,
+ self.register_path)
+
+ return uri
+
+ def register_user(self, username, password, keep=False):
+ """
+ @rtype: tuple
+ @rparam: (ok, request)
+ """
+ salt, vkey = self.srp.create_salted_verification_key(
+ username,
+ password,
+ self.HASHFUN,
+ self.NG)
+
+ user_data = {
+ 'user[login]': username,
+ 'user[password_verifier]': binascii.hexlify(vkey),
+ 'user[password_salt]': binascii.hexlify(salt)}
+
+ uri = self.get_registration_uri()
+ logger.debug('post to uri: %s' % uri)
+
+ # XXX get self.method
+ req = self.session.post(
+ uri, data=user_data,
+ timeout=SIGNUP_TIMEOUT,
+ verify=self.verify)
+ # we catch it in the form
+ #req.raise_for_status()
+ return (req.ok, req)
+
+
+class SRPAuth(requests.auth.AuthBase):
+
+ def __init__(self, username, password, server=None, verify=None):
+ # sanity check
+ null_check(server, 'server')
+ self.username = username
+ self.password = password
+ self.server = server
+ self.verify = verify
+
+ logger.debug('SRPAuth. verify=%s' % verify)
+ logger.debug('server: %s. username=%s' % (server, username))
+
+ self.init_data = None
+ self.session = requests.session()
+
+ self.init_srp()
+
+ def init_srp(self):
+ usr = srp.User(
+ self.username,
+ self.password,
+ srp.SHA256,
+ srp.NG_1024)
+ uname, A = usr.start_authentication()
+
+ self.srp_usr = usr
+ self.A = A
+
+ def get_auth_data(self):
+ return {
+ 'login': self.username,
+ 'A': binascii.hexlify(self.A)
+ }
+
+ def get_init_data(self):
+ try:
+ init_session = self.session.post(
+ self.server + '/1/sessions/',
+ data=self.get_auth_data(),
+ verify=self.verify)
+ except requests.exceptions.ConnectionError:
+ raise SRPAuthenticationError(
+ "No connection made (salt).")
+ except:
+ raise SRPAuthenticationError(
+ "Unknown error (salt).")
+ if init_session.status_code not in (200, ):
+ raise SRPAuthenticationError(
+ "No valid response (salt).")
+
+ self.init_data = init_session.json
+ return self.init_data
+
+ def get_server_proof_data(self):
+ try:
+ auth_result = self.session.put(
+ #self.server + '/1/sessions.json/' + self.username,
+ self.server + '/1/sessions/' + self.username,
+ data={'client_auth': binascii.hexlify(self.M)},
+ verify=self.verify)
+ except requests.exceptions.ConnectionError:
+ raise SRPAuthenticationError(
+ "No connection made (HAMK).")
+
+ if auth_result.status_code not in (200, ):
+ raise SRPAuthenticationError(
+ "No valid response (HAMK).")
+
+ self.auth_data = auth_result.json
+ return self.auth_data
+
+ def authenticate(self):
+ logger.debug('start authentication...')
+
+ init_data = self.get_init_data()
+ salt = init_data.get('salt', None)
+ B = init_data.get('B', None)
+
+ # XXX refactor this function
+ # move checks and un-hex
+ # to routines
+
+ if not salt or not B:
+ raise SRPAuthenticationError(
+ "Server did not send initial data.")
+
+ try:
+ unhex_salt = safe_unhexlify(salt)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (salt)")
+ try:
+ unhex_B = safe_unhexlify(B)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (B)")
+
+ self.M = self.srp_usr.process_challenge(
+ unhex_salt,
+ unhex_B
+ )
+
+ proof_data = self.get_server_proof_data()
+
+ HAMK = proof_data.get("M2", None)
+ if not HAMK:
+ errors = proof_data.get('errors', None)
+ if errors:
+ logger.error(errors)
+ raise SRPAuthenticationError("Server did not send HAMK.")
+
+ try:
+ unhex_HAMK = safe_unhexlify(HAMK)
+ except TypeError:
+ raise SRPAuthenticationError(
+ "Bad data from server (HAMK)")
+
+ self.srp_usr.verify_session(
+ unhex_HAMK)
+
+ try:
+ assert self.srp_usr.authenticated()
+ logger.debug('user is authenticated!')
+ except (AssertionError):
+ raise SRPAuthenticationError(
+ "Auth verification failed.")
+
+ def __call__(self, req):
+ self.authenticate()
+ req.cookies = self.session.cookies
+ return req
+
+
+def srpauth_protected(user=None, passwd=None, server=None, verify=True):
+ """
+ decorator factory that accepts
+ user and password keyword arguments
+ and add those to the decorated request
+ """
+ def srpauth(fn):
+ def wrapper(*args, **kwargs):
+ if user and passwd:
+ auth = SRPAuth(user, passwd, server, verify)
+ kwargs['auth'] = auth
+ kwargs['verify'] = verify
+ if not args:
+ logger.warning('attempting to get from empty uri!')
+ return fn(*args, **kwargs)
+ return wrapper
+ return srpauth
+
+
+def get_leap_credentials():
+ settings = QtCore.QSettings()
+ full_username = settings.value('username')
+ username, domain = full_username.split('@')
+ seed = settings.value('%s_seed' % domain, None)
+ password = leapkeyring.leap_get_password(full_username, seed=seed)
+ return (username, password)
+
+
+# XXX TODO
+# Pass verify as single argument,
+# in srpauth_protected style
+
+def magick_srpauth(fn):
+ """
+ decorator that gets user and password
+ from the config file and adds those to
+ the decorated request
+ """
+ logger.debug('magick srp auth decorator called')
+
+ def wrapper(*args, **kwargs):
+ #uri = args[0]
+ # XXX Ugh!
+ # Problem with this approach.
+ # This won't work when we're using
+ # api.foo.bar
+ # Unless we keep a table with the
+ # equivalencies...
+ user, passwd = get_leap_credentials()
+
+ # XXX pass verify and server too
+ # (pop)
+ auth = SRPAuth(user, passwd)
+ kwargs['auth'] = auth
+ return fn(*args, **kwargs)
+ return wrapper
+
+
+if __name__ == "__main__":
+ """
+ To test against test_provider (twisted version)
+ Register an user: (will be valid during the session)
+ >>> python auth.py add test password
+
+ Test login with that user:
+ >>> python auth.py login test password
+ """
+
+ import sys
+
+ if len(sys.argv) not in (4, 5):
+ print 'Usage: auth <add|login> <user> <pass> [server]'
+ sys.exit(0)
+
+ action = sys.argv[1]
+ user = sys.argv[2]
+ passwd = sys.argv[3]
+
+ if len(sys.argv) == 5:
+ SERVER = sys.argv[4]
+ else:
+ SERVER = "https://localhost:8443"
+
+ if action == "login":
+
+ @srpauth_protected(
+ user=user, passwd=passwd, server=SERVER, verify=False)
+ def test_srp_protected_get(*args, **kwargs):
+ req = requests.get(*args, **kwargs)
+ req.raise_for_status
+ return req
+
+ #req = test_srp_protected_get('https://localhost:8443/1/cert')
+ req = test_srp_protected_get('%s/1/cert' % SERVER)
+ #print 'cert :', req.content[:200] + "..."
+ print req.content
+ sys.exit(0)
+
+ if action == "add":
+ auth = LeapSRPRegister(provider=SERVER, verify=False)
+ auth.register_user(user, passwd)
diff --git a/src/leap/base/authentication.py b/src/leap/base/authentication.py
new file mode 100644
index 00000000..09ff1d07
--- /dev/null
+++ b/src/leap/base/authentication.py
@@ -0,0 +1,11 @@
+"""
+Authentication Base Class
+"""
+
+
+class Authentication(object):
+ """
+ I have no idea how Authentication (certs,?)
+ will be done, but stub it here.
+ """
+ pass
diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py
new file mode 100644
index 00000000..0bf44f59
--- /dev/null
+++ b/src/leap/base/checks.py
@@ -0,0 +1,213 @@
+# -*- coding: utf-8 -*-
+import logging
+import platform
+import re
+import socket
+
+import netifaces
+import sh
+
+from leap.base import constants
+from leap.base import exceptions
+
+logger = logging.getLogger(name=__name__)
+_platform = platform.system()
+
+#EVENTS OF NOTE
+EVENT_CONNECT_REFUSED = "[ECONNREFUSED]: Connection refused (code=111)"
+
+ICMP_TARGET = "8.8.8.8"
+
+
+class LeapNetworkChecker(object):
+ """
+ all network related checks
+ """
+ def __init__(self, *args, **kwargs):
+ provider_gw = kwargs.pop('provider_gw', None)
+ self.provider_gateway = provider_gw
+
+ def run_all(self, checker=None):
+ if not checker:
+ checker = self
+ #self.error = None # ?
+
+ # for MVS
+ checker.check_tunnel_default_interface()
+ checker.check_internet_connection()
+ checker.is_internet_up()
+
+ if self.provider_gateway:
+ checker.ping_gateway(self.provider_gateway)
+
+ checker.parse_log_and_react([], ())
+
+ def check_internet_connection(self):
+ if _platform == "Linux":
+ try:
+ output = sh.ping("-c", "5", "-w", "5", ICMP_TARGET)
+ # XXX should redirect this to netcheck logger.
+ # and don't clutter main log.
+ logger.debug('Network appears to be up.')
+ except sh.ErrorReturnCode_1 as e:
+ packet_loss = re.findall("\d+% packet loss", e.message)[0]
+ logger.debug("Unidentified Connection Error: " + packet_loss)
+ if not self.is_internet_up():
+ error = "No valid internet connection found."
+ else:
+ error = "Provider server appears to be down."
+
+ logger.error(error)
+ raise exceptions.NoInternetConnection(error)
+
+ else:
+ raise NotImplementedError
+
+ def is_internet_up(self):
+ iface, gateway = self.get_default_interface_gateway()
+ try:
+ self.ping_gateway(self.provider_gateway)
+ except exceptions.NoConnectionToGateway:
+ return False
+ return True
+
+ def _get_route_table_linux(self):
+ # do not use context manager, tests pass a StringIO
+ f = open("/proc/net/route")
+ route_table = f.readlines()
+ f.close()
+ #toss out header
+ route_table.pop(0)
+ if not route_table:
+ raise exceptions.NoDefaultInterfaceFoundError
+ return route_table
+
+ def _get_def_iface_osx(self):
+ default_iface = None
+ #gateway = None
+ routes = list(sh.route('-n', 'get', ICMP_TARGET, _iter=True))
+ iface = filter(lambda l: "interface" in l, routes)
+ if not iface:
+ return None, None
+ def_ifacel = re.findall('\w+\d', iface[0])
+ default_iface = def_ifacel[0] if def_ifacel else None
+ if not default_iface:
+ return None, None
+ _gw = filter(lambda l: "gateway" in l, routes)
+ gw = re.findall('\d+\.\d+\.\d+\.\d+', _gw[0])[0]
+ return default_iface, gw
+
+ def _get_tunnel_iface_linux(self):
+ # XXX review.
+ # valid also when local router has a default entry?
+ route_table = self._get_route_table_linux()
+ line = route_table.pop(0)
+ iface, destination = line.split('\t')[0:2]
+ if not destination == '00000000' or not iface == 'tun0':
+ raise exceptions.TunnelNotDefaultRouteError()
+ return True
+
+ def check_tunnel_default_interface(self):
+ """
+ Raises an TunnelNotDefaultRouteError
+ if tun0 is not the chosen default route
+ (including when no routes are present)
+ """
+ #logger.debug('checking tunnel default interface...')
+
+ if _platform == "Linux":
+ valid = self._get_tunnel_iface_linux()
+ return valid
+ elif _platform == "Darwin":
+ default_iface, gw = self._get_def_iface_osx()
+ #logger.debug('iface: %s', default_iface)
+ if default_iface != "tun0":
+ logger.debug('tunnel not default route! gw: %s', default_iface)
+ # XXX should catch this and act accordingly...
+ # but rather, this test should only be launched
+ # when we have successfully completed a connection
+ # ... TRIGGER: Connection stablished (or whatever it is)
+ # in the logs
+ raise exceptions.TunnelNotDefaultRouteError
+ else:
+ #logger.debug('PLATFORM !!! %s', _platform)
+ raise NotImplementedError
+
+ def _get_def_iface_linux(self):
+ default_iface = None
+ gateway = None
+
+ route_table = self._get_route_table_linux()
+ while route_table:
+ line = route_table.pop(0)
+ iface, destination, gateway = line.split('\t')[0:3]
+ if destination == '00000000':
+ default_iface = iface
+ break
+ return default_iface, gateway
+
+ def get_default_interface_gateway(self):
+ """
+ gets the interface we are going thru.
+ (this should be merged with check tunnel default interface,
+ imo...)
+ """
+ if _platform == "Linux":
+ default_iface, gw = self._get_def_iface_linux()
+ elif _platform == "Darwin":
+ default_iface, gw = self._get_def_iface_osx()
+ else:
+ raise NotImplementedError
+
+ if not default_iface:
+ raise exceptions.NoDefaultInterfaceFoundError
+
+ if default_iface not in netifaces.interfaces():
+ raise exceptions.InterfaceNotFoundError
+ logger.debug('-- default iface %s', default_iface)
+ return default_iface, gw
+
+ def ping_gateway(self, gateway):
+ # TODO: Discuss how much packet loss (%) is acceptable.
+
+ # XXX -- validate gateway
+ # -- is it a valid ip? (there's something in util)
+ # -- is it a domain?
+ # -- can we resolve? -- raise NoDNSError if not.
+
+ # XXX -- sh.ping implemtation needs review!
+ try:
+ output = sh.ping("-c", "10", gateway).stdout
+ except sh.ErrorReturnCode_1 as e:
+ output = e.message
+ finally:
+ packet_loss = int(re.findall("(\d+)% packet loss", output)[0])
+
+ logger.debug('packet loss %s%%' % packet_loss)
+ if packet_loss > constants.MAX_ICMP_PACKET_LOSS:
+ raise exceptions.NoConnectionToGateway
+
+ def check_name_resolution(self, domain_name):
+ try:
+ socket.gethostbyname(domain_name)
+ return True
+ except socket.gaierror:
+ raise exceptions.CannotResolveDomainError
+
+ def parse_log_and_react(self, log, error_matrix=None):
+ """
+ compares the recent openvpn status log to
+ strings passed in and executes the callbacks passed in.
+ @param log: openvpn log
+ @type log: list of strings
+ @param error_matrix: tuples of strings and tuples of callbacks
+ @type error_matrix: tuples strings and call backs
+ """
+ for line in log:
+ # we could compile a regex here to save some cycles up -- kali
+ for each in error_matrix:
+ error, callbacks = each
+ if error in line:
+ for cb in callbacks:
+ if callable(cb):
+ cb()
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
new file mode 100644
index 00000000..85bb3d66
--- /dev/null
+++ b/src/leap/base/config.py
@@ -0,0 +1,348 @@
+"""
+Configuration Base Class
+"""
+import grp
+import json
+import logging
+import re
+import socket
+import time
+import os
+
+logger = logging.getLogger(name=__name__)
+
+from dateutil import parser as dateparser
+from xdg import BaseDirectory
+import requests
+
+from leap.base import exceptions
+from leap.base import constants
+from leap.base.pluggableconfig import PluggableConfig
+from leap.util.fileutil import (mkdir_p)
+
+# move to base!
+from leap.eip import exceptions as eipexceptions
+
+
+class BaseLeapConfig(object):
+ slug = None
+
+ # XXX we have to enforce that every derived class
+ # has a slug (via interface)
+ # get property getter that raises NI..
+
+ def save(self):
+ raise NotImplementedError("abstract base class")
+
+ def load(self):
+ raise NotImplementedError("abstract base class")
+
+ def get_config(self, *kwargs):
+ raise NotImplementedError("abstract base class")
+
+ @property
+ def config(self):
+ return self.get_config()
+
+ def get_value(self, *kwargs):
+ raise NotImplementedError("abstract base class")
+
+
+class MetaConfigWithSpec(type):
+ """
+ metaclass for JSONLeapConfig classes.
+ It creates a configuration spec out of
+ the `spec` dictionary. The `properties` attribute
+ of the spec dict is turn into the `schema` attribute
+ of the new class (which will be used to validate against).
+ """
+ # XXX in the near future, this is the
+ # place where we want to enforce
+ # singletons, read-only and similar stuff.
+
+ def __new__(meta, classname, bases, classDict):
+ schema_obj = classDict.get('spec', None)
+
+ # not quite happy with this workaround.
+ # I want to raise if missing spec dict, but only
+ # for grand-children of this metaclass.
+ # maybe should use abc module for this.
+ abcderived = ("JSONLeapConfig",)
+ if schema_obj is None and classname not in abcderived:
+ raise exceptions.ImproperlyConfigured(
+ "missing spec dict on your derived class (%s)" % classname)
+
+ # we create a configuration spec attribute
+ # from the spec dict
+ config_class = type(
+ classname + "Spec",
+ (PluggableConfig, object),
+ {'options': schema_obj})
+ classDict['spec'] = config_class
+
+ return type.__new__(meta, classname, bases, classDict)
+
+##########################################################
+# some hacking still in progress:
+
+# Configs have:
+
+# - a slug (from where a filename/folder is derived)
+# - a spec (for validation and defaults).
+# this spec is conformant to the json-schema.
+# basically a dict that will be used
+# for type casting and validation, and defaults settings.
+
+# all config objects, since they are derived from BaseConfig, implement basic
+# useful methods:
+# - save
+# - load
+
+##########################################################
+
+
+class JSONLeapConfig(BaseLeapConfig):
+
+ __metaclass__ = MetaConfigWithSpec
+
+ def __init__(self, *args, **kwargs):
+ # sanity check
+ try:
+ assert self.slug is not None
+ except AssertionError:
+ raise exceptions.ImproperlyConfigured(
+ "missing slug on JSONLeapConfig"
+ " derived class")
+ try:
+ assert self.spec is not None
+ except AssertionError:
+ raise exceptions.ImproperlyConfigured(
+ "missing spec on JSONLeapConfig"
+ " derived class")
+ assert issubclass(self.spec, PluggableConfig)
+
+ self.domain = kwargs.pop('domain', None)
+ self._config = self.spec(format="json")
+ self._config.load()
+ self.fetcher = kwargs.pop('fetcher', requests)
+
+ # mandatory baseconfig interface
+
+ def save(self, to=None, force=False):
+ """
+ force param will skip the dirty check.
+ :type force: bool
+ """
+ # XXX this force=True does not feel to right
+ # but still have to look for a better way
+ # of dealing with dirtiness and the
+ # trick of loading remote config only
+ # when newer.
+
+ if force:
+ do_save = True
+ else:
+ do_save = self._config.is_dirty()
+
+ if do_save:
+ if to is None:
+ to = self.filename
+ folder, filename = os.path.split(to)
+ if folder and not os.path.isdir(folder):
+ mkdir_p(folder)
+ self._config.serialize(to)
+ return True
+
+ else:
+ return False
+
+ def load(self, fromfile=None, from_uri=None, fetcher=None,
+ force_download=False, verify=True):
+
+ if from_uri is not None:
+ fetched = self.fetch(
+ from_uri,
+ fetcher=fetcher,
+ verify=verify,
+ force_dl=force_download)
+ if fetched:
+ return
+ if fromfile is None:
+ fromfile = self.filename
+ if os.path.isfile(fromfile):
+ self._config.load(fromfile=fromfile)
+ else:
+ logger.warning('tried to load config from non-existent path')
+ logger.warning('Not Found: %s', fromfile)
+
+ def fetch(self, uri, fetcher=None, verify=True, force_dl=False):
+ if not fetcher:
+ fetcher = self.fetcher
+
+ logger.debug('uri: %s (verify: %s)' % (uri, verify))
+
+ rargs = (uri, )
+ rkwargs = {'verify': verify}
+ headers = {}
+
+ curmtime = self.get_mtime() if not force_dl else None
+ if curmtime:
+ logger.debug('requesting with if-modified-since %s' % curmtime)
+ headers['if-modified-since'] = curmtime
+ rkwargs['headers'] = headers
+
+ #request = fetcher.get(uri, verify=verify)
+ request = fetcher.get(*rargs, **rkwargs)
+ request.raise_for_status()
+
+ if request.status_code == 304:
+ logger.debug('...304 Not Changed')
+ # On this point, we have to assume that
+ # we HAD the filename. If that filename is corruct,
+ # we should enforce a force_download in the load
+ # method above.
+ self._config.load(fromfile=self.filename)
+ return True
+
+ if request.json:
+ mtime = None
+ last_modified = request.headers.get('last-modified', None)
+ if last_modified:
+ _mtime = dateparser.parse(last_modified)
+ mtime = int(_mtime.strftime("%s"))
+ if callable(request.json):
+ _json = request.json()
+ else:
+ # back-compat
+ _json = request.json
+ self._config.load(json.dumps(_json), mtime=mtime)
+ self._config.set_dirty()
+ else:
+ # not request.json
+ # might be server did not announce content properly,
+ # let's try deserializing all the same.
+ try:
+ self._config.load(request.content)
+ self._config.set_dirty()
+ except ValueError:
+ raise eipexceptions.LeapBadConfigFetchedError
+
+ return True
+
+ def get_mtime(self):
+ try:
+ _mtime = os.stat(self.filename)[8]
+ mtime = time.strftime("%c GMT", time.gmtime(_mtime))
+ return mtime
+ except OSError:
+ return None
+
+ def get_config(self):
+ return self._config.config
+
+ # public methods
+
+ def get_filename(self):
+ return self._slug_to_filename()
+
+ @property
+ def filename(self):
+ return self.get_filename()
+
+ def validate(self, data):
+ logger.debug('validating schema')
+ self._config.validate(data)
+ return True
+
+ # private
+
+ def _slug_to_filename(self):
+ # is this going to work in winland if slug is "foo/bar" ?
+ folder, filename = os.path.split(self.slug)
+ config_file = get_config_file(filename, folder)
+ return config_file
+
+ def exists(self):
+ return os.path.isfile(self.filename)
+
+
+#
+# utility functions
+#
+# (might be moved to some class as we see fit, but
+# let's remain functional for a while)
+# maybe base.config.util ??
+#
+
+
+def get_config_dir():
+ """
+ get the base dir for all leap config
+ @rparam: config path
+ @rtype: string
+ """
+ home = os.path.expanduser("~")
+ if re.findall("leap_tests-[_a-zA-Z0-9]{6}", home):
+ # we're inside a test! :)
+ return os.path.join(home, ".config/leap")
+ else:
+ # XXX dirspec is cross-platform,
+ # we should borrow some of those
+ # routines for osx/win and wrap this call.
+ return os.path.join(BaseDirectory.xdg_config_home,
+ 'leap')
+
+
+def get_config_file(filename, folder=None):
+ """
+ concatenates the given filename
+ with leap config dir.
+ @param filename: name of the file
+ @type filename: string
+ @rparam: full path to config file
+ """
+ path = []
+ path.append(get_config_dir())
+ if folder is not None:
+ path.append(folder)
+ path.append(filename)
+ return os.path.join(*path)
+
+
+def get_default_provider_path():
+ default_subpath = os.path.join("providers",
+ constants.DEFAULT_PROVIDER)
+ default_provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return default_provider_path
+
+
+def get_provider_path(domain):
+ # XXX if not domain, return get_default_provider_path
+ default_subpath = os.path.join("providers", domain)
+ provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return provider_path
+
+
+def validate_ip(ip_str):
+ """
+ raises exception if the ip_str is
+ not a valid representation of an ip
+ """
+ socket.inet_aton(ip_str)
+
+
+def get_username():
+ try:
+ return os.getlogin()
+ except OSError as e:
+ import pwd
+ return pwd.getpwuid(os.getuid())[0]
+
+
+def get_groupname():
+ gid = os.getgroups()[-1]
+ return grp.getgrgid(gid).gr_name
diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py
new file mode 100644
index 00000000..41d13935
--- /dev/null
+++ b/src/leap/base/connection.py
@@ -0,0 +1,115 @@
+"""
+Base Connection Classs
+"""
+from __future__ import (division, unicode_literals, print_function)
+
+import logging
+
+from leap.base.authentication import Authentication
+
+logger = logging.getLogger(name=__name__)
+
+
+class Connection(Authentication):
+ # JSONLeapConfig
+ #spec = {}
+
+ def __init__(self, *args, **kwargs):
+ self.connection_state = None
+ self.desired_connection_state = None
+ #XXX FIXME diamond inheritance gotcha..
+ #If you inherit from >1 class,
+ #super is only initializing one
+ #of the bases..!!
+ # I think we better pass config as a constructor
+ # parameter -- kali 2012-08-30 04:33
+ super(Connection, self).__init__(*args, **kwargs)
+
+ def connect(self):
+ """
+ entry point for connection process
+ """
+ pass
+
+ def disconnect(self):
+ """
+ disconnects client
+ """
+ pass
+
+ #def shutdown(self):
+ #"""
+ #shutdown and quit
+ #"""
+ #self.desired_con_state = self.status.DISCONNECTED
+
+ def connection_state(self):
+ """
+ returns the current connection state
+ """
+ return self.status.current
+
+ def desired_connection_state(self):
+ """
+ returns the desired_connection state
+ """
+ return self.desired_connection_state
+
+ def get_icon_name(self):
+ """
+ get icon name from status object
+ """
+ return self.status.get_state_icon()
+
+ #
+ # private methods
+ #
+
+ def _disconnect(self):
+ """
+ private method for disconnecting
+ """
+ if self.subp is not None:
+ self.subp.terminate()
+ self.subp = None
+ # XXX signal state changes! :)
+
+ def _is_alive(self):
+ """
+ don't know yet
+ """
+ pass
+
+ def _connect(self):
+ """
+ entry point for connection cascade methods.
+ """
+ #conn_result = ConState.DISCONNECTED
+ try:
+ conn_result = self._try_connection()
+ except UnrecoverableError as except_msg:
+ logger.error("FATAL: %s" % unicode(except_msg))
+ conn_result = self.status.UNRECOVERABLE
+ except Exception as except_msg:
+ self.error_queue.append(except_msg)
+ logger.error("Failed Connection: %s" %
+ unicode(except_msg))
+ return conn_result
+
+
+class ConnectionError(Exception):
+ """
+ generic connection error
+ """
+ def __str__(self):
+ if len(self.args) >= 1:
+ return repr(self.args[0])
+ else:
+ raise self()
+
+
+class UnrecoverableError(ConnectionError):
+ """
+ we cannot do anything about it, sorry
+ """
+ pass
diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py
new file mode 100644
index 00000000..f5665e5f
--- /dev/null
+++ b/src/leap/base/constants.py
@@ -0,0 +1,42 @@
+"""constants to be used in base module"""
+from leap import __branding
+APP_NAME = __branding.get("short_name", "leap-client")
+OPENVPN_BIN = "openvpn"
+
+# default provider placeholder
+# using `example.org` we make sure that this
+# is not going to be resolved during the tests phases
+# (we expect testers to add it to their /etc/hosts
+
+DEFAULT_PROVIDER = __branding.get(
+ "provider_domain",
+ "testprovider.example.org")
+
+DEFINITION_EXPECTED_PATH = "provider.json"
+
+DEFAULT_PROVIDER_DEFINITION = {
+ u"api_uri": "https://api.%s/" % DEFAULT_PROVIDER,
+ u"api_version": u"1",
+ u"ca_cert_fingerprint": "SHA256: fff",
+ u"ca_cert_uri": u"https://%s/ca.crt" % DEFAULT_PROVIDER,
+ u"default_language": u"en",
+ u"description": {
+ u"en": u"A demonstration service provider using the LEAP platform"
+ },
+ u"domain": "%s" % DEFAULT_PROVIDER,
+ u"enrollment_policy": u"open",
+ u"languages": [
+ u"en"
+ ],
+ u"name": {
+ u"en": u"Test Provider"
+ },
+ u"services": [
+ "openvpn"
+ ]
+}
+
+
+MAX_ICMP_PACKET_LOSS = 10
+
+ROUTE_CHECK_INTERVAL = 10
diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py
new file mode 100644
index 00000000..2e31b33b
--- /dev/null
+++ b/src/leap/base/exceptions.py
@@ -0,0 +1,97 @@
+"""
+Exception attributes and their meaning/uses
+-------------------------------------------
+
+* critical: if True, will abort execution prematurely,
+ after attempting any cleaning
+ action.
+
+* failfirst: breaks any error_check loop that is examining
+ the error queue.
+
+* message: the message that will be used in the __repr__ of the exception.
+
+* usermessage: the message that will be passed to user in ErrorDialogs
+ in Qt-land.
+"""
+from leap.util.translations import translate
+
+
+class LeapException(Exception):
+ """
+ base LeapClient exception
+ sets some parameters that we will check
+ during error checking routines
+ """
+
+ critical = False
+ failfirst = False
+ warning = False
+
+
+class CriticalError(LeapException):
+ """
+ we cannot do anything about it
+ """
+ critical = True
+ failfirst = True
+
+
+# In use ???
+# don't thing so. purge if not...
+
+class MissingConfigFileError(Exception):
+ pass
+
+
+class ImproperlyConfigured(Exception):
+ pass
+
+
+# NOTE: "Errors" (context) has to be a explicit string!
+
+
+class InterfaceNotFoundError(LeapException):
+ # XXX should take iface arg on init maybe?
+ message = "interface not found"
+ usermessage = translate(
+ "Errors",
+ "Interface not found")
+
+
+class NoDefaultInterfaceFoundError(LeapException):
+ message = "no default interface found"
+ usermessage = translate(
+ "Errors",
+ "Looks like your computer "
+ "is not connected to the internet")
+
+
+class NoConnectionToGateway(CriticalError):
+ message = "no connection to gateway"
+ usermessage = translate(
+ "Errors",
+ "Looks like there are problems "
+ "with your internet connection")
+
+
+class NoInternetConnection(CriticalError):
+ message = "No Internet connection found"
+ usermessage = translate(
+ "Errors",
+ "It looks like there is no internet connection.")
+ # and now we try to connect to our web to troubleshoot LOL :P
+
+
+class CannotResolveDomainError(LeapException):
+ message = "Cannot resolve domain"
+ usermessage = translate(
+ "Errors",
+ "Domain cannot be found")
+
+
+class TunnelNotDefaultRouteError(LeapException):
+ message = "Tunnel connection dissapeared. VPN down?"
+ usermessage = translate(
+ "Errors",
+ "The Encrypted Connection was lost.")
diff --git a/src/leap/base/jsonschema.py b/src/leap/base/jsonschema.py
new file mode 100644
index 00000000..56689b08
--- /dev/null
+++ b/src/leap/base/jsonschema.py
@@ -0,0 +1,791 @@
+"""
+An implementation of JSON Schema for Python
+
+The main functionality is provided by the validator classes for each of the
+supported JSON Schema versions.
+
+Most commonly, :func:`validate` is the quickest way to simply validate a given
+instance under a schema, and will create a validator for you.
+
+"""
+
+from __future__ import division, unicode_literals
+
+import collections
+import json
+import itertools
+import operator
+import re
+import sys
+
+
+__version__ = "0.8.0"
+
+PY3 = sys.version_info[0] >= 3
+
+if PY3:
+ from urllib import parse as urlparse
+ from urllib.parse import unquote
+ from urllib.request import urlopen
+ basestring = unicode = str
+ iteritems = operator.methodcaller("items")
+else:
+ from itertools import izip as zip
+ from urllib import unquote
+ from urllib2 import urlopen
+ import urlparse
+ iteritems = operator.methodcaller("iteritems")
+
+
+FLOAT_TOLERANCE = 10 ** -15
+validators = {}
+
+
+def validates(version):
+ """
+ Register the decorated validator for a ``version`` of the specification.
+
+ Registered validators and their meta schemas will be considered when
+ parsing ``$schema`` properties' URIs.
+
+ :argument str version: an identifier to use as the version's name
+ :returns: a class decorator to decorate the validator with the version
+
+ """
+
+ def _validates(cls):
+ validators[version] = cls
+ return cls
+ return _validates
+
+
+class UnknownType(Exception):
+ """
+ An attempt was made to check if an instance was of an unknown type.
+
+ """
+
+
+class RefResolutionError(Exception):
+ """
+ A JSON reference failed to resolve.
+
+ """
+
+
+class SchemaError(Exception):
+ """
+ The provided schema is malformed.
+
+ The same attributes are present as for :exc:`ValidationError`\s.
+
+ """
+
+ def __init__(self, message, validator=None, path=()):
+ super(SchemaError, self).__init__(message, validator, path)
+ self.message = message
+ self.path = list(path)
+ self.validator = validator
+
+ def __str__(self):
+ return self.message
+
+
+class ValidationError(Exception):
+ """
+ The instance didn't properly validate under the provided schema.
+
+ Relevant attributes are:
+ * ``message`` : a human readable message explaining the error
+ * ``path`` : a list containing the path to the offending element (or []
+ if the error happened globally) in *reverse* order (i.e.
+ deepest index first).
+
+ """
+
+ def __init__(self, message, validator=None, path=()):
+ # Any validator that recurses (e.g. properties and items) must append
+ # to the ValidationError's path to properly maintain where in the
+ # instance the error occurred
+ super(ValidationError, self).__init__(message, validator, path)
+ self.message = message
+ self.path = list(path)
+ self.validator = validator
+
+ def __str__(self):
+ return self.message
+
+
+@validates("draft3")
+class Draft3Validator(object):
+ """
+ A validator for JSON Schema draft 3.
+
+ """
+
+ DEFAULT_TYPES = {
+ "array": list, "boolean": bool, "integer": int, "null": type(None),
+ "number": (int, float), "object": dict, "string": basestring,
+ }
+
+ def __init__(self, schema, types=(), resolver=None):
+ self._types = dict(self.DEFAULT_TYPES)
+ self._types.update(types)
+
+ if resolver is None:
+ resolver = RefResolver.from_schema(schema)
+
+ self.resolver = resolver
+ self.schema = schema
+
+ def is_type(self, instance, type):
+ if type == "any":
+ return True
+ elif type not in self._types:
+ raise UnknownType(type)
+ type = self._types[type]
+
+ # bool inherits from int, so ensure bools aren't reported as integers
+ if isinstance(instance, bool):
+ type = _flatten(type)
+ if int in type and bool not in type:
+ return False
+ return isinstance(instance, type)
+
+ def is_valid(self, instance, _schema=None):
+ error = next(self.iter_errors(instance, _schema), None)
+ return error is None
+
+ @classmethod
+ def check_schema(cls, schema):
+ for error in cls(cls.META_SCHEMA).iter_errors(schema):
+ raise SchemaError(
+ error.message, validator=error.validator, path=error.path,
+ )
+
+ def iter_errors(self, instance, _schema=None):
+ if _schema is None:
+ _schema = self.schema
+
+ for k, v in iteritems(_schema):
+ validator = getattr(self, "validate_%s" % (k.lstrip("$"),), None)
+
+ if validator is None:
+ continue
+
+ errors = validator(v, instance, _schema) or ()
+ for error in errors:
+ # set the validator if it wasn't already set by the called fn
+ if error.validator is None:
+ error.validator = k
+ yield error
+
+ def validate(self, *args, **kwargs):
+ for error in self.iter_errors(*args, **kwargs):
+ raise error
+
+ def validate_type(self, types, instance, schema):
+ types = _list(types)
+
+ for type in types:
+ if self.is_type(type, "object"):
+ if self.is_valid(instance, type):
+ return
+ elif self.is_type(type, "string"):
+ if self.is_type(instance, type):
+ return
+ else:
+ yield ValidationError(_types_msg(instance, types))
+
+ def validate_properties(self, properties, instance, schema):
+ if not self.is_type(instance, "object"):
+ return
+
+ for property, subschema in iteritems(properties):
+ if property in instance:
+ for error in self.iter_errors(instance[property], subschema):
+ error.path.append(property)
+ yield error
+ elif subschema.get("required", False):
+ yield ValidationError(
+ "%r is a required property" % (property,),
+ validator="required",
+ path=[property],
+ )
+
+ def validate_patternProperties(self, patternProperties, instance, schema):
+ if not self.is_type(instance, "object"):
+ return
+
+ for pattern, subschema in iteritems(patternProperties):
+ for k, v in iteritems(instance):
+ if re.match(pattern, k):
+ for error in self.iter_errors(v, subschema):
+ yield error
+
+ def validate_additionalProperties(self, aP, instance, schema):
+ if not self.is_type(instance, "object"):
+ return
+
+ extras = set(_find_additional_properties(instance, schema))
+
+ if self.is_type(aP, "object"):
+ for extra in extras:
+ for error in self.iter_errors(instance[extra], aP):
+ yield error
+ elif not aP and extras:
+ error = "Additional properties are not allowed (%s %s unexpected)"
+ yield ValidationError(error % _extras_msg(extras))
+
+ def validate_dependencies(self, dependencies, instance, schema):
+ if not self.is_type(instance, "object"):
+ return
+
+ for property, dependency in iteritems(dependencies):
+ if property not in instance:
+ continue
+
+ if self.is_type(dependency, "object"):
+ for error in self.iter_errors(instance, dependency):
+ yield error
+ else:
+ dependencies = _list(dependency)
+ for dependency in dependencies:
+ if dependency not in instance:
+ yield ValidationError(
+ "%r is a dependency of %r" % (dependency, property)
+ )
+
+ def validate_items(self, items, instance, schema):
+ if not self.is_type(instance, "array"):
+ return
+
+ if self.is_type(items, "object"):
+ for index, item in enumerate(instance):
+ for error in self.iter_errors(item, items):
+ error.path.append(index)
+ yield error
+ else:
+ for (index, item), subschema in zip(enumerate(instance), items):
+ for error in self.iter_errors(item, subschema):
+ error.path.append(index)
+ yield error
+
+ def validate_additionalItems(self, aI, instance, schema):
+ if (
+ not self.is_type(instance, "array") or
+ not self.is_type(schema.get("items"), "array")
+ ):
+ return
+
+ if self.is_type(aI, "object"):
+ for item in instance[len(schema):]:
+ for error in self.iter_errors(item, aI):
+ yield error
+ elif not aI and len(instance) > len(schema.get("items", [])):
+ error = "Additional items are not allowed (%s %s unexpected)"
+ yield ValidationError(
+ error % _extras_msg(instance[len(schema.get("items", [])):])
+ )
+
+ def validate_minimum(self, minimum, instance, schema):
+ if not self.is_type(instance, "number"):
+ return
+
+ instance = float(instance)
+ if schema.get("exclusiveMinimum", False):
+ failed = instance <= minimum
+ cmp = "less than or equal to"
+ else:
+ failed = instance < minimum
+ cmp = "less than"
+
+ if failed:
+ yield ValidationError(
+ "%r is %s the minimum of %r" % (instance, cmp, minimum)
+ )
+
+ def validate_maximum(self, maximum, instance, schema):
+ if not self.is_type(instance, "number"):
+ return
+
+ instance = float(instance)
+ if schema.get("exclusiveMaximum", False):
+ failed = instance >= maximum
+ cmp = "greater than or equal to"
+ else:
+ failed = instance > maximum
+ cmp = "greater than"
+
+ if failed:
+ yield ValidationError(
+ "%r is %s the maximum of %r" % (instance, cmp, maximum)
+ )
+
+ def validate_minItems(self, mI, instance, schema):
+ if self.is_type(instance, "array") and len(instance) < mI:
+ yield ValidationError("%r is too short" % (instance,))
+
+ def validate_maxItems(self, mI, instance, schema):
+ if self.is_type(instance, "array") and len(instance) > mI:
+ yield ValidationError("%r is too long" % (instance,))
+
+ def validate_uniqueItems(self, uI, instance, schema):
+ if uI and self.is_type(instance, "array") and not _uniq(instance):
+ yield ValidationError("%r has non-unique elements" % instance)
+
+ def validate_pattern(self, patrn, instance, schema):
+ if self.is_type(instance, "string") and not re.match(patrn, instance):
+ yield ValidationError("%r does not match %r" % (instance, patrn))
+
+ def validate_minLength(self, mL, instance, schema):
+ if self.is_type(instance, "string") and len(instance) < mL:
+ yield ValidationError("%r is too short" % (instance,))
+
+ def validate_maxLength(self, mL, instance, schema):
+ if self.is_type(instance, "string") and len(instance) > mL:
+ yield ValidationError("%r is too long" % (instance,))
+
+ def validate_enum(self, enums, instance, schema):
+ if instance not in enums:
+ yield ValidationError("%r is not one of %r" % (instance, enums))
+
+ def validate_divisibleBy(self, dB, instance, schema):
+ if not self.is_type(instance, "number"):
+ return
+
+ if isinstance(dB, float):
+ mod = instance % dB
+ failed = (mod > FLOAT_TOLERANCE) and (dB - mod) > FLOAT_TOLERANCE
+ else:
+ failed = instance % dB
+
+ if failed:
+ yield ValidationError("%r is not divisible by %r" % (instance, dB))
+
+ def validate_disallow(self, disallow, instance, schema):
+ for disallowed in _list(disallow):
+ if self.is_valid(instance, {"type": [disallowed]}):
+ yield ValidationError(
+ "%r is disallowed for %r" % (disallowed, instance)
+ )
+
+ def validate_extends(self, extends, instance, schema):
+ if self.is_type(extends, "object"):
+ extends = [extends]
+ for subschema in extends:
+ for error in self.iter_errors(instance, subschema):
+ yield error
+
+ def validate_ref(self, ref, instance, schema):
+ resolved = self.resolver.resolve(ref)
+ for error in self.iter_errors(instance, resolved):
+ yield error
+
+
+Draft3Validator.META_SCHEMA = {
+ "$schema": "http://json-schema.org/draft-03/schema#",
+ "id": "http://json-schema.org/draft-03/schema#",
+ "type": "object",
+
+ "properties": {
+ "type": {
+ "type": ["string", "array"],
+ "items": {"type": ["string", {"$ref": "#"}]},
+ "uniqueItems": True,
+ "default": "any"
+ },
+ "properties": {
+ "type": "object",
+ "additionalProperties": {"$ref": "#", "type": "object"},
+ "default": {}
+ },
+ "patternProperties": {
+ "type": "object",
+ "additionalProperties": {"$ref": "#"},
+ "default": {}
+ },
+ "additionalProperties": {
+ "type": [{"$ref": "#"}, "boolean"], "default": {}
+ },
+ "items": {
+ "type": [{"$ref": "#"}, "array"],
+ "items": {"$ref": "#"},
+ "default": {}
+ },
+ "additionalItems": {
+ "type": [{"$ref": "#"}, "boolean"], "default": {}
+ },
+ "required": {"type": "boolean", "default": False},
+ "dependencies": {
+ "type": ["string", "array", "object"],
+ "additionalProperties": {
+ "type": ["string", "array", {"$ref": "#"}],
+ "items": {"type": "string"}
+ },
+ "default": {}
+ },
+ "minimum": {"type": "number"},
+ "maximum": {"type": "number"},
+ "exclusiveMinimum": {"type": "boolean", "default": False},
+ "exclusiveMaximum": {"type": "boolean", "default": False},
+ "minItems": {"type": "integer", "minimum": 0, "default": 0},
+ "maxItems": {"type": "integer", "minimum": 0},
+ "uniqueItems": {"type": "boolean", "default": False},
+ "pattern": {"type": "string", "format": "regex"},
+ "minLength": {"type": "integer", "minimum": 0, "default": 0},
+ "maxLength": {"type": "integer"},
+ "enum": {"type": "array", "minItems": 1, "uniqueItems": True},
+ "default": {"type": "any"},
+ "title": {"type": "string"},
+ "description": {"type": "string"},
+ "format": {"type": "string"},
+ "maxDecimal": {"type": "number", "minimum": 0},
+ "divisibleBy": {
+ "type": "number",
+ "minimum": 0,
+ "exclusiveMinimum": True,
+ "default": 1
+ },
+ "disallow": {
+ "type": ["string", "array"],
+ "items": {"type": ["string", {"$ref": "#"}]},
+ "uniqueItems": True
+ },
+ "extends": {
+ "type": [{"$ref": "#"}, "array"],
+ "items": {"$ref": "#"},
+ "default": {}
+ },
+ "id": {"type": "string", "format": "uri"},
+ "$ref": {"type": "string", "format": "uri"},
+ "$schema": {"type": "string", "format": "uri"},
+ },
+ "dependencies": {
+ "exclusiveMinimum": "minimum", "exclusiveMaximum": "maximum"
+ },
+}
+
+
+class RefResolver(object):
+ """
+ Resolve JSON References.
+
+ :argument str base_uri: URI of the referring document
+ :argument referrer: the actual referring document
+ :argument dict store: a mapping from URIs to documents to cache
+
+ """
+
+ def __init__(self, base_uri, referrer, store=()):
+ self.base_uri = base_uri
+ self.referrer = referrer
+ self.store = dict(store, **_meta_schemas())
+
+ @classmethod
+ def from_schema(cls, schema, *args, **kwargs):
+ """
+ Construct a resolver from a JSON schema object.
+
+ :argument schema schema: the referring schema
+ :rtype: :class:`RefResolver`
+
+ """
+
+ return cls(schema.get("id", ""), schema, *args, **kwargs)
+
+ def resolve(self, ref):
+ """
+ Resolve a JSON ``ref``.
+
+ :argument str ref: reference to resolve
+ :returns: the referrant document
+
+ """
+
+ base_uri = self.base_uri
+ uri, fragment = urlparse.urldefrag(urlparse.urljoin(base_uri, ref))
+
+ if uri in self.store:
+ document = self.store[uri]
+ elif not uri or uri == self.base_uri:
+ document = self.referrer
+ else:
+ document = self.resolve_remote(uri)
+
+ return self.resolve_fragment(document, fragment.lstrip("/"))
+
+ def resolve_fragment(self, document, fragment):
+ """
+ Resolve a ``fragment`` within the referenced ``document``.
+
+ :argument document: the referrant document
+ :argument str fragment: a URI fragment to resolve within it
+
+ """
+
+ parts = unquote(fragment).split("/") if fragment else []
+
+ for part in parts:
+ part = part.replace("~1", "/").replace("~0", "~")
+
+ if part not in document:
+ raise RefResolutionError(
+ "Unresolvable JSON pointer: %r" % fragment
+ )
+
+ document = document[part]
+
+ return document
+
+ def resolve_remote(self, uri):
+ """
+ Resolve a remote ``uri``.
+
+ Does not check the store first.
+
+ :argument str uri: the URI to resolve
+ :returns: the retrieved document
+
+ """
+
+ return json.load(urlopen(uri))
+
+
+class ErrorTree(object):
+ """
+ ErrorTrees make it easier to check which validations failed.
+
+ """
+
+ def __init__(self, errors=()):
+ self.errors = {}
+ self._contents = collections.defaultdict(self.__class__)
+
+ for error in errors:
+ container = self
+ for element in reversed(error.path):
+ container = container[element]
+ container.errors[error.validator] = error
+
+ def __contains__(self, k):
+ return k in self._contents
+
+ def __getitem__(self, k):
+ """
+ Retrieve the child tree with key ``k``.
+
+ """
+
+ return self._contents[k]
+
+ def __setitem__(self, k, v):
+ self._contents[k] = v
+
+ def __iter__(self):
+ return iter(self._contents)
+
+ def __len__(self):
+ return self.total_errors
+
+ def __repr__(self):
+ return "<%s (%s total errors)>" % (self.__class__.__name__, len(self))
+
+ @property
+ def total_errors(self):
+ """
+ The total number of errors in the entire tree, including children.
+
+ """
+
+ child_errors = sum(len(tree) for _, tree in iteritems(self._contents))
+ return len(self.errors) + child_errors
+
+
+def _meta_schemas():
+ """
+ Collect the urls and meta schemas from each known validator.
+
+ """
+
+ meta_schemas = (v.META_SCHEMA for v in validators.values())
+ return dict((urlparse.urldefrag(m["id"])[0], m) for m in meta_schemas)
+
+
+def _find_additional_properties(instance, schema):
+ """
+ Return the set of additional properties for the given ``instance``.
+
+ Weeds out properties that should have been validated by ``properties`` and
+ / or ``patternProperties``.
+
+ Assumes ``instance`` is dict-like already.
+
+ """
+
+ properties = schema.get("properties", {})
+ patterns = "|".join(schema.get("patternProperties", {}))
+ for property in instance:
+ if property not in properties:
+ if patterns and re.search(patterns, property):
+ continue
+ yield property
+
+
+def _extras_msg(extras):
+ """
+ Create an error message for extra items or properties.
+
+ """
+
+ if len(extras) == 1:
+ verb = "was"
+ else:
+ verb = "were"
+ return ", ".join(repr(extra) for extra in extras), verb
+
+
+def _types_msg(instance, types):
+ """
+ Create an error message for a failure to match the given types.
+
+ If the ``instance`` is an object and contains a ``name`` property, it will
+ be considered to be a description of that object and used as its type.
+
+ Otherwise the message is simply the reprs of the given ``types``.
+
+ """
+
+ reprs = []
+ for type in types:
+ try:
+ reprs.append(repr(type["name"]))
+ except Exception:
+ reprs.append(repr(type))
+ return "%r is not of type %s" % (instance, ", ".join(reprs))
+
+
+def _flatten(suitable_for_isinstance):
+ """
+ isinstance() can accept a bunch of really annoying different types:
+ * a single type
+ * a tuple of types
+ * an arbitrary nested tree of tuples
+
+ Return a flattened tuple of the given argument.
+
+ """
+
+ types = set()
+
+ if not isinstance(suitable_for_isinstance, tuple):
+ suitable_for_isinstance = (suitable_for_isinstance,)
+ for thing in suitable_for_isinstance:
+ if isinstance(thing, tuple):
+ types.update(_flatten(thing))
+ else:
+ types.add(thing)
+ return tuple(types)
+
+
+def _list(thing):
+ """
+ Wrap ``thing`` in a list if it's a single str.
+
+ Otherwise, return it unchanged.
+
+ """
+
+ if isinstance(thing, basestring):
+ return [thing]
+ return thing
+
+
+def _delist(thing):
+ """
+ Unwrap ``thing`` to a single element if its a single str in a list.
+
+ Otherwise, return it unchanged.
+
+ """
+
+ if (
+ isinstance(thing, list) and
+ len(thing) == 1
+ and isinstance(thing[0], basestring)
+ ):
+ return thing[0]
+ return thing
+
+
+def _unbool(element, true=object(), false=object()):
+ """
+ A hack to make True and 1 and False and 0 unique for _uniq.
+
+ """
+
+ if element is True:
+ return true
+ elif element is False:
+ return false
+ return element
+
+
+def _uniq(container):
+ """
+ Check if all of a container's elements are unique.
+
+ Successively tries first to rely that the elements are hashable, then
+ falls back on them being sortable, and finally falls back on brute
+ force.
+
+ """
+
+ try:
+ return len(set(_unbool(i) for i in container)) == len(container)
+ except TypeError:
+ try:
+ sort = sorted(_unbool(i) for i in container)
+ sliced = itertools.islice(sort, 1, None)
+ for i, j in zip(sort, sliced):
+ if i == j:
+ return False
+ except (NotImplementedError, TypeError):
+ seen = []
+ for e in container:
+ e = _unbool(e)
+ if e in seen:
+ return False
+ seen.append(e)
+ return True
+
+
+def validate(instance, schema, cls=Draft3Validator, *args, **kwargs):
+ """
+ Validate an ``instance`` under the given ``schema``.
+
+ >>> validate([2, 3, 4], {"maxItems" : 2})
+ Traceback (most recent call last):
+ ...
+ ValidationError: [2, 3, 4] is too long
+
+ :func:`validate` will first verify that the provided schema is itself
+ valid, since not doing so can lead to less obvious error messages and fail
+ in less obvious or consistent ways. If you know you have a valid schema
+ already or don't care, you might prefer using the ``validate`` method
+ directly on a specific validator (e.g. :meth:`Draft3Validator.validate`).
+
+ ``cls`` is a validator class that will be used to validate the instance.
+ By default this is a draft 3 validator. Any other provided positional and
+ keyword arguments will be provided to this class when constructing a
+ validator.
+
+ :raises:
+ :exc:`ValidationError` if the instance is invalid
+
+ :exc:`SchemaError` if the schema itself is invalid
+
+ """
+
+ cls.check_schema(schema)
+ cls(schema, *args, **kwargs).validate(instance)
diff --git a/src/leap/base/network.py b/src/leap/base/network.py
new file mode 100644
index 00000000..d841e692
--- /dev/null
+++ b/src/leap/base/network.py
@@ -0,0 +1,107 @@
+# -*- coding: utf-8 -*-
+from __future__ import (print_function)
+import logging
+import threading
+
+from leap.eip import config as eipconfig
+from leap.base.checks import LeapNetworkChecker
+from leap.base.constants import ROUTE_CHECK_INTERVAL
+from leap.base.exceptions import TunnelNotDefaultRouteError
+from leap.util.misc import null_check
+from leap.util.coroutines import (launch_thread, process_events)
+
+from time import sleep
+
+logger = logging.getLogger(name=__name__)
+
+
+class NetworkCheckerThread(object):
+ """
+ Manages network checking thread that makes sure we have a working network
+ connection.
+ """
+ def __init__(self, *args, **kwargs):
+
+ self.status_signals = kwargs.pop('status_signals', None)
+ self.error_cb = kwargs.pop(
+ 'error_cb',
+ lambda exc: logger.error("%s", exc.message))
+ self.shutdown = threading.Event()
+
+ # XXX get provider passed here
+ provider = kwargs.pop('provider', None)
+ null_check(provider, 'provider')
+
+ eipconf = eipconfig.EIPConfig(domain=provider)
+ eipconf.load()
+ eipserviceconf = eipconfig.EIPServiceConfig(domain=provider)
+ eipserviceconf.load()
+
+ gw = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipserviceconf)
+ self.checker = LeapNetworkChecker(
+ provider_gw=gw)
+
+ def start(self):
+ self.process_handle = self._launch_recurrent_network_checks(
+ (self.error_cb,))
+
+ def stop(self):
+ self.process_handle.join(timeout=0.1)
+ self.shutdown.set()
+ logger.debug("network checked stopped.")
+
+ def run_checks(self):
+ pass
+
+ #private methods
+
+ #here all the observers in fail_callbacks expect one positional argument,
+ #which is exception so we can try by passing a lambda with logger to
+ #check it works.
+
+ def _network_checks_thread(self, fail_callbacks):
+ #TODO: replace this with waiting for a signal from openvpn
+ while True:
+ try:
+ self.checker.check_tunnel_default_interface()
+ break
+ except TunnelNotDefaultRouteError:
+ # XXX ??? why do we sleep here???
+ # aa: If the openvpn isn't up and running yet,
+ # let's give it a moment to breath.
+ #logger.error('NOT DEFAULT ROUTE!----')
+ # Instead of this, we should flag when the
+ # iface IS SUPPOSED to be up imo. -- kali
+ sleep(1)
+
+ fail_observer_dict = dict(((
+ observer,
+ process_events(observer)) for observer in fail_callbacks))
+
+ while not self.shutdown.is_set():
+ try:
+ self.checker.check_tunnel_default_interface()
+ self.checker.check_internet_connection()
+ sleep(ROUTE_CHECK_INTERVAL)
+ except Exception as exc:
+ for obs in fail_observer_dict:
+ fail_observer_dict[obs].send(exc)
+ sleep(ROUTE_CHECK_INTERVAL)
+
+ #reset event
+ # I see a problem with this. You cannot stop it, it
+ # resets itself forever. -- kali
+
+ # XXX use QTimer for the recurrent triggers,
+ # and ditch the sleeps.
+ logger.debug('resetting event')
+ self.shutdown.clear()
+
+ def _launch_recurrent_network_checks(self, fail_callbacks):
+ # XXX reimplement using QTimer -- kali
+ watcher = launch_thread(
+ self._network_checks_thread,
+ (fail_callbacks,))
+ return watcher
diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py
new file mode 100644
index 00000000..6f9f3f6f
--- /dev/null
+++ b/src/leap/base/pluggableconfig.py
@@ -0,0 +1,462 @@
+"""
+generic configuration handlers
+"""
+import copy
+import json
+import logging
+import os
+import time
+import urlparse
+
+import jsonschema
+
+from leap.util.translations import LEAPTranslatable
+
+logger = logging.getLogger(__name__)
+
+
+__all__ = ['PluggableConfig',
+ 'adaptors',
+ 'types',
+ 'UnknownOptionException',
+ 'MissingValueException',
+ 'ConfigurationProviderException',
+ 'TypeCastException']
+
+# exceptions
+
+
+class ValidationError(Exception):
+ pass
+
+
+class UnknownOptionException(Exception):
+ """exception raised when a non-configuration
+ value is present in the configuration"""
+
+
+class MissingValueException(Exception):
+ """exception raised when a required value is missing"""
+
+
+class ConfigurationProviderException(Exception):
+ """exception raised when a configuration provider is missing, etc"""
+
+
+class TypeCastException(Exception):
+ """exception raised when a
+ configuration item cannot be coerced to a type"""
+
+
+class ConfigAdaptor(object):
+ """
+ abstract base class for config adaotors for
+ serialization/deserialization and custom validation
+ and type casting.
+ """
+ def read(self, filename):
+ raise NotImplementedError("abstract base class")
+
+ def write(self, config, filename):
+ with open(filename, 'w') as f:
+ self._write(f, config)
+
+ def _write(self, fp, config):
+ raise NotImplementedError("abstract base class")
+
+ def validate(self, config, schema):
+ raise NotImplementedError("abstract base class")
+
+
+adaptors = {}
+
+
+class JSONSchemaEncoder(json.JSONEncoder):
+ """
+ custom default encoder that
+ casts python objects to json objects for
+ the schema validation
+ """
+ def default(self, obj):
+ if obj is str:
+ return 'string'
+ if obj is unicode:
+ return 'string'
+ if obj is int:
+ return 'integer'
+ if obj is list:
+ return 'array'
+ if obj is dict:
+ return 'object'
+ if obj is bool:
+ return 'boolean'
+
+
+class JSONAdaptor(ConfigAdaptor):
+ indent = 2
+ extensions = ['json']
+
+ def read(self, _from):
+ if isinstance(_from, file):
+ _from_string = _from.read()
+ if isinstance(_from, str):
+ _from_string = _from
+ return json.loads(_from_string)
+
+ def _write(self, fp, config):
+ fp.write(json.dumps(config,
+ indent=self.indent,
+ sort_keys=True))
+
+ def validate(self, config, schema_obj):
+ schema_json = JSONSchemaEncoder().encode(schema_obj)
+ schema = json.loads(schema_json)
+ try:
+ jsonschema.validate(config, schema)
+ except jsonschema.ValidationError:
+ raise ValidationError
+
+
+adaptors['json'] = JSONAdaptor()
+
+#
+# Adaptors
+#
+# Allow to apply a predefined set of types to the
+# specs, so it checks the validity of formats and cast it
+# to proper python types.
+
+# TODO:
+# - HTTPS uri
+
+
+class DateType(object):
+ fmt = '%Y-%m-%d'
+
+ def to_python(self, data):
+ return time.strptime(data, self.fmt)
+
+ def get_prep_value(self, data):
+ return time.strftime(self.fmt, data)
+
+
+class TranslatableType(object):
+ """
+ a type that casts to LEAPTranslatable objects.
+ Used for labels we get from providers and stuff.
+ """
+
+ def to_python(self, data):
+ return LEAPTranslatable(data)
+
+ # needed? we already have an extended dict...
+ #def get_prep_value(self, data):
+ #return dict(data)
+
+
+class URIType(object):
+
+ def to_python(self, data):
+ parsed = urlparse.urlparse(data)
+ if not parsed.scheme:
+ raise TypeCastException("uri %s has no schema" % data)
+ return parsed
+
+ def get_prep_value(self, data):
+ return data.geturl()
+
+
+class HTTPSURIType(object):
+
+ def to_python(self, data):
+ parsed = urlparse.urlparse(data)
+ if not parsed.scheme:
+ raise TypeCastException("uri %s has no schema" % data)
+ if parsed.scheme != "https":
+ raise TypeCastException(
+ "uri %s does not has "
+ "https schema" % data)
+ return parsed
+
+ def get_prep_value(self, data):
+ return data.geturl()
+
+
+types = {
+ 'date': DateType(),
+ 'uri': URIType(),
+ 'https-uri': HTTPSURIType(),
+ 'translatable': TranslatableType(),
+}
+
+
+class PluggableConfig(object):
+
+ options = {}
+
+ def __init__(self,
+ adaptors=adaptors,
+ types=types,
+ format=None):
+
+ self.config = {}
+ self.adaptors = adaptors
+ self.types = types
+ self._format = format
+ self.mtime = None
+ self.dirty = False
+
+ @property
+ def option_dict(self):
+ if hasattr(self, 'options') and isinstance(self.options, dict):
+ return self.options.get('properties', None)
+
+ def items(self):
+ """
+ act like an iterator
+ """
+ if isinstance(self.option_dict, dict):
+ return self.option_dict.items()
+ return self.options
+
+ def validate(self, config, format=None):
+ """
+ validate config
+ """
+ schema = self.options
+ if format is None:
+ format = self._format
+
+ if format:
+ adaptor = self.get_adaptor(self._format)
+ adaptor.validate(config, schema)
+ else:
+ # we really should make format mandatory...
+ logger.error('no format passed to validate')
+
+ # first round of validation is ok.
+ # now we proceed to cast types if any specified.
+ self.to_python(config)
+
+ def to_python(self, config):
+ """
+ cast types following first type and then format indications.
+ """
+ unseen_options = [i for i in config if i not in self.option_dict]
+ if unseen_options:
+ raise UnknownOptionException(
+ "Unknown options: %s" % ', '.join(unseen_options))
+
+ for key, value in config.items():
+ _type = self.option_dict[key].get('type')
+ if _type is None and 'default' in self.option_dict[key]:
+ _type = type(self.option_dict[key]['default'])
+ if _type is not None:
+ tocast = True
+ if not callable(_type) and isinstance(value, _type):
+ tocast = False
+ if tocast:
+ try:
+ config[key] = _type(value)
+ except BaseException, e:
+ raise TypeCastException(
+ "Could not coerce %s, %s, "
+ "to type %s: %s" % (key, value, _type.__name__, e))
+ _format = self.option_dict[key].get('format', None)
+ _ftype = self.types.get(_format, None)
+ if _ftype:
+ try:
+ config[key] = _ftype.to_python(value)
+ except BaseException, e:
+ raise TypeCastException(
+ "Could not coerce %s, %s, "
+ "to format %s: %s" % (key, value,
+ _ftype.__class__.__name__,
+ e))
+
+ return config
+
+ def prep_value(self, config):
+ """
+ the inverse of to_python method,
+ called just before serialization
+ """
+ for key, value in config.items():
+ _format = self.option_dict[key].get('format', None)
+ _ftype = self.types.get(_format, None)
+ if _ftype and hasattr(_ftype, 'get_prep_value'):
+ try:
+ config[key] = _ftype.get_prep_value(value)
+ except BaseException, e:
+ raise TypeCastException(
+ "Could not serialize %s, %s, "
+ "by format %s: %s" % (key, value,
+ _ftype.__class__.__name__,
+ e))
+ else:
+ config[key] = value
+ return config
+
+ # methods for adding configuration
+
+ def get_default_values(self):
+ """
+ return a config options from configuration defaults
+ """
+ defaults = {}
+ for key, value in self.items():
+ if 'default' in value:
+ defaults[key] = value['default']
+ return copy.deepcopy(defaults)
+
+ def get_adaptor(self, format):
+ """
+ get specified format adaptor or
+ guess for a given filename
+ """
+ adaptor = self.adaptors.get(format, None)
+ if adaptor:
+ return adaptor
+
+ # not registered in adaptors dict, let's try all
+ for adaptor in self.adaptors.values():
+ if format in adaptor.extensions:
+ return adaptor
+
+ def filename2format(self, filename):
+ extension = os.path.splitext(filename)[-1]
+ return extension.lstrip('.') or None
+
+ def serialize(self, filename, format=None, full=False):
+ if not format:
+ format = self._format
+ if not format:
+ format = self.filename2format(filename)
+ if not format:
+ raise Exception('Please specify a format')
+ # TODO: more specific exception type
+
+ adaptor = self.get_adaptor(format)
+ if not adaptor:
+ raise Exception("Adaptor not found for format: %s" % format)
+
+ config = copy.deepcopy(self.config)
+ serializable = self.prep_value(config)
+ adaptor.write(serializable, filename)
+
+ if self.mtime:
+ self.touch_mtime(filename)
+
+ def touch_mtime(self, filename):
+ mtime = self.mtime
+ os.utime(filename, (mtime, mtime))
+
+ def deserialize(self, string=None, fromfile=None, format=None):
+ """
+ load configuration from a file or string
+ """
+
+ def _try_deserialize():
+ if fromfile:
+ with open(fromfile, 'r') as f:
+ content = adaptor.read(f)
+ elif string:
+ content = adaptor.read(string)
+ return content
+
+ # XXX cleanup this!
+
+ if fromfile:
+ assert os.path.exists(fromfile)
+ if not format:
+ format = self.filename2format(fromfile)
+
+ if not format:
+ format = self._format
+ if format:
+ adaptor = self.get_adaptor(format)
+ else:
+ adaptor = None
+
+ if adaptor:
+ content = _try_deserialize()
+ return content
+
+ # no adaptor, let's try rest of adaptors
+
+ adaptors = self.adaptors[:]
+
+ if format:
+ adaptors.sort(
+ key=lambda x: int(
+ format in x.extensions),
+ reverse=True)
+
+ for adaptor in adaptors:
+ content = _try_deserialize()
+ return content
+
+ def set_dirty(self):
+ self.dirty = True
+
+ def is_dirty(self):
+ return self.dirty
+
+ def load(self, *args, **kwargs):
+ """
+ load from string or file
+ if no string of fromfile option is given,
+ it will attempt to load from defaults
+ defined in the schema.
+ """
+ string = args[0] if args else None
+ fromfile = kwargs.get("fromfile", None)
+ mtime = kwargs.pop("mtime", None)
+ self.mtime = mtime
+ content = None
+
+ # start with defaults, so we can
+ # have partial values applied.
+ content = self.get_default_values()
+ if string and isinstance(string, str):
+ content = self.deserialize(string)
+
+ if not string and fromfile is not None:
+ #import ipdb;ipdb.set_trace()
+ content = self.deserialize(fromfile=fromfile)
+
+ if not content:
+ logger.error('no content could be loaded')
+ # XXX raise!
+ return
+
+ # lazy evaluation until first level of nesting
+ # to allow lambdas with context-dependant info
+ # like os.path.expanduser
+ for k, v in content.iteritems():
+ if callable(v):
+ content[k] = v()
+
+ self.validate(content)
+ self.config = content
+ return True
+
+
+def testmain(): # pragma: no cover
+
+ from tests import test_validation as t
+ import pprint
+
+ config = PluggableConfig(_format="json")
+ properties = copy.deepcopy(t.sample_spec)
+
+ config.options = properties
+ config.load(fromfile='data.json')
+
+ print 'config'
+ pprint.pprint(config.config)
+
+ config.serialize('/tmp/testserial.json')
+
+if __name__ == "__main__":
+ testmain()
diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py
new file mode 100644
index 00000000..d41f3695
--- /dev/null
+++ b/src/leap/base/providers.py
@@ -0,0 +1,29 @@
+"""all dealing with leap-providers: definition files, updating"""
+from leap.base import config as baseconfig
+from leap.base import specs
+
+
+class LeapProviderDefinition(baseconfig.JSONLeapConfig):
+ spec = specs.leap_provider_spec
+
+ def _get_slug(self):
+ domain = getattr(self, 'domain', None)
+ if domain:
+ path = baseconfig.get_provider_path(domain)
+ else:
+ path = baseconfig.get_default_provider_path()
+
+ return baseconfig.get_config_file(
+ 'provider.json', folder=path)
+
+ def _set_slug(self, *args, **kwargs):
+ raise AttributeError("you cannot set slug")
+
+ slug = property(_get_slug, _set_slug)
+
+
+class LeapProviderSet(object):
+ # we gather them from the filesystem
+ # TODO: (MVS+)
+ def __init__(self):
+ self.count = 0
diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py
new file mode 100644
index 00000000..fbe8a0e9
--- /dev/null
+++ b/src/leap/base/specs.py
@@ -0,0 +1,62 @@
+leap_provider_spec = {
+ 'description': 'provider definition',
+ 'type': 'object',
+ 'properties': {
+ 'version': {
+ 'type': unicode,
+ 'default': '0.1.0'
+ #'required': True
+ },
+ "default_language": {
+ 'type': unicode,
+ 'default': 'en'
+ },
+ 'domain': {
+ 'type': unicode, # XXX define uri type
+ 'default': 'testprovider.example.org'
+ #'required': True,
+ },
+ 'name': {
+ #'type': LEAPTranslatable,
+ 'type': dict,
+ 'format': 'translatable',
+ 'default': {u'en': u'Test Provider'}
+ #'required': True
+ },
+ 'description': {
+ #'type': LEAPTranslatable,
+ 'type': dict,
+ 'format': 'translatable',
+ 'default': {u'en': u'Test provider'}
+ },
+ 'enrollment_policy': {
+ 'type': unicode, # oneof ??
+ 'default': 'open'
+ },
+ 'services': {
+ 'type': list, # oneof ??
+ 'default': ['eip']
+ },
+ 'api_version': {
+ 'type': unicode,
+ 'default': '0.1.0' # version regexp
+ },
+ 'api_uri': {
+ 'type': unicode # uri
+ },
+ 'public_key': {
+ 'type': unicode # fingerprint
+ },
+ 'ca_cert_fingerprint': {
+ 'type': unicode,
+ },
+ 'ca_cert_uri': {
+ 'type': unicode,
+ 'format': 'https-uri'
+ },
+ 'languages': {
+ 'type': list,
+ 'default': ['en']
+ }
+ }
+}
diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/base/tests/__init__.py
diff --git a/src/leap/base/tests/test_auth.py b/src/leap/base/tests/test_auth.py
new file mode 100644
index 00000000..b3009a9b
--- /dev/null
+++ b/src/leap/base/tests/test_auth.py
@@ -0,0 +1,58 @@
+from BaseHTTPServer import BaseHTTPRequestHandler
+import urlparse
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+import requests
+#from mock import Mock
+
+from leap.base import auth
+#from leap.base import exceptions
+from leap.eip.tests.test_checks import NoLogRequestHandler
+from leap.testing.basetest import BaseLeapTest
+from leap.testing.https_server import BaseHTTPSServerTestCase
+
+
+class LeapSRPRegisterTests(BaseHTTPSServerTestCase, BaseLeapTest):
+ __name__ = "leap_srp_register_test"
+ provider = "testprovider.example.org"
+
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ responses = {
+ '/': ['OK', '']}
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_srp_auth_should_implement_check_methods(self):
+ SERVER = "https://localhost:8443"
+ srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False)
+
+ self.assertTrue(hasattr(srp_auth, "init_session"),
+ "missing meth")
+ self.assertTrue(hasattr(srp_auth, "get_registration_uri"),
+ "missing meth")
+ self.assertTrue(hasattr(srp_auth, "register_user"),
+ "missing meth")
+
+ def test_srp_auth_basic_functionality(self):
+ SERVER = "https://localhost:8443"
+ srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False)
+
+ self.assertIsInstance(srp_auth.session, requests.sessions.Session)
+ self.assertEqual(
+ srp_auth.get_registration_uri(),
+ "https://localhost:8443/1/users")
diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py
new file mode 100644
index 00000000..8126755b
--- /dev/null
+++ b/src/leap/base/tests/test_checks.py
@@ -0,0 +1,177 @@
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+import sh
+
+from mock import (patch, Mock)
+from StringIO import StringIO
+
+from leap.base import checks
+from leap.base import exceptions
+from leap.testing.basetest import BaseLeapTest
+
+_uid = os.getuid()
+
+
+class LeapNetworkCheckTest(BaseLeapTest):
+ __name__ = "leap_network_check_tests"
+
+ def setUp(self):
+ os.environ['PATH'] += ':/bin'
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_checker_should_implement_check_methods(self):
+ checker = checks.LeapNetworkChecker()
+
+ self.assertTrue(hasattr(checker, "check_internet_connection"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "check_tunnel_default_interface"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "is_internet_up"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "ping_gateway"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "parse_log_and_react"),
+ "missing meth")
+
+ def test_checker_should_actually_call_all_tests(self):
+ checker = checks.LeapNetworkChecker()
+ mc = Mock()
+ checker.run_all(checker=mc)
+ self.assertTrue(mc.check_internet_connection.called, "not called")
+ self.assertTrue(mc.check_tunnel_default_interface.called, "not called")
+ self.assertTrue(mc.is_internet_up.called, "not called")
+ self.assertTrue(mc.parse_log_and_react.called, "not called")
+
+ # ping gateway only called if we pass provider_gw
+ checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0")
+ mc = Mock()
+ checker.run_all(checker=mc)
+ self.assertTrue(mc.check_internet_connection.called, "not called")
+ self.assertTrue(mc.check_tunnel_default_interface.called, "not called")
+ self.assertTrue(mc.ping_gateway.called, "not called")
+ self.assertTrue(mc.is_internet_up.called, "not called")
+ self.assertTrue(mc.parse_log_and_react.called, "not called")
+
+ def test_get_default_interface_no_interface(self):
+ checker = checks.LeapNetworkChecker()
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ with self.assertRaises(exceptions.NoDefaultInterfaceFoundError):
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT")
+ checker.get_default_interface_gateway()
+
+ def test_check_tunnel_default_interface(self):
+ checker = checks.LeapNetworkChecker()
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ with self.assertRaises(exceptions.TunnelNotDefaultRouteError):
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT\n"
+ "wlan0\t00000000\t0102A8C0\t"
+ "0003\t0\t0\t0\t00000000\t0\t0\t0")
+ checker.check_tunnel_default_interface()
+
+ with patch('leap.base.checks.open', create=True) as mock_open:
+ mock_open.return_value = StringIO(
+ "Iface\tDestination Gateway\t"
+ "Flags\tRefCntd\tUse\tMetric\t"
+ "Mask\tMTU\tWindow\tIRTT\n"
+ "tun0\t00000000\t01002A0A\t0003\t0\t0\t0\t00000080\t0\t0\t0")
+ checker.check_tunnel_default_interface()
+
+ def test_ping_gateway_fail(self):
+ checker = checks.LeapNetworkChecker()
+ with patch.object(sh, "ping") as mocked_ping:
+ with self.assertRaises(exceptions.NoConnectionToGateway):
+ mocked_ping.return_value = Mock
+ mocked_ping.return_value.stdout = "11% packet loss"
+ checker.ping_gateway("4.2.2.2")
+
+ def test_ping_gateway(self):
+ checker = checks.LeapNetworkChecker()
+ with patch.object(sh, "ping") as mocked_ping:
+ mocked_ping.return_value = Mock
+ mocked_ping.return_value.stdout = """
+PING 4.2.2.2 (4.2.2.2) 56(84) bytes of data.
+64 bytes from 4.2.2.2: icmp_req=1 ttl=54 time=33.8 ms
+64 bytes from 4.2.2.2: icmp_req=2 ttl=54 time=30.6 ms
+64 bytes from 4.2.2.2: icmp_req=3 ttl=54 time=31.4 ms
+64 bytes from 4.2.2.2: icmp_req=4 ttl=54 time=36.1 ms
+64 bytes from 4.2.2.2: icmp_req=5 ttl=54 time=30.8 ms
+64 bytes from 4.2.2.2: icmp_req=6 ttl=54 time=30.4 ms
+64 bytes from 4.2.2.2: icmp_req=7 ttl=54 time=30.7 ms
+64 bytes from 4.2.2.2: icmp_req=8 ttl=54 time=32.7 ms
+64 bytes from 4.2.2.2: icmp_req=9 ttl=54 time=31.4 ms
+64 bytes from 4.2.2.2: icmp_req=10 ttl=54 time=33.3 ms
+
+--- 4.2.2.2 ping statistics ---
+10 packets transmitted, 10 received, 0% packet loss, time 9016ms
+rtt min/avg/max/mdev = 30.497/32.172/36.161/1.755 ms"""
+ checker.ping_gateway("4.2.2.2")
+
+ def test_check_internet_connection_failures(self):
+ checker = checks.LeapNetworkChecker()
+ TimeoutError = get_ping_timeout_error()
+ with patch.object(sh, "ping") as mocked_ping:
+ mocked_ping.side_effect = TimeoutError
+ with self.assertRaises(exceptions.NoInternetConnection):
+ with patch.object(checker, "ping_gateway") as mock_gateway:
+ mock_gateway.side_effect = exceptions.NoConnectionToGateway
+ checker.check_internet_connection()
+
+ with patch.object(sh, "ping") as mocked_ping:
+ mocked_ping.side_effect = TimeoutError
+ with self.assertRaises(exceptions.NoInternetConnection):
+ with patch.object(checker, "ping_gateway") as mock_gateway:
+ mock_gateway.return_value = True
+ checker.check_internet_connection()
+
+ def test_parse_log_and_react(self):
+ checker = checks.LeapNetworkChecker()
+ to_call = Mock()
+ log = [("leap.openvpn - INFO - Mon Nov 19 13:36:24 2012 "
+ "read UDPv4 [ECONNREFUSED]: Connection refused (code=111)")]
+ err_matrix = [(checks.EVENT_CONNECT_REFUSED, (to_call, ))]
+ checker.parse_log_and_react(log, err_matrix)
+ self.assertTrue(to_call.called)
+
+ log = [("2012-11-19 13:36:26,177 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command "
+ "failed: external program exited"),
+ ("2012-11-19 13:36:26,178 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command "
+ "failed: external program exited"),
+ ("2012-11-19 13:36:26,180 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command "
+ "failed: external program exited"),
+ ("2012-11-19 13:36:26,181 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 /sbin/ifconfig tun0 0.0.0.0"),
+ ("2012-11-19 13:36:26,182 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:24 2012 Linux ip addr del failed: external "
+ "program exited with error stat"),
+ ("2012-11-19 13:36:26,183 - leap.openvpn - INFO - "
+ "Mon Nov 19 13:36:26 2012 SIGTERM[hard,] received, process"
+ "exiting"), ]
+ to_call.reset_mock()
+ checker.parse_log_and_react(log, err_matrix)
+ self.assertFalse(to_call.called)
+
+ to_call.reset_mock()
+ checker.parse_log_and_react([], err_matrix)
+ self.assertFalse(to_call.called)
+
+
+def get_ping_timeout_error():
+ try:
+ sh.ping("-c", "1", "-w", "1", "8.8.7.7")
+ except Exception as e:
+ return e
diff --git a/src/leap/base/tests/test_config.py b/src/leap/base/tests/test_config.py
new file mode 100644
index 00000000..d03149b2
--- /dev/null
+++ b/src/leap/base/tests/test_config.py
@@ -0,0 +1,247 @@
+import json
+import os
+import platform
+import socket
+#import tempfile
+
+import mock
+import requests
+
+from leap.base import config
+from leap.base import constants
+from leap.base import exceptions
+from leap.eip import constants as eipconstants
+from leap.util.fileutil import mkdir_p
+from leap.testing.basetest import BaseLeapTest
+
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+_system = platform.system()
+
+
+class JSONLeapConfigTest(BaseLeapTest):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_metaclass(self):
+ with self.assertRaises(exceptions.ImproperlyConfigured) as exc:
+ class DummyTestConfig(config.JSONLeapConfig):
+ __metaclass__ = config.MetaConfigWithSpec
+ exc.startswith("missing spec dict")
+
+ class DummyTestConfig(config.JSONLeapConfig):
+ __metaclass__ = config.MetaConfigWithSpec
+ spec = {'properties': {}}
+ with self.assertRaises(exceptions.ImproperlyConfigured) as exc:
+ DummyTestConfig()
+ exc.startswith("missing slug")
+
+ class DummyTestConfig(config.JSONLeapConfig):
+ __metaclass__ = config.MetaConfigWithSpec
+ spec = {'properties': {}}
+ slug = "foo"
+ DummyTestConfig()
+
+######################################3
+#
+# provider fetch tests block
+#
+
+
+class ProviderTest(BaseLeapTest):
+ # override per test fixtures
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+
+# XXX depreacated. similar test in eip.checks
+
+#class BareHomeTestCase(ProviderTest):
+#
+ #__name__ = "provider_config_tests_bare_home"
+#
+ #def test_should_raise_if_missing_eip_json(self):
+ #with self.assertRaises(exceptions.MissingConfigFileError):
+ #config.get_config_json(os.path.join(self.home, 'eip.json'))
+
+
+class ProviderDefinitionTestCase(ProviderTest):
+ # XXX MOVE TO eip.test_checks
+ # -- kali 2012-08-24 00:38
+
+ __name__ = "provider_config_tests"
+
+ def setUp(self):
+ # dump a sample eip file
+ # XXX Move to Use EIP Spec Instead!!!
+ # XXX tests to be moved to eip.checks and eip.providers
+ # XXX can use eipconfig.dump_default_eipconfig
+
+ path = os.path.join(self.home, '.config', 'leap')
+ mkdir_p(path)
+ with open(os.path.join(path, 'eip.json'), 'w') as fp:
+ json.dump(eipconstants.EIP_SAMPLE_JSON, fp)
+
+
+# these tests below should move to
+# eip.checks
+# config.Configuration has been deprecated
+
+# TODO:
+# - We're instantiating a ProviderTest because we're doing the home wipeoff
+# on setUpClass instead of the setUp (for speedup of the general cases).
+
+# We really should be testing all of them in the same testCase, and
+# doing an extra wipe of the tempdir... but be careful!!!! do not mess with
+# os.environ home more than needed... that could potentially bite!
+
+# XXX actually, another thing to fix here is separating tests:
+# - test that requests has been called.
+# - check deeper for error types/msgs
+
+# we SHOULD inject requests dep in the constructor
+# (so we can pass mock easily).
+
+
+#class ProviderFetchConError(ProviderTest):
+ #def test_connection_error(self):
+ #with mock.patch.object(requests, "get") as mock_method:
+ #mock_method.side_effect = requests.ConnectionError
+ #cf = config.Configuration()
+ #self.assertIsInstance(cf.error, str)
+#
+#
+#class ProviderFetchHttpError(ProviderTest):
+ #def test_file_not_found(self):
+ #with mock.patch.object(requests, "get") as mock_method:
+ #mock_method.side_effect = requests.HTTPError
+ #cf = config.Configuration()
+ #self.assertIsInstance(cf.error, str)
+#
+#
+#class ProviderFetchInvalidUrl(ProviderTest):
+ #def test_invalid_url(self):
+ #cf = config.Configuration("ht")
+ #self.assertTrue(cf.error)
+
+
+# end provider fetch tests
+###########################################
+
+
+class ConfigHelperFunctions(BaseLeapTest):
+
+ __name__ = "config_helper_tests"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ # tests
+
+ @unittest.skipUnless(_system == "Linux", "linux only")
+ def test_lin_get_config_file(self):
+ """
+ config file path where expected? (linux)
+ """
+ self.assertEqual(
+ config.get_config_file(
+ 'test', folder="foo/bar"),
+ os.path.expanduser(
+ '~/.config/leap/foo/bar/test')
+ )
+
+ @unittest.skipUnless(_system == "Darwin", "mac only")
+ def test_mac_get_config_file(self):
+ """
+ config file path where expected? (mac)
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ @unittest.skipUnless(_system == "Windows", "win only")
+ def test_win_get_config_file(self):
+ """
+ config file path where expected?
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ #
+ # XXX hey, I'm raising exceptions here
+ # on purpose. just wanted to make sure
+ # that the skip stuff is doing it right.
+ # If you're working on win/macos tests,
+ # feel free to remove tests that you see
+ # are too redundant.
+
+ @unittest.skipUnless(_system == "Linux", "linux only")
+ def test_lin_get_config_dir(self):
+ """
+ nice config dir? (linux)
+ """
+ self.assertEqual(
+ config.get_config_dir(),
+ os.path.expanduser('~/.config/leap'))
+
+ @unittest.skipUnless(_system == "Darwin", "mac only")
+ def test_mac_get_config_dir(self):
+ """
+ nice config dir? (mac)
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ @unittest.skipUnless(_system == "Windows", "win only")
+ def test_win_get_config_dir(self):
+ """
+ nice config dir? (win)
+ """
+ self._missing_test_for_plat(do_raise=True)
+
+ # provider paths
+
+ @unittest.skipUnless(_system == "Linux", "linux only")
+ def test_get_default_provider_path(self):
+ """
+ is default provider path ok?
+ """
+ self.assertEqual(
+ config.get_default_provider_path(),
+ os.path.expanduser(
+ '~/.config/leap/providers/%s/' %
+ constants.DEFAULT_PROVIDER)
+ )
+
+ # validate ip
+
+ def test_validate_ip(self):
+ """
+ check our ip validation
+ """
+ config.validate_ip('3.3.3.3')
+ with self.assertRaises(socket.error):
+ config.validate_ip('255.255.255.256')
+ with self.assertRaises(socket.error):
+ config.validate_ip('foobar')
+
+ @unittest.skip
+ def test_validate_domain(self):
+ """
+ code to be written yet
+ """
+ raise NotImplementedError
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py
new file mode 100644
index 00000000..92bc1f2f
--- /dev/null
+++ b/src/leap/base/tests/test_providers.py
@@ -0,0 +1,148 @@
+import copy
+import json
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+
+from leap.base.pluggableconfig import ValidationError
+from leap.testing.basetest import BaseLeapTest
+from leap.base import providers
+
+
+EXPECTED_DEFAULT_CONFIG = {
+ u"api_version": u"0.1.0",
+ #u"description": "LEAPTranslatable<{u'en': u'Test provider'}>",
+ u"description": {u'en': u'Test provider'},
+ u"default_language": u"en",
+ #u"display_name": {u'en': u"Test Provider"},
+ u"domain": u"testprovider.example.org",
+ #u'name': "LEAPTranslatable<{u'en': u'Test Provider'}>",
+ u'name': {u'en': u'Test Provider'},
+ u"enrollment_policy": u"open",
+ #u"serial": 1,
+ u"services": [
+ u"eip"
+ ],
+ u"languages": [u"en"],
+ u"version": u"0.1.0"
+}
+
+
+class TestLeapProviderDefinition(BaseLeapTest):
+ def setUp(self):
+ self.domain = "testprovider.example.org"
+ self.definition = providers.LeapProviderDefinition(
+ domain=self.domain)
+ self.definition.save(force=True)
+ self.definition.load() # why have to load after save??
+ self.config = self.definition.config
+
+ def tearDown(self):
+ if hasattr(self, 'testfile') and os.path.isfile(self.testfile):
+ os.remove(self.testfile)
+
+ # tests
+
+ # XXX most of these tests can be made more abstract
+ # and moved to test_baseconfig *triangulate!*
+
+ def test_provider_slug_property(self):
+ slug = self.definition.slug
+ self.assertEquals(
+ slug,
+ os.path.join(
+ self.home,
+ '.config', 'leap', 'providers',
+ '%s' % self.domain,
+ 'provider.json'))
+ with self.assertRaises(AttributeError):
+ self.definition.slug = 23
+
+ def test_provider_dump(self):
+ # check a good provider definition is dumped to disk
+ self.testfile = self.get_tempfile('test.json')
+ self.definition.save(to=self.testfile, force=True)
+ deserialized = json.load(open(self.testfile, 'rb'))
+ self.maxDiff = None
+ #import ipdb;ipdb.set_trace()
+ self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)
+
+ def test_provider_dump_to_slug(self):
+ # same as above, but we test the ability to save to a
+ # file generated from the slug.
+ # XXX THIS TEST SHOULD MOVE TO test_baseconfig
+ self.definition.save()
+ filename = self.definition.filename
+ self.assertTrue(os.path.isfile(filename))
+ deserialized = json.load(open(filename, 'rb'))
+ self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)
+
+ def test_provider_load(self):
+ # check loading provider from disk file
+ self.testfile = self.get_tempfile('test_load.json')
+ with open(self.testfile, 'w') as wf:
+ wf.write(json.dumps(EXPECTED_DEFAULT_CONFIG))
+ self.definition.load(fromfile=self.testfile)
+ #self.assertDictEqual(self.config,
+ #EXPECTED_DEFAULT_CONFIG)
+ self.assertItemsEqual(self.config, EXPECTED_DEFAULT_CONFIG)
+
+ def test_provider_validation(self):
+ self.definition.validate(self.config)
+ _config = copy.deepcopy(self.config)
+ # bad type, raise validation error
+ _config['domain'] = 111
+ with self.assertRaises(ValidationError):
+ self.definition.validate(_config)
+
+ @unittest.skip
+ def test_load_malformed_json_definition(self):
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_type_validation(self):
+ # check various type validation
+ # type cast
+ raise NotImplementedError
+
+
+class TestLeapProviderSet(BaseLeapTest):
+
+ def setUp(self):
+ self.providers = providers.LeapProviderSet()
+
+ def tearDown(self):
+ pass
+ ###
+
+ def test_get_zero_count(self):
+ self.assertEqual(self.providers.count, 0)
+
+ @unittest.skip
+ def test_count_defined_providers(self):
+ # check the method used for making
+ # the list of providers
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_get_default_provider(self):
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_should_be_at_least_one_provider_after_init(self):
+ # when we init an empty environment,
+ # there should be at least one provider,
+ # that will be a dump of the default provider definition
+ # somehow a high level test
+ raise NotImplementedError
+
+ @unittest.skip
+ def test_get_eip_remote_from_default_provider(self):
+ # from: default provider
+ # expect: remote eip domain
+ raise NotImplementedError
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/leap/base/tests/test_validation.py b/src/leap/base/tests/test_validation.py
new file mode 100644
index 00000000..b45fbe3a
--- /dev/null
+++ b/src/leap/base/tests/test_validation.py
@@ -0,0 +1,93 @@
+import copy
+import datetime
+from functools import partial
+#import json
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+
+from leap.base.config import JSONLeapConfig
+from leap.base import pluggableconfig
+from leap.testing.basetest import BaseLeapTest
+
+SAMPLE_CONFIG_DICT = {
+ 'prop_one': 1,
+ 'prop_uri': "http://example.org",
+ 'prop_date': '2012-12-12',
+}
+
+EXPECTED_CONFIG = {
+ 'prop_one': 1,
+ 'prop_uri': "http://example.org",
+ 'prop_date': datetime.datetime(2012, 12, 12)
+}
+
+sample_spec = {
+ 'description': 'sample schema definition',
+ 'type': 'object',
+ 'properties': {
+ 'prop_one': {
+ 'type': int,
+ 'default': 1,
+ 'required': True
+ },
+ 'prop_uri': {
+ 'type': str,
+ 'default': 'http://example.org',
+ 'required': True,
+ 'format': 'uri'
+ },
+ 'prop_date': {
+ 'type': str,
+ 'default': '2012-12-12',
+ 'format': 'date'
+ }
+ }
+}
+
+
+class SampleConfig(JSONLeapConfig):
+ spec = sample_spec
+
+ @property
+ def slug(self):
+ return os.path.expanduser('~/sampleconfig.json')
+
+
+class TestJSONLeapConfigValidation(BaseLeapTest):
+ def setUp(self):
+ self.sampleconfig = SampleConfig()
+ self.sampleconfig.save()
+ self.sampleconfig.load()
+ self.config = self.sampleconfig.config
+
+ def tearDown(self):
+ if hasattr(self, 'testfile') and os.path.isfile(self.testfile):
+ os.remove(self.testfile)
+
+ # tests
+
+ def test_good_validation(self):
+ self.sampleconfig.validate(SAMPLE_CONFIG_DICT)
+
+ def test_broken_int(self):
+ _config = copy.deepcopy(SAMPLE_CONFIG_DICT)
+ _config['prop_one'] = '1'
+ self.assertRaises(
+ pluggableconfig.ValidationError,
+ partial(self.sampleconfig.validate, _config))
+
+ def test_format_property(self):
+ # JsonSchema Validator does not check the format property.
+ # We should have to extend the Configuration class
+ blah = copy.deepcopy(SAMPLE_CONFIG_DICT)
+ blah['prop_uri'] = 'xxx'
+ self.assertRaises(
+ pluggableconfig.TypeCastException,
+ partial(self.sampleconfig.validate, blah))
+
+
+if __name__ == "__main__":
+ unittest.main()