diff options
Diffstat (limited to 'src/leap/eip')
-rw-r--r-- | src/leap/eip/checks.py | 58 | ||||
-rw-r--r-- | src/leap/eip/tests/test_checks.py | 24 |
2 files changed, 71 insertions, 11 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index b0fd6323..51a7e219 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -1,7 +1,7 @@ #import json import logging import ssl -#import os +import os logging.basicConfig() logger = logging.getLogger(name=__name__) @@ -64,7 +64,8 @@ class ProviderCertChecker(object): # For MVS checker.is_there_provider_ca() checker.is_https_working() - checker.download_new_client_cert() + checker.check_new_cert_needed() + #checker.download_new_client_cert() def download_ca_cert(self): # MVS+ @@ -103,7 +104,16 @@ class ProviderCertChecker(object): self.fetcher.get(uri, verify=verify) return True - def download_new_client_cert(self, uri=None, verify=True): + def check_new_cert_needed(self, skip_download=False): + if not self.is_cert_valid(do_raise=False): + self.download_new_client_cert(skip_download=skip_download) + return True + return False + + def download_new_client_cert(self, uri=None, verify=True, + skip_download=False): + if skip_download: + return True if uri is None: uri = self._get_client_cert_uri() # XXX raise InsecureURI or something better @@ -112,12 +122,39 @@ class ProviderCertChecker(object): verify = self.cacert req = self.fetcher.get(uri, verify=verify) pemfile_content = req.content - self.validate_pemfile(pemfile_content) + self.is_valid_pemfile(pemfile_content) cert_path = self._get_client_cert_path() self.write_cert(pemfile_content, to=cert_path) return True - def validate_pemfile(self, cert_s): + def is_cert_valid(self, cert_path=None, do_raise=True): + exists = lambda: self.is_certificate_exists() + valid_pemfile = lambda: self.is_valid_pemfile() + not_expired = lambda: self.is_cert_not_expired() + print 'exists?', exists + print 'valid', valid_pemfile + print 'not expired', not_expired + + valid = exists() and valid_pemfile() and not_expired() + if not valid: + if do_raise: + raise Exception('missing cert') + else: + return False + return True + + def is_certificate_exists(self, certfile=None): + if certfile is None: + certfile = self._get_client_cert_path() + return os.path.isfile(certfile) + + def is_cert_not_expired(self): + return True + # XXX TODO + # waiting on #507. If we're not using PyOpenSSL or anything alike + # we will have to roll our own x509 parsing to extract time info. + + def is_valid_pemfile(self, cert_s=None): """ checks that the passed string is a valid pem certificate @@ -125,6 +162,10 @@ class ProviderCertChecker(object): @type cert_s: string @rtype: bool """ + if cert_s is None: + certfile = self._get_client_cert_path() + with open(certfile) as cf: + cert_s = cf.read() try: # XXX get a real cert validation # so far this is only checking begin/end @@ -136,14 +177,15 @@ class ProviderCertChecker(object): return True def _get_client_cert_uri(self): - # XXX TODO - # get from provider definition? - pass + return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER) def _get_client_cert_path(self): # MVS+ : get provider path return eipspecs.client_cert_path() + def is_cert_still_valid(self): + raise NotImplementedError + def write_cert(self, pemfile_content, to=None): with open(to, 'w') as cert_f: cert_f.write(pemfile_content) diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py index 541b884b..09fdaabf 100644 --- a/src/leap/eip/tests/test_checks.py +++ b/src/leap/eip/tests/test_checks.py @@ -199,7 +199,7 @@ class ProviderCertCheckerTest(BaseLeapTest): self.assertTrue(hasattr(checker, "is_there_provider_ca"), "missing meth") self.assertTrue(hasattr(checker, "is_https_working"), "missing meth") - self.assertTrue(hasattr(checker, "download_new_client_cert"), + self.assertTrue(hasattr(checker, "check_new_cert_needed"), "missing meth") def test_checker_should_actually_call_all_tests(self): @@ -217,7 +217,7 @@ class ProviderCertCheckerTest(BaseLeapTest): self.assertTrue(mc.is_there_provider_ca.called, "not called") self.assertTrue(mc.is_https_working.called, "not called") - self.assertTrue(mc.download_new_client_cert.called, + self.assertTrue(mc.check_new_cert_needed.called, "not called") # test individual check methods @@ -233,6 +233,7 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase): responses = { '/': ['OK', ''], '/client.cert': [ + # XXX get sample cert '-----BEGIN CERTIFICATE-----', '-----END CERTIFICATE-----'], '/badclient.cert': [ @@ -301,13 +302,30 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase): uri=uri, verify=cacert)) # did we write cert to its path? - self.assertTrue(os.path.isfile(eipspecs.client_cert_path())) + clientcertfile = eipspecs.client_cert_path() + self.assertTrue(os.path.isfile(clientcertfile)) certfile = eipspecs.client_cert_path() with open(certfile, 'r') as cf: certcontent = cf.read() self.assertEqual(certcontent, '\n'.join( self.request_handler.responses['/client.cert'])) + os.remove(clientcertfile) + + def test_is_cert_valid(self): + checker = eipchecks.ProviderCertChecker() + # TODO: better exception catching + with self.assertRaises(Exception) as exc: + self.assertFalse(checker.is_cert_valid()) + exc.message = "missing cert" + + def test_check_new_cert_needed(self): + # check: missing cert + checker = eipchecks.ProviderCertChecker() + self.assertTrue(checker.check_new_cert_needed(skip_download=True)) + # TODO check: malformed cert + # TODO check: expired cert + # TODO check: pass test server uri instead of skip if __name__ == "__main__": |