diff options
Diffstat (limited to 'src/leap/eip/checks.py')
| -rw-r--r-- | src/leap/eip/checks.py | 52 | 
1 files changed, 35 insertions, 17 deletions
| diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index b55f5827..cf758314 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -147,9 +147,10 @@ class ProviderCertChecker(object):          # For MVS          checker.is_there_provider_ca() -        checker.is_https_working() -        checker.check_new_cert_needed() -        #checker.download_new_client_cert() + +        # XXX FAKE IT!!! +        checker.is_https_working(verify=False) +        checker.check_new_cert_needed(verify=False)      def download_ca_cert(self):          # MVS+ @@ -184,7 +185,6 @@ class ProviderCertChecker(object):          # XXX raise InsecureURI or something better          logger.debug('is https working?')          logger.debug('uri: %s', uri) -        #import ipdb;ipdb.set_trace()          assert uri.startswith('https')          if verify is True and self.cacert is not None:              logger.debug('verify cert: %s', self.cacert) @@ -192,19 +192,26 @@ class ProviderCertChecker(object):          try:              self.fetcher.get(uri, verify=verify)          except requests.exceptions.SSLError: +            logger.debug('False!')              raise eipexceptions.EIPBadCertError          else:              logger.debug('True')              return True -    def check_new_cert_needed(self, skip_download=False): +    def check_new_cert_needed(self, skip_download=False, verify=True): +        logger.debug('is new cert needed?')          if not self.is_cert_valid(do_raise=False): -            self.download_new_client_cert(skip_download=skip_download) +            logger.debug('True') +            self.download_new_client_cert( +                skip_download=skip_download, +                verify=verify)              return True +        logger.debug('False')          return False      def download_new_client_cert(self, uri=None, verify=True,                                   skip_download=False): +        logger.debug('download new client cert')          if skip_download:              return True          if uri is None: @@ -213,20 +220,28 @@ class ProviderCertChecker(object):          assert uri.startswith('https')          if verify is True and self.cacert is not None:              verify = self.cacert -        req = self.fetcher.get(uri, verify=verify) -        pemfile_content = req.content -        self.is_valid_pemfile(pemfile_content) -        cert_path = self._get_client_cert_path() -        self.write_cert(pemfile_content, to=cert_path) +        try: +            req = self.fetcher.get(uri, verify=verify) +            req.raise_for_status() +        except requests.exceptions.SSLError: +            logger.warning('SSLError while fetching cert. ' +                           'Look below for stack trace.') +            # XXX raise better exception +            raise +        try: +            pemfile_content = req.content +            self.is_valid_pemfile(pemfile_content) +            cert_path = self._get_client_cert_path() +            self.write_cert(pemfile_content, to=cert_path) +        except: +            logger.warning('Error while validating cert') +            raise          return True      def is_cert_valid(self, cert_path=None, do_raise=True):          exists = lambda: self.is_certificate_exists()          valid_pemfile = lambda: self.is_valid_pemfile()          not_expired = lambda: self.is_cert_not_expired() -        #print 'exists?', exists -        #print 'valid', valid_pemfile -        #print 'not expired', not_expired          valid = exists() and valid_pemfile() and not_expired()          if not valid: @@ -268,6 +283,11 @@ class ProviderCertChecker(object):              # XXX use gnutls for get proper              # validation.              # crypto.X509Certificate(cert_s) +            sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5 +            # we might have private key and cert in the same file +            certparts = cert_s.split(sep) +            if len(certparts) > 1: +                cert_s = sep + certparts[1]              ssl.PEM_cert_to_DER_cert(cert_s)          except:              # XXX raise proper exception @@ -279,11 +299,10 @@ class ProviderCertChecker(object):      def _get_client_cert_uri(self):          # XXX get the whole thing from constants -        return "https://%s/cert/get" % (baseconstants.DEFAULT_PROVIDER) +        return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)      def _get_client_cert_path(self):          # MVS+ : get provider path -        #import ipdb;ipdb.set_trace()          return eipspecs.client_cert_path()      def write_cert(self, pemfile_content, to=None): @@ -397,7 +416,6 @@ class EIPConfigChecker(object):              from_uri=uri,              fetcher=self.fetcher,              verify=False) -        #import ipdb;ipdb.set_trace()          self.defaultprovider.save()      def fetch_eip_service_config(self, skip_download=False, | 
