diff options
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r-- | src/leap/eip/checks.py | 85 |
1 files changed, 58 insertions, 27 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index b14e5dd3..bd158e1e 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -1,5 +1,5 @@ import logging -import ssl +#import ssl #import platform import time import os @@ -21,6 +21,8 @@ from leap.eip import constants as eipconstants from leap.eip import exceptions as eipexceptions from leap.eip import specs as eipspecs from leap.util.fileutil import mkdir_p +from leap.util.web import get_https_domain_and_port +from leap.util.misc import null_check logger = logging.getLogger(name=__name__) @@ -46,7 +48,7 @@ reachable and testable as a whole. def get_branding_ca_cert(domain): - # XXX deprecated + # deprecated ca_file = BRANDING.get('provider_ca_file') if ca_file: return leapcerts.where(ca_file) @@ -63,6 +65,10 @@ class ProviderCertChecker(object): self.fetcher = fetcher self.domain = domain + #XXX needs some kind of autoinit + #right now we set by hand + #by loading and reading provider config + self.apidomain = None self.cacert = eipspecs.provider_ca_path(domain) def run_all( @@ -159,7 +165,7 @@ class ProviderCertChecker(object): if autocacert and verify is True and self.cacert is not None: logger.debug('verify cert: %s', self.cacert) verify = self.cacert - logger.debug('is https working?') + logger.debug('checking https connection') logger.debug('uri: %s (verify:%s)', uri, verify) try: self.fetcher.get(uri, verify=verify) @@ -167,27 +173,24 @@ class ProviderCertChecker(object): except requests.exceptions.SSLError: # as exc: logger.error("SSLError") raise eipexceptions.HttpsBadCertError - #logger.warning('BUG #638 CERT VERIFICATION FAILED! ' - #'(this should be CRITICAL)') - #logger.warning('SSLError: %s', exc.message) except requests.exceptions.ConnectionError: logger.error('ConnectionError') raise eipexceptions.HttpsNotSupported else: - logger.debug('True') return True def check_new_cert_needed(self, skip_download=False, verify=True): + # XXX add autocacert logger.debug('is new cert needed?') if not self.is_cert_valid(do_raise=False): - logger.debug('True') + logger.debug('cert needed: true') self.download_new_client_cert( skip_download=skip_download, verify=verify) return True - logger.debug('False') + logger.debug('cert needed: false') return False def download_new_client_cert(self, uri=None, verify=True, @@ -199,20 +202,20 @@ class ProviderCertChecker(object): if uri is None: uri = self._get_client_cert_uri() # XXX raise InsecureURI or something better - assert uri.startswith('https') + #assert uri.startswith('https') if verify is True and self.cacert is not None: verify = self.cacert + logger.debug('verify = %s', verify) fgetfn = self.fetcher.get if credentials: user, passwd = credentials - - logger.debug('domain = %s', self.domain) + logger.debug('apidomain = %s', self.apidomain) @srpauth_protected(user, passwd, - server="https://%s" % self.domain, + server="https://%s" % self.apidomain, verify=verify) def getfn(*args, **kwargs): return fgetfn(*args, **kwargs) @@ -231,11 +234,16 @@ class ProviderCertChecker(object): logger.warning('SSLError while fetching cert. ' 'Look below for stack trace.') # XXX raise better exception - raise + return self.fail("SSLError") + except Exception as exc: + return self.fail(exc.message) + try: + logger.debug('validating cert...') pemfile_content = req.content valid = self.is_valid_pemfile(pemfile_content) if not valid: + logger.warning('invalid cert') return False cert_path = self._get_client_cert_path() self.write_cert(pemfile_content, to=cert_path) @@ -299,8 +307,7 @@ class ProviderCertChecker(object): return u"https://%s/" % self.domain def _get_client_cert_uri(self): - # XXX get the whole thing from constants - return "https://%s/1/cert" % self.domain + return "https://%s/1/cert" % self.apidomain def _get_client_cert_path(self): return eipspecs.client_cert_path(domain=self.domain) @@ -327,6 +334,9 @@ class ProviderCertChecker(object): with open(to, 'w') as cert_f: cert_f.write(pemfile_content) + def set_api_domain(self, domain): + self.apidomain = domain + class EIPConfigChecker(object): """ @@ -346,10 +356,15 @@ class EIPConfigChecker(object): # if not domain, get from config self.domain = domain + self.apidomain = None + self.cacert = eipspecs.provider_ca_path(domain) - self.eipconfig = eipconfig.EIPConfig(domain=domain) self.defaultprovider = providers.LeapProviderDefinition(domain=domain) + self.defaultprovider.load() + self.eipconfig = eipconfig.EIPConfig(domain=domain) + self.set_api_domain() self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain) + self.eipserviceconfig.load() def run_all(self, checker=None, skip_download=False): """ @@ -433,31 +448,35 @@ class EIPConfigChecker(object): domain = config.get('provider', None) uri = self._get_provider_definition_uri(domain=domain) - # FIXME! Pass ca path verify!!! - # BUG #638 - # FIXME FIXME FIXME self.defaultprovider.load( from_uri=uri, fetcher=self.fetcher) - #verify=False) self.defaultprovider.save() def fetch_eip_service_config(self, skip_download=False, force_download=False, - config=None, uri=None, domain=None): + config=None, uri=None, # domain=None, + autocacert=True): if skip_download: return True if config is None: + self.eipserviceconfig.load() config = self.eipserviceconfig.config if uri is None: - if not domain: - domain = self.domain or config.get('provider', None) - uri = self._get_eip_service_uri(domain=domain) + #XXX + #if not domain: + #domain = self.domain or config.get('provider', None) + uri = self._get_eip_service_uri( + domain=self.apidomain) + + if autocacert and self.cacert is not None: + verify = self.cacert self.eipserviceconfig.load( from_uri=uri, fetcher=self.fetcher, - force_download=force_download) + force_download=force_download, + verify=verify) self.eipserviceconfig.save() def check_complete_eip_config(self, config=None): @@ -465,7 +484,6 @@ class EIPConfigChecker(object): if config is None: config = self.eipconfig.config try: - 'trying assertions' assert 'provider' in config assert config['provider'] is not None # XXX assert there is gateway !! @@ -504,3 +522,16 @@ class EIPConfigChecker(object): uri = "https://%s/%s" % (domain, path) logger.debug('getting eip service file from %s', uri) return uri + + def set_api_domain(self): + """sets api domain from defaultprovider config object""" + api = self.defaultprovider.config.get('api_uri', None) + # the caller is responsible for having loaded the config + # object at this point + if api: + api_dom = get_https_domain_and_port(api) + self.apidomain = "%s:%s" % api_dom + + def get_api_domain(self): + """gets api domain""" + return self.apidomain |