diff options
author | drebs <drebs@leap.se> | 2012-12-24 10:14:58 -0200 |
---|---|---|
committer | drebs <drebs@leap.se> | 2012-12-24 10:14:58 -0200 |
commit | 319e279b59ac080779d0a3375ae4d6582f5ee6a3 (patch) | |
tree | 118dd0f495c0d54f2b2c66ea235e4e4e6b8cefd5 /src/leap/base | |
parent | ca5fb41a55e1292005ed186baf3710831d9ad678 (diff) | |
parent | a7b091a0553e6120f3e0eb6d4e73a89732c589b2 (diff) |
Merge branch 'develop' of ssh://code.leap.se/leap_client into develop
Diffstat (limited to 'src/leap/base')
-rw-r--r-- | src/leap/base/auth.py | 45 | ||||
-rw-r--r-- | src/leap/base/checks.py | 11 | ||||
-rw-r--r-- | src/leap/base/config.py | 99 | ||||
-rw-r--r-- | src/leap/base/constants.py | 33 | ||||
-rw-r--r-- | src/leap/base/network.py | 20 | ||||
-rw-r--r-- | src/leap/base/pluggableconfig.py | 20 | ||||
-rw-r--r-- | src/leap/base/specs.py | 16 | ||||
-rw-r--r-- | src/leap/base/tests/test_auth.py | 58 | ||||
-rw-r--r-- | src/leap/base/tests/test_checks.py | 16 | ||||
-rw-r--r-- | src/leap/base/tests/test_providers.py | 17 |
10 files changed, 250 insertions, 85 deletions
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py index 50533278..ecc24179 100644 --- a/src/leap/base/auth.py +++ b/src/leap/base/auth.py @@ -10,6 +10,7 @@ from PyQt4 import QtCore from leap.base import constants as baseconstants from leap.crypto import leapkeyring +from leap.util.misc import null_check from leap.util.web import get_https_domain_and_port logger = logging.getLogger(__name__) @@ -26,11 +27,6 @@ one if not. """ -class ImproperlyConfigured(Exception): - """ - """ - - class SRPAuthenticationError(Exception): """ exception raised @@ -38,14 +34,6 @@ class SRPAuthenticationError(Exception): """ -def null_check(value, value_name): - try: - assert value is not None - except AssertionError: - raise ImproperlyConfigured( - "%s parameter cannot be None" % value_name) - - safe_unhexlify = lambda x: binascii.unhexlify(x) \ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) @@ -55,7 +43,7 @@ class LeapSRPRegister(object): def __init__(self, schema="https", provider=None, - port=None, + #port=None, verify=True, register_path="1/users.json", method="POST", @@ -64,13 +52,13 @@ class LeapSRPRegister(object): hashfun=srp.SHA256, ng_constant=srp.NG_1024): - null_check(provider, provider) + null_check(provider, "provider") self.schema = schema # XXX FIXME - self.provider = provider - self.port = port + #self.provider = provider + #self.port = port # XXX splitting server,port # deprecate port call. domain, port = get_https_domain_and_port(provider) @@ -154,9 +142,6 @@ class SRPAuth(requests.auth.AuthBase): self.init_srp() - def get_json_data(self, response): - return json.loads(response.content) - def init_srp(self): usr = srp.User( self.username, @@ -187,8 +172,7 @@ class SRPAuth(requests.auth.AuthBase): raise SRPAuthenticationError( "No valid response (salt).") - # XXX should get auth_result.json instead - self.init_data = self.get_json_data(init_session) + self.init_data = init_session.json return self.init_data def get_server_proof_data(self): @@ -206,13 +190,7 @@ class SRPAuth(requests.auth.AuthBase): raise SRPAuthenticationError( "No valid response (HAMK).") - # XXX should get auth_result.json instead - try: - self.auth_data = self.get_json_data(auth_result) - except ValueError: - raise SRPAuthenticationError( - "No valid data sent (HAMK)") - + self.auth_data = auth_result.json return self.auth_data def authenticate(self): @@ -267,13 +245,14 @@ class SRPAuth(requests.auth.AuthBase): try: assert self.srp_usr.authenticated() logger.debug('user is authenticated!') + print 'user is authenticated!' except (AssertionError): raise SRPAuthenticationError( "Auth verification failed.") def __call__(self, req): self.authenticate() - req.session = self.session + req.cookies = self.session.cookies return req @@ -367,8 +346,10 @@ if __name__ == "__main__": req.raise_for_status return req - req = test_srp_protected_get('https://localhost:8443/1/cert') - print 'cert :', req.content[:200] + "..." + #req = test_srp_protected_get('https://localhost:8443/1/cert') + req = test_srp_protected_get('%s/1/cert' % SERVER) + #print 'cert :', req.content[:200] + "..." + print req.content sys.exit(0) if action == "add": diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py index 23446f4a..dc2602c2 100644 --- a/src/leap/base/checks.py +++ b/src/leap/base/checks.py @@ -39,9 +39,6 @@ class LeapNetworkChecker(object): # XXX remove this hardcoded random ip # ping leap.se or eip provider instead...? requests.get('http://216.172.161.165') - - except (requests.HTTPError, requests.RequestException) as e: - raise exceptions.NoInternetConnection(e.message) except requests.ConnectionError as e: error = "Unidentified Connection Error" if e.message == "[Errno 113] No route to host": @@ -51,11 +48,17 @@ class LeapNetworkChecker(object): error = "Provider server appears to be down." logger.error(error) raise exceptions.NoInternetConnection(error) + except (requests.HTTPError, requests.RequestException) as e: + raise exceptions.NoInternetConnection(e.message) logger.debug('Network appears to be up.') def is_internet_up(self): iface, gateway = self.get_default_interface_gateway() - self.ping_gateway(self.provider_gateway) + try: + self.ping_gateway(self.provider_gateway) + except exceptions.NoConnectionToGateway: + return False + return True def check_tunnel_default_interface(self): """ diff --git a/src/leap/base/config.py b/src/leap/base/config.py index 0255fbab..438d1993 100644 --- a/src/leap/base/config.py +++ b/src/leap/base/config.py @@ -5,11 +5,12 @@ import grp import json import logging import socket -import tempfile +import time import os logger = logging.getLogger(name=__name__) +from dateutil import parser as dateparser import requests from leap.base import exceptions @@ -125,17 +126,43 @@ class JSONLeapConfig(BaseLeapConfig): # mandatory baseconfig interface - def save(self, to=None): - if to is None: - to = self.filename - folder, filename = os.path.split(to) - if folder and not os.path.isdir(folder): - mkdir_p(folder) - self._config.serialize(to) + def save(self, to=None, force=False): + """ + force param will skip the dirty check. + :type force: bool + """ + # XXX this force=True does not feel to right + # but still have to look for a better way + # of dealing with dirtiness and the + # trick of loading remote config only + # when newer. + + if force: + do_save = True + else: + do_save = self._config.is_dirty() + + if do_save: + if to is None: + to = self.filename + folder, filename = os.path.split(to) + if folder and not os.path.isdir(folder): + mkdir_p(folder) + self._config.serialize(to) + return True + + else: + return False + + def load(self, fromfile=None, from_uri=None, fetcher=None, + force_download=False, verify=False): - def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False): if from_uri is not None: - fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify) + fetched = self.fetch( + from_uri, + fetcher=fetcher, + verify=verify, + force_dl=force_download) if fetched: return if fromfile is None: @@ -146,33 +173,69 @@ class JSONLeapConfig(BaseLeapConfig): logger.error('tried to load config from non-existent path') logger.error('Not Found: %s', fromfile) - def fetch(self, uri, fetcher=None, verify=True): + def fetch(self, uri, fetcher=None, verify=True, force_dl=False): if not fetcher: fetcher = self.fetcher + logger.debug('verify: %s', verify) logger.debug('uri: %s', uri) - request = fetcher.get(uri, verify=verify) - # XXX should send a if-modified-since header - # XXX get 404, ... - # and raise a UnableToFetch... + rargs = (uri, ) + rkwargs = {'verify': verify} + headers = {} + + curmtime = self.get_mtime() if not force_dl else None + if curmtime: + logger.debug('requesting with if-modified-since %s' % curmtime) + headers['if-modified-since'] = curmtime + rkwargs['headers'] = headers + + #request = fetcher.get(uri, verify=verify) + request = fetcher.get(*rargs, **rkwargs) request.raise_for_status() - fd, fname = tempfile.mkstemp(suffix=".json") - if request.json: - self._config.load(json.dumps(request.json)) + if request.status_code == 304: + logger.debug('...304 Not Changed') + # On this point, we have to assume that + # we HAD the filename. If that filename is corruct, + # we should enforce a force_download in the load + # method above. + self._config.load(fromfile=self.filename) + return True + if request.json: + mtime = None + last_modified = request.headers.get('last-modified', None) + if last_modified: + _mtime = dateparser.parse(last_modified) + mtime = int(_mtime.strftime("%s")) + if callable(request.json): + _json = request.json() + else: + # back-compat + _json = request.json + self._config.load(json.dumps(_json), mtime=mtime) + self._config.set_dirty() else: # not request.json # might be server did not announce content properly, # let's try deserializing all the same. try: self._config.load(request.content) + self._config.set_dirty() except ValueError: raise eipexceptions.LeapBadConfigFetchedError return True + def get_mtime(self): + try: + _mtime = os.stat(self.filename)[8] + mtime = time.strftime("%c GMT", time.gmtime(_mtime)) + return mtime + except OSError: + return None + def get_config(self): return self._config.config diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py index f7be8d98..b38723be 100644 --- a/src/leap/base/constants.py +++ b/src/leap/base/constants.py @@ -14,18 +14,27 @@ DEFAULT_PROVIDER = __branding.get( DEFINITION_EXPECTED_PATH = "provider.json" DEFAULT_PROVIDER_DEFINITION = { - u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER, - u'api_version': u'0.1.0', - u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e', - u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER, - u'description': {u'en': u'This is a test provider'}, - u'display_name': {u'en': u'Test Provider'}, - u'domain': u'%s' % DEFAULT_PROVIDER, - u'enrollment_policy': u'open', - u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21', - u'serial': 1, - u'services': [u'eip'], - u'version': u'0.1.0'} + u"api_uri": "https://api.%s/" % DEFAULT_PROVIDER, + u"api_version": u"1", + u"ca_cert_fingerprint": "SHA256: fff", + u"ca_cert_uri": u"https://%s/ca.crt" % DEFAULT_PROVIDER, + u"default_language": u"en", + u"description": { + u"en": u"A demonstration service provider using the LEAP platform" + }, + u"domain": "%s" % DEFAULT_PROVIDER, + u"enrollment_policy": u"open", + u"languages": [ + u"en" + ], + u"name": { + u"en": u"Test Provider" + }, + u"services": [ + "openvpn" + ] +} + MAX_ICMP_PACKET_LOSS = 10 diff --git a/src/leap/base/network.py b/src/leap/base/network.py index 3aba3f61..765d8ea0 100644 --- a/src/leap/base/network.py +++ b/src/leap/base/network.py @@ -3,10 +3,11 @@ from __future__ import (print_function) import logging import threading -from leap.eip.config import get_eip_gateway +from leap.eip import config as eipconfig from leap.base.checks import LeapNetworkChecker from leap.base.constants import ROUTE_CHECK_INTERVAL from leap.base.exceptions import TunnelNotDefaultRouteError +from leap.util.misc import null_check from leap.util.coroutines import (launch_thread, process_events) from time import sleep @@ -27,11 +28,20 @@ class NetworkCheckerThread(object): lambda exc: logger.error("%s", exc.message)) self.shutdown = threading.Event() - # XXX get provider_gateway and pass it to checker - # see in eip.config for function - # #718 + # XXX get provider passed here + provider = kwargs.pop('provider', None) + null_check(provider, 'provider') + + eipconf = eipconfig.EIPConfig(domain=provider) + eipconf.load() + eipserviceconf = eipconfig.EIPServiceConfig(domain=provider) + eipserviceconf.load() + + gw = eipconfig.get_eip_gateway( + eipconfig=eipconf, + eipserviceconfig=eipserviceconf) self.checker = LeapNetworkChecker( - provider_gw=get_eip_gateway()) + provider_gw=gw) def start(self): self.process_handle = self._launch_recurrent_network_checks( diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py index b8615ad8..0ca985ea 100644 --- a/src/leap/base/pluggableconfig.py +++ b/src/leap/base/pluggableconfig.py @@ -180,6 +180,8 @@ class PluggableConfig(object): self.adaptors = adaptors self.types = types self._format = format + self.mtime = None + self.dirty = False @property def option_dict(self): @@ -319,6 +321,13 @@ class PluggableConfig(object): serializable = self.prep_value(config) adaptor.write(serializable, filename) + if self.mtime: + self.touch_mtime(filename) + + def touch_mtime(self, filename): + mtime = self.mtime + os.utime(filename, (mtime, mtime)) + def deserialize(self, string=None, fromfile=None, format=None): """ load configuration from a file or string @@ -364,6 +373,12 @@ class PluggableConfig(object): content = _try_deserialize() return content + def set_dirty(self): + self.dirty = True + + def is_dirty(self): + return self.dirty + def load(self, *args, **kwargs): """ load from string or file @@ -373,6 +388,8 @@ class PluggableConfig(object): """ string = args[0] if args else None fromfile = kwargs.get("fromfile", None) + mtime = kwargs.pop("mtime", None) + self.mtime = mtime content = None # start with defaults, so we can @@ -402,7 +419,8 @@ class PluggableConfig(object): return True -def testmain(): +def testmain(): # pragma: no cover + from tests import test_validation as t import pprint diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py index b4bb8dcf..962aa07d 100644 --- a/src/leap/base/specs.py +++ b/src/leap/base/specs.py @@ -2,22 +2,26 @@ leap_provider_spec = { 'description': 'provider definition', 'type': 'object', 'properties': { - 'serial': { - 'type': int, - 'default': 1, - 'required': True, - }, + #'serial': { + #'type': int, + #'default': 1, + #'required': True, + #}, 'version': { 'type': unicode, 'default': '0.1.0' #'required': True }, + "default_language": { + 'type': unicode, + 'default': 'en' + }, 'domain': { 'type': unicode, # XXX define uri type 'default': 'testprovider.example.org' #'required': True, }, - 'display_name': { + 'name': { 'type': dict, # XXX multilingual object? 'default': {u'en': u'Test Provider'} #'required': True diff --git a/src/leap/base/tests/test_auth.py b/src/leap/base/tests/test_auth.py new file mode 100644 index 00000000..17b84b52 --- /dev/null +++ b/src/leap/base/tests/test_auth.py @@ -0,0 +1,58 @@ +from BaseHTTPServer import BaseHTTPRequestHandler +import urlparse +try: + import unittest2 as unittest +except ImportError: + import unittest + +import requests +#from mock import Mock + +from leap.base import auth +#from leap.base import exceptions +from leap.eip.tests.test_checks import NoLogRequestHandler +from leap.testing.basetest import BaseLeapTest +from leap.testing.https_server import BaseHTTPSServerTestCase + + +class LeapSRPRegisterTests(BaseHTTPSServerTestCase, BaseLeapTest): + __name__ = "leap_srp_register_test" + provider = "testprovider.example.org" + + class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler): + responses = { + '/': ['OK', '']} + + def do_GET(self): + path = urlparse.urlparse(self.path) + message = '\n'.join(self.responses.get( + path.path, None)) + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_srp_auth_should_implement_check_methods(self): + SERVER = "https://localhost:8443" + srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False) + + self.assertTrue(hasattr(srp_auth, "init_session"), + "missing meth") + self.assertTrue(hasattr(srp_auth, "get_registration_uri"), + "missing meth") + self.assertTrue(hasattr(srp_auth, "register_user"), + "missing meth") + + def test_srp_auth_basic_functionality(self): + SERVER = "https://localhost:8443" + srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False) + + self.assertIsInstance(srp_auth.session, requests.sessions.Session) + self.assertEqual( + srp_auth.get_registration_uri(), + "https://localhost:8443/1/users.json") diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py index 8d573b1e..7a694f89 100644 --- a/src/leap/base/tests/test_checks.py +++ b/src/leap/base/tests/test_checks.py @@ -118,6 +118,22 @@ class LeapNetworkCheckTest(BaseLeapTest): with self.assertRaises(exceptions.NoInternetConnection): checker.check_internet_connection() + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.ConnectionError( + "[Errno 113] No route to host") + with self.assertRaises(exceptions.NoInternetConnection): + with patch.object(checker, "ping_gateway") as mock_ping: + mock_ping.return_value = True + checker.check_internet_connection() + + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.ConnectionError( + "[Errno 113] No route to host") + with self.assertRaises(exceptions.NoInternetConnection): + with patch.object(checker, "ping_gateway") as mock_ping: + mock_ping.side_effect = exceptions.NoConnectionToGateway + checker.check_internet_connection() + @unittest.skipUnless(_uid == 0, "root only") def test_ping_gateway(self): checker = checks.LeapNetworkChecker() diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py index 15c4ed58..9c11f270 100644 --- a/src/leap/base/tests/test_providers.py +++ b/src/leap/base/tests/test_providers.py @@ -8,7 +8,7 @@ import os import jsonschema -from leap import __branding as BRANDING +#from leap import __branding as BRANDING from leap.testing.basetest import BaseLeapTest from leap.base import providers @@ -16,10 +16,12 @@ from leap.base import providers EXPECTED_DEFAULT_CONFIG = { u"api_version": u"0.1.0", u"description": {u'en': u"Test provider"}, - u"display_name": {u'en': u"Test Provider"}, + u"default_language": u"en", + #u"display_name": {u'en': u"Test Provider"}, u"domain": u"testprovider.example.org", + u'name': {u'en': u'Test Provider'}, u"enrollment_policy": u"open", - u"serial": 1, + #u"serial": 1, u"services": [ u"eip" ], @@ -33,8 +35,8 @@ class TestLeapProviderDefinition(BaseLeapTest): self.domain = "testprovider.example.org" self.definition = providers.LeapProviderDefinition( domain=self.domain) - self.definition.save() - self.definition.load() + self.definition.save(force=True) + self.definition.load() # why have to load after save?? self.config = self.definition.config def tearDown(self): @@ -61,7 +63,7 @@ class TestLeapProviderDefinition(BaseLeapTest): def test_provider_dump(self): # check a good provider definition is dumped to disk self.testfile = self.get_tempfile('test.json') - self.definition.save(to=self.testfile) + self.definition.save(to=self.testfile, force=True) deserialized = json.load(open(self.testfile, 'rb')) self.maxDiff = None self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) @@ -88,7 +90,8 @@ class TestLeapProviderDefinition(BaseLeapTest): def test_provider_validation(self): self.definition.validate(self.config) _config = copy.deepcopy(self.config) - _config['serial'] = 'aaa' + # bad type, raise validation error + _config['domain'] = 111 with self.assertRaises(jsonschema.ValidationError): self.definition.validate(_config) |