summaryrefslogtreecommitdiff
path: root/src/leap/base
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/base')
-rw-r--r--src/leap/base/auth.py45
-rw-r--r--src/leap/base/checks.py11
-rw-r--r--src/leap/base/config.py99
-rw-r--r--src/leap/base/constants.py33
-rw-r--r--src/leap/base/network.py20
-rw-r--r--src/leap/base/pluggableconfig.py20
-rw-r--r--src/leap/base/specs.py16
-rw-r--r--src/leap/base/tests/test_auth.py58
-rw-r--r--src/leap/base/tests/test_checks.py16
-rw-r--r--src/leap/base/tests/test_providers.py17
10 files changed, 250 insertions, 85 deletions
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py
index 50533278..ecc24179 100644
--- a/src/leap/base/auth.py
+++ b/src/leap/base/auth.py
@@ -10,6 +10,7 @@ from PyQt4 import QtCore
from leap.base import constants as baseconstants
from leap.crypto import leapkeyring
+from leap.util.misc import null_check
from leap.util.web import get_https_domain_and_port
logger = logging.getLogger(__name__)
@@ -26,11 +27,6 @@ one if not.
"""
-class ImproperlyConfigured(Exception):
- """
- """
-
-
class SRPAuthenticationError(Exception):
"""
exception raised
@@ -38,14 +34,6 @@ class SRPAuthenticationError(Exception):
"""
-def null_check(value, value_name):
- try:
- assert value is not None
- except AssertionError:
- raise ImproperlyConfigured(
- "%s parameter cannot be None" % value_name)
-
-
safe_unhexlify = lambda x: binascii.unhexlify(x) \
if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
@@ -55,7 +43,7 @@ class LeapSRPRegister(object):
def __init__(self,
schema="https",
provider=None,
- port=None,
+ #port=None,
verify=True,
register_path="1/users.json",
method="POST",
@@ -64,13 +52,13 @@ class LeapSRPRegister(object):
hashfun=srp.SHA256,
ng_constant=srp.NG_1024):
- null_check(provider, provider)
+ null_check(provider, "provider")
self.schema = schema
# XXX FIXME
- self.provider = provider
- self.port = port
+ #self.provider = provider
+ #self.port = port
# XXX splitting server,port
# deprecate port call.
domain, port = get_https_domain_and_port(provider)
@@ -154,9 +142,6 @@ class SRPAuth(requests.auth.AuthBase):
self.init_srp()
- def get_json_data(self, response):
- return json.loads(response.content)
-
def init_srp(self):
usr = srp.User(
self.username,
@@ -187,8 +172,7 @@ class SRPAuth(requests.auth.AuthBase):
raise SRPAuthenticationError(
"No valid response (salt).")
- # XXX should get auth_result.json instead
- self.init_data = self.get_json_data(init_session)
+ self.init_data = init_session.json
return self.init_data
def get_server_proof_data(self):
@@ -206,13 +190,7 @@ class SRPAuth(requests.auth.AuthBase):
raise SRPAuthenticationError(
"No valid response (HAMK).")
- # XXX should get auth_result.json instead
- try:
- self.auth_data = self.get_json_data(auth_result)
- except ValueError:
- raise SRPAuthenticationError(
- "No valid data sent (HAMK)")
-
+ self.auth_data = auth_result.json
return self.auth_data
def authenticate(self):
@@ -267,13 +245,14 @@ class SRPAuth(requests.auth.AuthBase):
try:
assert self.srp_usr.authenticated()
logger.debug('user is authenticated!')
+ print 'user is authenticated!'
except (AssertionError):
raise SRPAuthenticationError(
"Auth verification failed.")
def __call__(self, req):
self.authenticate()
- req.session = self.session
+ req.cookies = self.session.cookies
return req
@@ -367,8 +346,10 @@ if __name__ == "__main__":
req.raise_for_status
return req
- req = test_srp_protected_get('https://localhost:8443/1/cert')
- print 'cert :', req.content[:200] + "..."
+ #req = test_srp_protected_get('https://localhost:8443/1/cert')
+ req = test_srp_protected_get('%s/1/cert' % SERVER)
+ #print 'cert :', req.content[:200] + "..."
+ print req.content
sys.exit(0)
if action == "add":
diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py
index 23446f4a..dc2602c2 100644
--- a/src/leap/base/checks.py
+++ b/src/leap/base/checks.py
@@ -39,9 +39,6 @@ class LeapNetworkChecker(object):
# XXX remove this hardcoded random ip
# ping leap.se or eip provider instead...?
requests.get('http://216.172.161.165')
-
- except (requests.HTTPError, requests.RequestException) as e:
- raise exceptions.NoInternetConnection(e.message)
except requests.ConnectionError as e:
error = "Unidentified Connection Error"
if e.message == "[Errno 113] No route to host":
@@ -51,11 +48,17 @@ class LeapNetworkChecker(object):
error = "Provider server appears to be down."
logger.error(error)
raise exceptions.NoInternetConnection(error)
+ except (requests.HTTPError, requests.RequestException) as e:
+ raise exceptions.NoInternetConnection(e.message)
logger.debug('Network appears to be up.')
def is_internet_up(self):
iface, gateway = self.get_default_interface_gateway()
- self.ping_gateway(self.provider_gateway)
+ try:
+ self.ping_gateway(self.provider_gateway)
+ except exceptions.NoConnectionToGateway:
+ return False
+ return True
def check_tunnel_default_interface(self):
"""
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
index 0255fbab..438d1993 100644
--- a/src/leap/base/config.py
+++ b/src/leap/base/config.py
@@ -5,11 +5,12 @@ import grp
import json
import logging
import socket
-import tempfile
+import time
import os
logger = logging.getLogger(name=__name__)
+from dateutil import parser as dateparser
import requests
from leap.base import exceptions
@@ -125,17 +126,43 @@ class JSONLeapConfig(BaseLeapConfig):
# mandatory baseconfig interface
- def save(self, to=None):
- if to is None:
- to = self.filename
- folder, filename = os.path.split(to)
- if folder and not os.path.isdir(folder):
- mkdir_p(folder)
- self._config.serialize(to)
+ def save(self, to=None, force=False):
+ """
+ force param will skip the dirty check.
+ :type force: bool
+ """
+ # XXX this force=True does not feel to right
+ # but still have to look for a better way
+ # of dealing with dirtiness and the
+ # trick of loading remote config only
+ # when newer.
+
+ if force:
+ do_save = True
+ else:
+ do_save = self._config.is_dirty()
+
+ if do_save:
+ if to is None:
+ to = self.filename
+ folder, filename = os.path.split(to)
+ if folder and not os.path.isdir(folder):
+ mkdir_p(folder)
+ self._config.serialize(to)
+ return True
+
+ else:
+ return False
+
+ def load(self, fromfile=None, from_uri=None, fetcher=None,
+ force_download=False, verify=False):
- def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False):
if from_uri is not None:
- fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify)
+ fetched = self.fetch(
+ from_uri,
+ fetcher=fetcher,
+ verify=verify,
+ force_dl=force_download)
if fetched:
return
if fromfile is None:
@@ -146,33 +173,69 @@ class JSONLeapConfig(BaseLeapConfig):
logger.error('tried to load config from non-existent path')
logger.error('Not Found: %s', fromfile)
- def fetch(self, uri, fetcher=None, verify=True):
+ def fetch(self, uri, fetcher=None, verify=True, force_dl=False):
if not fetcher:
fetcher = self.fetcher
+
logger.debug('verify: %s', verify)
logger.debug('uri: %s', uri)
- request = fetcher.get(uri, verify=verify)
- # XXX should send a if-modified-since header
- # XXX get 404, ...
- # and raise a UnableToFetch...
+ rargs = (uri, )
+ rkwargs = {'verify': verify}
+ headers = {}
+
+ curmtime = self.get_mtime() if not force_dl else None
+ if curmtime:
+ logger.debug('requesting with if-modified-since %s' % curmtime)
+ headers['if-modified-since'] = curmtime
+ rkwargs['headers'] = headers
+
+ #request = fetcher.get(uri, verify=verify)
+ request = fetcher.get(*rargs, **rkwargs)
request.raise_for_status()
- fd, fname = tempfile.mkstemp(suffix=".json")
- if request.json:
- self._config.load(json.dumps(request.json))
+ if request.status_code == 304:
+ logger.debug('...304 Not Changed')
+ # On this point, we have to assume that
+ # we HAD the filename. If that filename is corruct,
+ # we should enforce a force_download in the load
+ # method above.
+ self._config.load(fromfile=self.filename)
+ return True
+ if request.json:
+ mtime = None
+ last_modified = request.headers.get('last-modified', None)
+ if last_modified:
+ _mtime = dateparser.parse(last_modified)
+ mtime = int(_mtime.strftime("%s"))
+ if callable(request.json):
+ _json = request.json()
+ else:
+ # back-compat
+ _json = request.json
+ self._config.load(json.dumps(_json), mtime=mtime)
+ self._config.set_dirty()
else:
# not request.json
# might be server did not announce content properly,
# let's try deserializing all the same.
try:
self._config.load(request.content)
+ self._config.set_dirty()
except ValueError:
raise eipexceptions.LeapBadConfigFetchedError
return True
+ def get_mtime(self):
+ try:
+ _mtime = os.stat(self.filename)[8]
+ mtime = time.strftime("%c GMT", time.gmtime(_mtime))
+ return mtime
+ except OSError:
+ return None
+
def get_config(self):
return self._config.config
diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py
index f7be8d98..b38723be 100644
--- a/src/leap/base/constants.py
+++ b/src/leap/base/constants.py
@@ -14,18 +14,27 @@ DEFAULT_PROVIDER = __branding.get(
DEFINITION_EXPECTED_PATH = "provider.json"
DEFAULT_PROVIDER_DEFINITION = {
- u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER,
- u'api_version': u'0.1.0',
- u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e',
- u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER,
- u'description': {u'en': u'This is a test provider'},
- u'display_name': {u'en': u'Test Provider'},
- u'domain': u'%s' % DEFAULT_PROVIDER,
- u'enrollment_policy': u'open',
- u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21',
- u'serial': 1,
- u'services': [u'eip'],
- u'version': u'0.1.0'}
+ u"api_uri": "https://api.%s/" % DEFAULT_PROVIDER,
+ u"api_version": u"1",
+ u"ca_cert_fingerprint": "SHA256: fff",
+ u"ca_cert_uri": u"https://%s/ca.crt" % DEFAULT_PROVIDER,
+ u"default_language": u"en",
+ u"description": {
+ u"en": u"A demonstration service provider using the LEAP platform"
+ },
+ u"domain": "%s" % DEFAULT_PROVIDER,
+ u"enrollment_policy": u"open",
+ u"languages": [
+ u"en"
+ ],
+ u"name": {
+ u"en": u"Test Provider"
+ },
+ u"services": [
+ "openvpn"
+ ]
+}
+
MAX_ICMP_PACKET_LOSS = 10
diff --git a/src/leap/base/network.py b/src/leap/base/network.py
index 3aba3f61..765d8ea0 100644
--- a/src/leap/base/network.py
+++ b/src/leap/base/network.py
@@ -3,10 +3,11 @@ from __future__ import (print_function)
import logging
import threading
-from leap.eip.config import get_eip_gateway
+from leap.eip import config as eipconfig
from leap.base.checks import LeapNetworkChecker
from leap.base.constants import ROUTE_CHECK_INTERVAL
from leap.base.exceptions import TunnelNotDefaultRouteError
+from leap.util.misc import null_check
from leap.util.coroutines import (launch_thread, process_events)
from time import sleep
@@ -27,11 +28,20 @@ class NetworkCheckerThread(object):
lambda exc: logger.error("%s", exc.message))
self.shutdown = threading.Event()
- # XXX get provider_gateway and pass it to checker
- # see in eip.config for function
- # #718
+ # XXX get provider passed here
+ provider = kwargs.pop('provider', None)
+ null_check(provider, 'provider')
+
+ eipconf = eipconfig.EIPConfig(domain=provider)
+ eipconf.load()
+ eipserviceconf = eipconfig.EIPServiceConfig(domain=provider)
+ eipserviceconf.load()
+
+ gw = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipserviceconf)
self.checker = LeapNetworkChecker(
- provider_gw=get_eip_gateway())
+ provider_gw=gw)
def start(self):
self.process_handle = self._launch_recurrent_network_checks(
diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py
index b8615ad8..0ca985ea 100644
--- a/src/leap/base/pluggableconfig.py
+++ b/src/leap/base/pluggableconfig.py
@@ -180,6 +180,8 @@ class PluggableConfig(object):
self.adaptors = adaptors
self.types = types
self._format = format
+ self.mtime = None
+ self.dirty = False
@property
def option_dict(self):
@@ -319,6 +321,13 @@ class PluggableConfig(object):
serializable = self.prep_value(config)
adaptor.write(serializable, filename)
+ if self.mtime:
+ self.touch_mtime(filename)
+
+ def touch_mtime(self, filename):
+ mtime = self.mtime
+ os.utime(filename, (mtime, mtime))
+
def deserialize(self, string=None, fromfile=None, format=None):
"""
load configuration from a file or string
@@ -364,6 +373,12 @@ class PluggableConfig(object):
content = _try_deserialize()
return content
+ def set_dirty(self):
+ self.dirty = True
+
+ def is_dirty(self):
+ return self.dirty
+
def load(self, *args, **kwargs):
"""
load from string or file
@@ -373,6 +388,8 @@ class PluggableConfig(object):
"""
string = args[0] if args else None
fromfile = kwargs.get("fromfile", None)
+ mtime = kwargs.pop("mtime", None)
+ self.mtime = mtime
content = None
# start with defaults, so we can
@@ -402,7 +419,8 @@ class PluggableConfig(object):
return True
-def testmain():
+def testmain(): # pragma: no cover
+
from tests import test_validation as t
import pprint
diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py
index b4bb8dcf..962aa07d 100644
--- a/src/leap/base/specs.py
+++ b/src/leap/base/specs.py
@@ -2,22 +2,26 @@ leap_provider_spec = {
'description': 'provider definition',
'type': 'object',
'properties': {
- 'serial': {
- 'type': int,
- 'default': 1,
- 'required': True,
- },
+ #'serial': {
+ #'type': int,
+ #'default': 1,
+ #'required': True,
+ #},
'version': {
'type': unicode,
'default': '0.1.0'
#'required': True
},
+ "default_language": {
+ 'type': unicode,
+ 'default': 'en'
+ },
'domain': {
'type': unicode, # XXX define uri type
'default': 'testprovider.example.org'
#'required': True,
},
- 'display_name': {
+ 'name': {
'type': dict, # XXX multilingual object?
'default': {u'en': u'Test Provider'}
#'required': True
diff --git a/src/leap/base/tests/test_auth.py b/src/leap/base/tests/test_auth.py
new file mode 100644
index 00000000..17b84b52
--- /dev/null
+++ b/src/leap/base/tests/test_auth.py
@@ -0,0 +1,58 @@
+from BaseHTTPServer import BaseHTTPRequestHandler
+import urlparse
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+import requests
+#from mock import Mock
+
+from leap.base import auth
+#from leap.base import exceptions
+from leap.eip.tests.test_checks import NoLogRequestHandler
+from leap.testing.basetest import BaseLeapTest
+from leap.testing.https_server import BaseHTTPSServerTestCase
+
+
+class LeapSRPRegisterTests(BaseHTTPSServerTestCase, BaseLeapTest):
+ __name__ = "leap_srp_register_test"
+ provider = "testprovider.example.org"
+
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ responses = {
+ '/': ['OK', '']}
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_srp_auth_should_implement_check_methods(self):
+ SERVER = "https://localhost:8443"
+ srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False)
+
+ self.assertTrue(hasattr(srp_auth, "init_session"),
+ "missing meth")
+ self.assertTrue(hasattr(srp_auth, "get_registration_uri"),
+ "missing meth")
+ self.assertTrue(hasattr(srp_auth, "register_user"),
+ "missing meth")
+
+ def test_srp_auth_basic_functionality(self):
+ SERVER = "https://localhost:8443"
+ srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False)
+
+ self.assertIsInstance(srp_auth.session, requests.sessions.Session)
+ self.assertEqual(
+ srp_auth.get_registration_uri(),
+ "https://localhost:8443/1/users.json")
diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py
index 8d573b1e..7a694f89 100644
--- a/src/leap/base/tests/test_checks.py
+++ b/src/leap/base/tests/test_checks.py
@@ -118,6 +118,22 @@ class LeapNetworkCheckTest(BaseLeapTest):
with self.assertRaises(exceptions.NoInternetConnection):
checker.check_internet_connection()
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.side_effect = requests.ConnectionError(
+ "[Errno 113] No route to host")
+ with self.assertRaises(exceptions.NoInternetConnection):
+ with patch.object(checker, "ping_gateway") as mock_ping:
+ mock_ping.return_value = True
+ checker.check_internet_connection()
+
+ with patch.object(requests, "get") as mocked_get:
+ mocked_get.side_effect = requests.ConnectionError(
+ "[Errno 113] No route to host")
+ with self.assertRaises(exceptions.NoInternetConnection):
+ with patch.object(checker, "ping_gateway") as mock_ping:
+ mock_ping.side_effect = exceptions.NoConnectionToGateway
+ checker.check_internet_connection()
+
@unittest.skipUnless(_uid == 0, "root only")
def test_ping_gateway(self):
checker = checks.LeapNetworkChecker()
diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py
index 15c4ed58..9c11f270 100644
--- a/src/leap/base/tests/test_providers.py
+++ b/src/leap/base/tests/test_providers.py
@@ -8,7 +8,7 @@ import os
import jsonschema
-from leap import __branding as BRANDING
+#from leap import __branding as BRANDING
from leap.testing.basetest import BaseLeapTest
from leap.base import providers
@@ -16,10 +16,12 @@ from leap.base import providers
EXPECTED_DEFAULT_CONFIG = {
u"api_version": u"0.1.0",
u"description": {u'en': u"Test provider"},
- u"display_name": {u'en': u"Test Provider"},
+ u"default_language": u"en",
+ #u"display_name": {u'en': u"Test Provider"},
u"domain": u"testprovider.example.org",
+ u'name': {u'en': u'Test Provider'},
u"enrollment_policy": u"open",
- u"serial": 1,
+ #u"serial": 1,
u"services": [
u"eip"
],
@@ -33,8 +35,8 @@ class TestLeapProviderDefinition(BaseLeapTest):
self.domain = "testprovider.example.org"
self.definition = providers.LeapProviderDefinition(
domain=self.domain)
- self.definition.save()
- self.definition.load()
+ self.definition.save(force=True)
+ self.definition.load() # why have to load after save??
self.config = self.definition.config
def tearDown(self):
@@ -61,7 +63,7 @@ class TestLeapProviderDefinition(BaseLeapTest):
def test_provider_dump(self):
# check a good provider definition is dumped to disk
self.testfile = self.get_tempfile('test.json')
- self.definition.save(to=self.testfile)
+ self.definition.save(to=self.testfile, force=True)
deserialized = json.load(open(self.testfile, 'rb'))
self.maxDiff = None
self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)
@@ -88,7 +90,8 @@ class TestLeapProviderDefinition(BaseLeapTest):
def test_provider_validation(self):
self.definition.validate(self.config)
_config = copy.deepcopy(self.config)
- _config['serial'] = 'aaa'
+ # bad type, raise validation error
+ _config['domain'] = 111
with self.assertRaises(jsonschema.ValidationError):
self.definition.validate(_config)