diff options
Diffstat (limited to 'src/leap/base/config.py')
-rw-r--r-- | src/leap/base/config.py | 136 |
1 files changed, 106 insertions, 30 deletions
diff --git a/src/leap/base/config.py b/src/leap/base/config.py index cf01d1aa..e2f0beba 100644 --- a/src/leap/base/config.py +++ b/src/leap/base/config.py @@ -4,12 +4,15 @@ Configuration Base Class import grp import json import logging +import re import socket -import tempfile +import time import os logger = logging.getLogger(name=__name__) +from dateutil import parser as dateparser +from dirspec import basedir import requests from leap.base import exceptions @@ -118,23 +121,50 @@ class JSONLeapConfig(BaseLeapConfig): " derived class") assert issubclass(self.spec, PluggableConfig) + self.domain = kwargs.pop('domain', None) self._config = self.spec(format="json") self._config.load() self.fetcher = kwargs.pop('fetcher', requests) # mandatory baseconfig interface - def save(self, to=None): - if to is None: - to = self.filename - folder, filename = os.path.split(to) - if folder and not os.path.isdir(folder): - mkdir_p(folder) - self._config.serialize(to) + def save(self, to=None, force=False): + """ + force param will skip the dirty check. + :type force: bool + """ + # XXX this force=True does not feel to right + # but still have to look for a better way + # of dealing with dirtiness and the + # trick of loading remote config only + # when newer. + + if force: + do_save = True + else: + do_save = self._config.is_dirty() + + if do_save: + if to is None: + to = self.filename + folder, filename = os.path.split(to) + if folder and not os.path.isdir(folder): + mkdir_p(folder) + self._config.serialize(to) + return True + + else: + return False + + def load(self, fromfile=None, from_uri=None, fetcher=None, + force_download=False, verify=True): - def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False): if from_uri is not None: - fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify) + fetched = self.fetch( + from_uri, + fetcher=fetcher, + verify=verify, + force_dl=force_download) if fetched: return if fromfile is None: @@ -145,33 +175,68 @@ class JSONLeapConfig(BaseLeapConfig): logger.error('tried to load config from non-existent path') logger.error('Not Found: %s', fromfile) - def fetch(self, uri, fetcher=None, verify=True): + def fetch(self, uri, fetcher=None, verify=True, force_dl=False): if not fetcher: fetcher = self.fetcher - logger.debug('verify: %s', verify) - logger.debug('uri: %s', uri) - request = fetcher.get(uri, verify=verify) - # XXX should send a if-modified-since header - # XXX get 404, ... - # and raise a UnableToFetch... + logger.debug('uri: %s (verify: %s)' % (uri, verify)) + + rargs = (uri, ) + rkwargs = {'verify': verify} + headers = {} + + curmtime = self.get_mtime() if not force_dl else None + if curmtime: + logger.debug('requesting with if-modified-since %s' % curmtime) + headers['if-modified-since'] = curmtime + rkwargs['headers'] = headers + + #request = fetcher.get(uri, verify=verify) + request = fetcher.get(*rargs, **rkwargs) request.raise_for_status() - fd, fname = tempfile.mkstemp(suffix=".json") - if request.json: - self._config.load(json.dumps(request.json)) + if request.status_code == 304: + logger.debug('...304 Not Changed') + # On this point, we have to assume that + # we HAD the filename. If that filename is corruct, + # we should enforce a force_download in the load + # method above. + self._config.load(fromfile=self.filename) + return True + if request.json: + mtime = None + last_modified = request.headers.get('last-modified', None) + if last_modified: + _mtime = dateparser.parse(last_modified) + mtime = int(_mtime.strftime("%s")) + if callable(request.json): + _json = request.json() + else: + # back-compat + _json = request.json + self._config.load(json.dumps(_json), mtime=mtime) + self._config.set_dirty() else: # not request.json # might be server did not announce content properly, # let's try deserializing all the same. try: self._config.load(request.content) + self._config.set_dirty() except ValueError: raise eipexceptions.LeapBadConfigFetchedError return True + def get_mtime(self): + try: + _mtime = os.stat(self.filename)[8] + mtime = time.strftime("%c GMT", time.gmtime(_mtime)) + return mtime + except OSError: + return None + def get_config(self): return self._config.config @@ -216,15 +281,13 @@ def get_config_dir(): @rparam: config path @rtype: string """ - # TODO - # check for $XDG_CONFIG_HOME var? - # get a more sensible path for win/mac - # kclair: opinion? ^^ - - return os.path.expanduser( - os.path.join('~', - '.config', - 'leap')) + home = os.path.expanduser("~") + if re.findall("leap_tests-[a-zA-Z0-9]{6}", home): + # we're inside a test! :) + return os.path.join(home, ".config/leap") + else: + return os.path.join(basedir.default_config_home, + 'leap') def get_config_file(filename, folder=None): @@ -252,6 +315,15 @@ def get_default_provider_path(): return default_provider_path +def get_provider_path(domain): + # XXX if not domain, return get_default_provider_path + default_subpath = os.path.join("providers", domain) + provider_path = get_config_file( + '', + folder=default_subpath) + return provider_path + + def validate_ip(ip_str): """ raises exception if the ip_str is @@ -261,7 +333,11 @@ def validate_ip(ip_str): def get_username(): - return os.getlogin() + try: + return os.getlogin() + except OSError as e: + import pwd + return pwd.getpwuid(os.getuid())[0] def get_groupname(): |