diff options
Diffstat (limited to 'src/leap/base')
-rw-r--r-- | src/leap/base/auth.py | 376 | ||||
-rw-r--r-- | src/leap/base/checks.py | 17 | ||||
-rw-r--r-- | src/leap/base/config.py | 10 | ||||
-rw-r--r-- | src/leap/base/connection.py | 10 | ||||
-rw-r--r-- | src/leap/base/exceptions.py | 5 | ||||
-rw-r--r-- | src/leap/base/network.py | 2 | ||||
-rw-r--r-- | src/leap/base/providers.py | 14 | ||||
-rw-r--r-- | src/leap/base/tests/__init__.py | 0 | ||||
-rw-r--r-- | src/leap/base/tests/test_checks.py | 7 | ||||
-rw-r--r-- | src/leap/base/tests/test_providers.py | 6 |
10 files changed, 423 insertions, 24 deletions
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py new file mode 100644 index 00000000..50533278 --- /dev/null +++ b/src/leap/base/auth.py @@ -0,0 +1,376 @@ +import binascii +import json +import logging +#import urlparse + +import requests +import srp + +from PyQt4 import QtCore + +from leap.base import constants as baseconstants +from leap.crypto import leapkeyring +from leap.util.web import get_https_domain_and_port + +logger = logging.getLogger(__name__) + +SIGNUP_TIMEOUT = getattr(baseconstants, 'SIGNUP_TIMEOUT', 5) + +""" +Registration and authentication classes for the +SRP auth mechanism used in the leap platform. + +We're using the srp library which uses a c-based implementation +of the protocol if the c extension is available, and a python-based +one if not. +""" + + +class ImproperlyConfigured(Exception): + """ + """ + + +class SRPAuthenticationError(Exception): + """ + exception raised + for authentication errors + """ + + +def null_check(value, value_name): + try: + assert value is not None + except AssertionError: + raise ImproperlyConfigured( + "%s parameter cannot be None" % value_name) + + +safe_unhexlify = lambda x: binascii.unhexlify(x) \ + if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) + + +class LeapSRPRegister(object): + + def __init__(self, + schema="https", + provider=None, + port=None, + verify=True, + register_path="1/users.json", + method="POST", + fetcher=requests, + srp=srp, + hashfun=srp.SHA256, + ng_constant=srp.NG_1024): + + null_check(provider, provider) + + self.schema = schema + + # XXX FIXME + self.provider = provider + self.port = port + # XXX splitting server,port + # deprecate port call. + domain, port = get_https_domain_and_port(provider) + self.provider = domain + self.port = port + + self.verify = verify + self.register_path = register_path + self.method = method + self.fetcher = fetcher + self.srp = srp + self.HASHFUN = hashfun + self.NG = ng_constant + + self.init_session() + + def init_session(self): + self.session = self.fetcher.session() + + def get_registration_uri(self): + # XXX assert is https! + # use urlparse + if self.port: + uri = "%s://%s:%s/%s" % ( + self.schema, + self.provider, + self.port, + self.register_path) + else: + uri = "%s://%s/%s" % ( + self.schema, + self.provider, + self.register_path) + + return uri + + def register_user(self, username, password, keep=False): + """ + @rtype: tuple + @rparam: (ok, request) + """ + salt, vkey = self.srp.create_salted_verification_key( + username, + password, + self.HASHFUN, + self.NG) + + user_data = { + 'user[login]': username, + 'user[password_verifier]': binascii.hexlify(vkey), + 'user[password_salt]': binascii.hexlify(salt)} + + uri = self.get_registration_uri() + logger.debug('post to uri: %s' % uri) + + # XXX get self.method + req = self.session.post( + uri, data=user_data, + timeout=SIGNUP_TIMEOUT, + verify=self.verify) + logger.debug(req) + logger.debug('user_data: %s', user_data) + #logger.debug('response: %s', req.text) + # we catch it in the form + #req.raise_for_status() + return (req.ok, req) + + +class SRPAuth(requests.auth.AuthBase): + + def __init__(self, username, password, server=None, verify=None): + # sanity check + null_check(server, 'server') + self.username = username + self.password = password + self.server = server + self.verify = verify + + self.init_data = None + self.session = requests.session() + + self.init_srp() + + def get_json_data(self, response): + return json.loads(response.content) + + def init_srp(self): + usr = srp.User( + self.username, + self.password, + srp.SHA256, + srp.NG_1024) + uname, A = usr.start_authentication() + + self.srp_usr = usr + self.A = A + + def get_auth_data(self): + return { + 'login': self.username, + 'A': binascii.hexlify(self.A) + } + + def get_init_data(self): + try: + init_session = self.session.post( + self.server + '/1/sessions.json/', + data=self.get_auth_data(), + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (salt).") + if init_session.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (salt).") + + # XXX should get auth_result.json instead + self.init_data = self.get_json_data(init_session) + return self.init_data + + def get_server_proof_data(self): + try: + auth_result = self.session.put( + #self.server + '/1/sessions.json/' + self.username, + self.server + '/1/sessions/' + self.username, + data={'client_auth': binascii.hexlify(self.M)}, + verify=self.verify) + except requests.exceptions.ConnectionError: + raise SRPAuthenticationError( + "No connection made (HAMK).") + + if auth_result.status_code not in (200, ): + raise SRPAuthenticationError( + "No valid response (HAMK).") + + # XXX should get auth_result.json instead + try: + self.auth_data = self.get_json_data(auth_result) + except ValueError: + raise SRPAuthenticationError( + "No valid data sent (HAMK)") + + return self.auth_data + + def authenticate(self): + logger.debug('start authentication...') + + init_data = self.get_init_data() + salt = init_data.get('salt', None) + B = init_data.get('B', None) + + # XXX refactor this function + # move checks and un-hex + # to routines + + if not salt or not B: + raise SRPAuthenticationError( + "Server did not send initial data.") + + try: + unhex_salt = safe_unhexlify(salt) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (salt)") + try: + unhex_B = safe_unhexlify(B) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (B)") + + self.M = self.srp_usr.process_challenge( + unhex_salt, + unhex_B + ) + + proof_data = self.get_server_proof_data() + + HAMK = proof_data.get("M2", None) + if not HAMK: + errors = proof_data.get('errors', None) + if errors: + logger.error(errors) + raise SRPAuthenticationError("Server did not send HAMK.") + + try: + unhex_HAMK = safe_unhexlify(HAMK) + except TypeError: + raise SRPAuthenticationError( + "Bad data from server (HAMK)") + + self.srp_usr.verify_session( + unhex_HAMK) + + try: + assert self.srp_usr.authenticated() + logger.debug('user is authenticated!') + except (AssertionError): + raise SRPAuthenticationError( + "Auth verification failed.") + + def __call__(self, req): + self.authenticate() + req.session = self.session + return req + + +def srpauth_protected(user=None, passwd=None, server=None, verify=True): + """ + decorator factory that accepts + user and password keyword arguments + and add those to the decorated request + """ + def srpauth(fn): + def wrapper(*args, **kwargs): + if user and passwd: + auth = SRPAuth(user, passwd, server, verify) + kwargs['auth'] = auth + kwargs['verify'] = verify + return fn(*args, **kwargs) + return wrapper + return srpauth + + +def get_leap_credentials(): + settings = QtCore.QSettings() + full_username = settings.value('eip_username') + username, domain = full_username.split('@') + seed = settings.value('%s_seed' % domain, None) + password = leapkeyring.leap_get_password(full_username, seed=seed) + return (username, password) + + +# XXX TODO +# Pass verify as single argument, +# in srpauth_protected style + +def magick_srpauth(fn): + """ + decorator that gets user and password + from the config file and adds those to + the decorated request + """ + logger.debug('magick srp auth decorator called') + + def wrapper(*args, **kwargs): + #uri = args[0] + # XXX Ugh! + # Problem with this approach. + # This won't work when we're using + # api.foo.bar + # Unless we keep a table with the + # equivalencies... + user, passwd = get_leap_credentials() + + # XXX pass verify and server too + # (pop) + auth = SRPAuth(user, passwd) + kwargs['auth'] = auth + return fn(*args, **kwargs) + return wrapper + + +if __name__ == "__main__": + """ + To test against test_provider (twisted version) + Register an user: (will be valid during the session) + >>> python auth.py add test password + + Test login with that user: + >>> python auth.py login test password + """ + + import sys + + if len(sys.argv) not in (4, 5): + print 'Usage: auth <add|login> <user> <pass> [server]' + sys.exit(0) + + action = sys.argv[1] + user = sys.argv[2] + passwd = sys.argv[3] + + if len(sys.argv) == 5: + SERVER = sys.argv[4] + else: + SERVER = "https://localhost:8443" + + if action == "login": + + @srpauth_protected( + user=user, passwd=passwd, server=SERVER, verify=False) + def test_srp_protected_get(*args, **kwargs): + req = requests.get(*args, **kwargs) + req.raise_for_status + return req + + req = test_srp_protected_get('https://localhost:8443/1/cert') + print 'cert :', req.content[:200] + "..." + sys.exit(0) + + if action == "add": + auth = LeapSRPRegister(provider=SERVER, verify=False) + auth.register_user(user, passwd) diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py index 7285e74f..23446f4a 100644 --- a/src/leap/base/checks.py +++ b/src/leap/base/checks.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import logging import platform +import socket import netifaces import ping @@ -23,7 +24,7 @@ class LeapNetworkChecker(object): def run_all(self, checker=None): if not checker: checker = self - self.error = None # ? + #self.error = None # ? # for MVS checker.check_tunnel_default_interface() @@ -118,11 +119,9 @@ class LeapNetworkChecker(object): if packet_loss > constants.MAX_ICMP_PACKET_LOSS: raise exceptions.NoConnectionToGateway - # XXX check for name resolution servers - # dunno what's the best way to do this... - # check for etc/resolv entries or similar? - # just try to resolve? - # is there something in psutil? - - # def check_name_resolution(self): - # pass + def check_name_resolution(self, domain_name): + try: + socket.gethostbyname(domain_name) + return True + except socket.gaierror: + raise exceptions.CannotResolveDomainError diff --git a/src/leap/base/config.py b/src/leap/base/config.py index cf01d1aa..0255fbab 100644 --- a/src/leap/base/config.py +++ b/src/leap/base/config.py @@ -118,6 +118,7 @@ class JSONLeapConfig(BaseLeapConfig): " derived class") assert issubclass(self.spec, PluggableConfig) + self.domain = kwargs.pop('domain', None) self._config = self.spec(format="json") self._config.load() self.fetcher = kwargs.pop('fetcher', requests) @@ -252,6 +253,15 @@ def get_default_provider_path(): return default_provider_path +def get_provider_path(domain): + # XXX if not domain, return get_default_provider_path + default_subpath = os.path.join("providers", domain) + provider_path = get_config_file( + '', + folder=default_subpath) + return provider_path + + def validate_ip(ip_str): """ raises exception if the ip_str is diff --git a/src/leap/base/connection.py b/src/leap/base/connection.py index e478538d..41d13935 100644 --- a/src/leap/base/connection.py +++ b/src/leap/base/connection.py @@ -37,11 +37,11 @@ class Connection(Authentication): """ pass - def shutdown(self): - """ - shutdown and quit - """ - self.desired_con_state = self.status.DISCONNECTED + #def shutdown(self): + #""" + #shutdown and quit + #""" + #self.desired_con_state = self.status.DISCONNECTED def connection_state(self): """ diff --git a/src/leap/base/exceptions.py b/src/leap/base/exceptions.py index f12a49d5..227da953 100644 --- a/src/leap/base/exceptions.py +++ b/src/leap/base/exceptions.py @@ -67,6 +67,11 @@ class NoInternetConnection(CriticalError): # and now we try to connect to our web to troubleshoot LOL :P +class CannotResolveDomainError(LeapException): + message = "Cannot resolve domain" + usermessage = "Domain cannot be found" + + class TunnelNotDefaultRouteError(CriticalError): message = "Tunnel connection dissapeared. VPN down?" usermessage = "The Encrypted Connection was lost. Shutting down..." diff --git a/src/leap/base/network.py b/src/leap/base/network.py index 3891b00a..3aba3f61 100644 --- a/src/leap/base/network.py +++ b/src/leap/base/network.py @@ -31,7 +31,7 @@ class NetworkCheckerThread(object): # see in eip.config for function # #718 self.checker = LeapNetworkChecker( - provider_gw = get_eip_gateway()) + provider_gw=get_eip_gateway()) def start(self): self.process_handle = self._launch_recurrent_network_checks( diff --git a/src/leap/base/providers.py b/src/leap/base/providers.py index 7b219cc7..d41f3695 100644 --- a/src/leap/base/providers.py +++ b/src/leap/base/providers.py @@ -7,20 +7,20 @@ class LeapProviderDefinition(baseconfig.JSONLeapConfig): spec = specs.leap_provider_spec def _get_slug(self): - provider_path = baseconfig.get_default_provider_path() + domain = getattr(self, 'domain', None) + if domain: + path = baseconfig.get_provider_path(domain) + else: + path = baseconfig.get_default_provider_path() + return baseconfig.get_config_file( - 'provider.json', - folder=provider_path) + 'provider.json', folder=path) def _set_slug(self, *args, **kwargs): raise AttributeError("you cannot set slug") slug = property(_get_slug, _set_slug) - # TODO (MVS+) - # we will construct slug from providers/%s/definition.json - # where %s is domain name. we can get that on __init__ - class LeapProviderSet(object): # we gather them from the filesystem diff --git a/src/leap/base/tests/__init__.py b/src/leap/base/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/base/tests/__init__.py diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py index bec09ce6..8d573b1e 100644 --- a/src/leap/base/tests/test_checks.py +++ b/src/leap/base/tests/test_checks.py @@ -40,7 +40,14 @@ class LeapNetworkCheckTest(BaseLeapTest): def test_checker_should_actually_call_all_tests(self): checker = checks.LeapNetworkChecker() + mc = Mock() + checker.run_all(checker=mc) + self.assertTrue(mc.check_internet_connection.called, "not called") + self.assertTrue(mc.check_tunnel_default_interface.called, "not called") + self.assertTrue(mc.is_internet_up.called, "not called") + # ping gateway only called if we pass provider_gw + checker = checks.LeapNetworkChecker(provider_gw="0.0.0.0") mc = Mock() checker.run_all(checker=mc) self.assertTrue(mc.check_internet_connection.called, "not called") diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py index 8d3b8847..15c4ed58 100644 --- a/src/leap/base/tests/test_providers.py +++ b/src/leap/base/tests/test_providers.py @@ -30,7 +30,9 @@ EXPECTED_DEFAULT_CONFIG = { class TestLeapProviderDefinition(BaseLeapTest): def setUp(self): - self.definition = providers.LeapProviderDefinition() + self.domain = "testprovider.example.org" + self.definition = providers.LeapProviderDefinition( + domain=self.domain) self.definition.save() self.definition.load() self.config = self.definition.config @@ -51,7 +53,7 @@ class TestLeapProviderDefinition(BaseLeapTest): os.path.join( self.home, '.config', 'leap', 'providers', - '%s' % BRANDING.get('provider_domain'), + '%s' % self.domain, 'provider.json')) with self.assertRaises(AttributeError): self.definition.slug = 23 |