diff options
author | kali <kali@leap.se> | 2013-01-30 06:52:59 +0900 |
---|---|---|
committer | kali <kali@leap.se> | 2013-01-30 06:52:59 +0900 |
commit | 570f84756b3d1f689a172a3ff0c55abf6a60b9dd (patch) | |
tree | 0f262acacf35743664c6408edbafe6ba6f119d14 /src/leap/eip/checks.py | |
parent | e60abdcb796ad9e2c44e9b37d3a68e7f159c035c (diff) | |
parent | 10a2303fe2d21999bce56940daecb78576f5b741 (diff) |
Merge branch 'pre-release-0.2.0' into release-0.2.0
This merge contains today's state of develop branch,
minus soledad and email components.
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r-- | src/leap/eip/checks.py | 309 |
1 files changed, 215 insertions, 94 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index f739c3e8..9a34a428 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -1,23 +1,24 @@ import logging -import ssl -#import platform import time import os +import sys -from gnutls import crypto -#import netifaces -#import ping import requests from leap import __branding as BRANDING -from leap import certs +from leap import certs as leapcerts +from leap.base.auth import srpauth_protected, magick_srpauth +from leap.base import config as baseconfig from leap.base import constants as baseconstants from leap.base import providers +from leap.crypto import certs from leap.eip import config as eipconfig from leap.eip import constants as eipconstants from leap.eip import exceptions as eipexceptions from leap.eip import specs as eipspecs +from leap.util.certs import get_mac_cabundle from leap.util.fileutil import mkdir_p +from leap.util.web import get_https_domain_and_port logger = logging.getLogger(name=__name__) @@ -42,10 +43,11 @@ reachable and testable as a whole. """ -def get_ca_cert(): +def get_branding_ca_cert(domain): + # deprecated ca_file = BRANDING.get('provider_ca_file') if ca_file: - return certs.where(ca_file) + return leapcerts.where(ca_file) class ProviderCertChecker(object): @@ -54,18 +56,29 @@ class ProviderCertChecker(object): client certs and checking tls connection with provider. """ - def __init__(self, fetcher=requests): + def __init__(self, fetcher=requests, + domain=None): + self.fetcher = fetcher - self.cacert = get_ca_cert() + self.domain = domain + #XXX needs some kind of autoinit + #right now we set by hand + #by loading and reading provider config + self.apidomain = None + self.cacert = eipspecs.provider_ca_path(domain) + + def run_all( + self, checker=None, + skip_download=False, skip_verify=False): - def run_all(self, checker=None, skip_download=False, skip_verify=False): if not checker: checker = self do_verify = not skip_verify logger.debug('do_verify: %s', do_verify) - # For MVS+ # checker.download_ca_cert() + + # For MVS+ # checker.download_ca_signature() # checker.get_ca_signatures() # checker.is_there_trust_path() @@ -73,13 +86,44 @@ class ProviderCertChecker(object): # For MVS checker.is_there_provider_ca() - # XXX FAKE IT!!! - checker.is_https_working(verify=do_verify) + checker.is_https_working(verify=do_verify, autocacert=False) checker.check_new_cert_needed(verify=do_verify) - def download_ca_cert(self): - # MVS+ - raise NotImplementedError + def download_ca_cert(self, uri=None, verify=True): + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + + # should check domain exists + capath = self._get_ca_cert_path(self.domain) + with open(capath, 'w') as f: + f.write(req.content) + + def check_ca_cert_fingerprint( + self, hash_type="SHA256", + fingerprint=None): + """ + compares the fingerprint in + the ca cert with a string + we are passed + returns True if they are equal, False if not. + @param hash_type: digest function + @type hash_type: str + @param fingerprint: the fingerprint to compare with. + @type fingerprint: str (with : separator) + @rtype bool + """ + ca_cert_path = self.ca_cert_path + ca_cert_fpr = certs.get_cert_fingerprint( + filepath=ca_cert_path) + return ca_cert_fpr == fingerprint + + def verify_api_https(self, uri): + assert uri.startswith('https://') + cacert = self.ca_cert_path + verify = cacert and cacert or True + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + return True def download_ca_signature(self): # MVS+ @@ -94,80 +138,110 @@ class ProviderCertChecker(object): raise NotImplementedError def is_there_provider_ca(self): - from leap import certs - logger.debug('do we have provider_ca?') - cacert_path = BRANDING.get('provider_ca_file', None) - if not cacert_path: - logger.debug('False') + if not self.cacert: return False - self.cacert = certs.where(cacert_path) - logger.debug('True') - return True + cacert_exists = os.path.isfile(self.cacert) + if cacert_exists: + logger.debug('True') + return True + logger.debug('False!') + return False - def is_https_working(self, uri=None, verify=True): + def is_https_working( + self, uri=None, verify=True, + autocacert=False): if uri is None: uri = self._get_root_uri() # XXX raise InsecureURI or something better - assert uri.startswith('https') - if verify is True and self.cacert is not None: + try: + assert uri.startswith('https') + except AssertionError: + raise AssertionError( + "uri passed should start with https") + if autocacert and verify is True and self.cacert is not None: logger.debug('verify cert: %s', self.cacert) verify = self.cacert - logger.debug('is https working?') + if sys.platform == "darwin": + verify = get_mac_cabundle() + logger.debug('checking https connection') logger.debug('uri: %s (verify:%s)', uri, verify) + try: self.fetcher.get(uri, verify=verify) + except requests.exceptions.SSLError as exc: - logger.warning('False! CERT VERIFICATION FAILED! ' - '(this should be CRITICAL)') - logger.warning('SSLError: %s', exc.message) - # XXX RAISE! See #638 - #raise eipexceptions.EIPBadCertError - # XXX get requests.exceptions.ConnectionError Errno 110 - # Connection timed out, and raise ours. + raise eipexceptions.HttpsBadCertError + + except requests.exceptions.ConnectionError: + logger.error('ConnectionError') + raise eipexceptions.HttpsNotSupported + else: - logger.debug('True') return True def check_new_cert_needed(self, skip_download=False, verify=True): - logger.debug('is new cert needed?') + # XXX add autocacert if not self.is_cert_valid(do_raise=False): - logger.debug('True') + logger.debug('cert needed: true') self.download_new_client_cert( skip_download=skip_download, verify=verify) return True - logger.debug('False') + logger.debug('cert needed: false') return False def download_new_client_cert(self, uri=None, verify=True, - skip_download=False): + skip_download=False, + credentials=None): logger.debug('download new client cert') if skip_download: return True if uri is None: uri = self._get_client_cert_uri() # XXX raise InsecureURI or something better - assert uri.startswith('https') + #assert uri.startswith('https') + if verify is True and self.cacert is not None: verify = self.cacert + logger.debug('verify = %s', verify) + + fgetfn = self.fetcher.get + + if credentials: + user, passwd = credentials + logger.debug('apidomain = %s', self.apidomain) + + @srpauth_protected(user, passwd, + server="https://%s" % self.apidomain, + verify=verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) + + else: + # XXX FIXME fix decorated args + @magick_srpauth(verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) try: - # XXX FIXME!!!! - # verify=verify - # Workaround for #638. return to verification - # when That's done!!! - - # XXX HOOK SRP here... - # will have to be more generic in the future. - req = self.fetcher.get(uri, verify=False) + + req = getfn(uri, verify=verify) req.raise_for_status() + except requests.exceptions.SSLError: logger.warning('SSLError while fetching cert. ' 'Look below for stack trace.') # XXX raise better exception - raise + return self.fail("SSLError") + except Exception as exc: + return self.fail(exc.message) + try: + logger.debug('validating cert...') pemfile_content = req.content - self.is_valid_pemfile(pemfile_content) + valid = self.is_valid_pemfile(pemfile_content) + if not valid: + logger.warning('invalid cert') + return False cert_path = self._get_client_cert_path() self.write_cert(pemfile_content, to=cert_path) except: @@ -196,11 +270,8 @@ class ProviderCertChecker(object): def is_cert_not_expired(self, certfile=None, now=time.gmtime): if certfile is None: certfile = self._get_client_cert_path() - with open(certfile) as cf: - cert_s = cf.read() - cert = crypto.X509Certificate(cert_s) - from_ = time.gmtime(cert.activation_time) - to_ = time.gmtime(cert.expiration_time) + from_, to_ = certs.get_time_boundaries(certfile) + return from_ < now() < to_ def is_valid_pemfile(self, cert_s=None): @@ -216,33 +287,39 @@ class ProviderCertChecker(object): with open(certfile) as cf: cert_s = cf.read() try: - # XXX get a real cert validation - # so far this is only checking begin/end - # delimiters :) - # XXX use gnutls for get proper - # validation. - # crypto.X509Certificate(cert_s) - sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5 - # we might have private key and cert in the same file - certparts = cert_s.split(sep) - if len(certparts) > 1: - cert_s = sep + certparts[1] - ssl.PEM_cert_to_DER_cert(cert_s) - except: - # XXX raise proper exception - raise - return True + valid = certs.can_load_cert_and_pkey(cert_s) + except certs.BadCertError: + logger.warning("Not valid pemfile") + valid = False + return valid + + @property + def ca_cert_path(self): + return self._get_ca_cert_path(self.domain) def _get_root_uri(self): - return u"https://%s/" % baseconstants.DEFAULT_PROVIDER + return u"https://%s/" % self.domain def _get_client_cert_uri(self): - # XXX get the whole thing from constants - return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER) + return "https://%s/1/cert" % self.apidomain def _get_client_cert_path(self): - # MVS+ : get provider path - return eipspecs.client_cert_path() + return eipspecs.client_cert_path(domain=self.domain) + + def _get_ca_cert_path(self, domain): + # XXX this folder path will be broken for win + # and this should be moved to eipspecs.ca_path + + # XXX use baseconfig.get_provider_path(folder=Foo) + # !!! + + capath = baseconfig.get_config_file( + 'cacert.pem', + folder='providers/%s/keys/ca' % domain) + folder, fname = os.path.split(capath) + if not os.path.isdir(folder): + mkdir_p(folder) + return capath def write_cert(self, pemfile_content, to=None): folder, filename = os.path.split(to) @@ -251,6 +328,9 @@ class ProviderCertChecker(object): with open(to, 'w') as cert_f: cert_f.write(pemfile_content) + def set_api_domain(self, domain): + self.apidomain = domain + class EIPConfigChecker(object): """ @@ -260,16 +340,25 @@ class EIPConfigChecker(object): use run_all to run all checks. """ - def __init__(self, fetcher=requests): + def __init__(self, fetcher=requests, domain=None): # we do not want to accept too many # argument on init. # we want tests # to be explicitely run. + self.fetcher = fetcher - self.eipconfig = eipconfig.EIPConfig() - self.defaultprovider = providers.LeapProviderDefinition() - self.eipserviceconfig = eipconfig.EIPServiceConfig() + # if not domain, get from config + self.domain = domain + self.apidomain = None + self.cacert = eipspecs.provider_ca_path(domain) + + self.defaultprovider = providers.LeapProviderDefinition(domain=domain) + self.defaultprovider.load() + self.eipconfig = eipconfig.EIPConfig(domain=domain) + self.set_api_domain() + self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain) + self.eipserviceconfig.load() def run_all(self, checker=None, skip_download=False): """ @@ -330,7 +419,9 @@ class EIPConfigChecker(object): return True def fetch_definition(self, skip_download=False, - config=None, uri=None): + force_download=False, + config=None, uri=None, + domain=None): """ fetches a definition file from server """ @@ -347,27 +438,45 @@ class EIPConfigChecker(object): if config is None: config = self.defaultprovider.config if uri is None: - domain = config.get('provider', None) + if not domain: + domain = config.get('provider', None) uri = self._get_provider_definition_uri(domain=domain) - # FIXME! Pass ca path verify!!! + if sys.platform == "darwin": + verify = get_mac_cabundle() + else: + verify = True + self.defaultprovider.load( from_uri=uri, fetcher=self.fetcher, - verify=False) + verify=verify) self.defaultprovider.save() def fetch_eip_service_config(self, skip_download=False, - config=None, uri=None): + force_download=False, + config=None, uri=None, # domain=None, + autocacert=True, verify=True): if skip_download: return True if config is None: + self.eipserviceconfig.load() config = self.eipserviceconfig.config if uri is None: - domain = config.get('provider', None) - uri = self._get_eip_service_uri(domain=domain) + #XXX + #if not domain: + #domain = self.domain or config.get('provider', None) + uri = self._get_eip_service_uri( + domain=self.apidomain) + + if autocacert and self.cacert is not None: + verify = self.cacert - self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher) + self.eipserviceconfig.load( + from_uri=uri, + fetcher=self.fetcher, + force_download=force_download, + verify=verify) self.eipserviceconfig.save() def check_complete_eip_config(self, config=None): @@ -375,7 +484,6 @@ class EIPConfigChecker(object): if config is None: config = self.eipconfig.config try: - 'trying assertions' assert 'provider' in config assert config['provider'] is not None # XXX assert there is gateway !! @@ -395,11 +503,11 @@ class EIPConfigChecker(object): return self.eipconfig.exists() def _dump_default_eipconfig(self): - self.eipconfig.save() + self.eipconfig.save(force=True) def _get_provider_definition_uri(self, domain=None, path=None): if domain is None: - domain = baseconstants.DEFAULT_PROVIDER + domain = self.domain or baseconstants.DEFAULT_PROVIDER if path is None: path = baseconstants.DEFINITION_EXPECTED_PATH uri = u"https://%s/%s" % (domain, path) @@ -408,9 +516,22 @@ class EIPConfigChecker(object): def _get_eip_service_uri(self, domain=None, path=None): if domain is None: - domain = baseconstants.DEFAULT_PROVIDER + domain = self.domain or baseconstants.DEFAULT_PROVIDER if path is None: path = eipconstants.EIP_SERVICE_EXPECTED_PATH uri = "https://%s/%s" % (domain, path) logger.debug('getting eip service file from %s', uri) return uri + + def set_api_domain(self): + """sets api domain from defaultprovider config object""" + api = self.defaultprovider.config.get('api_uri', None) + # the caller is responsible for having loaded the config + # object at this point + if api: + api_dom = get_https_domain_and_port(api) + self.apidomain = "%s:%s" % api_dom + + def get_api_domain(self): + """gets api domain""" + return self.apidomain |