diff options
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r-- | src/leap/eip/checks.py | 542 |
1 files changed, 542 insertions, 0 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py new file mode 100644 index 00000000..af824c57 --- /dev/null +++ b/src/leap/eip/checks.py @@ -0,0 +1,542 @@ +import logging +import time +import os +import sys + +import requests + +from leap import __branding as BRANDING +from leap import certs as leapcerts +from leap.base.auth import srpauth_protected, magick_srpauth +from leap.base import config as baseconfig +from leap.base import constants as baseconstants +from leap.base import providers +from leap.crypto import certs +from leap.eip import config as eipconfig +from leap.eip import constants as eipconstants +from leap.eip import exceptions as eipexceptions +from leap.eip import specs as eipspecs +from leap.util.certs import get_mac_cabundle +from leap.util.fileutil import mkdir_p +from leap.util.web import get_https_domain_and_port + +logger = logging.getLogger(name=__name__) + +""" +ProviderCertChecker +------------------- +Checks on certificates. To be moved to base. +docs TBD + +EIPConfigChecker +---------- +It is used from the eip conductor (a instance of EIPConnection that is +managed from the QtApp), running `run_all` method before trying to call +`connect` or any other of the state-changing methods. + +It checks that the needed files are provided or can be discovered over the +net. Much of these tests are not specific to EIP module, and can be splitted +into base.tests to be invoked by the base leap init routines. +However, I'm testing them alltogether for the sake of having the whole unit +reachable and testable as a whole. + +""" + + +def get_branding_ca_cert(domain): + # deprecated + ca_file = BRANDING.get('provider_ca_file') + if ca_file: + return leapcerts.where(ca_file) + + +class ProviderCertChecker(object): + """ + Several checks needed for getting + client certs and checking tls connection + with provider. + """ + def __init__(self, fetcher=requests, + domain=None): + + self.fetcher = fetcher + self.domain = domain + #XXX needs some kind of autoinit + #right now we set by hand + #by loading and reading provider config + self.apidomain = None + self.cacert = eipspecs.provider_ca_path(domain) + + def run_all( + self, checker=None, + skip_download=False, skip_verify=False): + + if not checker: + checker = self + + do_verify = not skip_verify + logger.debug('do_verify: %s', do_verify) + # checker.download_ca_cert() + + # For MVS+ + # checker.download_ca_signature() + # checker.get_ca_signatures() + # checker.is_there_trust_path() + + # For MVS + checker.is_there_provider_ca() + + checker.is_https_working(verify=do_verify, autocacert=False) + checker.check_new_cert_needed(verify=do_verify) + + def download_ca_cert(self, uri=None, verify=True): + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + + # should check domain exists + capath = self._get_ca_cert_path(self.domain) + with open(capath, 'w') as f: + f.write(req.content) + + def check_ca_cert_fingerprint( + self, hash_type="SHA256", + fingerprint=None): + """ + compares the fingerprint in + the ca cert with a string + we are passed + returns True if they are equal, False if not. + @param hash_type: digest function + @type hash_type: str + @param fingerprint: the fingerprint to compare with. + @type fingerprint: str (with : separator) + @rtype bool + """ + ca_cert_path = self.ca_cert_path + ca_cert_fpr = certs.get_cert_fingerprint( + filepath=ca_cert_path) + return ca_cert_fpr == fingerprint + + def verify_api_https(self, uri): + assert uri.startswith('https://') + cacert = self.ca_cert_path + verify = cacert or True + + # DEBUG + logger.debug('uri -> %s' % uri) + logger.debug('cacertpath -> %s' % cacert) + + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + return True + + def download_ca_signature(self): + # MVS+ + raise NotImplementedError + + def get_ca_signatures(self): + # MVS+ + raise NotImplementedError + + def is_there_trust_path(self): + # MVS+ + raise NotImplementedError + + def is_there_provider_ca(self): + if not self.cacert: + return False + cacert_exists = os.path.isfile(self.cacert) + if cacert_exists: + logger.debug('True') + return True + logger.debug('False!') + return False + + def is_https_working( + self, uri=None, verify=True, + autocacert=False): + if uri is None: + uri = self._get_root_uri() + # XXX raise InsecureURI or something better + try: + assert uri.startswith('https') + except AssertionError: + raise AssertionError( + "uri passed should start with https") + if autocacert and verify is True and self.cacert is not None: + logger.debug('verify cert: %s', self.cacert) + verify = self.cacert + if sys.platform == "darwin": + verify = get_mac_cabundle() + logger.debug('checking https connection') + logger.debug('uri: %s (verify:%s)', uri, verify) + + try: + self.fetcher.get(uri, verify=verify) + + except requests.exceptions.SSLError as exc: + raise eipexceptions.HttpsBadCertError + + except requests.exceptions.ConnectionError: + logger.error('ConnectionError') + raise eipexceptions.HttpsNotSupported + + else: + return True + + def check_new_cert_needed(self, skip_download=False, verify=True): + # XXX add autocacert + if not self.is_cert_valid(do_raise=False): + logger.debug('cert needed: true') + self.download_new_client_cert( + skip_download=skip_download, + verify=verify) + return True + logger.debug('cert needed: false') + return False + + def download_new_client_cert(self, uri=None, verify=True, + skip_download=False, + credentials=None): + logger.debug('download new client cert') + if skip_download: + return True + if uri is None: + uri = self._get_client_cert_uri() + # XXX raise InsecureURI or something better + #assert uri.startswith('https') + + if verify is True and self.cacert is not None: + verify = self.cacert + logger.debug('verify = %s', verify) + + fgetfn = self.fetcher.get + + if credentials: + user, passwd = credentials + logger.debug('apidomain = %s', self.apidomain) + + @srpauth_protected(user, passwd, + server="https://%s" % self.apidomain, + verify=verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) + + else: + # XXX FIXME fix decorated args + @magick_srpauth(verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) + try: + + req = getfn(uri, verify=verify) + req.raise_for_status() + + except requests.exceptions.SSLError: + logger.warning('SSLError while fetching cert. ' + 'Look below for stack trace.') + # XXX raise better exception + return self.fail("SSLError") + except Exception as exc: + return self.fail(exc.message) + + try: + logger.debug('validating cert...') + pemfile_content = req.content + valid = self.is_valid_pemfile(pemfile_content) + if not valid: + logger.warning('invalid cert') + return False + cert_path = self._get_client_cert_path() + self.write_cert(pemfile_content, to=cert_path) + except: + logger.warning('Error while validating cert') + raise + return True + + def is_cert_valid(self, cert_path=None, do_raise=True): + exists = lambda: self.is_certificate_exists() + valid_pemfile = lambda: self.is_valid_pemfile() + not_expired = lambda: self.is_cert_not_expired() + + valid = exists() and valid_pemfile() and not_expired() + if not valid: + if do_raise: + raise Exception('missing valid cert') + else: + return False + return True + + def is_certificate_exists(self, certfile=None): + if certfile is None: + certfile = self._get_client_cert_path() + return os.path.isfile(certfile) + + def is_cert_not_expired(self, certfile=None, now=time.gmtime): + if certfile is None: + certfile = self._get_client_cert_path() + from_, to_ = certs.get_time_boundaries(certfile) + + return from_ < now() < to_ + + def is_valid_pemfile(self, cert_s=None): + """ + checks that the passed string + is a valid pem certificate + @param cert_s: string containing pem content + @type cert_s: string + @rtype: bool + """ + if cert_s is None: + certfile = self._get_client_cert_path() + with open(certfile) as cf: + cert_s = cf.read() + try: + valid = certs.can_load_cert_and_pkey(cert_s) + except certs.BadCertError: + logger.warning("Not valid pemfile") + valid = False + return valid + + @property + def ca_cert_path(self): + return self._get_ca_cert_path(self.domain) + + def _get_root_uri(self): + return u"https://%s/" % self.domain + + def _get_client_cert_uri(self): + return "https://%s/1/cert" % self.apidomain + + def _get_client_cert_path(self): + return eipspecs.client_cert_path(domain=self.domain) + + def _get_ca_cert_path(self, domain): + # XXX this folder path will be broken for win + # and this should be moved to eipspecs.ca_path + + # XXX use baseconfig.get_provider_path(folder=Foo) + # !!! + + capath = baseconfig.get_config_file( + 'cacert.pem', + folder='providers/%s/keys/ca' % domain) + folder, fname = os.path.split(capath) + if not os.path.isdir(folder): + mkdir_p(folder) + return capath + + def write_cert(self, pemfile_content, to=None): + folder, filename = os.path.split(to) + if not os.path.isdir(folder): + mkdir_p(folder) + with open(to, 'w') as cert_f: + cert_f.write(pemfile_content) + + def set_api_domain(self, domain): + self.apidomain = domain + + +class EIPConfigChecker(object): + """ + Several checks needed + to ensure a EIPConnection + can be sucessfully established. + use run_all to run all checks. + """ + + def __init__(self, fetcher=requests, domain=None): + # we do not want to accept too many + # argument on init. + # we want tests + # to be explicitely run. + + self.fetcher = fetcher + + # if not domain, get from config + self.domain = domain + self.apidomain = None + self.cacert = eipspecs.provider_ca_path(domain) + + self.defaultprovider = providers.LeapProviderDefinition(domain=domain) + self.defaultprovider.load() + self.eipconfig = eipconfig.EIPConfig(domain=domain) + self.set_api_domain() + self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain) + self.eipserviceconfig.load() + + def run_all(self, checker=None, skip_download=False): + """ + runs all checks in a row. + will raise if some error encountered. + catching those exceptions is not + our responsibility at this moment + """ + if not checker: + checker = self + + # let's call all tests + # needed for a sane eip session. + + # TODO: get rid of check_default. + # check_complete should + # be enough. but here to make early tests easier. + checker.check_default_eipconfig() + + checker.check_is_there_default_provider() + checker.fetch_definition(skip_download=skip_download) + checker.fetch_eip_service_config(skip_download=skip_download) + checker.check_complete_eip_config() + #checker.ping_gateway() + + # public checks + + def check_default_eipconfig(self): + """ + checks if default eipconfig exists, + and dumps a default file if not + """ + # XXX ONLY a transient check + # because some old function still checks + # for eip config at the beginning. + + # it *really* does not make sense to + # dump it right now, we can get an in-memory + # config object and dump it to disk in a + # later moment + logger.debug('checking default eip config') + if not self._is_there_default_eipconfig(): + self._dump_default_eipconfig() + + def check_is_there_default_provider(self, config=None): + """ + raises EIPMissingDefaultProvider if no + default provider found on eip config. + This is catched by ui and runs FirstRunWizard (MVS+) + """ + if config is None: + config = self.eipconfig.config + logger.debug('checking default provider') + provider = config.get('provider', None) + if provider is None: + raise eipexceptions.EIPMissingDefaultProvider + # XXX raise also if malformed ProviderDefinition? + return True + + def fetch_definition(self, skip_download=False, + force_download=False, + config=None, uri=None, + domain=None): + """ + fetches a definition file from server + """ + # TODO: + # - Implement diff + # - overwrite only if different. + # (attend to serial field different, for instance) + + logger.debug('fetching definition') + + if skip_download: + logger.debug('(fetching def skipped)') + return True + if config is None: + config = self.defaultprovider.config + if uri is None: + if not domain: + domain = config.get('provider', None) + uri = self._get_provider_definition_uri(domain=domain) + + if sys.platform == "darwin": + verify = get_mac_cabundle() + else: + verify = True + + self.defaultprovider.load( + from_uri=uri, + fetcher=self.fetcher, + verify=verify) + self.defaultprovider.save() + + def fetch_eip_service_config(self, skip_download=False, + force_download=False, + config=None, uri=None, # domain=None, + autocacert=True, verify=True): + if skip_download: + return True + if config is None: + self.eipserviceconfig.load() + config = self.eipserviceconfig.config + if uri is None: + #XXX + #if not domain: + #domain = self.domain or config.get('provider', None) + uri = self._get_eip_service_uri( + domain=self.apidomain) + + if autocacert and self.cacert is not None: + verify = self.cacert + + self.eipserviceconfig.load( + from_uri=uri, + fetcher=self.fetcher, + force_download=force_download, + verify=verify) + self.eipserviceconfig.save() + + def check_complete_eip_config(self, config=None): + # TODO check for gateway + if config is None: + config = self.eipconfig.config + try: + assert 'provider' in config + assert config['provider'] is not None + # XXX assert there is gateway !! + except AssertionError: + raise eipexceptions.EIPConfigurationError + + # XXX TODO: + # We should WRITE eip config if missing or + # incomplete at this point + #self.eipconfig.save() + + # + # private helpers + # + + def _is_there_default_eipconfig(self): + return self.eipconfig.exists() + + def _dump_default_eipconfig(self): + self.eipconfig.save(force=True) + + def _get_provider_definition_uri(self, domain=None, path=None): + if domain is None: + domain = self.domain or baseconstants.DEFAULT_PROVIDER + if path is None: + path = baseconstants.DEFINITION_EXPECTED_PATH + uri = u"https://%s/%s" % (domain, path) + logger.debug('getting provider definition from %s' % uri) + return uri + + def _get_eip_service_uri(self, domain=None, path=None): + if domain is None: + domain = self.domain or baseconstants.DEFAULT_PROVIDER + if path is None: + path = eipconstants.EIP_SERVICE_EXPECTED_PATH + uri = "https://%s/%s" % (domain, path) + logger.debug('getting eip service file from %s', uri) + return uri + + def set_api_domain(self): + """sets api domain from defaultprovider config object""" + api = self.defaultprovider.config.get('api_uri', None) + # the caller is responsible for having loaded the config + # object at this point + if api: + api_dom = get_https_domain_and_port(api) + self.apidomain = "%s:%s" % api_dom + + def get_api_domain(self): + """gets api domain""" + return self.apidomain |