summaryrefslogtreecommitdiff
path: root/src/leap/eip/checks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r--src/leap/eip/checks.py58
1 files changed, 50 insertions, 8 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index b0fd6323..51a7e219 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -1,7 +1,7 @@
#import json
import logging
import ssl
-#import os
+import os
logging.basicConfig()
logger = logging.getLogger(name=__name__)
@@ -64,7 +64,8 @@ class ProviderCertChecker(object):
# For MVS
checker.is_there_provider_ca()
checker.is_https_working()
- checker.download_new_client_cert()
+ checker.check_new_cert_needed()
+ #checker.download_new_client_cert()
def download_ca_cert(self):
# MVS+
@@ -103,7 +104,16 @@ class ProviderCertChecker(object):
self.fetcher.get(uri, verify=verify)
return True
- def download_new_client_cert(self, uri=None, verify=True):
+ def check_new_cert_needed(self, skip_download=False):
+ if not self.is_cert_valid(do_raise=False):
+ self.download_new_client_cert(skip_download=skip_download)
+ return True
+ return False
+
+ def download_new_client_cert(self, uri=None, verify=True,
+ skip_download=False):
+ if skip_download:
+ return True
if uri is None:
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
@@ -112,12 +122,39 @@ class ProviderCertChecker(object):
verify = self.cacert
req = self.fetcher.get(uri, verify=verify)
pemfile_content = req.content
- self.validate_pemfile(pemfile_content)
+ self.is_valid_pemfile(pemfile_content)
cert_path = self._get_client_cert_path()
self.write_cert(pemfile_content, to=cert_path)
return True
- def validate_pemfile(self, cert_s):
+ def is_cert_valid(self, cert_path=None, do_raise=True):
+ exists = lambda: self.is_certificate_exists()
+ valid_pemfile = lambda: self.is_valid_pemfile()
+ not_expired = lambda: self.is_cert_not_expired()
+ print 'exists?', exists
+ print 'valid', valid_pemfile
+ print 'not expired', not_expired
+
+ valid = exists() and valid_pemfile() and not_expired()
+ if not valid:
+ if do_raise:
+ raise Exception('missing cert')
+ else:
+ return False
+ return True
+
+ def is_certificate_exists(self, certfile=None):
+ if certfile is None:
+ certfile = self._get_client_cert_path()
+ return os.path.isfile(certfile)
+
+ def is_cert_not_expired(self):
+ return True
+ # XXX TODO
+ # waiting on #507. If we're not using PyOpenSSL or anything alike
+ # we will have to roll our own x509 parsing to extract time info.
+
+ def is_valid_pemfile(self, cert_s=None):
"""
checks that the passed string
is a valid pem certificate
@@ -125,6 +162,10 @@ class ProviderCertChecker(object):
@type cert_s: string
@rtype: bool
"""
+ if cert_s is None:
+ certfile = self._get_client_cert_path()
+ with open(certfile) as cf:
+ cert_s = cf.read()
try:
# XXX get a real cert validation
# so far this is only checking begin/end
@@ -136,14 +177,15 @@ class ProviderCertChecker(object):
return True
def _get_client_cert_uri(self):
- # XXX TODO
- # get from provider definition?
- pass
+ return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER)
def _get_client_cert_path(self):
# MVS+ : get provider path
return eipspecs.client_cert_path()
+ def is_cert_still_valid(self):
+ raise NotImplementedError
+
def write_cert(self, pemfile_content, to=None):
with open(to, 'w') as cert_f:
cert_f.write(pemfile_content)