summaryrefslogtreecommitdiff
path: root/src/leap/base
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/base')
-rw-r--r--src/leap/base/auth.py16
-rw-r--r--src/leap/base/config.py94
-rw-r--r--src/leap/base/network.py20
-rw-r--r--src/leap/base/pluggableconfig.py17
-rw-r--r--src/leap/base/tests/test_providers.py8
5 files changed, 114 insertions, 41 deletions
diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py
index 50533278..73856bb0 100644
--- a/src/leap/base/auth.py
+++ b/src/leap/base/auth.py
@@ -10,6 +10,7 @@ from PyQt4 import QtCore
from leap.base import constants as baseconstants
from leap.crypto import leapkeyring
+from leap.util.misc import null_check
from leap.util.web import get_https_domain_and_port
logger = logging.getLogger(__name__)
@@ -26,11 +27,6 @@ one if not.
"""
-class ImproperlyConfigured(Exception):
- """
- """
-
-
class SRPAuthenticationError(Exception):
"""
exception raised
@@ -38,14 +34,6 @@ class SRPAuthenticationError(Exception):
"""
-def null_check(value, value_name):
- try:
- assert value is not None
- except AssertionError:
- raise ImproperlyConfigured(
- "%s parameter cannot be None" % value_name)
-
-
safe_unhexlify = lambda x: binascii.unhexlify(x) \
if (len(x) % 2 == 0) else binascii.unhexlify('0' + x)
@@ -64,7 +52,7 @@ class LeapSRPRegister(object):
hashfun=srp.SHA256,
ng_constant=srp.NG_1024):
- null_check(provider, provider)
+ null_check(provider, "provider")
self.schema = schema
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
index 0255fbab..b307ad05 100644
--- a/src/leap/base/config.py
+++ b/src/leap/base/config.py
@@ -5,11 +5,12 @@ import grp
import json
import logging
import socket
-import tempfile
+import time
import os
logger = logging.getLogger(name=__name__)
+from dateutil import parser as dateparser
import requests
from leap.base import exceptions
@@ -125,17 +126,43 @@ class JSONLeapConfig(BaseLeapConfig):
# mandatory baseconfig interface
- def save(self, to=None):
- if to is None:
- to = self.filename
- folder, filename = os.path.split(to)
- if folder and not os.path.isdir(folder):
- mkdir_p(folder)
- self._config.serialize(to)
+ def save(self, to=None, force=False):
+ """
+ force param will skip the dirty check.
+ :type force: bool
+ """
+ # XXX this force=True does not feel to right
+ # but still have to look for a better way
+ # of dealing with dirtiness and the
+ # trick of loading remote config only
+ # when newer.
+
+ if force:
+ do_save = True
+ else:
+ do_save = self._config.is_dirty()
+
+ if do_save:
+ if to is None:
+ to = self.filename
+ folder, filename = os.path.split(to)
+ if folder and not os.path.isdir(folder):
+ mkdir_p(folder)
+ self._config.serialize(to)
+ return True
+
+ else:
+ return False
+
+ def load(self, fromfile=None, from_uri=None, fetcher=None,
+ force_download=False, verify=False):
- def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False):
if from_uri is not None:
- fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify)
+ fetched = self.fetch(
+ from_uri,
+ fetcher=fetcher,
+ verify=verify,
+ force_dl=force_download)
if fetched:
return
if fromfile is None:
@@ -146,33 +173,64 @@ class JSONLeapConfig(BaseLeapConfig):
logger.error('tried to load config from non-existent path')
logger.error('Not Found: %s', fromfile)
- def fetch(self, uri, fetcher=None, verify=True):
+ def fetch(self, uri, fetcher=None, verify=True, force_dl=False):
if not fetcher:
fetcher = self.fetcher
+
logger.debug('verify: %s', verify)
logger.debug('uri: %s', uri)
- request = fetcher.get(uri, verify=verify)
- # XXX should send a if-modified-since header
- # XXX get 404, ...
- # and raise a UnableToFetch...
+ rargs = (uri, )
+ rkwargs = {'verify': verify}
+ headers = {}
+
+ curmtime = self.get_mtime() if not force_dl else None
+ if curmtime:
+ logger.debug('requesting with if-modified-since %s' % curmtime)
+ headers['if-modified-since'] = curmtime
+ rkwargs['headers'] = headers
+
+ #request = fetcher.get(uri, verify=verify)
+ request = fetcher.get(*rargs, **rkwargs)
request.raise_for_status()
- fd, fname = tempfile.mkstemp(suffix=".json")
- if request.json:
- self._config.load(json.dumps(request.json))
+ if request.status_code == 304:
+ logger.debug('...304 Not Changed')
+ # On this point, we have to assume that
+ # we HAD the filename. If that filename is corruct,
+ # we should enforce a force_download in the load
+ # method above.
+ self._config.load(fromfile=self.filename)
+ return True
+ if request.json:
+ mtime = None
+ last_modified = request.headers.get('last-modified', None)
+ if last_modified:
+ _mtime = dateparser.parse(last_modified)
+ mtime = int(_mtime.strftime("%s"))
+ self._config.load(json.dumps(request.json), mtime=mtime)
+ self._config.set_dirty()
else:
# not request.json
# might be server did not announce content properly,
# let's try deserializing all the same.
try:
self._config.load(request.content)
+ self._config.set_dirty()
except ValueError:
raise eipexceptions.LeapBadConfigFetchedError
return True
+ def get_mtime(self):
+ try:
+ _mtime = os.stat(self.filename)[8]
+ mtime = time.strftime("%c GMT", time.gmtime(_mtime))
+ return mtime
+ except OSError:
+ return None
+
def get_config(self):
return self._config.config
diff --git a/src/leap/base/network.py b/src/leap/base/network.py
index 3aba3f61..765d8ea0 100644
--- a/src/leap/base/network.py
+++ b/src/leap/base/network.py
@@ -3,10 +3,11 @@ from __future__ import (print_function)
import logging
import threading
-from leap.eip.config import get_eip_gateway
+from leap.eip import config as eipconfig
from leap.base.checks import LeapNetworkChecker
from leap.base.constants import ROUTE_CHECK_INTERVAL
from leap.base.exceptions import TunnelNotDefaultRouteError
+from leap.util.misc import null_check
from leap.util.coroutines import (launch_thread, process_events)
from time import sleep
@@ -27,11 +28,20 @@ class NetworkCheckerThread(object):
lambda exc: logger.error("%s", exc.message))
self.shutdown = threading.Event()
- # XXX get provider_gateway and pass it to checker
- # see in eip.config for function
- # #718
+ # XXX get provider passed here
+ provider = kwargs.pop('provider', None)
+ null_check(provider, 'provider')
+
+ eipconf = eipconfig.EIPConfig(domain=provider)
+ eipconf.load()
+ eipserviceconf = eipconfig.EIPServiceConfig(domain=provider)
+ eipserviceconf.load()
+
+ gw = eipconfig.get_eip_gateway(
+ eipconfig=eipconf,
+ eipserviceconfig=eipserviceconf)
self.checker = LeapNetworkChecker(
- provider_gw=get_eip_gateway())
+ provider_gw=gw)
def start(self):
self.process_handle = self._launch_recurrent_network_checks(
diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py
index b8615ad8..34c1e060 100644
--- a/src/leap/base/pluggableconfig.py
+++ b/src/leap/base/pluggableconfig.py
@@ -180,6 +180,8 @@ class PluggableConfig(object):
self.adaptors = adaptors
self.types = types
self._format = format
+ self.mtime = None
+ self.dirty = False
@property
def option_dict(self):
@@ -319,6 +321,13 @@ class PluggableConfig(object):
serializable = self.prep_value(config)
adaptor.write(serializable, filename)
+ if self.mtime:
+ self.touch_mtime(filename)
+
+ def touch_mtime(self, filename):
+ mtime = self.mtime
+ os.utime(filename, (mtime, mtime))
+
def deserialize(self, string=None, fromfile=None, format=None):
"""
load configuration from a file or string
@@ -364,6 +373,12 @@ class PluggableConfig(object):
content = _try_deserialize()
return content
+ def set_dirty(self):
+ self.dirty = True
+
+ def is_dirty(self):
+ return self.dirty
+
def load(self, *args, **kwargs):
"""
load from string or file
@@ -373,6 +388,8 @@ class PluggableConfig(object):
"""
string = args[0] if args else None
fromfile = kwargs.get("fromfile", None)
+ mtime = kwargs.pop("mtime", None)
+ self.mtime = mtime
content = None
# start with defaults, so we can
diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py
index 15c4ed58..d9604fab 100644
--- a/src/leap/base/tests/test_providers.py
+++ b/src/leap/base/tests/test_providers.py
@@ -8,7 +8,7 @@ import os
import jsonschema
-from leap import __branding as BRANDING
+#from leap import __branding as BRANDING
from leap.testing.basetest import BaseLeapTest
from leap.base import providers
@@ -33,8 +33,8 @@ class TestLeapProviderDefinition(BaseLeapTest):
self.domain = "testprovider.example.org"
self.definition = providers.LeapProviderDefinition(
domain=self.domain)
- self.definition.save()
- self.definition.load()
+ self.definition.save(force=True)
+ self.definition.load() # why have to load after save??
self.config = self.definition.config
def tearDown(self):
@@ -61,7 +61,7 @@ class TestLeapProviderDefinition(BaseLeapTest):
def test_provider_dump(self):
# check a good provider definition is dumped to disk
self.testfile = self.get_tempfile('test.json')
- self.definition.save(to=self.testfile)
+ self.definition.save(to=self.testfile, force=True)
deserialized = json.load(open(self.testfile, 'rb'))
self.maxDiff = None
self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG)