summaryrefslogtreecommitdiff
path: root/src/leap/eip/checks.py
diff options
context:
space:
mode:
authorkali <kali@leap.se>2013-01-30 06:52:59 +0900
committerkali <kali@leap.se>2013-01-30 06:52:59 +0900
commit570f84756b3d1f689a172a3ff0c55abf6a60b9dd (patch)
tree0f262acacf35743664c6408edbafe6ba6f119d14 /src/leap/eip/checks.py
parente60abdcb796ad9e2c44e9b37d3a68e7f159c035c (diff)
parent10a2303fe2d21999bce56940daecb78576f5b741 (diff)
Merge branch 'pre-release-0.2.0' into release-0.2.0
This merge contains today's state of develop branch, minus soledad and email components.
Diffstat (limited to 'src/leap/eip/checks.py')
-rw-r--r--src/leap/eip/checks.py309
1 files changed, 215 insertions, 94 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index f739c3e8..9a34a428 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -1,23 +1,24 @@
import logging
-import ssl
-#import platform
import time
import os
+import sys
-from gnutls import crypto
-#import netifaces
-#import ping
import requests
from leap import __branding as BRANDING
-from leap import certs
+from leap import certs as leapcerts
+from leap.base.auth import srpauth_protected, magick_srpauth
+from leap.base import config as baseconfig
from leap.base import constants as baseconstants
from leap.base import providers
+from leap.crypto import certs
from leap.eip import config as eipconfig
from leap.eip import constants as eipconstants
from leap.eip import exceptions as eipexceptions
from leap.eip import specs as eipspecs
+from leap.util.certs import get_mac_cabundle
from leap.util.fileutil import mkdir_p
+from leap.util.web import get_https_domain_and_port
logger = logging.getLogger(name=__name__)
@@ -42,10 +43,11 @@ reachable and testable as a whole.
"""
-def get_ca_cert():
+def get_branding_ca_cert(domain):
+ # deprecated
ca_file = BRANDING.get('provider_ca_file')
if ca_file:
- return certs.where(ca_file)
+ return leapcerts.where(ca_file)
class ProviderCertChecker(object):
@@ -54,18 +56,29 @@ class ProviderCertChecker(object):
client certs and checking tls connection
with provider.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests,
+ domain=None):
+
self.fetcher = fetcher
- self.cacert = get_ca_cert()
+ self.domain = domain
+ #XXX needs some kind of autoinit
+ #right now we set by hand
+ #by loading and reading provider config
+ self.apidomain = None
+ self.cacert = eipspecs.provider_ca_path(domain)
+
+ def run_all(
+ self, checker=None,
+ skip_download=False, skip_verify=False):
- def run_all(self, checker=None, skip_download=False, skip_verify=False):
if not checker:
checker = self
do_verify = not skip_verify
logger.debug('do_verify: %s', do_verify)
- # For MVS+
# checker.download_ca_cert()
+
+ # For MVS+
# checker.download_ca_signature()
# checker.get_ca_signatures()
# checker.is_there_trust_path()
@@ -73,13 +86,44 @@ class ProviderCertChecker(object):
# For MVS
checker.is_there_provider_ca()
- # XXX FAKE IT!!!
- checker.is_https_working(verify=do_verify)
+ checker.is_https_working(verify=do_verify, autocacert=False)
checker.check_new_cert_needed(verify=do_verify)
- def download_ca_cert(self):
- # MVS+
- raise NotImplementedError
+ def download_ca_cert(self, uri=None, verify=True):
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+
+ # should check domain exists
+ capath = self._get_ca_cert_path(self.domain)
+ with open(capath, 'w') as f:
+ f.write(req.content)
+
+ def check_ca_cert_fingerprint(
+ self, hash_type="SHA256",
+ fingerprint=None):
+ """
+ compares the fingerprint in
+ the ca cert with a string
+ we are passed
+ returns True if they are equal, False if not.
+ @param hash_type: digest function
+ @type hash_type: str
+ @param fingerprint: the fingerprint to compare with.
+ @type fingerprint: str (with : separator)
+ @rtype bool
+ """
+ ca_cert_path = self.ca_cert_path
+ ca_cert_fpr = certs.get_cert_fingerprint(
+ filepath=ca_cert_path)
+ return ca_cert_fpr == fingerprint
+
+ def verify_api_https(self, uri):
+ assert uri.startswith('https://')
+ cacert = self.ca_cert_path
+ verify = cacert and cacert or True
+ req = self.fetcher.get(uri, verify=verify)
+ req.raise_for_status()
+ return True
def download_ca_signature(self):
# MVS+
@@ -94,80 +138,110 @@ class ProviderCertChecker(object):
raise NotImplementedError
def is_there_provider_ca(self):
- from leap import certs
- logger.debug('do we have provider_ca?')
- cacert_path = BRANDING.get('provider_ca_file', None)
- if not cacert_path:
- logger.debug('False')
+ if not self.cacert:
return False
- self.cacert = certs.where(cacert_path)
- logger.debug('True')
- return True
+ cacert_exists = os.path.isfile(self.cacert)
+ if cacert_exists:
+ logger.debug('True')
+ return True
+ logger.debug('False!')
+ return False
- def is_https_working(self, uri=None, verify=True):
+ def is_https_working(
+ self, uri=None, verify=True,
+ autocacert=False):
if uri is None:
uri = self._get_root_uri()
# XXX raise InsecureURI or something better
- assert uri.startswith('https')
- if verify is True and self.cacert is not None:
+ try:
+ assert uri.startswith('https')
+ except AssertionError:
+ raise AssertionError(
+ "uri passed should start with https")
+ if autocacert and verify is True and self.cacert is not None:
logger.debug('verify cert: %s', self.cacert)
verify = self.cacert
- logger.debug('is https working?')
+ if sys.platform == "darwin":
+ verify = get_mac_cabundle()
+ logger.debug('checking https connection')
logger.debug('uri: %s (verify:%s)', uri, verify)
+
try:
self.fetcher.get(uri, verify=verify)
+
except requests.exceptions.SSLError as exc:
- logger.warning('False! CERT VERIFICATION FAILED! '
- '(this should be CRITICAL)')
- logger.warning('SSLError: %s', exc.message)
- # XXX RAISE! See #638
- #raise eipexceptions.EIPBadCertError
- # XXX get requests.exceptions.ConnectionError Errno 110
- # Connection timed out, and raise ours.
+ raise eipexceptions.HttpsBadCertError
+
+ except requests.exceptions.ConnectionError:
+ logger.error('ConnectionError')
+ raise eipexceptions.HttpsNotSupported
+
else:
- logger.debug('True')
return True
def check_new_cert_needed(self, skip_download=False, verify=True):
- logger.debug('is new cert needed?')
+ # XXX add autocacert
if not self.is_cert_valid(do_raise=False):
- logger.debug('True')
+ logger.debug('cert needed: true')
self.download_new_client_cert(
skip_download=skip_download,
verify=verify)
return True
- logger.debug('False')
+ logger.debug('cert needed: false')
return False
def download_new_client_cert(self, uri=None, verify=True,
- skip_download=False):
+ skip_download=False,
+ credentials=None):
logger.debug('download new client cert')
if skip_download:
return True
if uri is None:
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
- assert uri.startswith('https')
+ #assert uri.startswith('https')
+
if verify is True and self.cacert is not None:
verify = self.cacert
+ logger.debug('verify = %s', verify)
+
+ fgetfn = self.fetcher.get
+
+ if credentials:
+ user, passwd = credentials
+ logger.debug('apidomain = %s', self.apidomain)
+
+ @srpauth_protected(user, passwd,
+ server="https://%s" % self.apidomain,
+ verify=verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+
+ else:
+ # XXX FIXME fix decorated args
+ @magick_srpauth(verify)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
try:
- # XXX FIXME!!!!
- # verify=verify
- # Workaround for #638. return to verification
- # when That's done!!!
-
- # XXX HOOK SRP here...
- # will have to be more generic in the future.
- req = self.fetcher.get(uri, verify=False)
+
+ req = getfn(uri, verify=verify)
req.raise_for_status()
+
except requests.exceptions.SSLError:
logger.warning('SSLError while fetching cert. '
'Look below for stack trace.')
# XXX raise better exception
- raise
+ return self.fail("SSLError")
+ except Exception as exc:
+ return self.fail(exc.message)
+
try:
+ logger.debug('validating cert...')
pemfile_content = req.content
- self.is_valid_pemfile(pemfile_content)
+ valid = self.is_valid_pemfile(pemfile_content)
+ if not valid:
+ logger.warning('invalid cert')
+ return False
cert_path = self._get_client_cert_path()
self.write_cert(pemfile_content, to=cert_path)
except:
@@ -196,11 +270,8 @@ class ProviderCertChecker(object):
def is_cert_not_expired(self, certfile=None, now=time.gmtime):
if certfile is None:
certfile = self._get_client_cert_path()
- with open(certfile) as cf:
- cert_s = cf.read()
- cert = crypto.X509Certificate(cert_s)
- from_ = time.gmtime(cert.activation_time)
- to_ = time.gmtime(cert.expiration_time)
+ from_, to_ = certs.get_time_boundaries(certfile)
+
return from_ < now() < to_
def is_valid_pemfile(self, cert_s=None):
@@ -216,33 +287,39 @@ class ProviderCertChecker(object):
with open(certfile) as cf:
cert_s = cf.read()
try:
- # XXX get a real cert validation
- # so far this is only checking begin/end
- # delimiters :)
- # XXX use gnutls for get proper
- # validation.
- # crypto.X509Certificate(cert_s)
- sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5
- # we might have private key and cert in the same file
- certparts = cert_s.split(sep)
- if len(certparts) > 1:
- cert_s = sep + certparts[1]
- ssl.PEM_cert_to_DER_cert(cert_s)
- except:
- # XXX raise proper exception
- raise
- return True
+ valid = certs.can_load_cert_and_pkey(cert_s)
+ except certs.BadCertError:
+ logger.warning("Not valid pemfile")
+ valid = False
+ return valid
+
+ @property
+ def ca_cert_path(self):
+ return self._get_ca_cert_path(self.domain)
def _get_root_uri(self):
- return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
+ return u"https://%s/" % self.domain
def _get_client_cert_uri(self):
- # XXX get the whole thing from constants
- return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)
+ return "https://%s/1/cert" % self.apidomain
def _get_client_cert_path(self):
- # MVS+ : get provider path
- return eipspecs.client_cert_path()
+ return eipspecs.client_cert_path(domain=self.domain)
+
+ def _get_ca_cert_path(self, domain):
+ # XXX this folder path will be broken for win
+ # and this should be moved to eipspecs.ca_path
+
+ # XXX use baseconfig.get_provider_path(folder=Foo)
+ # !!!
+
+ capath = baseconfig.get_config_file(
+ 'cacert.pem',
+ folder='providers/%s/keys/ca' % domain)
+ folder, fname = os.path.split(capath)
+ if not os.path.isdir(folder):
+ mkdir_p(folder)
+ return capath
def write_cert(self, pemfile_content, to=None):
folder, filename = os.path.split(to)
@@ -251,6 +328,9 @@ class ProviderCertChecker(object):
with open(to, 'w') as cert_f:
cert_f.write(pemfile_content)
+ def set_api_domain(self, domain):
+ self.apidomain = domain
+
class EIPConfigChecker(object):
"""
@@ -260,16 +340,25 @@ class EIPConfigChecker(object):
use run_all to run all checks.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests, domain=None):
# we do not want to accept too many
# argument on init.
# we want tests
# to be explicitely run.
+
self.fetcher = fetcher
- self.eipconfig = eipconfig.EIPConfig()
- self.defaultprovider = providers.LeapProviderDefinition()
- self.eipserviceconfig = eipconfig.EIPServiceConfig()
+ # if not domain, get from config
+ self.domain = domain
+ self.apidomain = None
+ self.cacert = eipspecs.provider_ca_path(domain)
+
+ self.defaultprovider = providers.LeapProviderDefinition(domain=domain)
+ self.defaultprovider.load()
+ self.eipconfig = eipconfig.EIPConfig(domain=domain)
+ self.set_api_domain()
+ self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)
+ self.eipserviceconfig.load()
def run_all(self, checker=None, skip_download=False):
"""
@@ -330,7 +419,9 @@ class EIPConfigChecker(object):
return True
def fetch_definition(self, skip_download=False,
- config=None, uri=None):
+ force_download=False,
+ config=None, uri=None,
+ domain=None):
"""
fetches a definition file from server
"""
@@ -347,27 +438,45 @@ class EIPConfigChecker(object):
if config is None:
config = self.defaultprovider.config
if uri is None:
- domain = config.get('provider', None)
+ if not domain:
+ domain = config.get('provider', None)
uri = self._get_provider_definition_uri(domain=domain)
- # FIXME! Pass ca path verify!!!
+ if sys.platform == "darwin":
+ verify = get_mac_cabundle()
+ else:
+ verify = True
+
self.defaultprovider.load(
from_uri=uri,
fetcher=self.fetcher,
- verify=False)
+ verify=verify)
self.defaultprovider.save()
def fetch_eip_service_config(self, skip_download=False,
- config=None, uri=None):
+ force_download=False,
+ config=None, uri=None, # domain=None,
+ autocacert=True, verify=True):
if skip_download:
return True
if config is None:
+ self.eipserviceconfig.load()
config = self.eipserviceconfig.config
if uri is None:
- domain = config.get('provider', None)
- uri = self._get_eip_service_uri(domain=domain)
+ #XXX
+ #if not domain:
+ #domain = self.domain or config.get('provider', None)
+ uri = self._get_eip_service_uri(
+ domain=self.apidomain)
+
+ if autocacert and self.cacert is not None:
+ verify = self.cacert
- self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher)
+ self.eipserviceconfig.load(
+ from_uri=uri,
+ fetcher=self.fetcher,
+ force_download=force_download,
+ verify=verify)
self.eipserviceconfig.save()
def check_complete_eip_config(self, config=None):
@@ -375,7 +484,6 @@ class EIPConfigChecker(object):
if config is None:
config = self.eipconfig.config
try:
- 'trying assertions'
assert 'provider' in config
assert config['provider'] is not None
# XXX assert there is gateway !!
@@ -395,11 +503,11 @@ class EIPConfigChecker(object):
return self.eipconfig.exists()
def _dump_default_eipconfig(self):
- self.eipconfig.save()
+ self.eipconfig.save(force=True)
def _get_provider_definition_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_PROVIDER
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
if path is None:
path = baseconstants.DEFINITION_EXPECTED_PATH
uri = u"https://%s/%s" % (domain, path)
@@ -408,9 +516,22 @@ class EIPConfigChecker(object):
def _get_eip_service_uri(self, domain=None, path=None):
if domain is None:
- domain = baseconstants.DEFAULT_PROVIDER
+ domain = self.domain or baseconstants.DEFAULT_PROVIDER
if path is None:
path = eipconstants.EIP_SERVICE_EXPECTED_PATH
uri = "https://%s/%s" % (domain, path)
logger.debug('getting eip service file from %s', uri)
return uri
+
+ def set_api_domain(self):
+ """sets api domain from defaultprovider config object"""
+ api = self.defaultprovider.config.get('api_uri', None)
+ # the caller is responsible for having loaded the config
+ # object at this point
+ if api:
+ api_dom = get_https_domain_and_port(api)
+ self.apidomain = "%s:%s" % api_dom
+
+ def get_api_domain(self):
+ """gets api domain"""
+ return self.apidomain