diff options
author | kali <kali@leap.se> | 2012-11-14 00:38:20 +0900 |
---|---|---|
committer | kali <kali@leap.se> | 2012-11-14 00:38:20 +0900 |
commit | 21875404282522a9c83bfb9c85d6a24fa59d20f8 (patch) | |
tree | ae0409bd742ce3a6f994ae9bb31fc5ab7225f1c6 /src/leap/eip/checks.py | |
parent | f6e900f024074435349eb778a2d89baed55e1e6c (diff) | |
parent | d24c7328fa845737dbb83d512e4b3f287634c4cc (diff) |
Merge branch 'feature/generic-wizard' into develop
The generic wizard (big) branch is now stabilised.
A bunch of refactors have gone together with this topic branch:
- client does not have any info included for default service providers.
- user has to run the first-run wizard and manually entry domain for sample provider.
- remove all remains of the older branding strategy for default provider.
- srp registration + authentication are integrated with the signup process.
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r-- | src/leap/eip/checks.py | 198 |
1 files changed, 150 insertions, 48 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index f739c3e8..116c535e 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -4,15 +4,18 @@ import ssl import time import os -from gnutls import crypto +import gnutls.crypto #import netifaces #import ping import requests from leap import __branding as BRANDING -from leap import certs +from leap import certs as leapcerts +from leap.base.auth import srpauth_protected, magick_srpauth +from leap.base import config as baseconfig from leap.base import constants as baseconstants from leap.base import providers +from leap.crypto import certs from leap.eip import config as eipconfig from leap.eip import constants as eipconstants from leap.eip import exceptions as eipexceptions @@ -42,10 +45,11 @@ reachable and testable as a whole. """ -def get_ca_cert(): +def get_branding_ca_cert(domain): + # XXX deprecated ca_file = BRANDING.get('provider_ca_file') if ca_file: - return certs.where(ca_file) + return leapcerts.where(ca_file) class ProviderCertChecker(object): @@ -54,18 +58,25 @@ class ProviderCertChecker(object): client certs and checking tls connection with provider. """ - def __init__(self, fetcher=requests): + def __init__(self, fetcher=requests, + domain=None): + self.fetcher = fetcher - self.cacert = get_ca_cert() + self.domain = domain + self.cacert = eipspecs.provider_ca_path(domain) + + def run_all( + self, checker=None, + skip_download=False, skip_verify=False): - def run_all(self, checker=None, skip_download=False, skip_verify=False): if not checker: checker = self do_verify = not skip_verify logger.debug('do_verify: %s', do_verify) - # For MVS+ # checker.download_ca_cert() + + # For MVS+ # checker.download_ca_signature() # checker.get_ca_signatures() # checker.is_there_trust_path() @@ -74,12 +85,44 @@ class ProviderCertChecker(object): checker.is_there_provider_ca() # XXX FAKE IT!!! - checker.is_https_working(verify=do_verify) + checker.is_https_working(verify=do_verify, autocacert=True) checker.check_new_cert_needed(verify=do_verify) - def download_ca_cert(self): - # MVS+ - raise NotImplementedError + def download_ca_cert(self, uri=None, verify=True): + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + + # should check domain exists + capath = self._get_ca_cert_path(self.domain) + with open(capath, 'w') as f: + f.write(req.content) + + def check_ca_cert_fingerprint( + self, hash_type="SHA256", + fingerprint=None): + """ + compares the fingerprint in + the ca cert with a string + we are passed + returns True if they are equal, False if not. + @param hash_type: digest function + @type hash_type: str + @param fingerprint: the fingerprint to compare with. + @type fingerprint: str (with : separator) + @rtype bool + """ + ca_cert_path = self.ca_cert_path + ca_cert_fpr = certs.get_cert_fingerprint( + filepath=ca_cert_path) + return ca_cert_fpr == fingerprint + + def verify_api_https(self, uri): + assert uri.startswith('https://') + cacert = self.ca_cert_path + verify = cacert and cacert or True + req = self.fetcher.get(uri, verify=verify) + req.raise_for_status() + return True def download_ca_signature(self): # MVS+ @@ -94,36 +137,47 @@ class ProviderCertChecker(object): raise NotImplementedError def is_there_provider_ca(self): - from leap import certs - logger.debug('do we have provider_ca?') - cacert_path = BRANDING.get('provider_ca_file', None) - if not cacert_path: - logger.debug('False') + if not self.cacert: return False - self.cacert = certs.where(cacert_path) - logger.debug('True') - return True + cacert_exists = os.path.isfile(self.cacert) + if cacert_exists: + logger.debug('True') + return True + logger.debug('False!') + return False - def is_https_working(self, uri=None, verify=True): + def is_https_working( + self, uri=None, verify=True, + autocacert=False): if uri is None: uri = self._get_root_uri() # XXX raise InsecureURI or something better - assert uri.startswith('https') - if verify is True and self.cacert is not None: + try: + assert uri.startswith('https') + except AssertionError: + raise AssertionError( + "uri passed should start with https") + if autocacert and verify is True and self.cacert is not None: logger.debug('verify cert: %s', self.cacert) verify = self.cacert + #import pdb4qt; pdb4qt.set_trace() logger.debug('is https working?') logger.debug('uri: %s (verify:%s)', uri, verify) try: self.fetcher.get(uri, verify=verify) + except requests.exceptions.SSLError as exc: - logger.warning('False! CERT VERIFICATION FAILED! ' + logger.error("SSLError") + # XXX RAISE! See #638 + #raise eipexceptions.HttpsBadCertError + logger.warning('BUG #638 CERT VERIFICATION FAILED! ' '(this should be CRITICAL)') logger.warning('SSLError: %s', exc.message) - # XXX RAISE! See #638 - #raise eipexceptions.EIPBadCertError - # XXX get requests.exceptions.ConnectionError Errno 110 - # Connection timed out, and raise ours. + + except requests.exceptions.ConnectionError: + logger.error('ConnectionError') + raise eipexceptions.HttpsNotSupported + else: logger.debug('True') return True @@ -140,7 +194,8 @@ class ProviderCertChecker(object): return False def download_new_client_cert(self, uri=None, verify=True, - skip_download=False): + skip_download=False, + credentials=None): logger.debug('download new client cert') if skip_download: return True @@ -148,18 +203,38 @@ class ProviderCertChecker(object): uri = self._get_client_cert_uri() # XXX raise InsecureURI or something better assert uri.startswith('https') + if verify is True and self.cacert is not None: verify = self.cacert + + fgetfn = self.fetcher.get + + if credentials: + user, passwd = credentials + + logger.debug('domain = %s', self.domain) + + @srpauth_protected(user, passwd, + server="https://%s" % self.domain, + verify=verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) + + else: + # XXX FIXME fix decorated args + @magick_srpauth(verify) + def getfn(*args, **kwargs): + return fgetfn(*args, **kwargs) try: + # XXX FIXME!!!! # verify=verify # Workaround for #638. return to verification # when That's done!!! - - # XXX HOOK SRP here... - # will have to be more generic in the future. - req = self.fetcher.get(uri, verify=False) + #req = self.fetcher.get(uri, verify=False) + req = getfn(uri, verify=False) req.raise_for_status() + except requests.exceptions.SSLError: logger.warning('SSLError while fetching cert. ' 'Look below for stack trace.') @@ -198,7 +273,7 @@ class ProviderCertChecker(object): certfile = self._get_client_cert_path() with open(certfile) as cf: cert_s = cf.read() - cert = crypto.X509Certificate(cert_s) + cert = gnutls.crypto.X509Certificate(cert_s) from_ = time.gmtime(cert.activation_time) to_ = time.gmtime(cert.expiration_time) return from_ < now() < to_ @@ -233,16 +308,34 @@ class ProviderCertChecker(object): raise return True + @property + def ca_cert_path(self): + return self._get_ca_cert_path(self.domain) + def _get_root_uri(self): - return u"https://%s/" % baseconstants.DEFAULT_PROVIDER + return u"https://%s/" % self.domain def _get_client_cert_uri(self): # XXX get the whole thing from constants - return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER) + return "https://%s/1/cert" % self.domain def _get_client_cert_path(self): - # MVS+ : get provider path - return eipspecs.client_cert_path() + return eipspecs.client_cert_path(domain=self.domain) + + def _get_ca_cert_path(self, domain): + # XXX this folder path will be broken for win + # and this should be moved to eipspecs.ca_path + + # XXX use baseconfig.get_provider_path(folder=Foo) + # !!! + + capath = baseconfig.get_config_file( + 'cacert.pem', + folder='providers/%s/keys/ca' % domain) + folder, fname = os.path.split(capath) + if not os.path.isdir(folder): + mkdir_p(folder) + return capath def write_cert(self, pemfile_content, to=None): folder, filename = os.path.split(to) @@ -260,16 +353,20 @@ class EIPConfigChecker(object): use run_all to run all checks. """ - def __init__(self, fetcher=requests): + def __init__(self, fetcher=requests, domain=None): # we do not want to accept too many # argument on init. # we want tests # to be explicitely run. + self.fetcher = fetcher - self.eipconfig = eipconfig.EIPConfig() - self.defaultprovider = providers.LeapProviderDefinition() - self.eipserviceconfig = eipconfig.EIPServiceConfig() + # if not domain, get from config + self.domain = domain + + self.eipconfig = eipconfig.EIPConfig(domain=domain) + self.defaultprovider = providers.LeapProviderDefinition(domain=domain) + self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain) def run_all(self, checker=None, skip_download=False): """ @@ -330,7 +427,8 @@ class EIPConfigChecker(object): return True def fetch_definition(self, skip_download=False, - config=None, uri=None): + config=None, uri=None, + domain=None): """ fetches a definition file from server """ @@ -347,10 +445,13 @@ class EIPConfigChecker(object): if config is None: config = self.defaultprovider.config if uri is None: - domain = config.get('provider', None) + if not domain: + domain = config.get('provider', None) uri = self._get_provider_definition_uri(domain=domain) # FIXME! Pass ca path verify!!! + # BUG #638 + # FIXME FIXME FIXME self.defaultprovider.load( from_uri=uri, fetcher=self.fetcher, @@ -358,13 +459,14 @@ class EIPConfigChecker(object): self.defaultprovider.save() def fetch_eip_service_config(self, skip_download=False, - config=None, uri=None): + config=None, uri=None, domain=None): if skip_download: return True if config is None: config = self.eipserviceconfig.config if uri is None: - domain = config.get('provider', None) + if not domain: + domain = self.domain or config.get('provider', None) uri = self._get_eip_service_uri(domain=domain) self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher) @@ -399,7 +501,7 @@ class EIPConfigChecker(object): def _get_provider_definition_uri(self, domain=None, path=None): if domain is None: - domain = baseconstants.DEFAULT_PROVIDER + domain = self.domain or baseconstants.DEFAULT_PROVIDER if path is None: path = baseconstants.DEFINITION_EXPECTED_PATH uri = u"https://%s/%s" % (domain, path) @@ -408,7 +510,7 @@ class EIPConfigChecker(object): def _get_eip_service_uri(self, domain=None, path=None): if domain is None: - domain = baseconstants.DEFAULT_PROVIDER + domain = self.domain or baseconstants.DEFAULT_PROVIDER if path is None: path = eipconstants.EIP_SERVICE_EXPECTED_PATH uri = "https://%s/%s" % (domain, path) |