diff options
Diffstat (limited to 'src/leap/base')
-rw-r--r-- | src/leap/base/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/base/auth.py | 376 | ||||
-rw-r--r-- | src/leap/base/authentication.py | 11 | ||||
-rw-r--r-- | src/leap/base/checks.py | 127 | ||||
-rw-r--r-- | src/leap/base/config.py | 279 | ||||
-rw-r--r-- | src/leap/base/connection.py | 115 | ||||
-rw-r--r-- | src/leap/base/constants.py | 32 | ||||
-rw-r--r-- | src/leap/base/exceptions.py | 77 | ||||
-rw-r--r-- | src/leap/base/network.py | 84 | ||||
-rw-r--r-- | src/leap/base/pluggableconfig.py | 421 | ||||
-rw-r--r-- | src/leap/base/providers.py | 29 | ||||
-rw-r--r-- | src/leap/base/specs.py | 59 | ||||
-rw-r--r-- | src/leap/base/tests/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/base/tests/test_checks.py | 124 | ||||
-rw-r--r-- | src/leap/base/tests/test_config.py | 247 | ||||
-rw-r--r-- | src/leap/base/tests/test_providers.py | 143 | ||||
-rw-r--r-- | src/leap/base/tests/test_validation.py | 92 |
17 files changed, 2216 insertions, 0 deletions
diff --git a/src/leap/base/__init__.py b/src/leap/base/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/base/__init__.py diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py new file mode 100644 index 00000000..50533278 --- /dev/null +++ b/src/leap/base/auth.py @@ -0,0 +1,376 @@ +import binascii +import json +import logging +#import urlparse + +import requests +import srp + +from PyQt4 import QtCore + +from leap.base import constants as baseconstants +from leap.crypto import leapkeyring +from leap.util.web import get_https_domain_and_port + +logger = logging.getLogger(__name__) + +SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5) + +""" +Registration and authentication classes for the +SRP auth mechanism used in the leap platform. + +We're using the srp library which uses a c-based implementation +of the protocol if the c extension is available, and a python-based +one if not. +""" + + +class ImproperlyConfigured(Exception): + """ + """ + + +class SRPAuthenticationError(Exception): + """ + exception raised + for authentication errors + """ + + +def null_check(value, value_name): + try: + assert value is not None + except AssertionError: + raise ImproperlyConfigured( + "%s parameter cannot be None" % value_name) + + +safe_unhexlify = lambda x: binascii.unhexlify(x) \ + if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) + + +class LeapSRPRegister(object): + + def __init__(self, + schema="https", + provider=None, + port=None, + verify=True, + register_path="1/users.json", + method="POST", + fetcher=requests, + srp=srp, + hashfun=srp.SHA256, + ng_constant=srp.NG_1024): + + null_check(provider, provider) + + self.schema = schema + + # XXX FIXME + self.provider = provider + self.port = port + # XXX splitting server,port + # deprecate port call. + domain, port = get_https_domain_and_port(provider) + self.provider = domain + self.port = port + + self.verify = verify + self.register_path = register_path + self.method = method + self.fetcher = fetcher + self.srp = srp + self.HASHFUN = hashfun + self.NG = ng_constant + + self.init_session() + + def init_session(self): + self.session = self.fetcher.session() + + def get_registration_uri(self): + # XXX assert is https! + # use urlparse + if self.port: + uri = "%s://%s:%s/%s" % ( + self.schema, + self.provider, + self.port, + self.register_path) + else: + uri = "%s://%s/%s" % ( + self.schema, + self.provider, + self.register_path) + + return uri + + def register_user(self, username, password, keep=False): + """ + @rtype: tuple + @rparam: (ok, request) + """ + salt, vkey = self.srp.create_salted_verification_key( + username, + password, + self.HASHFUN, + self.NG) + + user_data = { + 'user[login]': username, + 'user[password_verifier]': binascii.hexlify(vkey), + 'user[password_salt]': binascii.hexlify(salt)} + + uri = self.get_registration_uri() + logger.debug('post to uri: %s' % uri) + + # XXX get self.method + req = self.session.post( + uri, data=user_data, + timeout=SIGNUP_TIMEOUT, + verify=self.verify) + logger.debug(req) + logger.debug('user_data: %s', user_data) + #logger.debug('response: %s', req.text) + # we catch it in the form + #req.raise_for_status() + return (req.ok, req) + + +class SRPAuth(requests.auth.AuthBase): + + def __init__(self, username, password, server=None, verify=None): + # sanity check + null_check(server, 'server') + self.username = username + self.password = password + self.server = server + self.verify = verify + + self.init_data = None + self.session = requests.session() + + self.init_srp() + + def get_json_data(self, response): + return json.loads(response.content) + + def init_srp(self): + usr = srp.User( + self.username, + self.password, + srp.SHA256, + srp.NG_1024) + uname, A = usr.start_authentication() + + self.srp_usr = usr + self.A = A + + def get_auth_data(self): + return { + 'login': self.username, + 'A': binascii.hexlify(self.A) + } + + def get_init_data(self): + try: + init_session = self.session.post( + self.server + '/1/sessions.json/', + data=self.get_auth_data(), + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (salt).") + if init_session.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (salt).") + + # XXX should get auth_result.json instead + self.init_data = self.get_json_data(init_session) + return self.init_data + + def get_server_proof_data(self): + try: + auth_result = self.session.put( + #self.server + '/1/sessions.json/' + self.username, + self.server + '/1/sessions/' + self.username, + data={'client_auth': binascii.hexlify(self.M)}, + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (HAMK).") + + if auth_result.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (HAMK).") + + # XXX should get auth_result.json instead + try: + self.auth_data = self.get_json_data(auth_result) + except ValueError: + raise SRPAuthenticationError( + "No valid data sent (HAMK)") + + return self.auth_data + + def authenticate(self): + logger.debug('start authentication...') + + init_data = self.get_init_data() + salt = init_data.get('salt', None) + B = init_data.get('B', None) + + # XXX refactor this function + # move checks and un-hex + # to routines + + if not salt or not B: + raise SRPAuthenticationError( + "Server did not send initial data.") + + try: + unhex_salt = safe_unhexlify(salt) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (salt)") + try: + unhex_B = safe_unhexlify(B) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (B)") + + self.M = self.srp_usr.process_challenge( + unhex_salt, + unhex_B + ) + + proof_data = self.get_server_proof_data() + + HAMK = proof_data.get("M2", None) + if not HAMK: + errors = proof_data.get('errors', None) + if errors: + logger.error(errors) + raise SRPAuthenticationError("Server did not send HAMK.") + + try: + unhex_HAMK = safe_unhexlify(HAMK) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (HAMK)") + + self.srp_usr.verify_session( + unhex_HAMK) + + try: + assert self.srp_usr.authenticated() + logger.debug('user is authenticated!') + except (AssertionError): + raise SRPAuthenticationError( + "Auth verification failed.") + + def __call__(self, req): + self.authenticate() + req.session = self.session + return req + + +def srpauth_protected(user=None, passwd=None, server=None, verify=True): + """ + decorator factory that accepts + user and password keyword arguments + and add those to the decorated request + """ + def srpauth(fn): + def wrapper(*args, **kwargs): + if user and passwd: + auth = SRPAuth(user, passwd, server, verify) + kwargs['auth'] = auth + kwargs['verify'] = verify + return fn(*args, **kwargs) + return wrapper + return srpauth + + +def get_leap_credentials(): + settings = QtCore.QSettings() + full_username = settings.value('eip_username') + username, domain = full_username.split('@') + seed = settings.value('%s_seed' % domain, None) + password = leapkeyring.leap_get_password(full_username, seed=seed) + return (username, password) + + +# XXX TODO +# Pass verify as single argument, +# in srpauth_protected style + +def magick_srpauth(fn): + """ + decorator that gets user and password + from the config file and adds those to + the decorated request + """ + logger.debug('magick srp auth decorator called') + + def wrapper(*args, **kwargs): + #uri = args[0] + # XXX Ugh! + # Problem with this approach. + # This won't work when we're using + # api.foo.bar + # Unless we keep a table with the + # equivalencies... + user, passwd = get_leap_credentials() + + # XXX pass verify and server too + # (pop) + auth = SRPAuth(user, passwd) + kwargs['auth'] = auth + return fn(*args, **kwargs) + return wrapper + + +if __name__ == "__main__": + """ + To test against test_provider (twisted version) + Register an user: (will be valid during the session) + >>> python auth.py add test password + + Test login with that user: + >>> python auth.py login test password + """ + + import sys + + if len(sys.argv) not in (4, 5): + print 'Usage: auth <add|login> <user> <pass> [server]' + sys.exit(0) + + action = sys.argv[1] + user = sys.argv[2] + passwd = sys.argv[3] + + if len(sys.argv) == 5: + SERVER = sys.argv[4] + else: + SERVER = "https://localhost:8443" + + if action == "login": + + @srpauth_protected( + user=user, passwd=passwd, server=SERVER, verify=False) + def test_srp_protected_get(*args, **kwargs): + req = requests.get(*args, **kwargs) + req.raise_for_status + return req + + req = test_srp_protected_get('https://localhost:8443/1/cert') + print 'cert :', req.content[:200] + "..." + sys.exit(0) + + if action == "add": + auth = LeapSRPRegister(provider=SERVER, verify=False) + auth.register_user(user, passwd) diff --git a/src/leap/base/authentication.py b/src/leap/base/authentication.py new file mode 100644 index 00000000..09ff1d07 --- /dev/null +++ b/src/leap/base/authentication.py @@ -0,0 +1,11 @@ +""" +Authentication Base Class +""" + + +class Authentication(object): + """ + I have no idea how Authentication (certs,?) + will be done, but stub it here. + """ + pass diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py new file mode 100644 index 00000000..23446f4a --- /dev/null +++ b/src/leap/base/checks.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +import logging +import platform +import socket + +import netifaces +import ping +import requests + +from leap.base import constants +from leap.base import exceptions + +logger = logging.getLogger(name=__name__) + + +class LeapNetworkChecker(object): + """ + all network related checks + """ + def __init__(self, *args, **kwargs): + provider_gw = kwargs.pop('provider_gw', None) + self.provider_gateway = provider_gw + + def run_all(self, checker=None): + if not checker: + checker = self + #self.error = None # ? + + # for MVS + checker.check_tunnel_default_interface() + checker.check_internet_connection() + checker.is_internet_up() + + if self.provider_gateway: + checker.ping_gateway(self.provider_gateway) + + def check_internet_connection(self): + try: + # XXX remove this hardcoded random ip + # ping leap.se or eip provider instead...? + requests.get('http://216.172.161.165') + + except (requests.HTTPError, requests.RequestException) as e: + raise exceptions.NoInternetConnection(e.message) + except requests.ConnectionError as e: + error = "Unidentified Connection Error" + if e.message == "[Errno 113] No route to host": + if not self.is_internet_up(): + error = "No valid internet connection found." + else: + error = "Provider server appears to be down." + logger.error(error) + raise exceptions.NoInternetConnection(error) + logger.debug('Network appears to be up.') + + def is_internet_up(self): + iface, gateway = self.get_default_interface_gateway() + self.ping_gateway(self.provider_gateway) + + def check_tunnel_default_interface(self): + """ + Raises an TunnelNotDefaultRouteError + (including when no routes are present) + """ + if not platform.system() == "Linux": + raise NotImplementedError + + f = open("/proc/net/route") + route_table = f.readlines() + f.close() + #toss out header + route_table.pop(0) + + if not route_table: + raise exceptions.TunnelNotDefaultRouteError() + + line = route_table.pop(0) + iface, destination = line.split('\t')[0:2] + if not destination == '00000000' or not iface == 'tun0': + raise exceptions.TunnelNotDefaultRouteError() + + def get_default_interface_gateway(self): + """only impletemented for linux so far.""" + if not platform.system() == "Linux": + raise NotImplementedError + + # XXX use psutil + f = open("/proc/net/route") + route_table = f.readlines() + f.close() + #toss out header + route_table.pop(0) + + default_iface = None + gateway = None + while route_table: + line = route_table.pop(0) + iface, destination, gateway = line.split('\t')[0:3] + if destination == '00000000': + default_iface = iface + break + + if not default_iface: + raise exceptions.NoDefaultInterfaceFoundError + + if default_iface not in netifaces.interfaces(): + raise exceptions.InterfaceNotFoundError + + return default_iface, gateway + + def ping_gateway(self, gateway): + # TODO: Discuss how much packet loss (%) is acceptable. + + # XXX -- validate gateway + # -- is it a valid ip? (there's something in util) + # -- is it a domain? + # -- can we resolve? -- raise NoDNSError if not. + packet_loss = ping.quiet_ping(gateway)[0] + if packet_loss > constants.MAX_ICMP_PACKET_LOSS: + raise exceptions.NoConnectionToGateway + + def check_name_resolution(self, domain_name): + try: + socket.gethostbyname(domain_name) + return True + except socket.gaierror: + raise exceptions.CannotResolveDomainError diff --git a/src/leap/base/config.py b/src/leap/base/config.py new file mode 100644 index 00000000..0255fbab --- /dev/null +++ b/src/leap/base/config.py @@ -0,0 +1,279 @@ +""" +Configuration Base Class +""" +import grp +import json +import logging +import socket +import tempfile +import os + +logger = logging.getLogger(name=__name__) + +import requests + +from leap.base import exceptions +from leap.base import constants +from leap.base.pluggableconfig import PluggableConfig +from leap.util.fileutil import (mkdir_p) + +# move to base! +from leap.eip import exceptions as eipexceptions + + +class BaseLeapConfig(object): + slug = None + + # XXX we have to enforce that every derived class + # has a slug (via interface) + # get property getter that raises NI.. + + def save(self): + raise NotImplementedError("abstract base class") + + def load(self): + raise NotImplementedError("abstract base class") + + def get_config(self, *kwargs): + raise NotImplementedError("abstract base class") + + @property + def config(self): + return self.get_config() + + def get_value(self, *kwargs): + raise NotImplementedError("abstract base class") + + +class MetaConfigWithSpec(type): + """ + metaclass for JSONLeapConfig classes. + It creates a configuration spec out of + the `spec` dictionary. The `properties` attribute + of the spec dict is turn into the `schema` attribute + of the new class (which will be used to validate against). + """ + # XXX in the near future, this is the + # place where we want to enforce + # singletons, read-only and similar stuff. + + def __new__(meta, classname, bases, classDict): + schema_obj = classDict.get('spec', None) + + # not quite happy with this workaround. + # I want to raise if missing spec dict, but only + # for grand-children of this metaclass. + # maybe should use abc module for this. + abcderived = ("JSONLeapConfig",) + if schema_obj is None and classname not in abcderived: + raise exceptions.ImproperlyConfigured( + "missing spec dict on your derived class (%s)" % classname) + + # we create a configuration spec attribute + # from the spec dict + config_class = type( + classname + "Spec", + (PluggableConfig, object), + {'options': schema_obj}) + classDict['spec'] = config_class + + return type.__new__(meta, classname, bases, classDict) + +########################################################## +# some hacking still in progress: + +# Configs have: + +# - a slug (from where a filename/folder is derived) +# - a spec (for validation and defaults). +# this spec is conformant to the json-schema. +# basically a dict that will be used +# for type casting and validation, and defaults settings. + +# all config objects, since they are derived from BaseConfig, implement basic +# useful methods: +# - save +# - load + +########################################################## + + +class JSONLeapConfig(BaseLeapConfig): + + __metaclass__ = MetaConfigWithSpec + + def __init__(self, *args, **kwargs): + # sanity check + try: + assert self.slug is not None + except AssertionError: + raise exceptions.ImproperlyConfigured( + "missing slug on JSONLeapConfig" + " derived class") + try: + assert self.spec is not None + except AssertionError: + raise exceptions.ImproperlyConfigured( + "missing spec on JSONLeapConfig" + " derived class") + assert issubclass(self.spec, PluggableConfig) + + self.domain = kwargs.pop('domain', None) + self._config = self.spec(format="json") + self._config.load() + self.fetcher = kwargs.pop('fetcher', requests) + + # mandatory baseconfig interface + + def save(self, to=None): + if to is None: + to = self.filename + folder, filename = os.path.split(to) + if folder and not os.path.isdir(folder): + mkdir_p(folder) + self._config.serialize(to) + + def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False): + if from_uri is not None: + fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify) + if fetched: + return + if fromfile is None: + fromfile = self.filename + if os.path.isfile(fromfile): + self._config.load(fromfile=fromfile) + else: + logger.error('tried to load config from non-existent path') + logger.error('Not Found: %s', fromfile) + + def fetch(self, uri, fetcher=None, verify=True): + if not fetcher: + fetcher = self.fetcher + logger.debug('verify: %s', verify) + logger.debug('uri: %s', uri) + request = fetcher.get(uri, verify=verify) + # XXX should send a if-modified-since header + + # XXX get 404, ... + # and raise a UnableToFetch... + request.raise_for_status() + fd, fname = tempfile.mkstemp(suffix=".json") + + if request.json: + self._config.load(json.dumps(request.json)) + + else: + # not request.json + # might be server did not announce content properly, + # let's try deserializing all the same. + try: + self._config.load(request.content) + except ValueError: + raise eipexceptions.LeapBadConfigFetchedError + + return True + + def get_config(self): + return self._config.config + + # public methods + + def get_filename(self): + return self._slug_to_filename() + + @property + def filename(self): + return self.get_filename() + + def validate(self, data): + logger.debug('validating schema') + self._config.validate(data) + return True + + # private + + def _slug_to_filename(self): + # is this going to work in winland if slug is "foo/bar" ? + folder, filename = os.path.split(self.slug) + config_file = get_config_file(filename, folder) + return config_file + + def exists(self): + return os.path.isfile(self.filename) + + +# +# utility functions +# +# (might be moved to some class as we see fit, but +# let's remain functional for a while) +# maybe base.config.util ?? +# + + +def get_config_dir(): + """ + get the base dir for all leap config + @rparam: config path + @rtype: string + """ + # TODO + # check for $XDG_CONFIG_HOME var? + # get a more sensible path for win/mac + # kclair: opinion? ^^ + + return os.path.expanduser( + os.path.join('~', + '.config', + 'leap')) + + +def get_config_file(filename, folder=None): + """ + concatenates the given filename + with leap config dir. + @param filename: name of the file + @type filename: string + @rparam: full path to config file + """ + path = [] + path.append(get_config_dir()) + if folder is not None: + path.append(folder) + path.append(filename) + return os.path.join(*path) + + +def get_default_provider_path(): + default_subpath = os.path.join("providers", + constants.DEFAULT_PROVIDER) + default_provider_path = get_config_file( + '', + folder=default_subpath) + return default_provider_path + + +def get_provider_path(domain): + # XXX if not domain, return get_default_provider_path + default_subpath = os.path.join("providers", domain) + provider_path = get_config_file( + '', + folder=default_subpath) + return provider_path + + +def validate_ip(ip_str): + """ + raises exception if the ip_str is + not a valid representation of an ip + """ + socket.inet_aton(ip_str) + + +def get_username(): + return os.getlogin() + + +def get_groupname(): + gid = os.getgroups()[-1] + return grp.getgrgid(gid).gr_name diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py new file mode 100644 index 00000000..41d13935 --- /dev/null +++ b/src/leap/base/connection.py @@ -0,0 +1,115 @@ +""" +Base Connection Classs +""" +from __future__ import (division, unicode_literals, print_function) + +import logging + +from leap.base.authentication import Authentication + +logger = logging.getLogger(name=__name__) + + +class Connection(Authentication): + # JSONLeapConfig + #spec = {} + + def __init__(self, *args, **kwargs): + self.connection_state = None + self.desired_connection_state = None + #XXX FIXME diamond inheritance gotcha.. + #If you inherit from >1 class, + #super is only initializing one + #of the bases..!! + # I think we better pass config as a constructor + # parameter -- kali 2012-08-30 04:33 + super(Connection, self).__init__(*args, **kwargs) + + def connect(self): + """ + entry point for connection process + """ + pass + + def disconnect(self): + """ + disconnects client + """ + pass + + #def shutdown(self): + #""" + #shutdown and quit + #""" + #self.desired_con_state = self.status.DISCONNECTED + + def connection_state(self): + """ + returns the current connection state + """ + return self.status.current + + def desired_connection_state(self): + """ + returns the desired_connection state + """ + return self.desired_connection_state + + def get_icon_name(self): + """ + get icon name from status object + """ + return self.status.get_state_icon() + + # + # private methods + # + + def _disconnect(self): + """ + private method for disconnecting + """ + if self.subp is not None: + self.subp.terminate() + self.subp = None + # XXX signal state changes! :) + + def _is_alive(self): + """ + don't know yet + """ + pass + + def _connect(self): + """ + entry point for connection cascade methods. + """ + #conn_result = ConState.DISCONNECTED + try: + conn_result = self._try_connection() + except UnrecoverableError as except_msg: + logger.error("FATAL: %s" % unicode(except_msg)) + conn_result = self.status.UNRECOVERABLE + except Exception as except_msg: + self.error_queue.append(except_msg) + logger.error("Failed Connection: %s" % + unicode(except_msg)) + return conn_result + + +class ConnectionError(Exception): + """ + generic connection error + """ + def __str__(self): + if len(self.args) >= 1: + return repr(self.args[0]) + else: + raise self() + + +class UnrecoverableError(ConnectionError): + """ + we cannot do anything about it, sorry + """ + pass diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py new file mode 100644 index 00000000..f7be8d98 --- /dev/null +++ b/src/leap/base/constants.py @@ -0,0 +1,32 @@ +"""constants to be used in base module""" +from leap import __branding +APP_NAME = __branding.get("short_name", "leap") + +# default provider placeholder +# using `example.org` we make sure that this +# is not going to be resolved during the tests phases +# (we expect testers to add it to their /etc/hosts + +DEFAULT_PROVIDER = __branding.get( + "provider_domain", + "testprovider.example.org") + +DEFINITION_EXPECTED_PATH = "provider.json" + +DEFAULT_PROVIDER_DEFINITION = { + u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER, + u'api_version': u'0.1.0', + u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e', + u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER, + u'description': {u'en': u'This is a test provider'}, + u'display_name': {u'en': u'Test Provider'}, + u'domain': u'%s' % DEFAULT_PROVIDER, + u'enrollment_policy': u'open', + u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21', + u'serial': 1, + u'services': [u'eip'], + u'version': u'0.1.0'} + +MAX_ICMP_PACKET_LOSS = 10 + +ROUTE_CHECK_INTERVAL = 10 diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py new file mode 100644 index 00000000..227da953 --- /dev/null +++ b/src/leap/base/exceptions.py @@ -0,0 +1,77 @@ +""" +Exception attributes and their meaning/uses +------------------------------------------- + +* critical: if True, will abort execution prematurely, + after attempting any cleaning + action. + +* failfirst: breaks any error_check loop that is examining + the error queue. + +* message: the message that will be used in the __repr__ of the exception. + +* usermessage: the message that will be passed to user in ErrorDialogs + in Qt-land. +""" + + +class LeapException(Exception): + """ + base LeapClient exception + sets some parameters that we will check + during error checking routines + """ + critical = False + failfirst = False + warning = False + + +class CriticalError(LeapException): + """ + we cannot do anything about it + """ + critical = True + failfirst = True + + +# In use ??? +# don't thing so. purge if not... + +class MissingConfigFileError(Exception): + pass + + +class ImproperlyConfigured(Exception): + pass + + +class NoDefaultInterfaceFoundError(LeapException): + message = "no default interface found" + usermessage = "Looks like your computer is not connected to the internet" + + +class InterfaceNotFoundError(LeapException): + # XXX should take iface arg on init maybe? + message = "interface not found" + + +class NoConnectionToGateway(CriticalError): + message = "no connection to gateway" + usermessage = "Looks like there are problems with your internet connection" + + +class NoInternetConnection(CriticalError): + message = "No Internet connection found" + usermessage = "It looks like there is no internet connection." + # and now we try to connect to our web to troubleshoot LOL :P + + +class CannotResolveDomainError(LeapException): + message = "Cannot resolve domain" + usermessage = "Domain cannot be found" + + +class TunnelNotDefaultRouteError(CriticalError): + message = "Tunnel connection dissapeared. VPN down?" + usermessage = "The Encrypted Connection was lost. Shutting down..." diff --git a/src/leap/base/network.py b/src/leap/base/network.py new file mode 100644 index 00000000..3aba3f61 --- /dev/null +++ b/src/leap/base/network.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +from __future__ import (print_function) +import logging +import threading + +from leap.eip.config import get_eip_gateway +from leap.base.checks import LeapNetworkChecker +from leap.base.constants import ROUTE_CHECK_INTERVAL +from leap.base.exceptions import TunnelNotDefaultRouteError +from leap.util.coroutines import (launch_thread, process_events) + +from time import sleep + +logger = logging.getLogger(name=__name__) + + +class NetworkCheckerThread(object): + """ + Manages network checking thread that makes sure we have a working network + connection. + """ + def __init__(self, *args, **kwargs): + self.status_signals = kwargs.pop('status_signals', None) + #self.watcher_cb = kwargs.pop('status_signals', None) + self.error_cb = kwargs.pop( + 'error_cb', + lambda exc: logger.error("%s", exc.message)) + self.shutdown = threading.Event() + + # XXX get provider_gateway and pass it to checker + # see in eip.config for function + # #718 + self.checker = LeapNetworkChecker( + provider_gw=get_eip_gateway()) + + def start(self): + self.process_handle = self._launch_recurrent_network_checks( + (self.error_cb,)) + + def stop(self): + self.shutdown.set() + logger.debug("network checked stopped.") + + def run_checks(self): + pass + + #private methods + + #here all the observers in fail_callbacks expect one positional argument, + #which is exception so we can try by passing a lambda with logger to + #check it works. + def _network_checks_thread(self, fail_callbacks): + #TODO: replace this with waiting for a signal from openvpn + while True: + try: + self.checker.check_tunnel_default_interface() + break + except TunnelNotDefaultRouteError: + # XXX ??? why do we sleep here??? + # aa: If the openvpn isn't up and running yet, + # let's give it a moment to breath. + sleep(1) + + fail_observer_dict = dict((( + observer, + process_events(observer)) for observer in fail_callbacks)) + while not self.shutdown.is_set(): + try: + self.checker.check_tunnel_default_interface() + self.checker.check_internet_connection() + sleep(ROUTE_CHECK_INTERVAL) + except Exception as exc: + for obs in fail_observer_dict: + fail_observer_dict[obs].send(exc) + sleep(ROUTE_CHECK_INTERVAL) + #reset event + self.shutdown.clear() + + def _launch_recurrent_network_checks(self, fail_callbacks): + #we need to wrap the fail callback in a tuple + watcher = launch_thread( + self._network_checks_thread, + (fail_callbacks,)) + return watcher diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py new file mode 100644 index 00000000..b8615ad8 --- /dev/null +++ b/src/leap/base/pluggableconfig.py @@ -0,0 +1,421 @@ +""" +generic configuration handlers +""" +import copy +import json +import logging +import os +import time +import urlparse + +import jsonschema + +logger = logging.getLogger(__name__) + + +__all__ = ['PluggableConfig', + 'adaptors', + 'types', + 'UnknownOptionException', + 'MissingValueException', + 'ConfigurationProviderException', + 'TypeCastException'] + +# exceptions + + +class UnknownOptionException(Exception): + """exception raised when a non-configuration + value is present in the configuration""" + + +class MissingValueException(Exception): + """exception raised when a required value is missing""" + + +class ConfigurationProviderException(Exception): + """exception raised when a configuration provider is missing, etc""" + + +class TypeCastException(Exception): + """exception raised when a + configuration item cannot be coerced to a type""" + + +class ConfigAdaptor(object): + """ + abstract base class for config adaotors for + serialization/deserialization and custom validation + and type casting. + """ + def read(self, filename): + raise NotImplementedError("abstract base class") + + def write(self, config, filename): + with open(filename, 'w') as f: + self._write(f, config) + + def _write(self, fp, config): + raise NotImplementedError("abstract base class") + + def validate(self, config, schema): + raise NotImplementedError("abstract base class") + + +adaptors = {} + + +class JSONSchemaEncoder(json.JSONEncoder): + """ + custom default encoder that + casts python objects to json objects for + the schema validation + """ + def default(self, obj): + if obj is str: + return 'string' + if obj is unicode: + return 'string' + if obj is int: + return 'integer' + if obj is list: + return 'array' + if obj is dict: + return 'object' + if obj is bool: + return 'boolean' + + +class JSONAdaptor(ConfigAdaptor): + indent = 2 + extensions = ['json'] + + def read(self, _from): + if isinstance(_from, file): + _from_string = _from.read() + if isinstance(_from, str): + _from_string = _from + return json.loads(_from_string) + + def _write(self, fp, config): + fp.write(json.dumps(config, + indent=self.indent, + sort_keys=True)) + + def validate(self, config, schema_obj): + schema_json = JSONSchemaEncoder().encode(schema_obj) + schema = json.loads(schema_json) + jsonschema.validate(config, schema) + + +adaptors['json'] = JSONAdaptor() + +# +# Adaptors +# +# Allow to apply a predefined set of types to the +# specs, so it checks the validity of formats and cast it +# to proper python types. + +# TODO: +# - multilingual object. +# - HTTPS uri + + +class DateType(object): + fmt = '%Y-%m-%d' + + def to_python(self, data): + return time.strptime(data, self.fmt) + + def get_prep_value(self, data): + return time.strftime(self.fmt, data) + + +class URIType(object): + + def to_python(self, data): + parsed = urlparse.urlparse(data) + if not parsed.scheme: + raise TypeCastException("uri %s has no schema" % data) + return parsed + + def get_prep_value(self, data): + return data.geturl() + + +class HTTPSURIType(object): + + def to_python(self, data): + parsed = urlparse.urlparse(data) + if not parsed.scheme: + raise TypeCastException("uri %s has no schema" % data) + if parsed.scheme != "https": + raise TypeCastException( + "uri %s does not has " + "https schema" % data) + return parsed + + def get_prep_value(self, data): + return data.geturl() + + +types = { + 'date': DateType(), + 'uri': URIType(), + 'https-uri': HTTPSURIType(), +} + + +class PluggableConfig(object): + + options = {} + + def __init__(self, + adaptors=adaptors, + types=types, + format=None): + + self.config = {} + self.adaptors = adaptors + self.types = types + self._format = format + + @property + def option_dict(self): + if hasattr(self, 'options') and isinstance(self.options, dict): + return self.options.get('properties', None) + + def items(self): + """ + act like an iterator + """ + if isinstance(self.option_dict, dict): + return self.option_dict.items() + return self.options + + def validate(self, config, format=None): + """ + validate config + """ + schema = self.options + if format is None: + format = self._format + + if format: + adaptor = self.get_adaptor(self._format) + adaptor.validate(config, schema) + else: + # we really should make format mandatory... + logger.error('no format passed to validate') + + # first round of validation is ok. + # now we proceed to cast types if any specified. + self.to_python(config) + + def to_python(self, config): + """ + cast types following first type and then format indications. + """ + unseen_options = [i for i in config if i not in self.option_dict] + if unseen_options: + raise UnknownOptionException( + "Unknown options: %s" % ', '.join(unseen_options)) + + for key, value in config.items(): + _type = self.option_dict[key].get('type') + if _type is None and 'default' in self.option_dict[key]: + _type = type(self.option_dict[key]['default']) + if _type is not None: + tocast = True + if not callable(_type) and isinstance(value, _type): + tocast = False + if tocast: + try: + config[key] = _type(value) + except BaseException, e: + raise TypeCastException( + "Could not coerce %s, %s, " + "to type %s: %s" % (key, value, _type.__name__, e)) + _format = self.option_dict[key].get('format', None) + _ftype = self.types.get(_format, None) + if _ftype: + try: + config[key] = _ftype.to_python(value) + except BaseException, e: + raise TypeCastException( + "Could not coerce %s, %s, " + "to format %s: %s" % (key, value, + _ftype.__class__.__name__, + e)) + + return config + + def prep_value(self, config): + """ + the inverse of to_python method, + called just before serialization + """ + for key, value in config.items(): + _format = self.option_dict[key].get('format', None) + _ftype = self.types.get(_format, None) + if _ftype and hasattr(_ftype, 'get_prep_value'): + try: + config[key] = _ftype.get_prep_value(value) + except BaseException, e: + raise TypeCastException( + "Could not serialize %s, %s, " + "by format %s: %s" % (key, value, + _ftype.__class__.__name__, + e)) + else: + config[key] = value + return config + + # methods for adding configuration + + def get_default_values(self): + """ + return a config options from configuration defaults + """ + defaults = {} + for key, value in self.items(): + if 'default' in value: + defaults[key] = value['default'] + return copy.deepcopy(defaults) + + def get_adaptor(self, format): + """ + get specified format adaptor or + guess for a given filename + """ + adaptor = self.adaptors.get(format, None) + if adaptor: + return adaptor + + # not registered in adaptors dict, let's try all + for adaptor in self.adaptors.values(): + if format in adaptor.extensions: + return adaptor + + def filename2format(self, filename): + extension = os.path.splitext(filename)[-1] + return extension.lstrip('.') or None + + def serialize(self, filename, format=None, full=False): + if not format: + format = self._format + if not format: + format = self.filename2format(filename) + if not format: + raise Exception('Please specify a format') + # TODO: more specific exception type + + adaptor = self.get_adaptor(format) + if not adaptor: + raise Exception("Adaptor not found for format: %s" % format) + + config = copy.deepcopy(self.config) + serializable = self.prep_value(config) + adaptor.write(serializable, filename) + + def deserialize(self, string=None, fromfile=None, format=None): + """ + load configuration from a file or string + """ + + def _try_deserialize(): + if fromfile: + with open(fromfile, 'r') as f: + content = adaptor.read(f) + elif string: + content = adaptor.read(string) + return content + + # XXX cleanup this! + + if fromfile: + assert os.path.exists(fromfile) + if not format: + format = self.filename2format(fromfile) + + if not format: + format = self._format + if format: + adaptor = self.get_adaptor(format) + else: + adaptor = None + + if adaptor: + content = _try_deserialize() + return content + + # no adaptor, let's try rest of adaptors + + adaptors = self.adaptors[:] + + if format: + adaptors.sort( + key=lambda x: int( + format in x.extensions), + reverse=True) + + for adaptor in adaptors: + content = _try_deserialize() + return content + + def load(self, *args, **kwargs): + """ + load from string or file + if no string of fromfile option is given, + it will attempt to load from defaults + defined in the schema. + """ + string = args[0] if args else None + fromfile = kwargs.get("fromfile", None) + content = None + + # start with defaults, so we can + # have partial values applied. + content = self.get_default_values() + if string and isinstance(string, str): + content = self.deserialize(string) + + if not string and fromfile is not None: + #import ipdb;ipdb.set_trace() + content = self.deserialize(fromfile=fromfile) + + if not content: + logger.error('no content could be loaded') + # XXX raise! + return + + # lazy evaluation until first level of nesting + # to allow lambdas with context-dependant info + # like os.path.expanduser + for k, v in content.iteritems(): + if callable(v): + content[k] = v() + + self.validate(content) + self.config = content + return True + + +def testmain(): + from tests import test_validation as t + import pprint + + config = PluggableConfig(_format="json") + properties = copy.deepcopy(t.sample_spec) + + config.options = properties + config.load(fromfile='data.json') + + print 'config' + pprint.pprint(config.config) + + config.serialize('/tmp/testserial.json') + +if __name__ == "__main__": + testmain() diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py new file mode 100644 index 00000000..d41f3695 --- /dev/null +++ b/src/leap/base/providers.py @@ -0,0 +1,29 @@ +"""all dealing with leap-providers: definition files, updating""" +from leap.base import config as baseconfig +from leap.base import specs + + +class LeapProviderDefinition(baseconfig.JSONLeapConfig): + spec = specs.leap_provider_spec + + def _get_slug(self): + domain = getattr(self, 'domain', None) + if domain: + path = baseconfig.get_provider_path(domain) + else: + path = baseconfig.get_default_provider_path() + + return baseconfig.get_config_file( + 'provider.json', folder=path) + + def _set_slug(self, *args, **kwargs): + raise AttributeError("you cannot set slug") + + slug = property(_get_slug, _set_slug) + + +class LeapProviderSet(object): + # we gather them from the filesystem + # TODO: (MVS+) + def __init__(self): + self.count = 0 diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py new file mode 100644 index 00000000..b4bb8dcf --- /dev/null +++ b/src/leap/base/specs.py @@ -0,0 +1,59 @@ +leap_provider_spec = { + 'description': 'provider definition', + 'type': 'object', + 'properties': { + 'serial': { + 'type': int, + 'default': 1, + 'required': True, + }, + 'version': { + 'type': unicode, + 'default': '0.1.0' + #'required': True + }, + 'domain': { + 'type': unicode, # XXX define uri type + 'default': 'testprovider.example.org' + #'required': True, + }, + 'display_name': { + 'type': dict, # XXX multilingual object? + 'default': {u'en': u'Test Provider'} + #'required': True + }, + 'description': { + 'type': dict, + 'default': {u'en': u'Test provider'} + }, + 'enrollment_policy': { + 'type': unicode, # oneof ?? + 'default': 'open' + }, + 'services': { + 'type': list, # oneof ?? + 'default': ['eip'] + }, + 'api_version': { + 'type': unicode, + 'default': '0.1.0' # version regexp + }, + 'api_uri': { + 'type': unicode # uri + }, + 'public_key': { + 'type': unicode # fingerprint + }, + 'ca_cert_fingerprint': { + 'type': unicode, + }, + 'ca_cert_uri': { + 'type': unicode, + 'format': 'https-uri' + }, + 'languages': { + 'type': list, + 'default': ['en'] + } + } +} diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/base/tests/__init__.py diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py new file mode 100644 index 00000000..8d573b1e --- /dev/null +++ b/src/leap/base/tests/test_checks.py @@ -0,0 +1,124 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +from mock import (patch, Mock) +from StringIO import StringIO + +import ping +import requests + +from leap.base import checks +from leap.base import exceptions +from leap.testing.basetest import BaseLeapTest + +_uid = os.getuid() + + +class LeapNetworkCheckTest(BaseLeapTest): + __name__ = "leap_network_check_tests" + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_checker_should_implement_check_methods(self): + checker = checks.LeapNetworkChecker() + + self.assertTrue(hasattr(checker, "check_internet_connection"), + "missing meth") + self.assertTrue(hasattr(checker, "check_tunnel_default_interface"), + "missing meth") + self.assertTrue(hasattr(checker, "is_internet_up"), + "missing meth") + self.assertTrue(hasattr(checker, "ping_gateway"), + "missing meth") + + def test_checker_should_actually_call_all_tests(self): + checker = checks.LeapNetworkChecker() + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + + # ping gateway only called if we pass provider_gw + checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0") + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.ping_gateway.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + + def test_get_default_interface_no_interface(self): + checker = checks.LeapNetworkChecker() + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.NoDefaultInterfaceFoundError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT") + checker.get_default_interface_gateway() + + def test_check_tunnel_default_interface(self): + checker = checks.LeapNetworkChecker() + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.TunnelNotDefaultRouteError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT") + checker.check_tunnel_default_interface() + + with patch('leap.base.checks.open', create=True) as mock_open: + with self.assertRaises(exceptions.TunnelNotDefaultRouteError): + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT\n" + "wlan0\t00000000\t0102A8C0\t" + "0003\t0\t0\t0\t00000000\t0\t0\t0") + checker.check_tunnel_default_interface() + + with patch('leap.base.checks.open', create=True) as mock_open: + mock_open.return_value = StringIO( + "Iface\tDestination Gateway\t" + "Flags\tRefCntd\tUse\tMetric\t" + "Mask\tMTU\tWindow\tIRTT\n" + "tun0\t00000000\t01002A0A\t0003\t0\t0\t0\t00000080\t0\t0\t0") + checker.check_tunnel_default_interface() + + def test_ping_gateway_fail(self): + checker = checks.LeapNetworkChecker() + with patch.object(ping, "quiet_ping") as mocked_ping: + with self.assertRaises(exceptions.NoConnectionToGateway): + mocked_ping.return_value = [11, "", ""] + checker.ping_gateway("4.2.2.2") + + def test_check_internet_connection_failures(self): + checker = checks.LeapNetworkChecker() + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.HTTPError + with self.assertRaises(exceptions.NoInternetConnection): + checker.check_internet_connection() + + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.RequestException + with self.assertRaises(exceptions.NoInternetConnection): + checker.check_internet_connection() + + #TODO: Mock possible errors that can be raised by is_internet_up + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.ConnectionError + with self.assertRaises(exceptions.NoInternetConnection): + checker.check_internet_connection() + + @unittest.skipUnless(_uid == 0, "root only") + def test_ping_gateway(self): + checker = checks.LeapNetworkChecker() + checker.ping_gateway("4.2.2.2") diff --git a/src/leap/base/tests/test_config.py b/src/leap/base/tests/test_config.py new file mode 100644 index 00000000..d03149b2 --- /dev/null +++ b/src/leap/base/tests/test_config.py @@ -0,0 +1,247 @@ +import json +import os +import platform +import socket +#import tempfile + +import mock +import requests + +from leap.base import config +from leap.base import constants +from leap.base import exceptions +from leap.eip import constants as eipconstants +from leap.util.fileutil import mkdir_p +from leap.testing.basetest import BaseLeapTest + + +try: + import unittest2 as unittest +except ImportError: + import unittest + +_system = platform.system() + + +class JSONLeapConfigTest(BaseLeapTest): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_metaclass(self): + with self.assertRaises(exceptions.ImproperlyConfigured) as exc: + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + exc.startswith("missing spec dict") + + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + spec = {'properties': {}} + with self.assertRaises(exceptions.ImproperlyConfigured) as exc: + DummyTestConfig() + exc.startswith("missing slug") + + class DummyTestConfig(config.JSONLeapConfig): + __metaclass__ = config.MetaConfigWithSpec + spec = {'properties': {}} + slug = "foo" + DummyTestConfig() + +######################################3 +# +# provider fetch tests block +# + + +class ProviderTest(BaseLeapTest): + # override per test fixtures + + def setUp(self): + pass + + def tearDown(self): + pass + + +# XXX depreacated. similar test in eip.checks + +#class BareHomeTestCase(ProviderTest): +# + #__name__ = "provider_config_tests_bare_home" +# + #def test_should_raise_if_missing_eip_json(self): + #with self.assertRaises(exceptions.MissingConfigFileError): + #config.get_config_json(os.path.join(self.home, 'eip.json')) + + +class ProviderDefinitionTestCase(ProviderTest): + # XXX MOVE TO eip.test_checks + # -- kali 2012-08-24 00:38 + + __name__ = "provider_config_tests" + + def setUp(self): + # dump a sample eip file + # XXX Move to Use EIP Spec Instead!!! + # XXX tests to be moved to eip.checks and eip.providers + # XXX can use eipconfig.dump_default_eipconfig + + path = os.path.join(self.home, '.config', 'leap') + mkdir_p(path) + with open(os.path.join(path, 'eip.json'), 'w') as fp: + json.dump(eipconstants.EIP_SAMPLE_JSON, fp) + + +# these tests below should move to +# eip.checks +# config.Configuration has been deprecated + +# TODO: +# - We're instantiating a ProviderTest because we're doing the home wipeoff +# on setUpClass instead of the setUp (for speedup of the general cases). + +# We really should be testing all of them in the same testCase, and +# doing an extra wipe of the tempdir... but be careful!!!! do not mess with +# os.environ home more than needed... that could potentially bite! + +# XXX actually, another thing to fix here is separating tests: +# - test that requests has been called. +# - check deeper for error types/msgs + +# we SHOULD inject requests dep in the constructor +# (so we can pass mock easily). + + +#class ProviderFetchConError(ProviderTest): + #def test_connection_error(self): + #with mock.patch.object(requests, "get") as mock_method: + #mock_method.side_effect = requests.ConnectionError + #cf = config.Configuration() + #self.assertIsInstance(cf.error, str) +# +# +#class ProviderFetchHttpError(ProviderTest): + #def test_file_not_found(self): + #with mock.patch.object(requests, "get") as mock_method: + #mock_method.side_effect = requests.HTTPError + #cf = config.Configuration() + #self.assertIsInstance(cf.error, str) +# +# +#class ProviderFetchInvalidUrl(ProviderTest): + #def test_invalid_url(self): + #cf = config.Configuration("ht") + #self.assertTrue(cf.error) + + +# end provider fetch tests +########################################### + + +class ConfigHelperFunctions(BaseLeapTest): + + __name__ = "config_helper_tests" + + def setUp(self): + pass + + def tearDown(self): + pass + + # tests + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_lin_get_config_file(self): + """ + config file path where expected? (linux) + """ + self.assertEqual( + config.get_config_file( + 'test', folder="foo/bar"), + os.path.expanduser( + '~/.config/leap/foo/bar/test') + ) + + @unittest.skipUnless(_system == "Darwin", "mac only") + def test_mac_get_config_file(self): + """ + config file path where expected? (mac) + """ + self._missing_test_for_plat(do_raise=True) + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_get_config_file(self): + """ + config file path where expected? + """ + self._missing_test_for_plat(do_raise=True) + + # + # XXX hey, I'm raising exceptions here + # on purpose. just wanted to make sure + # that the skip stuff is doing it right. + # If you're working on win/macos tests, + # feel free to remove tests that you see + # are too redundant. + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_lin_get_config_dir(self): + """ + nice config dir? (linux) + """ + self.assertEqual( + config.get_config_dir(), + os.path.expanduser('~/.config/leap')) + + @unittest.skipUnless(_system == "Darwin", "mac only") + def test_mac_get_config_dir(self): + """ + nice config dir? (mac) + """ + self._missing_test_for_plat(do_raise=True) + + @unittest.skipUnless(_system == "Windows", "win only") + def test_win_get_config_dir(self): + """ + nice config dir? (win) + """ + self._missing_test_for_plat(do_raise=True) + + # provider paths + + @unittest.skipUnless(_system == "Linux", "linux only") + def test_get_default_provider_path(self): + """ + is default provider path ok? + """ + self.assertEqual( + config.get_default_provider_path(), + os.path.expanduser( + '~/.config/leap/providers/%s/' % + constants.DEFAULT_PROVIDER) + ) + + # validate ip + + def test_validate_ip(self): + """ + check our ip validation + """ + config.validate_ip('3.3.3.3') + with self.assertRaises(socket.error): + config.validate_ip('255.255.255.256') + with self.assertRaises(socket.error): + config.validate_ip('foobar') + + @unittest.skip + def test_validate_domain(self): + """ + code to be written yet + """ + raise NotImplementedError + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py new file mode 100644 index 00000000..15c4ed58 --- /dev/null +++ b/src/leap/base/tests/test_providers.py @@ -0,0 +1,143 @@ +import copy +import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +import jsonschema + +from leap import __branding as BRANDING +from leap.testing.basetest import BaseLeapTest +from leap.base import providers + + +EXPECTED_DEFAULT_CONFIG = { + u"api_version": u"0.1.0", + u"description": {u'en': u"Test provider"}, + u"display_name": {u'en': u"Test Provider"}, + u"domain": u"testprovider.example.org", + u"enrollment_policy": u"open", + u"serial": 1, + u"services": [ + u"eip" + ], + u"languages": [u"en"], + u"version": u"0.1.0" +} + + +class TestLeapProviderDefinition(BaseLeapTest): + def setUp(self): + self.domain = "testprovider.example.org" + self.definition = providers.LeapProviderDefinition( + domain=self.domain) + self.definition.save() + self.definition.load() + self.config = self.definition.config + + def tearDown(self): + if hasattr(self, 'testfile') and os.path.isfile(self.testfile): + os.remove(self.testfile) + + # tests + + # XXX most of these tests can be made more abstract + # and moved to test_baseconfig *triangulate!* + + def test_provider_slug_property(self): + slug = self.definition.slug + self.assertEquals( + slug, + os.path.join( + self.home, + '.config', 'leap', 'providers', + '%s' % self.domain, + 'provider.json')) + with self.assertRaises(AttributeError): + self.definition.slug = 23 + + def test_provider_dump(self): + # check a good provider definition is dumped to disk + self.testfile = self.get_tempfile('test.json') + self.definition.save(to=self.testfile) + deserialized = json.load(open(self.testfile, 'rb')) + self.maxDiff = None + self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) + + def test_provider_dump_to_slug(self): + # same as above, but we test the ability to save to a + # file generated from the slug. + # XXX THIS TEST SHOULD MOVE TO test_baseconfig + self.definition.save() + filename = self.definition.filename + self.assertTrue(os.path.isfile(filename)) + deserialized = json.load(open(filename, 'rb')) + self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) + + def test_provider_load(self): + # check loading provider from disk file + self.testfile = self.get_tempfile('test_load.json') + with open(self.testfile, 'w') as wf: + wf.write(json.dumps(EXPECTED_DEFAULT_CONFIG)) + self.definition.load(fromfile=self.testfile) + self.assertDictEqual(self.config, + EXPECTED_DEFAULT_CONFIG) + + def test_provider_validation(self): + self.definition.validate(self.config) + _config = copy.deepcopy(self.config) + _config['serial'] = 'aaa' + with self.assertRaises(jsonschema.ValidationError): + self.definition.validate(_config) + + @unittest.skip + def test_load_malformed_json_definition(self): + raise NotImplementedError + + @unittest.skip + def test_type_validation(self): + # check various type validation + # type cast + raise NotImplementedError + + +class TestLeapProviderSet(BaseLeapTest): + + def setUp(self): + self.providers = providers.LeapProviderSet() + + def tearDown(self): + pass + ### + + def test_get_zero_count(self): + self.assertEqual(self.providers.count, 0) + + @unittest.skip + def test_count_defined_providers(self): + # check the method used for making + # the list of providers + raise NotImplementedError + + @unittest.skip + def test_get_default_provider(self): + raise NotImplementedError + + @unittest.skip + def test_should_be_at_least_one_provider_after_init(self): + # when we init an empty environment, + # there should be at least one provider, + # that will be a dump of the default provider definition + # somehow a high level test + raise NotImplementedError + + @unittest.skip + def test_get_eip_remote_from_default_provider(self): + # from: default provider + # expect: remote eip domain + raise NotImplementedError + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/base/tests/test_validation.py b/src/leap/base/tests/test_validation.py new file mode 100644 index 00000000..87e99648 --- /dev/null +++ b/src/leap/base/tests/test_validation.py @@ -0,0 +1,92 @@ +import copy +import datetime +#import json +try: + import unittest2 as unittest +except ImportError: + import unittest +import os + +import jsonschema + +from leap.base.config import JSONLeapConfig +from leap.base import pluggableconfig +from leap.testing.basetest import BaseLeapTest + +SAMPLE_CONFIG_DICT = { + 'prop_one': 1, + 'prop_uri': "http://example.org", + 'prop_date': '2012-12-12', +} + +EXPECTED_CONFIG = { + 'prop_one': 1, + 'prop_uri': "http://example.org", + 'prop_date': datetime.datetime(2012, 12, 12) +} + +sample_spec = { + 'description': 'sample schema definition', + 'type': 'object', + 'properties': { + 'prop_one': { + 'type': int, + 'default': 1, + 'required': True + }, + 'prop_uri': { + 'type': str, + 'default': 'http://example.org', + 'required': True, + 'format': 'uri' + }, + 'prop_date': { + 'type': str, + 'default': '2012-12-12', + 'format': 'date' + } + } +} + + +class SampleConfig(JSONLeapConfig): + spec = sample_spec + + @property + def slug(self): + return os.path.expanduser('~/sampleconfig.json') + + +class TestJSONLeapConfigValidation(BaseLeapTest): + def setUp(self): + self.sampleconfig = SampleConfig() + self.sampleconfig.save() + self.sampleconfig.load() + self.config = self.sampleconfig.config + + def tearDown(self): + if hasattr(self, 'testfile') and os.path.isfile(self.testfile): + os.remove(self.testfile) + + # tests + + def test_good_validation(self): + self.sampleconfig.validate(SAMPLE_CONFIG_DICT) + + def test_broken_int(self): + _config = copy.deepcopy(SAMPLE_CONFIG_DICT) + _config['prop_one'] = '1' + with self.assertRaises(jsonschema.ValidationError): + self.sampleconfig.validate(_config) + + def test_format_property(self): + # JsonSchema Validator does not check the format property. + # We should have to extend the Configuration class + blah = copy.deepcopy(SAMPLE_CONFIG_DICT) + blah['prop_uri'] = 'xxx' + with self.assertRaises(pluggableconfig.TypeCastException): + self.sampleconfig.validate(blah) + + +if __name__ == "__main__": + unittest.main() |