summaryrefslogtreecommitdiff
path: root/src/leap/eip/checks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r--src/leap/eip/checks.py67
1 files changed, 52 insertions, 15 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index 74afd677..635308bb 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -11,6 +11,7 @@ import requests
from leap import __branding as BRANDING
from leap import certs as leapcerts
+from leap.base.auth import srpauth_protected
from leap.base import config as baseconfig
from leap.base import constants as baseconstants
from leap.base import providers
@@ -98,6 +99,17 @@ class ProviderCertChecker(object):
def check_ca_cert_fingerprint(
self, hash_type="SHA256",
fingerprint=None):
+ """
+ compares the fingerprint in
+ the ca cert with a string
+ we are passed
+ returns True if they are equal, False if not.
+ @param hash_type: digest function
+ @type hash_type: str
+ @param fingerprint: the fingerprint to compare with.
+ @type fingerprint: str (with : separator)
+ @rtype bool
+ """
ca_cert_path = self.ca_cert_path
ca_cert_fpr = certs.get_cert_fingerprint(
filepath=ca_cert_path)
@@ -185,7 +197,8 @@ class ProviderCertChecker(object):
return False
def download_new_client_cert(self, uri=None, verify=True,
- skip_download=False):
+ skip_download=False,
+ credentials=None):
logger.debug('download new client cert')
if skip_download:
return True
@@ -193,18 +206,34 @@ class ProviderCertChecker(object):
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
assert uri.startswith('https')
+
if verify is True and self.cacert is not None:
verify = self.cacert
+
+ fgetfn = self.fetcher.get
+
+ if credentials:
+ user, passwd = credentials
+
+ @srpauth_protected(user, passwd)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+
+ else:
+ # XXX use magic_srpauth decorator instead,
+ # merge with the branch above
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
try:
+
# XXX FIXME!!!!
# verify=verify
# Workaround for #638. return to verification
# when That's done!!!
-
- # XXX HOOK SRP here...
- # will have to be more generic in the future.
- req = self.fetcher.get(uri, verify=False)
+ #req = self.fetcher.get(uri, verify=False)
+ req = getfn(uri, verify=False)
req.raise_for_status()
+
except requests.exceptions.SSLError:
logger.warning('SSLError while fetching cert. '
'Look below for stack trace.')
@@ -283,23 +312,26 @@ class ProviderCertChecker(object):
return self._get_ca_cert_path(self.domain)
def _get_root_uri(self):
- return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
+ return u"https://%s/" % self.domain
def _get_client_cert_uri(self):
# XXX get the whole thing from constants
- return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)
+ return "https://%s/1/cert" % self.domain
def _get_client_cert_path(self):
# MVS+ : get provider path
- return eipspecs.client_cert_path()
+ return eipspecs.client_cert_path(domain=self.domain)
def _get_ca_cert_path(self, domain):
# XXX this folder path will be broken for win
# and this should be moved to eipspecs.ca_path
+ # XXX use baseconfig.get_provider_path(folder=Foo)
+ # !!!
+
capath = baseconfig.get_config_file(
'cacert.pem',
- folder='providers/%s/certs/ca' % domain)
+ folder='providers/%s/keys/ca' % domain)
folder, fname = os.path.split(capath)
if not os.path.isdir(folder):
mkdir_p(folder)
@@ -321,16 +353,20 @@ class EIPConfigChecker(object):
use run_all to run all checks.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests, domain=None):
# we do not want to accept too many
# argument on init.
# we want tests
# to be explicitely run.
+
self.fetcher = fetcher
- self.eipconfig = eipconfig.EIPConfig()
- self.defaultprovider = providers.LeapProviderDefinition()
- self.eipserviceconfig = eipconfig.EIPServiceConfig()
+ # if not domain, get from config
+ self.domain = domain
+
+ self.eipconfig = eipconfig.EIPConfig(domain=domain)
+ self.defaultprovider = providers.LeapProviderDefinition(domain=domain)
+ self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)
def run_all(self, checker=None, skip_download=False):
"""
@@ -421,13 +457,14 @@ class EIPConfigChecker(object):
self.defaultprovider.save()
def fetch_eip_service_config(self, skip_download=False,
- config=None, uri=None):
+ config=None, uri=None, domain=None):
if skip_download:
return True
if config is None:
config = self.eipserviceconfig.config
if uri is None:
- domain = config.get('provider', None)
+ if not domain:
+ domain = config.get('provider', None)
uri = self._get_eip_service_uri(domain=domain)
self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher)