summaryrefslogtreecommitdiff
path: root/src/leap/eip
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/eip')
-rw-r--r--src/leap/eip/checks.py58
-rw-r--r--src/leap/eip/tests/test_checks.py24
2 files changed, 71 insertions, 11 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index b0fd6323..51a7e219 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -1,7 +1,7 @@
#import json
import logging
import ssl
-#import os
+import os
logging.basicConfig()
logger = logging.getLogger(name=__name__)
@@ -64,7 +64,8 @@ class ProviderCertChecker(object):
# For MVS
checker.is_there_provider_ca()
checker.is_https_working()
- checker.download_new_client_cert()
+ checker.check_new_cert_needed()
+ #checker.download_new_client_cert()
def download_ca_cert(self):
# MVS+
@@ -103,7 +104,16 @@ class ProviderCertChecker(object):
self.fetcher.get(uri, verify=verify)
return True
- def download_new_client_cert(self, uri=None, verify=True):
+ def check_new_cert_needed(self, skip_download=False):
+ if not self.is_cert_valid(do_raise=False):
+ self.download_new_client_cert(skip_download=skip_download)
+ return True
+ return False
+
+ def download_new_client_cert(self, uri=None, verify=True,
+ skip_download=False):
+ if skip_download:
+ return True
if uri is None:
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
@@ -112,12 +122,39 @@ class ProviderCertChecker(object):
verify = self.cacert
req = self.fetcher.get(uri, verify=verify)
pemfile_content = req.content
- self.validate_pemfile(pemfile_content)
+ self.is_valid_pemfile(pemfile_content)
cert_path = self._get_client_cert_path()
self.write_cert(pemfile_content, to=cert_path)
return True
- def validate_pemfile(self, cert_s):
+ def is_cert_valid(self, cert_path=None, do_raise=True):
+ exists = lambda: self.is_certificate_exists()
+ valid_pemfile = lambda: self.is_valid_pemfile()
+ not_expired = lambda: self.is_cert_not_expired()
+ print 'exists?', exists
+ print 'valid', valid_pemfile
+ print 'not expired', not_expired
+
+ valid = exists() and valid_pemfile() and not_expired()
+ if not valid:
+ if do_raise:
+ raise Exception('missing cert')
+ else:
+ return False
+ return True
+
+ def is_certificate_exists(self, certfile=None):
+ if certfile is None:
+ certfile = self._get_client_cert_path()
+ return os.path.isfile(certfile)
+
+ def is_cert_not_expired(self):
+ return True
+ # XXX TODO
+ # waiting on #507. If we're not using PyOpenSSL or anything alike
+ # we will have to roll our own x509 parsing to extract time info.
+
+ def is_valid_pemfile(self, cert_s=None):
"""
checks that the passed string
is a valid pem certificate
@@ -125,6 +162,10 @@ class ProviderCertChecker(object):
@type cert_s: string
@rtype: bool
"""
+ if cert_s is None:
+ certfile = self._get_client_cert_path()
+ with open(certfile) as cf:
+ cert_s = cf.read()
try:
# XXX get a real cert validation
# so far this is only checking begin/end
@@ -136,14 +177,15 @@ class ProviderCertChecker(object):
return True
def _get_client_cert_uri(self):
- # XXX TODO
- # get from provider definition?
- pass
+ return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER)
def _get_client_cert_path(self):
# MVS+ : get provider path
return eipspecs.client_cert_path()
+ def is_cert_still_valid(self):
+ raise NotImplementedError
+
def write_cert(self, pemfile_content, to=None):
with open(to, 'w') as cert_f:
cert_f.write(pemfile_content)
diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py
index 541b884b..09fdaabf 100644
--- a/src/leap/eip/tests/test_checks.py
+++ b/src/leap/eip/tests/test_checks.py
@@ -199,7 +199,7 @@ class ProviderCertCheckerTest(BaseLeapTest):
self.assertTrue(hasattr(checker, "is_there_provider_ca"),
"missing meth")
self.assertTrue(hasattr(checker, "is_https_working"), "missing meth")
- self.assertTrue(hasattr(checker, "download_new_client_cert"),
+ self.assertTrue(hasattr(checker, "check_new_cert_needed"),
"missing meth")
def test_checker_should_actually_call_all_tests(self):
@@ -217,7 +217,7 @@ class ProviderCertCheckerTest(BaseLeapTest):
self.assertTrue(mc.is_there_provider_ca.called, "not called")
self.assertTrue(mc.is_https_working.called,
"not called")
- self.assertTrue(mc.download_new_client_cert.called,
+ self.assertTrue(mc.check_new_cert_needed.called,
"not called")
# test individual check methods
@@ -233,6 +233,7 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase):
responses = {
'/': ['OK', ''],
'/client.cert': [
+ # XXX get sample cert
'-----BEGIN CERTIFICATE-----',
'-----END CERTIFICATE-----'],
'/badclient.cert': [
@@ -301,13 +302,30 @@ class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase):
uri=uri, verify=cacert))
# did we write cert to its path?
- self.assertTrue(os.path.isfile(eipspecs.client_cert_path()))
+ clientcertfile = eipspecs.client_cert_path()
+ self.assertTrue(os.path.isfile(clientcertfile))
certfile = eipspecs.client_cert_path()
with open(certfile, 'r') as cf:
certcontent = cf.read()
self.assertEqual(certcontent,
'\n'.join(
self.request_handler.responses['/client.cert']))
+ os.remove(clientcertfile)
+
+ def test_is_cert_valid(self):
+ checker = eipchecks.ProviderCertChecker()
+ # TODO: better exception catching
+ with self.assertRaises(Exception) as exc:
+ self.assertFalse(checker.is_cert_valid())
+ exc.message = "missing cert"
+
+ def test_check_new_cert_needed(self):
+ # check: missing cert
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(checker.check_new_cert_needed(skip_download=True))
+ # TODO check: malformed cert
+ # TODO check: expired cert
+ # TODO check: pass test server uri instead of skip
if __name__ == "__main__":