summaryrefslogtreecommitdiff
path: root/src/leap/base/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/base/config.py')
-rw-r--r--src/leap/base/config.py136
1 files changed, 106 insertions, 30 deletions
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
index cf01d1aa..e2f0beba 100644
--- a/src/leap/base/config.py
+++ b/src/leap/base/config.py
@@ -4,12 +4,15 @@ Configuration Base Class
import grp
import json
import logging
+import re
import socket
-import tempfile
+import time
import os
logger = logging.getLogger(name=__name__)
+from dateutil import parser as dateparser
+from dirspec import basedir
import requests
from leap.base import exceptions
@@ -118,23 +121,50 @@ class JSONLeapConfig(BaseLeapConfig):
" derived class")
assert issubclass(self.spec, PluggableConfig)
+ self.domain = kwargs.pop('domain', None)
self._config = self.spec(format="json")
self._config.load()
self.fetcher = kwargs.pop('fetcher', requests)
# mandatory baseconfig interface
- def save(self, to=None):
- if to is None:
- to = self.filename
- folder, filename = os.path.split(to)
- if folder and not os.path.isdir(folder):
- mkdir_p(folder)
- self._config.serialize(to)
+ def save(self, to=None, force=False):
+ """
+ force param will skip the dirty check.
+ :type force: bool
+ """
+ # XXX this force=True does not feel to right
+ # but still have to look for a better way
+ # of dealing with dirtiness and the
+ # trick of loading remote config only
+ # when newer.
+
+ if force:
+ do_save = True
+ else:
+ do_save = self._config.is_dirty()
+
+ if do_save:
+ if to is None:
+ to = self.filename
+ folder, filename = os.path.split(to)
+ if folder and not os.path.isdir(folder):
+ mkdir_p(folder)
+ self._config.serialize(to)
+ return True
+
+ else:
+ return False
+
+ def load(self, fromfile=None, from_uri=None, fetcher=None,
+ force_download=False, verify=True):
- def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False):
if from_uri is not None:
- fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify)
+ fetched = self.fetch(
+ from_uri,
+ fetcher=fetcher,
+ verify=verify,
+ force_dl=force_download)
if fetched:
return
if fromfile is None:
@@ -145,33 +175,68 @@ class JSONLeapConfig(BaseLeapConfig):
logger.error('tried to load config from non-existent path')
logger.error('Not Found: %s', fromfile)
- def fetch(self, uri, fetcher=None, verify=True):
+ def fetch(self, uri, fetcher=None, verify=True, force_dl=False):
if not fetcher:
fetcher = self.fetcher
- logger.debug('verify: %s', verify)
- logger.debug('uri: %s', uri)
- request = fetcher.get(uri, verify=verify)
- # XXX should send a if-modified-since header
- # XXX get 404, ...
- # and raise a UnableToFetch...
+ logger.debug('uri: %s (verify: %s)' % (uri, verify))
+
+ rargs = (uri, )
+ rkwargs = {'verify': verify}
+ headers = {}
+
+ curmtime = self.get_mtime() if not force_dl else None
+ if curmtime:
+ logger.debug('requesting with if-modified-since %s' % curmtime)
+ headers['if-modified-since'] = curmtime
+ rkwargs['headers'] = headers
+
+ #request = fetcher.get(uri, verify=verify)
+ request = fetcher.get(*rargs, **rkwargs)
request.raise_for_status()
- fd, fname = tempfile.mkstemp(suffix=".json")
- if request.json:
- self._config.load(json.dumps(request.json))
+ if request.status_code == 304:
+ logger.debug('...304 Not Changed')
+ # On this point, we have to assume that
+ # we HAD the filename. If that filename is corruct,
+ # we should enforce a force_download in the load
+ # method above.
+ self._config.load(fromfile=self.filename)
+ return True
+ if request.json:
+ mtime = None
+ last_modified = request.headers.get('last-modified', None)
+ if last_modified:
+ _mtime = dateparser.parse(last_modified)
+ mtime = int(_mtime.strftime("%s"))
+ if callable(request.json):
+ _json = request.json()
+ else:
+ # back-compat
+ _json = request.json
+ self._config.load(json.dumps(_json), mtime=mtime)
+ self._config.set_dirty()
else:
# not request.json
# might be server did not announce content properly,
# let's try deserializing all the same.
try:
self._config.load(request.content)
+ self._config.set_dirty()
except ValueError:
raise eipexceptions.LeapBadConfigFetchedError
return True
+ def get_mtime(self):
+ try:
+ _mtime = os.stat(self.filename)[8]
+ mtime = time.strftime("%c GMT", time.gmtime(_mtime))
+ return mtime
+ except OSError:
+ return None
+
def get_config(self):
return self._config.config
@@ -216,15 +281,13 @@ def get_config_dir():
@rparam: config path
@rtype: string
"""
- # TODO
- # check for $XDG_CONFIG_HOME var?
- # get a more sensible path for win/mac
- # kclair: opinion? ^^
-
- return os.path.expanduser(
- os.path.join('~',
- '.config',
- 'leap'))
+ home = os.path.expanduser("~")
+ if re.findall("leap_tests-[a-zA-Z0-9]{6}", home):
+ # we're inside a test! :)
+ return os.path.join(home, ".config/leap")
+ else:
+ return os.path.join(basedir.default_config_home,
+ 'leap')
def get_config_file(filename, folder=None):
@@ -252,6 +315,15 @@ def get_default_provider_path():
return default_provider_path
+def get_provider_path(domain):
+ # XXX if not domain, return get_default_provider_path
+ default_subpath = os.path.join("providers", domain)
+ provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return provider_path
+
+
def validate_ip(ip_str):
"""
raises exception if the ip_str is
@@ -261,7 +333,11 @@ def validate_ip(ip_str):
def get_username():
- return os.getlogin()
+ try:
+ return os.getlogin()
+ except OSError as e:
+ import pwd
+ return pwd.getpwuid(os.getuid())[0]
def get_groupname():