diff options
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r-- | src/leap/eip/checks.py | 58 |
1 files changed, 50 insertions, 8 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index b0fd6323..51a7e219 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -1,7 +1,7 @@ #import json import logging import ssl -#import os +import os logging.basicConfig() logger = logging.getLogger(name=__name__) @@ -64,7 +64,8 @@ class ProviderCertChecker(object): # For MVS checker.is_there_provider_ca() checker.is_https_working() - checker.download_new_client_cert() + checker.check_new_cert_needed() + #checker.download_new_client_cert() def download_ca_cert(self): # MVS+ @@ -103,7 +104,16 @@ class ProviderCertChecker(object): self.fetcher.get(uri, verify=verify) return True - def download_new_client_cert(self, uri=None, verify=True): + def check_new_cert_needed(self, skip_download=False): + if not self.is_cert_valid(do_raise=False): + self.download_new_client_cert(skip_download=skip_download) + return True + return False + + def download_new_client_cert(self, uri=None, verify=True, + skip_download=False): + if skip_download: + return True if uri is None: uri = self._get_client_cert_uri() # XXX raise InsecureURI or something better @@ -112,12 +122,39 @@ class ProviderCertChecker(object): verify = self.cacert req = self.fetcher.get(uri, verify=verify) pemfile_content = req.content - self.validate_pemfile(pemfile_content) + self.is_valid_pemfile(pemfile_content) cert_path = self._get_client_cert_path() self.write_cert(pemfile_content, to=cert_path) return True - def validate_pemfile(self, cert_s): + def is_cert_valid(self, cert_path=None, do_raise=True): + exists = lambda: self.is_certificate_exists() + valid_pemfile = lambda: self.is_valid_pemfile() + not_expired = lambda: self.is_cert_not_expired() + print 'exists?', exists + print 'valid', valid_pemfile + print 'not expired', not_expired + + valid = exists() and valid_pemfile() and not_expired() + if not valid: + if do_raise: + raise Exception('missing cert') + else: + return False + return True + + def is_certificate_exists(self, certfile=None): + if certfile is None: + certfile = self._get_client_cert_path() + return os.path.isfile(certfile) + + def is_cert_not_expired(self): + return True + # XXX TODO + # waiting on #507. If we're not using PyOpenSSL or anything alike + # we will have to roll our own x509 parsing to extract time info. + + def is_valid_pemfile(self, cert_s=None): """ checks that the passed string is a valid pem certificate @@ -125,6 +162,10 @@ class ProviderCertChecker(object): @type cert_s: string @rtype: bool """ + if cert_s is None: + certfile = self._get_client_cert_path() + with open(certfile) as cf: + cert_s = cf.read() try: # XXX get a real cert validation # so far this is only checking begin/end @@ -136,14 +177,15 @@ class ProviderCertChecker(object): return True def _get_client_cert_uri(self): - # XXX TODO - # get from provider definition? - pass + return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER) def _get_client_cert_path(self): # MVS+ : get provider path return eipspecs.client_cert_path() + def is_cert_still_valid(self): + raise NotImplementedError + def write_cert(self, pemfile_content, to=None): with open(to, 'w') as cert_f: cert_f.write(pemfile_content) |