diff options
Diffstat (limited to 'src/leap/base')
-rw-r--r-- | src/leap/base/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/base/auth.py | 355 | ||||
-rw-r--r-- | src/leap/base/authentication.py | 11 | ||||
-rw-r--r-- | src/leap/base/checks.py | 213 | ||||
-rw-r--r-- | src/leap/base/config.py | 348 | ||||
-rw-r--r-- | src/leap/base/connection.py | 115 | ||||
-rw-r--r-- | src/leap/base/constants.py | 42 | ||||
-rw-r--r-- | src/leap/base/exceptions.py | 97 | ||||
-rw-r--r-- | src/leap/base/jsonschema.py | 791 | ||||
-rw-r--r-- | src/leap/base/network.py | 107 | ||||
-rw-r--r-- | src/leap/base/pluggableconfig.py | 462 | ||||
-rw-r--r-- | src/leap/base/providers.py | 29 | ||||
-rw-r--r-- | src/leap/base/specs.py | 62 | ||||
-rw-r--r-- | src/leap/base/tests/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/base/tests/test_auth.py | 58 | ||||
-rw-r--r-- | src/leap/base/tests/test_checks.py | 177 | ||||
-rw-r--r-- | src/leap/base/tests/test_config.py | 247 | ||||
-rw-r--r-- | src/leap/base/tests/test_providers.py | 148 | ||||
-rw-r--r-- | src/leap/base/tests/test_validation.py | 93 |
19 files changed, 3355 insertions, 0 deletions
diff --git a/src/leap/base/__init__.py b/src/leap/base/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/base/__init__.py diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py new file mode 100644 index 00000000..c2d3f424 --- /dev/null +++ b/src/leap/base/auth.py @@ -0,0 +1,355 @@ +import binascii +import json +import logging +#import urlparse + +import requests +import srp + +from PyQt4 import QtCore + +from leap.base import constants as baseconstants +from leap.crypto import leapkeyring +from leap.util.misc import null_check +from leap.util.web import get_https_domain_and_port + +logger = logging.getLogger(__name__) + +SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5) + +""" +Registration and authentication classes for the +SRP auth mechanism used in the leap platform. + +We're using the srp library which uses a c-based implementation +of the protocol if the c extension is available, and a python-based +one if not. +""" + + +class SRPAuthenticationError(Exception): + """ + exception raised + for authentication errors + """ + + +safe_unhexlify = lambda x: binascii.unhexlify(x) \ + if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) + + +class LeapSRPRegister(object): + + def __init__(self, + schema="https", + provider=None, + verify=True, + register_path="1/users", + method="POST", + fetcher=requests, + srp=srp, + hashfun=srp.SHA256, + ng_constant=srp.NG_1024): + + null_check(provider, "provider") + + self.schema = schema + + domain, port = get_https_domain_and_port(provider) + self.provider = domain + self.port = port + + self.verify = verify + self.register_path = register_path + self.method = method + self.fetcher = fetcher + self.srp = srp + self.HASHFUN = hashfun + self.NG = ng_constant + + self.init_session() + + def init_session(self): + self.session = self.fetcher.session() + + def get_registration_uri(self): + # XXX assert is https! + # use urlparse + if self.port: + uri = "%s://%s:%s/%s" % ( + self.schema, + self.provider, + self.port, + self.register_path) + else: + uri = "%s://%s/%s" % ( + self.schema, + self.provider, + self.register_path) + + return uri + + def register_user(self, username, password, keep=False): + """ + @rtype: tuple + @rparam: (ok, request) + """ + salt, vkey = self.srp.create_salted_verification_key( + username, + password, + self.HASHFUN, + self.NG) + + user_data = { + 'user[login]': username, + 'user[password_verifier]': binascii.hexlify(vkey), + 'user[password_salt]': binascii.hexlify(salt)} + + uri = self.get_registration_uri() + logger.debug('post to uri: %s' % uri) + + # XXX get self.method + req = self.session.post( + uri, data=user_data, + timeout=SIGNUP_TIMEOUT, + verify=self.verify) + # we catch it in the form + #req.raise_for_status() + return (req.ok, req) + + +class SRPAuth(requests.auth.AuthBase): + + def __init__(self, username, password, server=None, verify=None): + # sanity check + null_check(server, 'server') + self.username = username + self.password = password + self.server = server + self.verify = verify + + logger.debug('SRPAuth. verify=%s' % verify) + logger.debug('server: %s. username=%s' % (server, username)) + + self.init_data = None + self.session = requests.session() + + self.init_srp() + + def init_srp(self): + usr = srp.User( + self.username, + self.password, + srp.SHA256, + srp.NG_1024) + uname, A = usr.start_authentication() + + self.srp_usr = usr + self.A = A + + def get_auth_data(self): + return { + 'login': self.username, + 'A': binascii.hexlify(self.A) + } + + def get_init_data(self): + try: + init_session = self.session.post( + self.server + '/1/sessions/', + data=self.get_auth_data(), + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (salt).") + except: + raise SRPAuthenticationError( + "Unknown error (salt).") + if init_session.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (salt).") + + self.init_data = init_session.json + return self.init_data + + def get_server_proof_data(self): + try: + auth_result = self.session.put( + #self.server + '/1/sessions.json/' + self.username, + self.server + '/1/sessions/' + self.username, + data={'client_auth': binascii.hexlify(self.M)}, + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (HAMK).") + + if auth_result.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (HAMK).") + + self.auth_data = auth_result.json + return self.auth_data + + def authenticate(self): + logger.debug('start authentication...') + + init_data = self.get_init_data() + salt = init_data.get('salt', None) + B = init_data.get('B', None) + + # XXX refactor this function + # move checks and un-hex + # to routines + + if not salt or not B: + raise SRPAuthenticationError( + "Server did not send initial data.") + + try: + unhex_salt = safe_unhexlify(salt) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (salt)") + try: + unhex_B = safe_unhexlify(B) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (B)") + + self.M = self.srp_usr.process_challenge( + unhex_salt, + unhex_B + ) + + proof_data = self.get_server_proof_data() + + HAMK = proof_data.get("M2", None) + if not HAMK: + errors = proof_data.get('errors', None) + if errors: + logger.error(errors) + raise SRPAuthenticationError("Server did not send HAMK.") + + try: + unhex_HAMK = safe_unhexlify(HAMK) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (HAMK)") + + self.srp_usr.verify_session( + unhex_HAMK) + + try: + assert self.srp_usr.authenticated() + logger.debug('user is authenticated!') + except (AssertionError): + raise SRPAuthenticationError( + "Auth verification failed.") + + def __call__(self, req): + self.authenticate() + req.cookies = self.session.cookies + return req + + +def srpauth_protected(user=None, passwd=None, server=None, verify=True): + """ + decorator factory that accepts + user and password keyword arguments + and add those to the decorated request + """ + def srpauth(fn): + def wrapper(*args, **kwargs): + if user and passwd: + auth = SRPAuth(user, passwd, server, verify) + kwargs['auth'] = auth + kwargs['verify'] = verify + if not args: + logger.warning('attempting to get from empty uri!') + return fn(*args, **kwargs) + return wrapper + return srpauth + + +def get_leap_credentials(): + settings = QtCore.QSettings() + full_username = settings.value('username') + username, domain = full_username.split('@') + seed = settings.value('%s_seed' % domain, None) + password = leapkeyring.leap_get_password(full_username, seed=seed) + return (username, password) + + +# XXX TODO +# Pass verify as single argument, +# in srpauth_protected style + +def magick_srpauth(fn): + """ + decorator that gets user and password + from the config file and adds those to + the decorated request + """ + logger.debug('magick srp auth decorator called') + + def wrapper(*args, **kwargs): + #uri = args[0] + # XXX Ugh! + # Problem with this approach. + # This won't work when we're using + # api.foo.bar + # Unless we keep a table with the + # equivalencies... + user, passwd = get_leap_credentials() + + # XXX pass verify and server too + # (pop) + auth = SRPAuth(user, passwd) + kwargs['auth'] = auth + return fn(*args, **kwargs) + return wrapper + + +if __name__ == "__main__": + """ + To test against test_provider (twisted version) + Register an user: (will be valid during the session) + >>> python auth.py add test password + + Test login with that user: + >>> python auth.py login test password + """ + + import sys + + if len(sys.argv) not in (4, 5): + print 'Usage: auth <add|login> <user> <pass> [server]' + sys.exit(0) + + action = sys.argv[1] + user = sys.argv[2] + passwd = sys.argv[3] + + if len(sys.argv) == 5: + SERVER = sys.argv[4] + else: + SERVER = "https://localhost:8443" + + if action == "login": + + @srpauth_protected( + user=user, passwd=passwd, server=SERVER, verify=False) + def test_srp_protected_get(*args, **kwargs): + req = requests.get(*args, **kwargs) + req.raise_for_status + return req + + #req = test_srp_protected_get('https://localhost:8443/1/cert') + req = test_srp_protected_get('%s/1/cert' % SERVER) + #print 'cert :', req.content[:200] + "..." + print req.content + sys.exit(0) + + if action == "add": + auth = LeapSRPRegister(provider=SERVER, verify=False) + auth.register_user(user, passwd) diff --git a/src/leap/base/authentication.py b/src/leap/base/authentication.py new file mode 100644 index 00000000..09ff1d07 --- /dev/null +++ b/src/leap/base/authentication.py @@ -0,0 +1,11 @@ +""" +Authentication Base Class +""" + + +class Authentication(object): + """ + I have no idea how Authentication (certs,?) + will be done, but stub it here. + """ + pass diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py new file mode 100644 index 00000000..0bf44f59 --- /dev/null +++ b/src/leap/base/checks.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +import logging +import platform +import re +import socket + +import netifaces +import sh + +from leap.base import constants +from leap.base import exceptions + +logger = logging.getLogger(name=__name__) +_platform = platform.system() + +#EVENTS OF NOTE +EVENT_CONNECT_REFUSED = "[ECONNREFUSED]: Connection refused (code=111)" + +ICMP_TARGET = "8.8.8.8" + + +class LeapNetworkChecker(object): + """ + all network related checks + """ + def __init__(self, *args, **kwargs): + provider_gw = kwargs.pop('provider_gw', None) + self.provider_gateway = provider_gw + + def run_all(self, checker=None): + if not checker: + checker = self + #self.error = None # ? + + # for MVS + checker.check_tunnel_default_interface() + checker.check_internet_connection() + checker.is_internet_up() + + if self.provider_gateway: + checker.ping_gateway(self.provider_gateway) + + checker.parse_log_and_react([], ()) + + def check_internet_connection(self): + if _platform == "Linux": + try: + output = sh.ping("-c", "5", "-w", "5", ICMP_TARGET) + # XXX should redirect this to netcheck logger. + # and don't clutter main log. + logger.debug('Network appears to be up.') + except sh.ErrorReturnCode_1 as e: + packet_loss = re.findall("\d+% packet loss", e.message)[0] + logger.debug("Unidentified Connection Error: " + packet_loss) + if not self.is_internet_up(): + error = "No valid internet connection found." + else: + error = "Provider server appears to be down." + + logger.error(error) + raise exceptions.NoInternetConnection(error) + + else: + raise NotImplementedError + + def is_internet_up(self): + iface, gateway = self.get_default_interface_gateway() + try: + self.ping_gateway(self.provider_gateway) + except exceptions.NoConnectionToGateway: + return False + return True + + def _get_route_table_linux(self): + # do not use context manager, tests pass a StringIO + f = open("/proc/net/route") + route_table = f.readlines() + f.close() + #toss out header + route_table.pop(0) + if not route_table: + raise exceptions.NoDefaultInterfaceFoundError + return route_table + + def _get_def_iface_osx(self): + default_iface = None + #gateway = None + routes = list(sh.route('-n', 'get', ICMP_TARGET, _iter=True)) + iface = filter(lambda l: "interface" in l, routes) + if not iface: + return None, None + def_ifacel = re.findall('\w+\d', iface[0]) + default_iface = def_ifacel[0] if def_ifacel else None + if not default_iface: + return None, None + _gw = filter(lambda l: "gateway" in l, routes) + gw = re.findall('\d+\.\d+\.\d+\.\d+', _gw[0])[0] + return default_iface, gw + + def _get_tunnel_iface_linux(self): + # XXX review. + # valid also when local router has a default entry? + route_table = self._get_route_table_linux() + line = route_table.pop(0) + iface, destination = line.split('\t')[0:2] + if not destination == '00000000' or not iface == 'tun0': + raise exceptions.TunnelNotDefaultRouteError() + return True + + def check_tunnel_default_interface(self): + """ + Raises an TunnelNotDefaultRouteError + if tun0 is not the chosen default route + (including when no routes are present) + """ + #logger.debug('checking tunnel default interface...') + + if _platform == "Linux": + valid = self._get_tunnel_iface_linux() + return valid + elif _platform == "Darwin": + default_iface, gw = self._get_def_iface_osx() + #logger.debug('iface: %s', default_iface) + if default_iface != "tun0": + logger.debug('tunnel not default route! gw: %s', default_iface) + # XXX should catch this and act accordingly... + # but rather, this test should only be launched + # when we have successfully completed a connection + # ... TRIGGER: Connection stablished (or whatever it is) + # in the logs + raise exceptions.TunnelNotDefaultRouteError + else: + #logger.debug('PLATFORM !!! %s', _platform) + raise NotImplementedError + + def _get_def_iface_linux(self): + default_iface = None + gateway = None + + route_table = self._get_route_table_linux() + while route_table: + line = route_table.pop(0) + iface, destination, gateway = line.split('\t')[0:3] + if destination == '00000000': + default_iface = iface + break + return default_iface, gateway + + def get_default_interface_gateway(self): + """ + gets the interface we are going thru. + (this should be merged with check tunnel default interface, + imo...) + """ + if _platform == "Linux": + default_iface, gw = self._get_def_iface_linux() + elif _platform == "Darwin": + default_iface, gw = self._get_def_iface_osx() + else: + raise NotImplementedError + + if not default_iface: + raise exceptions.NoDefaultInterfaceFoundError + + if default_iface not in netifaces.interfaces(): + raise exceptions.InterfaceNotFoundError + logger.debug('-- default iface %s', default_iface) + return default_iface, gw + + def ping_gateway(self, gateway): + # TODO: Discuss how much packet loss (%) is acceptable. + + # XXX -- validate gateway + # -- is it a valid ip? (there's something in util) + # -- is it a domain? + # -- can we resolve? -- raise NoDNSError if not. + + # XXX -- sh.ping implemtation needs review! + try: + output = sh.ping("-c", "10", gateway).stdout + except sh.ErrorReturnCode_1 as e: + output = e.message + finally: + packet_loss = int(re.findall("(\d+)% packet loss", output)[0]) + + logger.debug('packet loss %s%%' % packet_loss) + if packet_loss > constants.MAX_ICMP_PACKET_LOSS: + raise exceptions.NoConnectionToGateway + + def check_name_resolution(self, domain_name): + try: + socket.gethostbyname(domain_name) + return True + except socket.gaierror: + raise exceptions.CannotResolveDomainError + + def parse_log_and_react(self, log, error_matrix=None): + """ + compares the recent openvpn status log to + strings passed in and executes the callbacks passed in. + @param log: openvpn log + @type log: list of strings + @param error_matrix: tuples of strings and tuples of callbacks + @type error_matrix: tuples strings and call backs + """ + for line in log: + # we could compile a regex here to save some cycles up -- kali + for each in error_matrix: + error, callbacks = each + if error in line: + for cb in callbacks: + if callable(cb): + cb() diff --git a/src/leap/base/config.py b/src/leap/base/config.py new file mode 100644 index 00000000..85bb3d66 --- /dev/null +++ b/src/leap/base/config.py @@ -0,0 +1,348 @@ +""" +Configuration Base Class +""" +import grp +import json +import logging +import re +import socket +import time +import os + +logger = logging.getLogger(name=__name__) + +from dateutil import parser as dateparser +from xdg import BaseDirectory +import requests + +from leap.base import exceptions +from leap.base import constants +from leap.base.pluggableconfig import PluggableConfig +from leap.util.fileutil import (mkdir_p) + +# move to base! +from leap.eip import exceptions as eipexceptions + + +class BaseLeapConfig(object): + slug = None + + # XXX we have to enforce that every derived class + # has a slug (via interface) + # get property getter that raises NI.. + + def save(self): + raise NotImplementedError("abstract base class") + + def load(self): + raise NotImplementedError("abstract base class") + + def get_config(self, *kwargs): + raise NotImplementedError("abstract base class") + + @property + def config(self): + return self.get_config() + + def get_value(self, *kwargs): + raise NotImplementedError("abstract base class") + + +class MetaConfigWithSpec(type): + """ + metaclass for JSONLeapConfig classes. + It creates a configuration spec out of + the `spec` dictionary. The `properties` attribute + of the spec dict is turn into the `schema` attribute + of the new class (which will be used to validate against). + """ + # XXX in the near future, this is the + # place where we want to enforce + # singletons, read-only and similar stuff. + + def __new__(meta, classname, bases, classDict): + schema_obj = classDict.get('spec', None) + + # not quite happy with this workaround. + # I want to raise if missing spec dict, but only + # for grand-children of this metaclass. + # maybe should use abc module for this. + abcderived = ("JSONLeapConfig",) + if schema_obj is None and classname not in abcderived: + raise exceptions.ImproperlyConfigured( + "missing spec dict on your derived class (%s)" % classname) + + # we create a configuration spec attribute + # from the spec dict + config_class = type( + classname + "Spec", + (PluggableConfig, object), + {'options': schema_obj}) + classDict['spec'] = config_class + + return type.__new__(meta, classname, bases, classDict) + +########################################################## +# some hacking still in progress: + +# Configs have: + +# - a slug (from where a filename/folder is derived) +# - a spec (for validation and defaults). +# this spec is conformant to the json-schema. +# basically a dict that will be used +# for type casting and validation, and defaults settings. + +# all config objects, since they are derived from BaseConfig, implement basic +# useful methods: +# - save +# - load + +########################################################## + + +class JSONLeapConfig(BaseLeapConfig): + + __metaclass__ = MetaConfigWithSpec + + def __init__(self, *args, **kwargs): + # sanity check + try: + assert self.slug is not None + except AssertionError: + raise exceptions.ImproperlyConfigured( + "missing slug on JSONLeapConfig" + " derived class") + try: + assert self.spec is not None + except AssertionError: + raise exceptions.ImproperlyConfigured( + "missing spec on JSONLeapConfig" + " derived class") + assert issubclass(self.spec, PluggableConfig) + + self.domain = kwargs.pop('domain', None) + self._config = self.spec(format="json") + self._config.load() + self.fetcher = kwargs.pop('fetcher', requests) + + # mandatory baseconfig interface + + def save(self, to=None, force=False): + """ + force param will skip the dirty check. + :type force: bool + """ + # XXX this force=True does not feel to right + # but still have to look for a better way + # of dealing with dirtiness and the + # trick of loading remote config only + # when newer. + + if force: + do_save = True + else: + do_save = self._config.is_dirty() + + if do_save: + if to is None: + to = self.filename + folder, filename = os.path.split(to) + if folder and not os.path.isdir(folder): + mkdir_p(folder) + self._config.serialize(to) + return True + + else: + return False + + def load(self, fromfile=None, from_uri=None, fetcher=None, + force_download=False, verify=True): + + if from_uri is not None: + fetched = self.fetch( + from_uri, + fetcher=fetcher, + verify=verify, + force_dl=force_download) + if fetched: + return + if fromfile is None: + fromfile = self.filename + if os.path.isfile(fromfile): + self._config.load(fromfile=fromfile) + else: + logger.warning('tried to load config from non-existent path') + logger.warning('Not Found: %s', fromfile) + + def fetch(self, uri, fetcher=None, verify=True, force_dl=False): + if not fetcher: + fetcher = self.fetcher + + logger.debug('uri: %s (verify: %s)' % (uri, verify)) + + rargs = (uri, ) + rkwargs = {'verify': verify} + headers = {} + + curmtime = self.get_mtime() if not force_dl else None + if curmtime: + logger.debug('requesting with if-modified-since %s' % curmtime) + headers['if-modified-since'] = curmtime + rkwargs['headers'] = headers + + #request = fetcher.get(uri, verify=verify) + request = fetcher.get(*rargs, **rkwargs) + request.raise_for_status() + + if request.status_code == 304: + logger.debug('...304 Not Changed') + # On this point, we have to assume that + # we HAD the filename. If that filename is corruct, + # we should enforce a force_download in the load + # method above. + self._config.load(fromfile=self.filename) + return True + + if request.json: + mtime = None + last_modified = request.headers.get('last-modified', None) + if last_modified: + _mtime = dateparser.parse(last_modified) + mtime = int(_mtime.strftime("%s")) + if callable(request.json): + _json = request.json() + else: + # back-compat + _json = request.json + self._config.load(json.dumps(_json), mtime=mtime) + self._config.set_dirty() + else: + # not request.json + # might be server did not announce content properly, + # let's try deserializing all the same. + try: + self._config.load(request.content) + self._config.set_dirty() + except ValueError: + raise eipexceptions.LeapBadConfigFetchedError + + return True + + def get_mtime(self): + try: + _mtime = os.stat(self.filename)[8] + mtime = time.strftime("%c GMT", time.gmtime(_mtime)) + return mtime + except OSError: + return None + + def get_config(self): + return self._config.config + + # public methods + + def get_filename(self): + return self._slug_to_filename() + + @property + def filename(self): + return self.get_filename() + + def validate(self, data): + logger.debug('validating schema') + self._config.validate(data) + return True + + # private + + def _slug_to_filename(self): + # is this going to work in winland if slug is "foo/bar" ? + folder, filename = os.path.split(self.slug) + config_file = get_config_file(filename, folder) + return config_file + + def exists(self): + return os.path.isfile(self.filename) + + +# +# utility functions +# +# (might be moved to some class as we see fit, but +# let's remain functional for a while) +# maybe base.config.util ?? +# + + +def get_config_dir(): + """ + get the base dir for all leap config + @rparam: config path + @rtype: string + """ + home = os.path.expanduser("~") + if re.findall("leap_tests-[_a-zA-Z0-9]{6}", home): + # we're inside a test! :) + return os.path.join(home, ".config/leap") + else: + # XXX dirspec is cross-platform, + # we should borrow some of those + # routines for osx/win and wrap this call. + return os.path.join(BaseDirectory.xdg_config_home, + 'leap') + + +def get_config_file(filename, folder=None): + """ + concatenates the given filename + with leap config dir. + @param filename: name of the file + @type filename: string + @rparam: full path to config file + """ + path = [] + path.append(get_config_dir()) + if folder is not None: + path.append(folder) + path.append(filename) + return os.path.join(*path) + + +def get_default_provider_path(): + default_subpath = os.path.join("providers", + constants.DEFAULT_PROVIDER) + default_provider_path = get_config_file( + '', + folder=default_subpath) + return default_provider_path + + +def get_provider_path(domain): + # XXX if not domain, return get_default_provider_path + default_subpath = os.path.join("providers", domain) + provider_path = get_config_file( + '', + folder=default_subpath) + return provider_path + + +def validate_ip(ip_str): + """ + raises exception if the ip_str is + not a valid representation of an ip + """ + socket.inet_aton(ip_str) + + +def get_username(): + try: + return os.getlogin() + except OSError as e: + import pwd + return pwd.getpwuid(os.getuid())[0] + + +def get_groupname(): + gid = os.getgroups()[-1] + return grp.getgrgid(gid).gr_name diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py new file mode 100644 index 00000000..41d13935 --- /dev/null +++ b/src/leap/base/connection.py @@ -0,0 +1,115 @@ +""" +Base Connection Classs +""" +from __future__ import (division, unicode_literals, print_function) + +import logging + +from leap.base.authentication import Authentication + +logger = logging.getLogger(name=__name__) + + +class Connection(Authentication): + # JSONLeapConfig + #spec = {} + + def __init__(self, *args, **kwargs): + self.connection_state = None + self.desired_connection_state = None + #XXX FIXME diamond inheritance gotcha.. + #If you inherit from >1 class, + #super is only initializing one + #of the bases..!! + # I think we better pass config as a constructor + # parameter -- kali 2012-08-30 04:33 + super(Connection, self).__init__(*args, **kwargs) + + def connect(self): + """ + entry point for connection process + """ + pass + + def disconnect(self): + """ + disconnects client + """ + pass + + #def shutdown(self): + #""" + #shutdown and quit + #""" + #self.desired_con_state = self.status.DISCONNECTED + + def connection_state(self): + """ + returns the current connection state + """ + return self.status.current + + def desired_connection_state(self): + """ + returns the desired_connection state + """ + return self.desired_connection_state + + def get_icon_name(self): + """ + get icon name from status object + """ + return self.status.get_state_icon() + + # + # private methods + # + + def _disconnect(self): + """ + private method for disconnecting + """ + if self.subp is not None: + self.subp.terminate() + self.subp = None + # XXX signal state changes! :) + + def _is_alive(self): + """ + don't know yet + """ + pass + + def _connect(self): + """ + entry point for connection cascade methods. + """ + #conn_result = ConState.DISCONNECTED + try: + conn_result = self._try_connection() + except UnrecoverableError as except_msg: + logger.error("FATAL: %s" % unicode(except_msg)) + conn_result = self.status.UNRECOVERABLE + except Exception as except_msg: + self.error_queue.append(except_msg) + logger.error("Failed Connection: %s" % + unicode(except_msg)) + return conn_result + + +class ConnectionError(Exception): + """ + generic connection error + """ + def __str__(self): + if len(self.args) >= 1: + return repr(self.args[0]) + else: + raise self() + + +class UnrecoverableError(ConnectionError): + """ + we cannot do anything about it, sorry + """ + pass diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py new file mode 100644 index 00000000..f5665e5f --- /dev/null +++ b/src/leap/base/constants.py @@ -0,0 +1,42 @@ +"""constants to be used in base module""" +from leap import __branding +APP_NAME = __branding.get("short_name", "leap-client") +OPENVPN_BIN = "openvpn" + +# default provider placeholder +# using `example.org` we make sure that this +# is not going to be resolved during the tests phases +# (we expect testers to add it to their /etc/hosts + +DEFAULT_PROVIDER = __branding.get( + "provider_domain", + "testprovider.example.org") + +DEFINITION_EXPECTED_PATH = "provider.json" + +DEFAULT_PROVIDER_DEFINITION = { + u"api_uri": "https://api.%s/" % DEFAULT_PROVIDER, + u"api_version": u"1", + u"ca_cert_fingerprint": "SHA256: fff", + u"ca_cert_uri": u"https://%s/ca.crt" % DEFAULT_PROVIDER, + u"default_language": u"en", + u"description": { + u"en": u"A demonstration service provider using the LEAP platform" + }, + u"domain": "%s" % DEFAULT_PROVIDER, + u"enrollment_policy": u"open", + u"languages": [ + u"en" + ], + u"name": { + u"en": u"Test Provider" + }, + u"services": [ + "openvpn" + ] +} + + +MAX_ICMP_PACKET_LOSS = 10 + +ROUTE_CHECK_INTERVAL = 10 diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py new file mode 100644 index 00000000..2e31b33b --- /dev/null +++ b/src/leap/base/exceptions.py @@ -0,0 +1,97 @@ +""" +Exception attributes and their meaning/uses +------------------------------------------- + +* critical: if True, will abort execution prematurely, + after attempting any cleaning + action. + +* failfirst: breaks any error_check loop that is examining + the error queue. + +* message: the message that will be used in the __repr__ of the exception. + +* usermessage: the message that will be passed to user in ErrorDialogs + in Qt-land. +""" +from leap.util.translations import translate + + +class LeapException(Exception): + """ + base LeapClient exception + sets some parameters that we will check + during error checking routines + """ + + critical = False + failfirst = False + warning = False + + +class CriticalError(LeapException): + """ + we cannot do anything about it + """ + critical = True + failfirst = True + + +# In use ??? +# don't thing so. purge if not... + +class MissingConfigFileError(Exception): + pass + + +class ImproperlyConfigured(Exception): + pass + + +# NOTE: "Errors" (context) has to be a explicit string! + + +class InterfaceNotFoundError(LeapException): + # XXX should take iface arg on init maybe? + message = "interface not found" + usermessage = translate( + "Errors", + "Interface not found") + + +class NoDefaultInterfaceFoundError(LeapException): + message = "no default interface found" + usermessage = translate( + "Errors", + "Looks like your computer " + "is not connected to the internet") + + +class NoConnectionToGateway(CriticalError): + message = "no connection to gateway" + usermessage = translate( + "Errors", + "Looks like there are problems " + "with your internet connection") + + +class NoInternetConnection(CriticalError): + message = "No Internet connection found" + usermessage = translate( + "Errors", + "It looks like there is no internet connection.") + # and now we try to connect to our web to troubleshoot LOL :P + + +class CannotResolveDomainError(LeapException): + message = "Cannot resolve domain" + usermessage = translate( + "Errors", + "Domain cannot be found") + + +class TunnelNotDefaultRouteError(LeapException): + message = "Tunnel connection dissapeared. VPN down?" + usermessage = translate( + "Errors", + "The Encrypted Connection was lost.") diff --git a/src/leap/base/jsonschema.py b/src/leap/base/jsonschema.py new file mode 100644 index 00000000..56689b08 --- /dev/null +++ b/src/leap/base/jsonschema.py @@ -0,0 +1,791 @@ +""" +An implementation of JSON Schema for Python + +The main functionality is provided by the validator classes for each of the +supported JSON Schema versions. + +Most commonly, :func:`validate` is the quickest way to simply validate a given +instance under a schema, and will create a validator for you. + +""" + +from __future__ import division, unicode_literals + +import collections +import json +import itertools +import operator +import re +import sys + + +__version__ = "0.8.0" + +PY3 = sys.version_info[0] >= 3 + +if PY3: + from urllib import parse as urlparse + from urllib.parse import unquote + from urllib.request import urlopen + basestring = unicode = str + iteritems = operator.methodcaller("items") +else: + from itertools import izip as zip + from urllib import unquote + from urllib2 import urlopen + import urlparse + iteritems = operator.methodcaller("iteritems") + + +FLOAT_TOLERANCE = 10 ** -15 +validators = {} + + +def validates(version): + """ + Register the decorated validator for a ``version`` of the specification. + + Registered validators and their meta schemas will be considered when + parsing ``$schema`` properties' URIs. + + :argument str version: an identifier to use as the version's name + :returns: a class decorator to decorate the validator with the version + + """ + + def _validates(cls): + validators[version] = cls + return cls + return _validates + + +class UnknownType(Exception): + """ + An attempt was made to check if an instance was of an unknown type. + + """ + + +class RefResolutionError(Exception): + """ + A JSON reference failed to resolve. + + """ + + +class SchemaError(Exception): + """ + The provided schema is malformed. + + The same attributes are present as for :exc:`ValidationError`\s. + + """ + + def __init__(self, message, validator=None, path=()): + super(SchemaError, self).__init__(message, validator, path) + self.message = message + self.path = list(path) + self.validator = validator + + def __str__(self): + return self.message + + +class ValidationError(Exception): + """ + The instance didn't properly validate under the provided schema. + + Relevant attributes are: + * ``message`` : a human readable message explaining the error + * ``path`` : a list containing the path to the offending element (or [] + if the error happened globally) in *reverse* order (i.e. + deepest index first). + + """ + + def __init__(self, message, validator=None, path=()): + # Any validator that recurses (e.g. properties and items) must append + # to the ValidationError's path to properly maintain where in the + # instance the error occurred + super(ValidationError, self).__init__(message, validator, path) + self.message = message + self.path = list(path) + self.validator = validator + + def __str__(self): + return self.message + + +@validates("draft3") +class Draft3Validator(object): + """ + A validator for JSON Schema draft 3. + + """ + + DEFAULT_TYPES = { + "array": list, "boolean": bool, "integer": int, "null": type(None), + "number": (int, float), "object": dict, "string": basestring, + } + + def __init__(self, schema, types=(), resolver=None): + self._types = dict(self.DEFAULT_TYPES) + self._types.update(types) + + if resolver is None: + resolver = RefResolver.from_schema(schema) + + self.resolver = resolver + self.schema = schema + + def is_type(self, instance, type): + if type == "any": + return True + elif type not in self._types: + raise UnknownType(type) + type = self._types[type] + + # bool inherits from int, so ensure bools aren't reported as integers + if isinstance(instance, bool): + type = _flatten(type) + if int in type and bool not in type: + return False + return isinstance(instance, type) + + def is_valid(self, instance, _schema=None): + error = next(self.iter_errors(instance, _schema), None) + return error is None + + @classmethod + def check_schema(cls, schema): + for error in cls(cls.META_SCHEMA).iter_errors(schema): + raise SchemaError( + error.message, validator=error.validator, path=error.path, + ) + + def iter_errors(self, instance, _schema=None): + if _schema is None: + _schema = self.schema + + for k, v in iteritems(_schema): + validator = getattr(self, "validate_%s" % (k.lstrip("$"),), None) + + if validator is None: + continue + + errors = validator(v, instance, _schema) or () + for error in errors: + # set the validator if it wasn't already set by the called fn + if error.validator is None: + error.validator = k + yield error + + def validate(self, *args, **kwargs): + for error in self.iter_errors(*args, **kwargs): + raise error + + def validate_type(self, types, instance, schema): + types = _list(types) + + for type in types: + if self.is_type(type, "object"): + if self.is_valid(instance, type): + return + elif self.is_type(type, "string"): + if self.is_type(instance, type): + return + else: + yield ValidationError(_types_msg(instance, types)) + + def validate_properties(self, properties, instance, schema): + if not self.is_type(instance, "object"): + return + + for property, subschema in iteritems(properties): + if property in instance: + for error in self.iter_errors(instance[property], subschema): + error.path.append(property) + yield error + elif subschema.get("required", False): + yield ValidationError( + "%r is a required property" % (property,), + validator="required", + path=[property], + ) + + def validate_patternProperties(self, patternProperties, instance, schema): + if not self.is_type(instance, "object"): + return + + for pattern, subschema in iteritems(patternProperties): + for k, v in iteritems(instance): + if re.match(pattern, k): + for error in self.iter_errors(v, subschema): + yield error + + def validate_additionalProperties(self, aP, instance, schema): + if not self.is_type(instance, "object"): + return + + extras = set(_find_additional_properties(instance, schema)) + + if self.is_type(aP, "object"): + for extra in extras: + for error in self.iter_errors(instance[extra], aP): + yield error + elif not aP and extras: + error = "Additional properties are not allowed (%s %s unexpected)" + yield ValidationError(error % _extras_msg(extras)) + + def validate_dependencies(self, dependencies, instance, schema): + if not self.is_type(instance, "object"): + return + + for property, dependency in iteritems(dependencies): + if property not in instance: + continue + + if self.is_type(dependency, "object"): + for error in self.iter_errors(instance, dependency): + yield error + else: + dependencies = _list(dependency) + for dependency in dependencies: + if dependency not in instance: + yield ValidationError( + "%r is a dependency of %r" % (dependency, property) + ) + + def validate_items(self, items, instance, schema): + if not self.is_type(instance, "array"): + return + + if self.is_type(items, "object"): + for index, item in enumerate(instance): + for error in self.iter_errors(item, items): + error.path.append(index) + yield error + else: + for (index, item), subschema in zip(enumerate(instance), items): + for error in self.iter_errors(item, subschema): + error.path.append(index) + yield error + + def validate_additionalItems(self, aI, instance, schema): + if ( + not self.is_type(instance, "array") or + not self.is_type(schema.get("items"), "array") + ): + return + + if self.is_type(aI, "object"): + for item in instance[len(schema):]: + for error in self.iter_errors(item, aI): + yield error + elif not aI and len(instance) > len(schema.get("items", [])): + error = "Additional items are not allowed (%s %s unexpected)" + yield ValidationError( + error % _extras_msg(instance[len(schema.get("items", [])):]) + ) + + def validate_minimum(self, minimum, instance, schema): + if not self.is_type(instance, "number"): + return + + instance = float(instance) + if schema.get("exclusiveMinimum", False): + failed = instance <= minimum + cmp = "less than or equal to" + else: + failed = instance < minimum + cmp = "less than" + + if failed: + yield ValidationError( + "%r is %s the minimum of %r" % (instance, cmp, minimum) + ) + + def validate_maximum(self, maximum, instance, schema): + if not self.is_type(instance, "number"): + return + + instance = float(instance) + if schema.get("exclusiveMaximum", False): + failed = instance >= maximum + cmp = "greater than or equal to" + else: + failed = instance > maximum + cmp = "greater than" + + if failed: + yield ValidationError( + "%r is %s the maximum of %r" % (instance, cmp, maximum) + ) + + def validate_minItems(self, mI, instance, schema): + if self.is_type(instance, "array") and len(instance) < mI: + yield ValidationError("%r is too short" % (instance,)) + + def validate_maxItems(self, mI, instance, schema): + if self.is_type(instance, "array") and len(instance) > mI: + yield ValidationError("%r is too long" % (instance,)) + + def validate_uniqueItems(self, uI, instance, schema): + if uI and self.is_type(instance, "array") and not _uniq(instance): + yield ValidationError("%r has non-unique elements" % instance) + + def validate_pattern(self, patrn, instance, schema): + if self.is_type(instance, "string") and not re.match(patrn, instance): + yield ValidationError("%r does not match %r" % (instance, patrn)) + + def validate_minLength(self, mL, instance, schema): + if self.is_type(instance, "string") and len(instance) < mL: + yield ValidationError("%r is too short" % (instance,)) + + def validate_maxLength(self, mL, instance, schema): + if self.is_type(instance, "string") and len(instance) > mL: + yield ValidationError("%r is too long" % (instance,)) + + def validate_enum(self, enums, instance, schema): + if instance not in enums: + yield ValidationError("%r is not one of %r" % (instance, enums)) + + def validate_divisibleBy(self, dB, instance, schema): + if not self.is_type(instance, "number"): + return + + if isinstance(dB, float): + mod = instance % dB + failed = (mod > FLOAT_TOLERANCE) and (dB - mod) > FLOAT_TOLERANCE + else: + failed = instance % dB + + if failed: + yield ValidationError("%r is not divisible by %r" % (instance, dB)) + + def validate_disallow(self, disallow, instance, schema): + for disallowed in _list(disallow): + if self.is_valid(instance, {"type": [disallowed]}): + yield ValidationError( + "%r is disallowed for %r" % (disallowed, instance) + ) + + def validate_extends(self, extends, instance, schema): + if self.is_type(extends, "object"): + extends = [extends] + for subschema in extends: + for error in self.iter_errors(instance, subschema): + yield error + + def validate_ref(self, ref, instance, schema): + resolved = self.resolver.resolve(ref) + for error in self.iter_errors(instance, resolved): + yield error + + +Draft3Validator.META_SCHEMA = { + "$schema": "http://json-schema.org/draft-03/schema#", + "id": "http://json-schema.org/draft-03/schema#", + "type": "object", + + "properties": { + "type": { + "type": ["string", "array"], + "items": {"type": ["string", {"$ref": "#"}]}, + "uniqueItems": True, + "default": "any" + }, + "properties": { + "type": "object", + "additionalProperties": {"$ref": "#", "type": "object"}, + "default": {} + }, + "patternProperties": { + "type": "object", + "additionalProperties": {"$ref": "#"}, + "default": {} + }, + "additionalProperties": { + "type": [{"$ref": "#"}, "boolean"], "default": {} + }, + "items": { + "type": [{"$ref": "#"}, "array"], + "items": {"$ref": "#"}, + "default": {} + }, + "additionalItems": { + "type": [{"$ref": "#"}, "boolean"], "default": {} + }, + "required": {"type": "boolean", "default": False}, + "dependencies": { + "type": ["string", "array", "object"], + "additionalProperties": { + "type": ["string", "array", {"$ref": "#"}], + "items": {"type": "string"} + }, + "default": {} + }, + "minimum": {"type": "number"}, + "maximum": {"type": "number"}, + "exclusiveMinimum": {"type": "boolean", "default": False}, + "exclusiveMaximum": {"type": "boolean", "default": False}, + "minItems": {"type": "integer", "minimum": 0, "default": 0}, + "maxItems": {"type": "integer", "minimum": 0}, + "uniqueItems": {"type": "boolean", "default": False}, + "pattern": {"type": "string", "format": "regex"}, + "minLength": {"type": "integer", "minimum": 0, "default": 0}, + "maxLength": {"type": "integer"}, + "enum": {"type": "array", "minItems": 1, "uniqueItems": True}, + "default": {"type": "any"}, + "title": {"type": "string"}, + "description": {"type": "string"}, + "format": {"type": "string"}, + "maxDecimal": {"type": "number", "minimum": 0}, + "divisibleBy": { + "type": "number", + "minimum": 0, + "exclusiveMinimum": True, + "default": 1 + }, + "disallow": { + "type": ["string", "array"], + "items": {"type": ["string", {"$ref": "#"}]}, + "uniqueItems": True + }, + "extends": { + "type": [{"$ref": "#"}, "array"], + "items": {"$ref": "#"}, + "default": {} + }, + "id": {"type": "string", "format": "uri"}, + "$ref": {"type": "string", "format": "uri"}, + "$schema": {"type": "string", "format": "uri"}, + }, + "dependencies": { + "exclusiveMinimum": "minimum", "exclusiveMaximum": "maximum" + }, +} + + +class RefResolver(object): + """ + Resolve JSON References. + + :argument str base_uri: URI of the referring document + :argument referrer: the actual referring document + :argument dict store: a mapping from URIs to documents to cache + + """ + + def __init__(self, base_uri, referrer, store=()): + self.base_uri = base_uri + self.referrer = referrer + self.store = dict(store, **_meta_schemas()) + + @classmethod + def from_schema(cls, schema, *args, **kwargs): + """ + Construct a resolver from a JSON schema object. + + :argument schema schema: the referring schema + :rtype: :class:`RefResolver` + + """ + + return cls(schema.get("id", ""), schema, *args, **kwargs) + + def resolve(self, ref): + """ + Resolve a JSON ``ref``. + + :argument str ref: reference to resolve + :returns: the referrant document + + """ + + base_uri = self.base_uri + uri, fragment = urlparse.urldefrag(urlparse.urljoin(base_uri, ref)) + + if uri in self.store: + document = self.store[uri] + elif not uri or uri == self.base_uri: + document = self.referrer + else: + document = self.resolve_remote(uri) + + return self.resolve_fragment(document, fragment.lstrip("/")) + + def resolve_fragment(self, document, fragment): + """ + Resolve a ``fragment`` within the referenced ``document``. + + :argument document: the referrant document + :argument str fragment: a URI fragment to resolve within it + + """ + + parts = unquote(fragment).split("/") if fragment else [] + + for part in parts: + part = part.replace("~1", "/").replace("~0", "~") + + if part not in document: + raise RefResolutionError( + "Unresolvable JSON pointer: %r" % fragment + ) + + document = document[part] + + return document + + def resolve_remote(self, uri): + """ + Resolve a remote ``uri``. + + Does not check the store first. + + :argument str uri: the URI to resolve + :returns: the retrieved document + + """ + + return json.load(urlopen(uri)) + + +class ErrorTree(object): + """ + ErrorTrees make it easier to check which validations failed. + + """ + + def __init__(self, errors=()): + self.errors = {} + self._contents = collections.defaultdict(self.__class__) + + for error in errors: + container = self + for element in reversed(error.path): + container = container[element] + container.errors[error.validator] = error + + def __contains__(self, k): + return k in self._contents + + def __getitem__(self, k): + """ + Retrieve the child tree with key ``k``. + + """ + + return self._contents[k] + + def __setitem__(self, k, v): + self._contents[k] = v + + def __iter__(self): + return iter(self._contents) + + def __len__(self): + return self.total_errors + + def __repr__(self): + return "<%s (%s total errors)>" % (self.__class__.__name__, len(self)) + + @property + def total_errors(self): + """ + The total number of errors in the entire tree, including children. + + """ + + child_errors = sum(len(tree) for _, tree in iteritems(self._contents)) + return len(self.errors) + child_errors + + +def _meta_schemas(): + """ + Collect the urls and meta schemas from each known validator. + + """ + + meta_schemas = (v.META_SCHEMA for v in validators.values()) + return dict((urlparse.urldefrag(m["id"])[0], m) for m in meta_schemas) + + +def _find_additional_properties(instance, schema): + """ + Return the set of additional properties for the given ``instance``. + + Weeds out properties that should have been validated by ``properties`` and + / or ``patternProperties``. + + Assumes ``instance`` is dict-like already. + + """ + + properties = schema.get("properties", {}) + patterns = "|".join(schema.get("patternProperties", {})) + for property in instance: + if property not in properties: + if patterns and re.search(patterns, property): + continue + yield property + + +def _extras_msg(extras): + """ + Create an error message for extra items or properties. + + """ + + if len(extras) == 1: + verb = "was" + else: + verb = "were" + return ", ".join(repr(extra) for extra in extras), verb + + +def _types_msg(instance, types): + """ + Create an error message for a failure to match the given types. + + If the ``instance`` is an object and contains a ``name`` property, it will + be considered to be a description of that object and used as its type. + + Otherwise the message is simply the reprs of the given ``types``. + + """ + + reprs = [] + for type in types: + try: + reprs.append(repr(type["name"])) + except Exception: + reprs.append(repr(type)) + return "%r is not of type %s" % (instance, ", ".join(reprs)) + + +def _flatten(suitable_for_isinstance): + """ + isinstance() can accept a bunch of really annoying different types: + * a single type + * a tuple of types + * an arbitrary nested tree of tuples + + Return a flattened tuple of the given argument. + + """ + + types = set() + + if not isinstance(suitable_for_isinstance, tuple): + suitable_for_isinstance = (suitable_for_isinstance,) + for thing in suitable_for_isinstance: + if isinstance(thing, tuple): + types.update(_flatten(thing)) + else: + types.add(thing) + return tuple(types) + + +def _list(thing): + """ + Wrap ``thing`` in a list if it's a single str. + + Otherwise, return it unchanged. + + """ + + if isinstance(thing, basestring): + return [thing] + return thing + + +def _delist(thing): + """ + Unwrap ``thing`` to a single element if its a single str in a list. + + Otherwise, return it unchanged. + + """ + + if ( + isinstance(thing, list) and + len(thing) == 1 + and isinstance(thing[0], basestring) + ): + return thing[0] + return thing + + +def _unbool(element, true=object(), false=object()): + """ + A hack to make True and 1 and False and 0 unique for _uniq. + + """ + + if element is True: + return true + elif element is False: + return false + return element + + +def _uniq(container): + """ + Check if all of a container's elements are unique. + + Successively tries first to rely that the elements are hashable, then + falls back on them being sortable, and finally falls back on brute + force. + + """ + + try: + return len(set(_unbool(i) for i in container)) == len(container) + except TypeError: + try: + sort = sorted(_unbool(i) for i in container) + sliced = itertools.islice(sort, 1, None) + for i, j in zip(sort, sliced): + if i == j: + return False + except (NotImplementedError, TypeError): + seen = [] + for e in container: + e = _unbool(e) + if e in seen: + return False + seen.append(e) + return True + + +def validate(instance, schema, cls=Draft3Validator, *args, **kwargs): + """ + Validate an ``instance`` under the given ``schema``. + + >>> validate([2, 3, 4], {"maxItems" : 2}) + Traceback (most recent call last): + ... + ValidationError: [2, 3, 4] is too long + + :func:`validate` will first verify that the provided schema is itself + valid, since not doing so can lead to less obvious error messages and fail + in less obvious or consistent ways. If you know you have a valid schema + already or don't care, you might prefer using the ``validate`` method + directly on a specific validator (e.g. :meth:`Draft3Validator.validate`). + + ``cls`` is a validator class that will be used to validate the instance. + By default this is a draft 3 validator. Any other provided positional and + keyword arguments will be provided to this class when constructing a + validator. + + :raises: + :exc:`ValidationError` if the instance is invalid + + :exc:`SchemaError` if the schema itself is invalid + + """ + + cls.check_schema(schema) + cls(schema, *args, **kwargs).validate(instance) diff --git a/src/leap/base/network.py b/src/leap/base/network.py new file mode 100644 index 00000000..d841e692 --- /dev/null +++ b/src/leap/base/network.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +from __future__ import (print_function) +import logging +import threading + +from leap.eip import config as eipconfig +from leap.base.checks import LeapNetworkChecker +from leap.base.constants import ROUTE_CHECK_INTERVAL +from leap.base.exceptions import TunnelNotDefaultRouteError +from leap.util.misc import null_check +from leap.util.coroutines import (launch_thread, process_events) + +from time import sleep + +logger = logging.getLogger(name=__name__) + + +class NetworkCheckerThread(object): + """ + Manages network checking thread that makes sure we have a working network + connection. + """ + def __init__(self, *args, **kwargs): + + self.status_signals = kwargs.pop('status_signals', None) + self.error_cb = kwargs.pop( + 'error_cb', + lambda exc: logger.error("%s", exc.message)) + self.shutdown = threading.Event() + + # XXX get provider passed here + provider = kwargs.pop('provider', None) + null_check(provider, 'provider') + + eipconf = eipconfig.EIPConfig(domain=provider) + eipconf.load() + eipserviceconf = eipconfig.EIPServiceConfig(domain=provider) + eipserviceconf.load() + + gw = eipconfig.get_eip_gateway( + eipconfig=eipconf, + eipserviceconfig=eipserviceconf) + self.checker = LeapNetworkChecker( + provider_gw=gw) + + def start(self): + self.process_handle = self._launch_recurrent_network_checks( + (self.error_cb,)) + + def stop(self): + self.process_handle.join(timeout=0.1) + self.shutdown.set() + logger.debug("network checked stopped.") + + def run_checks(self): + pass + + #private methods + + #here all the observers in fail_callbacks expect one positional argument, + #which is exception so we can try by passing a lambda with logger to + #check it works. + + def _network_checks_thread(self, fail_callbacks): + #TODO: replace this with waiting for a signal from openvpn + while True: + try: + self.checker.check_tunnel_default_interface() + break + except TunnelNotDefaultRouteError: + # XXX ??? why do we sleep here??? + # aa: If the openvpn isn't up and running yet, + # let's give it a moment to breath. + #logger.error('NOT DEFAULT ROUTE!----') + # Instead of this, we should flag when the + # iface IS SUPPOSED to be up imo. -- kali + sleep(1) + + fail_observer_dict = dict((( + observer, + process_events(observer)) for observer in fail_callbacks)) + + while not self.shutdown.is_set(): + try: + self.checker.check_tunnel_default_interface() + self.checker.check_internet_connection() + sleep(ROUTE_CHECK_INTERVAL) + except Exception as exc: + for obs in fail_observer_dict: + fail_observer_dict[obs].send(exc) + sleep(ROUTE_CHECK_INTERVAL) + + #reset event + # I see a problem with this. You cannot stop it, it + # resets itself forever. -- kali + + # XXX use QTimer for the recurrent triggers, + # and ditch the sleeps. + logger.debug('resetting event') + self.shutdown.clear() + + def _launch_recurrent_network_checks(self, fail_callbacks): + # XXX reimplement using QTimer -- kali + watcher = launch_thread( + self._network_checks_thread, + (fail_callbacks,)) + return watcher diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py new file mode 100644 index 00000000..6f9f3f6f --- /dev/null +++ b/src/leap/base/pluggableconfig.py @@ -0,0 +1,462 @@ +""" +generic configuration handlers +""" +import copy +import json +import logging +import os +import time +import urlparse + +import jsonschema + +from leap.util.translations import LEAPTranslatable + +logger = logging.getLogger(__name__) + + +__all__ = ['PluggableConfig', + 'adaptors', + 'types', + 'UnknownOptionException', + 'MissingValueException', + 'ConfigurationProviderException', + 'TypeCastException'] + +# exceptions + + +class ValidationError(Exception): + pass + + +class UnknownOptionException(Exception): + """exception raised when a non-configuration + value is present in the configuration""" + + +class MissingValueException(Exception): + """exception raised when a required value is missing""" + + +class ConfigurationProviderException(Exception): + """exception raised when a configuration provider is missing, etc""" + + +class TypeCastException(Exception): + """exception raised when a + configuration item cannot be coerced to a type""" + + +class ConfigAdaptor(object): + """ + abstract base class for config adaotors for + serialization/deserialization and custom validation + and type casting. + """ + def read(self, filename): + raise NotImplementedError("abstract base class") + + def write(self, config, filename): + with open(filename, 'w') as f: + self._write(f, config) + + def _write(self, fp, config): + raise NotImplementedError("abstract base class") + + def validate(self, config, schema): + raise NotImplementedError("abstract base class") + + +adaptors = {} + + +class JSONSchemaEncoder(json.JSONEncoder): + """ + custom default encoder that + casts python objects to json objects for + the schema validation + """ + def default(self, obj): + if obj is str: + return 'string' + if obj is unicode: + return 'string' + if obj is int: + return 'integer' + if obj is list: + return 'array' + if obj is dict: + return 'object' + if obj is bool: + return 'boolean' + + +class JSONAdaptor(ConfigAdaptor): + indent = 2 + extensions = ['json'] + + def read(self, _from): + if isinstance(_from, file): + _from_string = _from.read() + if isinstance(_from, str): + _from_string = _from + return json.loads(_from_string) + + def _write(self, fp, config): + fp.write(json.dumps(config, + indent=self.indent, + sort_keys=True)) + + def validate(self, config, schema_obj): + schema_json = JSONSchemaEncoder().encode(schema_obj) + schema = json.loads(schema_json) + try: + jsonschema.validate(config, schema) + except jsonschema.ValidationError: + raise ValidationError + + +adaptors['json'] = JSONAdaptor() + +# +# Adaptors +# +# Allow to apply a predefined set of types to the +# specs, so it checks the validity of formats and cast it +# to proper python types. + +# TODO: +# - HTTPS uri + + +class DateType(object): + fmt = '%Y-%m-%d' + + def to_python(self, data): + return time.strptime(data, self.fmt) + + def get_prep_value(self, data): + return time.strftime(self.fmt, data) + + +class TranslatableType(object): + """ + a type that casts to LEAPTranslatable objects. + Used for labels we get from providers and stuff. + """ + + def to_python(self, data): + return LEAPTranslatable(data) + + # needed? we already have an extended dict... + #def get_prep_value(self, data): + #return dict(data) + + +class URIType(object): + + def to_python(self, data): + parsed = urlparse.urlparse(data) + if not parsed.scheme: + raise TypeCastException("uri %s has no schema" % data) + return parsed + + def get_prep_value(self, data): + return data.geturl() + + +class HTTPSURIType(object): + + def to_python(self, data): + parsed = urlparse.urlparse(data) + if not parsed.scheme: + raise TypeCastException("uri %s has no schema" % data) + if parsed.scheme != "https": + raise TypeCastException( + "uri %s does not has " + "https schema" % data) + return parsed + + def get_prep_value(self, data): + return data.geturl() + + +types = { + 'date': DateType(), + 'uri': URIType(), + 'https-uri': HTTPSURIType(), + 'translatable': TranslatableType(), +} + + +class PluggableConfig(object): + + options = {} + + def __init__(self, + adaptors=adaptors, + types=types, + format=None): + + self.config = {} + self.adaptors = adaptors + self.types = types + self._format = format + self.mtime = None + self.dirty = False + + @property + def option_dict(self): + if hasattr(self, 'options') and isinstance(self.options, dict): + return self.options.get('properties', None) + + def items(self): + """ + act like an iterator + """ + if isinstance(self.option_dict, dict): + return self.option_dict.items() + return self.options + + def validate(self, config, format=None): + """ + validate config + """ + schema = self.options + if format is None: + format = self._format + + if format: + adaptor = self.get_adaptor(self._format) + adaptor.validate(config, schema) + else: + # we really should make format mandatory... + logger.error('no format passed to validate') + + # first round of validation is ok. + # now we proceed to cast types if any specified. + self.to_python(config) + + def to_python(self, config): + """ + cast types following first type and then format indications. + """ + unseen_options = [i for i in config if i not in self.option_dict] + if unseen_options: + raise UnknownOptionException( + "Unknown options: %s" % ', '.join(unseen_options)) + + for key, value in config.items(): + _type = self.option_dict[key].get('type') + if _type is None and 'default' in self.option_dict[key]: + _type = type(self.option_dict[key]['default']) + if _type is not None: + tocast = True + if not callable(_type) and isinstance(value, _type): + tocast = False + if tocast: + try: + config[key] = _type(value) + except BaseException, e: + raise TypeCastException( + "Could not coerce %s, %s, " + "to type %s: %s" % (key, value, _type.__name__, e)) + _format = self.option_dict[key].get('format', None) + _ftype = self.types.get(_format, None) + if _ftype: + try: + config[key] = _ftype.to_python(value) + except BaseException, e: + raise TypeCastException( + "Could not coerce %s, %s, " + "to format %s: %s" % (key, value, + _ftype.__class__.__name__, + e)) + + return config + + def prep_value(self, config): + """ + the inverse of to_python method, + called just before serialization + """ + for key, value in config.items(): + _format = self.option_dict[key].get('format', None) + _ftype = self.types.get(_format, None) + if _ftype and hasattr(_ftype, 'get_prep_value'): + try: + config[key] = _ftype.get_prep_value(value) + except BaseException, e: + raise TypeCastException( + "Could not serialize %s, %s, " + "by format %s: %s" % (key, value, + _ftype.__class__.__name__, + e)) + else: + config[key] = value + return config + + # methods for adding configuration + + def get_default_values(self): + """ + return a config options from configuration defaults + """ + defaults = {} + for key, value in self.items(): + if 'default' in value: + defaults[key] = value['default'] + return copy.deepcopy(defaults) + + def get_adaptor(self, format): + """ + get specified format adaptor or + guess for a given filename + """ + adaptor = self.adaptors.get(format, None) + if adaptor: + return adaptor + + # not registered in adaptors dict, let's try all + for adaptor in self.adaptors.values(): + if format in adaptor.extensions: + return adaptor + + def filename2format(self, filename): + extension = os.path.splitext(filename)[-1] + return extension.lstrip('.') or None + + def serialize(self, filename, format=None, full=False): + if not format: + format = self._format + if not format: + format = self.filename2format(filename) + if not format: + raise Exception('Please specify a format') + # TODO: more specific exception type + + adaptor = self.get_adaptor(format) + if not adaptor: + raise Exception("Adaptor not found for format: %s" % format) + + config = copy.deepcopy(self.config) + serializable = self.prep_value(config) + adaptor.write(serializable, filename) + + if self.mtime: + self.touch_mtime(filename) + + def touch_mtime(self, filename): + mtime = self.mtime + os.utime(filename, (mtime, mtime)) + + def deserialize(self, string=None, fromfile=None, format=None): + """ + load configuration from a file or string + """ + + def _try_deserialize(): + if fromfile: + with open(fromfile, 'r') as f: + content = adaptor.read(f) + elif string: + content = adaptor.read(string) + return content + + # XXX cleanup this! + + if fromfile: + assert os.path.exists(fromfile) + if not format: + format = self.filename2format(fromfile) + + if not format: + format = self._format + if format: + adaptor = self.get_adaptor(format) + else: + adaptor = None + + if adaptor: + content = _try_deserialize() + return content + + # no adaptor, let's try rest of adaptors + + adaptors = self.adaptors[:] + + if format: + adaptors.sort( + key=lambda x: int( + format in x.extensions), + reverse=True) + + for adaptor in adaptors: + content = _try_deserialize() + return content + + def set_dirty(self): + self.dirty = True + + def is_dirty(self): + return self.dirty + + def load(self, *args, **kwargs): + """ + load from string or file + if no string of fromfile option is given, + it will attempt to load from defaults + defined in the schema. + """ + string = args[0] if args else None + fromfile = kwargs.get("fromfile", None) + mtime = kwargs.pop("mtime", None) + self.mtime = mtime + content = None + + # start with defaults, so we can + # have partial values applied. + content = self.get_default_values() + if string and isinstance(string, str): + content = self.deserialize(string) + + if not string and fromfile is not None: + #import ipdb;ipdb.set_trace() + content = self.deserialize(fromfile=fromfile) + + if not content: + logger.error('no content could be loaded') + # XXX raise! + return + + # lazy evaluation until first level of nesting + # to allow lambdas with context-dependant info + # like os.path.expanduser + for k, v in content.iteritems(): + if callable(v): + content[k] = v() + + self.validate(content) + self.config = content + return True + + +def testmain(): # pragma: no cover + + from tests import test_validation as t + import pprint + + config = PluggableConfig(_format="json") + properties = copy.deepcopy(t.sample_spec) + + config.options = properties + config.load(fromfile='data.json') + + print 'config' + pprint.pprint(config.config) + + config.serialize('/tmp/testserial.json') + +if __name__ == "__main__": + testmain() diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py new file mode 100644 index 00000000..d41f3695 --- /dev/null +++ b/src/leap/base/providers.py @@ -0,0 +1,29 @@ +"""all dealing with leap-providers: definition files, updating""" +from leap.base import config as baseconfig +from leap.base import specs + + +class LeapProviderDefinition(baseconfig.JSONLeapConfig): + spec = specs.leap_provider_spec + + def _get_slug(self): + domain = getattr(self, 'domain', None) + if domain: + path = baseconfig.get_provider_path(domain) + else: + path = baseconfig.get_default_provider_path() + + return baseconfig.get_config_file( + 'provider.json', folder=path) + + def _set_slug(self, *args, **kwargs): + raise AttributeError("you cannot set slug") + + slug = property(_get_slug, _set_slug) + + +class LeapProviderSet(object): + # we gather them from the filesystem + # TODO: (MVS+) + def __init__(self): + self.count = 0 diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py new file mode 100644 index 00000000..fbe8a0e9 --- /dev/null +++ b/src/leap/base/specs.py @@ -0,0 +1,62 @@ +leap_provider_spec = { + 'description': 'provider definition', + 'type': 'object', + 'properties': { + 'version': { + 'type': unicode, + 'default': '0.1.0' + #'required': True + }, + "default_language": { + 'type': unicode, + 'default': 'en' + }, + 'domain': { + 'type': unicode, # XXX define uri type + 'default': 'testprovider.example.org' + #'required': True, + }, + 'name': { + #'type': LEAPTranslatable, + 'type': dict, + 'format': 'translatable', + 'default': {u'en': u'Test Provider'} + #'required': True + }, + 'description': { + #'type': LEAPTranslatable, + 'type': dict, + 'format': 'translatable', + 'default': {u'en': u'Test provider'} + }, + 'enrollment_policy': { + 'type': unicode, # oneof ?? + 'default': 'open' + }, + 'services': { + 'type': list, # oneof ?? + 'default': ['eip'] + }, + 'api_version': { + 'type': unicode, + 'default': '0.1.0' # version regexp + }, + 'api_uri': { + 'type': unicode # uri + }, + 'public_key': { + 'type': unicode # fingerprint + }, + 'ca_cert_fingerprint': { + 'type': unicode, + }, + 'ca_cert_uri': { + 'type': unicode, + 'format': 'https-uri' + }, + 'languages': { + 'type': list, + 'default': ['en'] + } + } +} diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/base/tests/__init__.py diff --git a/src/leap/base/tests/test_auth.py b/src/leap/base/tests/test_auth.py new file mode 100644 index 00000000..b3009a9b --- /dev/null +++ b/src/leap/base/tests/test_auth.py @@ -0,0 +1,58 @@ +from BaseHTTPServer import BaseHTTPRequestHandler +import urlparse +try: + import unittest2 as unittest +except ImportError: + import unittest + +import requests +#from mock import Mock + +from leap.base import auth +#from leap.base import exceptions +from leap.eip.tests.test_checks import NoLogRequestHandler +from leap.testing.basetest import BaseLeapTest +from leap.testing.https_server import BaseHTTPSServerTestCase + + +class LeapSRPRegisterTests(BaseHTTPSServerTestCase, BaseLeapTest): + __name__ = "leap_srp_register_test" + provider = "testprovider.example.org" + + class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler): + responses = { + '/': ['OK', '']} + + def do_GET(self): + path = urlparse.urlparse(self.path) + message = '\n'.join(self.responses.get( + path.path, None)) + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_srp_auth_should_implement_check_methods(self): + SERVER = "https://localhost:8443" + srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False) + + self.assertTrue(hasattr(srp_auth, "init_session"), + "missing meth") + self.assertTrue(hasattr(srp_auth, "get_registration_uri"), + "missing meth") + self.assertTrue(hasattr(srp_auth, "register_user"), + "missing meth") + + def test_srp_auth_basic_functionality(self): + SERVER = "https://localhost:8443" + srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False) + + self.assertIsInstance(srp_auth.session, requests.sessions.Session) + self.assertEqual( + srp_auth.get_registration_uri(), + "https://localhost:8443/1/users") diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py new file mode 100644 index 00000000..8126755b --- /dev/null +++ b/src/leap/base/tests/test_checks.py @@ -0,0 +1,177 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest +import os +import sh + +from mock import (patch, Mock) +from StringIO import StringIO + +from leap.base import checks +from leap.base import exceptions +from leap.testing.basetest import BaseLeapTest + +_uid = os.getuid() + + +class LeapNetworkCheckTest(BaseLeapTest): + __name__ = "leap_network_check_tests" + + def setUp(self): + os.environ['PATH'] += ':/bin' + pass + + def tearDown(self): + pass + + def test_checker_should_implement_check_methods(self): + checker = checks.LeapNetworkChecker() + + self.assertTrue(hasattr(checker, "check_internet_connection"), + "missing meth") + self.assertTrue(hasattr(checker, "check_tunnel_default_interface"), + "missing meth") + self.assertTrue(hasattr(checker, "is_internet_up"), + "missing meth") + self.assertTrue(hasattr(checker, "ping_gateway"), + "missing meth") + self.assertTrue(hasattr(checker, "parse_log_and_react"), + "missing meth") + + def test_checker_should_actually_call_all_tests(self): + checker = checks.LeapNetworkChecker() + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + self.assertTrue(mc.parse_log_and_react.called, "not called") + + # ping gateway only called if we pass provider_gw + checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0") + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.ping_gateway.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + self.assertTrue(mc.parse_log_and_react.called, "not called") + + def test_get_default_interface_no_interface(self): + checker = checks.LeapNetworkChecker() + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.NoDefaultInterfaceFoundError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT") + checker.get_default_interface_gateway() + + def test_check_tunnel_default_interface(self): + checker = checks.LeapNetworkChecker() + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.TunnelNotDefaultRouteError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT\n" + "wlan0\t00000000\t0102A8C0\t" + "0003\t0\t0\t0\t00000000\t0\t0\t0") + checker.check_tunnel_default_interface() + + with patch('leap.base.checks.open', create=True) as mock_open: + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT\n" + "tun0\t00000000\t01002A0A\t0003\t0\t0\t0\t00000080\t0\t0\t0") + checker.check_tunnel_default_interface() + + def test_ping_gateway_fail(self): + checker = checks.LeapNetworkChecker() + with patch.object(sh, "ping") as mocked_ping: + with self.assertRaises(exceptions.NoConnectionToGateway): + mocked_ping.return_value = Mock + mocked_ping.return_value.stdout = "11% packet loss" + checker.ping_gateway("4.2.2.2") + + def test_ping_gateway(self): + checker = checks.LeapNetworkChecker() + with patch.object(sh, "ping") as mocked_ping: + mocked_ping.return_value = Mock + mocked_ping.return_value.stdout = """ +PING 4.2.2.2 (4.2.2.2) 56(84) bytes of data. +64 bytes from 4.2.2.2: icmp_req=1 ttl=54 time=33.8 ms +64 bytes from 4.2.2.2: icmp_req=2 ttl=54 time=30.6 ms +64 bytes from 4.2.2.2: icmp_req=3 ttl=54 time=31.4 ms +64 bytes from 4.2.2.2: icmp_req=4 ttl=54 time=36.1 ms +64 bytes from 4.2.2.2: icmp_req=5 ttl=54 time=30.8 ms +64 bytes from 4.2.2.2: icmp_req=6 ttl=54 time=30.4 ms +64 bytes from 4.2.2.2: icmp_req=7 ttl=54 time=30.7 ms +64 bytes from 4.2.2.2: icmp_req=8 ttl=54 time=32.7 ms +64 bytes from 4.2.2.2: icmp_req=9 ttl=54 time=31.4 ms +64 bytes from 4.2.2.2: icmp_req=10 ttl=54 time=33.3 ms + +--- 4.2.2.2 ping statistics --- +10 packets transmitted, 10 received, 0% packet loss, time 9016ms +rtt min/avg/max/mdev = 30.497/32.172/36.161/1.755 ms""" + checker.ping_gateway("4.2.2.2") + + def test_check_internet_connection_failures(self): + checker = checks.LeapNetworkChecker() + TimeoutError = get_ping_timeout_error() + with patch.object(sh, "ping") as mocked_ping: + mocked_ping.side_effect = TimeoutError + with self.assertRaises(exceptions.NoInternetConnection): + with patch.object(checker, "ping_gateway") as mock_gateway: + mock_gateway.side_effect = exceptions.NoConnectionToGateway + checker.check_internet_connection() + + with patch.object(sh, "ping") as mocked_ping: + mocked_ping.side_effect = TimeoutError + with self.assertRaises(exceptions.NoInternetConnection): + with patch.object(checker, "ping_gateway") as mock_gateway: + mock_gateway.return_value = True + checker.check_internet_connection() + + def test_parse_log_and_react(self): + checker = checks.LeapNetworkChecker() + to_call = Mock() + log = [("leap.openvpn - INFO - Mon Nov 19 13:36:24 2012 " + "read UDPv4 [ECONNREFUSED]: Connection refused (code=111)")] + err_matrix = [(checks.EVENT_CONNECT_REFUSED, (to_call, ))] + checker.parse_log_and_react(log, err_matrix) + self.assertTrue(to_call.called) + + log = [("2012-11-19 13:36:26,177 - leap.openvpn - INFO - " + "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command " + "failed: external program exited"), + ("2012-11-19 13:36:26,178 - leap.openvpn - INFO - " + "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command " + "failed: external program exited"), + ("2012-11-19 13:36:26,180 - leap.openvpn - INFO - " + "Mon Nov 19 13:36:24 2012 ERROR: Linux route delete command " + "failed: external program exited"), + ("2012-11-19 13:36:26,181 - leap.openvpn - INFO - " + "Mon Nov 19 13:36:24 2012 /sbin/ifconfig tun0 0.0.0.0"), + ("2012-11-19 13:36:26,182 - leap.openvpn - INFO - " + "Mon Nov 19 13:36:24 2012 Linux ip addr del failed: external " + "program exited with error stat"), + ("2012-11-19 13:36:26,183 - leap.openvpn - INFO - " + "Mon Nov 19 13:36:26 2012 SIGTERM[hard,] received, process" + "exiting"), ] + to_call.reset_mock() + checker.parse_log_and_react(log, err_matrix) + self.assertFalse(to_call.called) + + to_call.reset_mock() + checker.parse_log_and_react([], err_matrix) + self.assertFalse(to_call.called) + + +def get_ping_timeout_error(): + try: + sh.ping("-c", "1", "-w", "1", "8.8.7.7") + except Exception as e: + return e diff --git a/src/leap/base/tests/test_config.py b/src/leap/base/tests/test_config.py new file mode 100644 index 00000000..d03149b2 --- /dev/null +++ b/src/leap/base/tests/test_config.py @@ -0,0 +1,247 @@ +import json +import os +import platform +import socket +#import tempfile + +import mock +import requests + +from leap.base import config +from leap.base import constants +from leap.base import exceptions +from leap.eip import constants as eipconstants +from leap.util.fileutil import mkdir_p +from leap.testing.basetest import BaseLeapTest + + +try: + import unittest2 as unittest +except ImportError: + import unittest + +_system = platform.system() + + +class JSONLeapConfigTest(BaseLeapTest): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_metaclass(self): + with self.assertRaises(exceptions.ImproperlyConfigured) as exc: + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + exc.startswith("missing spec dict") + + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + spec = {'properties': {}} + with self.assertRaises(exceptions.ImproperlyConfigured) as exc: + DummyTestConfig() + exc.startswith("missing slug") + + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + spec = {'properties': {}} + slug = "foo" + DummyTestConfig() + +######################################3 +# +# provider fetch tests block +# + + +class ProviderTest(BaseLeapTest): + # override per test fixtures + + def setUp(self): + pass + + def tearDown(self): + pass + + +# XXX depreacated. similar test in eip.checks + +#class BareHomeTestCase(ProviderTest): +# + #__name__ = "provider_config_tests_bare_home" +# + #def test_should_raise_if_missing_eip_json(self): + #with self.assertRaises(exceptions.MissingConfigFileError): + #config.get_config_json(os.path.join(self.home, 'eip.json')) + + +class ProviderDefinitionTestCase(ProviderTest): + # XXX MOVE TO eip.test_checks + # -- kali 2012-08-24 00:38 + + __name__ = "provider_config_tests" + + def setUp(self): + # dump a sample eip file + # XXX Move to Use EIP Spec Instead!!! + # XXX tests to be moved to eip.checks and eip.providers + # XXX can use eipconfig.dump_default_eipconfig + + path = os.path.join(self.home, '.config', 'leap') + mkdir_p(path) + with open(os.path.join(path, 'eip.json'), 'w') as fp: + json.dump(eipconstants.EIP_SAMPLE_JSON, fp) + + +# these tests below should move to +# eip.checks +# config.Configuration has been deprecated + +# TODO: +# - We're instantiating a ProviderTest because we're doing the home wipeoff +# on setUpClass instead of the setUp (for speedup of the general cases). + +# We really should be testing all of them in the same testCase, and +# doing an extra wipe of the tempdir... but be careful!!!! do not mess with +# os.environ home more than needed... that could potentially bite! + +# XXX actually, another thing to fix here is separating tests: +# - test that requests has been called. +# - check deeper for error types/msgs + +# we SHOULD inject requests dep in the constructor +# (so we can pass mock easily). + + +#class ProviderFetchConError(ProviderTest): + #def test_connection_error(self): + #with mock.patch.object(requests, "get") as mock_method: + #mock_method.side_effect = requests.ConnectionError + #cf = config.Configuration() + #self.assertIsInstance(cf.error, str) +# +# +#class ProviderFetchHttpError(ProviderTest): + #def test_file_not_found(self): + #with mock.patch.object(requests, "get") as mock_method: + #mock_method.side_effect = requests.HTTPError + #cf = config.Configuration() + #self.assertIsInstance(cf.error, str) +# +# +#class ProviderFetchInvalidUrl(ProviderTest): + #def test_invalid_url(self): + #cf = config.Configuration("ht") + #self.assertTrue(cf.error) + + +# end provider fetch tests +########################################### + + +class ConfigHelperFunctions(BaseLeapTest): + + __name__ = "config_helper_tests" + + def setUp(self): + pass + + def tearDown(self): + pass + + # tests + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_lin_get_config_file(self): + """ + config file path where expected? (linux) + """ + self.assertEqual( + config.get_config_file( + 'test', folder="foo/bar"), + os.path.expanduser( + '~/.config/leap/foo/bar/test') + ) + + @unittest.skipUnless(_system == "Darwin", "mac only") + def test_mac_get_config_file(self): + """ + config file path where expected? (mac) + """ + self._missing_test_for_plat(do_raise=True) + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_get_config_file(self): + """ + config file path where expected? + """ + self._missing_test_for_plat(do_raise=True) + + # + # XXX hey, I'm raising exceptions here + # on purpose. just wanted to make sure + # that the skip stuff is doing it right. + # If you're working on win/macos tests, + # feel free to remove tests that you see + # are too redundant. + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_lin_get_config_dir(self): + """ + nice config dir? (linux) + """ + self.assertEqual( + config.get_config_dir(), + os.path.expanduser('~/.config/leap')) + + @unittest.skipUnless(_system == "Darwin", "mac only") + def test_mac_get_config_dir(self): + """ + nice config dir? (mac) + """ + self._missing_test_for_plat(do_raise=True) + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_get_config_dir(self): + """ + nice config dir? (win) + """ + self._missing_test_for_plat(do_raise=True) + + # provider paths + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_get_default_provider_path(self): + """ + is default provider path ok? + """ + self.assertEqual( + config.get_default_provider_path(), + os.path.expanduser( + '~/.config/leap/providers/%s/' % + constants.DEFAULT_PROVIDER) + ) + + # validate ip + + def test_validate_ip(self): + """ + check our ip validation + """ + config.validate_ip('3.3.3.3') + with self.assertRaises(socket.error): + config.validate_ip('255.255.255.256') + with self.assertRaises(socket.error): + config.validate_ip('foobar') + + @unittest.skip + def test_validate_domain(self): + """ + code to be written yet + """ + raise NotImplementedError + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py new file mode 100644 index 00000000..92bc1f2f --- /dev/null +++ b/src/leap/base/tests/test_providers.py @@ -0,0 +1,148 @@ +import copy +import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +from leap.base.pluggableconfig import ValidationError +from leap.testing.basetest import BaseLeapTest +from leap.base import providers + + +EXPECTED_DEFAULT_CONFIG = { + u"api_version": u"0.1.0", + #u"description": "LEAPTranslatable<{u'en': u'Test provider'}>", + u"description": {u'en': u'Test provider'}, + u"default_language": u"en", + #u"display_name": {u'en': u"Test Provider"}, + u"domain": u"testprovider.example.org", + #u'name': "LEAPTranslatable<{u'en': u'Test Provider'}>", + u'name': {u'en': u'Test Provider'}, + u"enrollment_policy": u"open", + #u"serial": 1, + u"services": [ + u"eip" + ], + u"languages": [u"en"], + u"version": u"0.1.0" +} + + +class TestLeapProviderDefinition(BaseLeapTest): + def setUp(self): + self.domain = "testprovider.example.org" + self.definition = providers.LeapProviderDefinition( + domain=self.domain) + self.definition.save(force=True) + self.definition.load() # why have to load after save?? + self.config = self.definition.config + + def tearDown(self): + if hasattr(self, 'testfile') and os.path.isfile(self.testfile): + os.remove(self.testfile) + + # tests + + # XXX most of these tests can be made more abstract + # and moved to test_baseconfig *triangulate!* + + def test_provider_slug_property(self): + slug = self.definition.slug + self.assertEquals( + slug, + os.path.join( + self.home, + '.config', 'leap', 'providers', + '%s' % self.domain, + 'provider.json')) + with self.assertRaises(AttributeError): + self.definition.slug = 23 + + def test_provider_dump(self): + # check a good provider definition is dumped to disk + self.testfile = self.get_tempfile('test.json') + self.definition.save(to=self.testfile, force=True) + deserialized = json.load(open(self.testfile, 'rb')) + self.maxDiff = None + #import ipdb;ipdb.set_trace() + self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) + + def test_provider_dump_to_slug(self): + # same as above, but we test the ability to save to a + # file generated from the slug. + # XXX THIS TEST SHOULD MOVE TO test_baseconfig + self.definition.save() + filename = self.definition.filename + self.assertTrue(os.path.isfile(filename)) + deserialized = json.load(open(filename, 'rb')) + self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) + + def test_provider_load(self): + # check loading provider from disk file + self.testfile = self.get_tempfile('test_load.json') + with open(self.testfile, 'w') as wf: + wf.write(json.dumps(EXPECTED_DEFAULT_CONFIG)) + self.definition.load(fromfile=self.testfile) + #self.assertDictEqual(self.config, + #EXPECTED_DEFAULT_CONFIG) + self.assertItemsEqual(self.config, EXPECTED_DEFAULT_CONFIG) + + def test_provider_validation(self): + self.definition.validate(self.config) + _config = copy.deepcopy(self.config) + # bad type, raise validation error + _config['domain'] = 111 + with self.assertRaises(ValidationError): + self.definition.validate(_config) + + @unittest.skip + def test_load_malformed_json_definition(self): + raise NotImplementedError + + @unittest.skip + def test_type_validation(self): + # check various type validation + # type cast + raise NotImplementedError + + +class TestLeapProviderSet(BaseLeapTest): + + def setUp(self): + self.providers = providers.LeapProviderSet() + + def tearDown(self): + pass + ### + + def test_get_zero_count(self): + self.assertEqual(self.providers.count, 0) + + @unittest.skip + def test_count_defined_providers(self): + # check the method used for making + # the list of providers + raise NotImplementedError + + @unittest.skip + def test_get_default_provider(self): + raise NotImplementedError + + @unittest.skip + def test_should_be_at_least_one_provider_after_init(self): + # when we init an empty environment, + # there should be at least one provider, + # that will be a dump of the default provider definition + # somehow a high level test + raise NotImplementedError + + @unittest.skip + def test_get_eip_remote_from_default_provider(self): + # from: default provider + # expect: remote eip domain + raise NotImplementedError + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/base/tests/test_validation.py b/src/leap/base/tests/test_validation.py new file mode 100644 index 00000000..b45fbe3a --- /dev/null +++ b/src/leap/base/tests/test_validation.py @@ -0,0 +1,93 @@ +import copy +import datetime +from functools import partial +#import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +from leap.base.config import JSONLeapConfig +from leap.base import pluggableconfig +from leap.testing.basetest import BaseLeapTest + +SAMPLE_CONFIG_DICT = { + 'prop_one': 1, + 'prop_uri': "http://example.org", + 'prop_date': '2012-12-12', +} + +EXPECTED_CONFIG = { + 'prop_one': 1, + 'prop_uri': "http://example.org", + 'prop_date': datetime.datetime(2012, 12, 12) +} + +sample_spec = { + 'description': 'sample schema definition', + 'type': 'object', + 'properties': { + 'prop_one': { + 'type': int, + 'default': 1, + 'required': True + }, + 'prop_uri': { + 'type': str, + 'default': 'http://example.org', + 'required': True, + 'format': 'uri' + }, + 'prop_date': { + 'type': str, + 'default': '2012-12-12', + 'format': 'date' + } + } +} + + +class SampleConfig(JSONLeapConfig): + spec = sample_spec + + @property + def slug(self): + return os.path.expanduser('~/sampleconfig.json') + + +class TestJSONLeapConfigValidation(BaseLeapTest): + def setUp(self): + self.sampleconfig = SampleConfig() + self.sampleconfig.save() + self.sampleconfig.load() + self.config = self.sampleconfig.config + + def tearDown(self): + if hasattr(self, 'testfile') and os.path.isfile(self.testfile): + os.remove(self.testfile) + + # tests + + def test_good_validation(self): + self.sampleconfig.validate(SAMPLE_CONFIG_DICT) + + def test_broken_int(self): + _config = copy.deepcopy(SAMPLE_CONFIG_DICT) + _config['prop_one'] = '1' + self.assertRaises( + pluggableconfig.ValidationError, + partial(self.sampleconfig.validate, _config)) + + def test_format_property(self): + # JsonSchema Validator does not check the format property. + # We should have to extend the Configuration class + blah = copy.deepcopy(SAMPLE_CONFIG_DICT) + blah['prop_uri'] = 'xxx' + self.assertRaises( + pluggableconfig.TypeCastException, + partial(self.sampleconfig.validate, blah)) + + +if __name__ == "__main__": + unittest.main() |