summaryrefslogtreecommitdiff
path: root/src/leap/eip/checks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r--src/leap/eip/checks.py125
1 files changed, 87 insertions, 38 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index f368c551..cf758314 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -9,6 +9,8 @@ import netifaces
import ping
import requests
+from leap import __branding as BRANDING
+from leap import certs
from leap.base import constants as baseconstants
from leap.base import providers
from leap.eip import config as eipconfig
@@ -20,6 +22,11 @@ from leap.util.fileutil import mkdir_p
logger = logging.getLogger(name=__name__)
"""
+ProviderCertChecker
+-------------------
+Checks on certificates. To be moved to base.
+docs TBD
+
EIPConfigChecker
----------
It is used from the eip conductor (a instance of EIPConnection that is
@@ -36,14 +43,15 @@ LeapNetworkChecker
------------------
Network checks. To be moved to base.
docs TBD
-
-ProviderCertChecker
--------------------
-Checks on certificates.
-docs TBD
"""
+def get_ca_cert():
+ ca_file = BRANDING.get('provider_ca_file')
+ if ca_file:
+ return certs.where(ca_file)
+
+
class LeapNetworkChecker(object):
"""
all network related checks
@@ -67,6 +75,7 @@ class LeapNetworkChecker(object):
# XXX we probably should raise an exception here?
# unless we use this as smoke test
try:
+ # XXX remove this hardcoded random ip
requests.get('http://216.172.161.165')
except (requests.HTTPError, requests.RequestException) as e:
self.error = e.message
@@ -124,7 +133,7 @@ class ProviderCertChecker(object):
"""
def __init__(self, fetcher=requests):
self.fetcher = fetcher
- self.cacert = None
+ self.cacert = get_ca_cert()
def run_all(self, checker=None, skip_download=False):
if not checker:
@@ -138,9 +147,10 @@ class ProviderCertChecker(object):
# For MVS
checker.is_there_provider_ca()
- checker.is_https_working()
- checker.check_new_cert_needed()
- #checker.download_new_client_cert()
+
+ # XXX FAKE IT!!!
+ checker.is_https_working(verify=False)
+ checker.check_new_cert_needed(verify=False)
def download_ca_cert(self):
# MVS+
@@ -159,34 +169,49 @@ class ProviderCertChecker(object):
raise NotImplementedError
def is_there_provider_ca(self):
- # XXX fake it till you make it! :P
+ from leap import certs
+ logger.debug('do we have provider_ca?')
+ cacert_path = BRANDING.get('provider_ca_file', None)
+ if not cacert_path:
+ logger.debug('False')
+ return False
+ self.cacert = certs.where(cacert_path)
+ logger.debug('True')
return True
- # enable this when we have
- # a custom "branded" bundle
- # certs package.
- try:
- from leap.custom import certs
- except ImportError:
- raise
- self.cacert = certs.where('cacert.pem')
-
def is_https_working(self, uri=None, verify=True):
+ if uri is None:
+ uri = self._get_root_uri()
# XXX raise InsecureURI or something better
+ logger.debug('is https working?')
+ logger.debug('uri: %s', uri)
assert uri.startswith('https')
if verify is True and self.cacert is not None:
+ logger.debug('verify cert: %s', self.cacert)
verify = self.cacert
- self.fetcher.get(uri, verify=verify)
- return True
+ try:
+ self.fetcher.get(uri, verify=verify)
+ except requests.exceptions.SSLError:
+ logger.debug('False!')
+ raise eipexceptions.EIPBadCertError
+ else:
+ logger.debug('True')
+ return True
- def check_new_cert_needed(self, skip_download=False):
+ def check_new_cert_needed(self, skip_download=False, verify=True):
+ logger.debug('is new cert needed?')
if not self.is_cert_valid(do_raise=False):
- self.download_new_client_cert(skip_download=skip_download)
+ logger.debug('True')
+ self.download_new_client_cert(
+ skip_download=skip_download,
+ verify=verify)
return True
+ logger.debug('False')
return False
def download_new_client_cert(self, uri=None, verify=True,
skip_download=False):
+ logger.debug('download new client cert')
if skip_download:
return True
if uri is None:
@@ -195,20 +220,28 @@ class ProviderCertChecker(object):
assert uri.startswith('https')
if verify is True and self.cacert is not None:
verify = self.cacert
- req = self.fetcher.get(uri, verify=verify)
- pemfile_content = req.content
- self.is_valid_pemfile(pemfile_content)
- cert_path = self._get_client_cert_path()
- self.write_cert(pemfile_content, to=cert_path)
+ try:
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+ except requests.exceptions.SSLError:
+ logger.warning('SSLError while fetching cert. '
+ 'Look below for stack trace.')
+ # XXX raise better exception
+ raise
+ try:
+ pemfile_content = req.content
+ self.is_valid_pemfile(pemfile_content)
+ cert_path = self._get_client_cert_path()
+ self.write_cert(pemfile_content, to=cert_path)
+ except:
+ logger.warning('Error while validating cert')
+ raise
return True
def is_cert_valid(self, cert_path=None, do_raise=True):
exists = lambda: self.is_certificate_exists()
valid_pemfile = lambda: self.is_valid_pemfile()
not_expired = lambda: self.is_cert_not_expired()
- #print 'exists?', exists
- #print 'valid', valid_pemfile
- #print 'not expired', not_expired
valid = exists() and valid_pemfile() and not_expired()
if not valid:
@@ -250,18 +283,26 @@ class ProviderCertChecker(object):
# XXX use gnutls for get proper
# validation.
# crypto.X509Certificate(cert_s)
+ sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5
+ # we might have private key and cert in the same file
+ certparts = cert_s.split(sep)
+ if len(certparts) > 1:
+ cert_s = sep + certparts[1]
ssl.PEM_cert_to_DER_cert(cert_s)
except:
# XXX raise proper exception
raise
return True
+ def _get_root_uri(self):
+ return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
+
def _get_client_cert_uri(self):
- return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER)
+ # XXX get the whole thing from constants
+ return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)
def _get_client_cert_path(self):
# MVS+ : get provider path
- #import ipdb;ipdb.set_trace()
return eipspecs.client_cert_path()
def write_cert(self, pemfile_content, to=None):
@@ -370,7 +411,11 @@ class EIPConfigChecker(object):
domain = config.get('provider', None)
uri = self._get_provider_definition_uri(domain=domain)
- self.defaultprovider.load(from_uri=uri, fetcher=self.fetcher)
+ # FIXME! Pass ca path verify!!!
+ self.defaultprovider.load(
+ from_uri=uri,
+ fetcher=self.fetcher,
+ verify=False)
self.defaultprovider.save()
def fetch_eip_service_config(self, skip_download=False,
@@ -414,14 +459,18 @@ class EIPConfigChecker(object):
def _get_provider_definition_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_TEST_PROVIDER
+ domain = baseconstants.DEFAULT_PROVIDER
if path is None:
path = baseconstants.DEFINITION_EXPECTED_PATH
- return "https://%s/%s" % (domain, path)
+ uri = u"https://%s/%s" % (domain, path)
+ logger.debug('getting provider definition from %s' % uri)
+ return uri
def _get_eip_service_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_TEST_PROVIDER
+ domain = baseconstants.DEFAULT_PROVIDER
if path is None:
path = eipconstants.EIP_SERVICE_EXPECTED_PATH
- return "https://%s/%s" % (domain, path)
+ uri = "https://%s/%s" % (domain, path)
+ logger.debug('getting eip service file from %s', uri)
+ return uri