summaryrefslogtreecommitdiff
path: root/src/leap/eip/checks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r--src/leap/eip/checks.py198
1 files changed, 150 insertions, 48 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index f739c3e8..116c535e 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -4,15 +4,18 @@ import ssl
import time
import os
-from gnutls import crypto
+import gnutls.crypto
#import netifaces
#import ping
import requests
from leap import __branding as BRANDING
-from leap import certs
+from leap import certs as leapcerts
+from leap.base.auth import srpauth_protected, magick_srpauth
+from leap.base import config as baseconfig
from leap.base import constants as baseconstants
from leap.base import providers
+from leap.crypto import certs
from leap.eip import config as eipconfig
from leap.eip import constants as eipconstants
from leap.eip import exceptions as eipexceptions
@@ -42,10 +45,11 @@ reachable and testable as a whole.
"""
-def get_ca_cert():
+def get_branding_ca_cert(domain):
+ # XXX deprecated
ca_file = BRANDING.get('provider_ca_file')
if ca_file:
- return certs.where(ca_file)
+ return leapcerts.where(ca_file)
class ProviderCertChecker(object):
@@ -54,18 +58,25 @@ class ProviderCertChecker(object):
client certs and checking tls connection
with provider.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests,
+ domain=None):
+
self.fetcher = fetcher
- self.cacert = get_ca_cert()
+ self.domain = domain
+ self.cacert = eipspecs.provider_ca_path(domain)
+
+ def run_all(
+ self, checker=None,
+ skip_download=False, skip_verify=False):
- def run_all(self, checker=None, skip_download=False, skip_verify=False):
if not checker:
checker = self
do_verify = not skip_verify
logger.debug('do_verify: %s', do_verify)
- # For MVS+
# checker.download_ca_cert()
+
+ # For MVS+
# checker.download_ca_signature()
# checker.get_ca_signatures()
# checker.is_there_trust_path()
@@ -74,12 +85,44 @@ class ProviderCertChecker(object):
checker.is_there_provider_ca()
# XXX FAKE IT!!!
- checker.is_https_working(verify=do_verify)
+ checker.is_https_working(verify=do_verify, autocacert=True)
checker.check_new_cert_needed(verify=do_verify)
- def download_ca_cert(self):
- # MVS+
- raise NotImplementedError
+ def download_ca_cert(self, uri=None, verify=True):
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+
+ # should check domain exists
+ capath = self._get_ca_cert_path(self.domain)
+ with open(capath, 'w') as f:
+ f.write(req.content)
+
+ def check_ca_cert_fingerprint(
+ self, hash_type="SHA256",
+ fingerprint=None):
+ """
+ compares the fingerprint in
+ the ca cert with a string
+ we are passed
+ returns True if they are equal, False if not.
+ @param hash_type: digest function
+ @type hash_type: str
+ @param fingerprint: the fingerprint to compare with.
+ @type fingerprint: str (with : separator)
+ @rtype bool
+ """
+ ca_cert_path = self.ca_cert_path
+ ca_cert_fpr = certs.get_cert_fingerprint(
+ filepath=ca_cert_path)
+ return ca_cert_fpr == fingerprint
+
+ def verify_api_https(self, uri):
+ assert uri.startswith('https://')
+ cacert = self.ca_cert_path
+ verify = cacert and cacert or True
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+ return True
def download_ca_signature(self):
# MVS+
@@ -94,36 +137,47 @@ class ProviderCertChecker(object):
raise NotImplementedError
def is_there_provider_ca(self):
- from leap import certs
- logger.debug('do we have provider_ca?')
- cacert_path = BRANDING.get('provider_ca_file', None)
- if not cacert_path:
- logger.debug('False')
+ if not self.cacert:
return False
- self.cacert = certs.where(cacert_path)
- logger.debug('True')
- return True
+ cacert_exists = os.path.isfile(self.cacert)
+ if cacert_exists:
+ logger.debug('True')
+ return True
+ logger.debug('False!')
+ return False
- def is_https_working(self, uri=None, verify=True):
+ def is_https_working(
+ self, uri=None, verify=True,
+ autocacert=False):
if uri is None:
uri = self._get_root_uri()
# XXX raise InsecureURI or something better
- assert uri.startswith('https')
- if verify is True and self.cacert is not None:
+ try:
+ assert uri.startswith('https')
+ except AssertionError:
+ raise AssertionError(
+ "uri passed should start with https")
+ if autocacert and verify is True and self.cacert is not None:
logger.debug('verify cert: %s', self.cacert)
verify = self.cacert
+ #import pdb4qt; pdb4qt.set_trace()
logger.debug('is https working?')
logger.debug('uri: %s (verify:%s)', uri, verify)
try:
self.fetcher.get(uri, verify=verify)
+
except requests.exceptions.SSLError as exc:
- logger.warning('False! CERT VERIFICATION FAILED! '
+ logger.error("SSLError")
+ # XXX RAISE! See #638
+ #raise eipexceptions.HttpsBadCertError
+ logger.warning('BUG #638 CERT VERIFICATION FAILED! '
'(this should be CRITICAL)')
logger.warning('SSLError: %s', exc.message)
- # XXX RAISE! See #638
- #raise eipexceptions.EIPBadCertError
- # XXX get requests.exceptions.ConnectionError Errno 110
- # Connection timed out, and raise ours.
+
+ except requests.exceptions.ConnectionError:
+ logger.error('ConnectionError')
+ raise eipexceptions.HttpsNotSupported
+
else:
logger.debug('True')
return True
@@ -140,7 +194,8 @@ class ProviderCertChecker(object):
return False
def download_new_client_cert(self, uri=None, verify=True,
- skip_download=False):
+ skip_download=False,
+ credentials=None):
logger.debug('download new client cert')
if skip_download:
return True
@@ -148,18 +203,38 @@ class ProviderCertChecker(object):
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
assert uri.startswith('https')
+
if verify is True and self.cacert is not None:
verify = self.cacert
+
+ fgetfn = self.fetcher.get
+
+ if credentials:
+ user, passwd = credentials
+
+ logger.debug('domain = %s', self.domain)
+
+ @srpauth_protected(user, passwd,
+ server="https://%s" % self.domain,
+ verify=verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+
+ else:
+ # XXX FIXME fix decorated args
+ @magick_srpauth(verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
try:
+
# XXX FIXME!!!!
# verify=verify
# Workaround for #638. return to verification
# when That's done!!!
-
- # XXX HOOK SRP here...
- # will have to be more generic in the future.
- req = self.fetcher.get(uri, verify=False)
+ #req = self.fetcher.get(uri, verify=False)
+ req = getfn(uri, verify=False)
req.raise_for_status()
+
except requests.exceptions.SSLError:
logger.warning('SSLError while fetching cert. '
'Look below for stack trace.')
@@ -198,7 +273,7 @@ class ProviderCertChecker(object):
certfile = self._get_client_cert_path()
with open(certfile) as cf:
cert_s = cf.read()
- cert = crypto.X509Certificate(cert_s)
+ cert = gnutls.crypto.X509Certificate(cert_s)
from_ = time.gmtime(cert.activation_time)
to_ = time.gmtime(cert.expiration_time)
return from_ < now() < to_
@@ -233,16 +308,34 @@ class ProviderCertChecker(object):
raise
return True
+ @property
+ def ca_cert_path(self):
+ return self._get_ca_cert_path(self.domain)
+
def _get_root_uri(self):
- return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
+ return u"https://%s/" % self.domain
def _get_client_cert_uri(self):
# XXX get the whole thing from constants
- return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)
+ return "https://%s/1/cert" % self.domain
def _get_client_cert_path(self):
- # MVS+ : get provider path
- return eipspecs.client_cert_path()
+ return eipspecs.client_cert_path(domain=self.domain)
+
+ def _get_ca_cert_path(self, domain):
+ # XXX this folder path will be broken for win
+ # and this should be moved to eipspecs.ca_path
+
+ # XXX use baseconfig.get_provider_path(folder=Foo)
+ # !!!
+
+ capath = baseconfig.get_config_file(
+ 'cacert.pem',
+ folder='providers/%s/keys/ca' % domain)
+ folder, fname = os.path.split(capath)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ return capath
def write_cert(self, pemfile_content, to=None):
folder, filename = os.path.split(to)
@@ -260,16 +353,20 @@ class EIPConfigChecker(object):
use run_all to run all checks.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests, domain=None):
# we do not want to accept too many
# argument on init.
# we want tests
# to be explicitely run.
+
self.fetcher = fetcher
- self.eipconfig = eipconfig.EIPConfig()
- self.defaultprovider = providers.LeapProviderDefinition()
- self.eipserviceconfig = eipconfig.EIPServiceConfig()
+ # if not domain, get from config
+ self.domain = domain
+
+ self.eipconfig = eipconfig.EIPConfig(domain=domain)
+ self.defaultprovider = providers.LeapProviderDefinition(domain=domain)
+ self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)
def run_all(self, checker=None, skip_download=False):
"""
@@ -330,7 +427,8 @@ class EIPConfigChecker(object):
return True
def fetch_definition(self, skip_download=False,
- config=None, uri=None):
+ config=None, uri=None,
+ domain=None):
"""
fetches a definition file from server
"""
@@ -347,10 +445,13 @@ class EIPConfigChecker(object):
if config is None:
config = self.defaultprovider.config
if uri is None:
- domain = config.get('provider', None)
+ if not domain:
+ domain = config.get('provider', None)
uri = self._get_provider_definition_uri(domain=domain)
# FIXME! Pass ca path verify!!!
+ # BUG #638
+ # FIXME FIXME FIXME
self.defaultprovider.load(
from_uri=uri,
fetcher=self.fetcher,
@@ -358,13 +459,14 @@ class EIPConfigChecker(object):
self.defaultprovider.save()
def fetch_eip_service_config(self, skip_download=False,
- config=None, uri=None):
+ config=None, uri=None, domain=None):
if skip_download:
return True
if config is None:
config = self.eipserviceconfig.config
if uri is None:
- domain = config.get('provider', None)
+ if not domain:
+ domain = self.domain or config.get('provider', None)
uri = self._get_eip_service_uri(domain=domain)
self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher)
@@ -399,7 +501,7 @@ class EIPConfigChecker(object):
def _get_provider_definition_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_PROVIDER
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
if path is None:
path = baseconstants.DEFINITION_EXPECTED_PATH
uri = u"https://%s/%s" % (domain, path)
@@ -408,7 +510,7 @@ class EIPConfigChecker(object):
def _get_eip_service_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_PROVIDER
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
if path is None:
path = eipconstants.EIP_SERVICE_EXPECTED_PATH
uri = "https://%s/%s" % (domain, path)