diff options
Diffstat (limited to 'src/leap/eip')
-rw-r--r-- | src/leap/eip/checks.py | 154 | ||||
-rw-r--r-- | src/leap/eip/tests/test_checks.py | 171 |
2 files changed, 322 insertions, 3 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index c6a7ca72..51a7e219 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -1,5 +1,6 @@ -import json +#import json import logging +import ssl import os logging.basicConfig() @@ -13,6 +14,7 @@ from leap.base import providers from leap.eip import config as eipconfig from leap.eip import constants as eipconstants from leap.eip import exceptions as eipexceptions +from leap.eip import specs as eipspecs """ EIPConfigChecker @@ -40,12 +42,158 @@ class LeapNetworkChecker(object): class ProviderCertChecker(object): - pass + """ + Several checks needed for getting + client certs and checking tls connection + with provider. + """ + def __init__(self, fetcher=requests): + self.fetcher = fetcher + self.cacert = None + + def run_all(self, checker=None, skip_download=False): + if not checker: + checker = self + + # For MVS+ + # checker.download_ca_cert() + # checker.download_ca_signature() + # checker.get_ca_signatures() + # checker.is_there_trust_path() + + # For MVS + checker.is_there_provider_ca() + checker.is_https_working() + checker.check_new_cert_needed() + #checker.download_new_client_cert() + + def download_ca_cert(self): + # MVS+ + raise NotImplementedError + + def download_ca_signature(self): + # MVS+ + raise NotImplementedError + + def get_ca_signatures(self): + # MVS+ + raise NotImplementedError + + def is_there_trust_path(self): + # MVS+ + raise NotImplementedError + + def is_there_provider_ca(self): + # XXX fake it till you make it! :P + return True + + # enable this when we have + # a custom "branded" bundle + # certs package. + try: + from leap.custom import certs + except ImportError: + raise + self.cacert = certs.where('cacert.pem') + + def is_https_working(self, uri=None, verify=True): + # XXX raise InsecureURI or something better + assert uri.startswith('https') + if verify is True and self.cacert is not None: + verify = self.cacert + self.fetcher.get(uri, verify=verify) + return True + + def check_new_cert_needed(self, skip_download=False): + if not self.is_cert_valid(do_raise=False): + self.download_new_client_cert(skip_download=skip_download) + return True + return False + + def download_new_client_cert(self, uri=None, verify=True, + skip_download=False): + if skip_download: + return True + if uri is None: + uri = self._get_client_cert_uri() + # XXX raise InsecureURI or something better + assert uri.startswith('https') + if verify is True and self.cacert is not None: + verify = self.cacert + req = self.fetcher.get(uri, verify=verify) + pemfile_content = req.content + self.is_valid_pemfile(pemfile_content) + cert_path = self._get_client_cert_path() + self.write_cert(pemfile_content, to=cert_path) + return True + + def is_cert_valid(self, cert_path=None, do_raise=True): + exists = lambda: self.is_certificate_exists() + valid_pemfile = lambda: self.is_valid_pemfile() + not_expired = lambda: self.is_cert_not_expired() + print 'exists?', exists + print 'valid', valid_pemfile + print 'not expired', not_expired + + valid = exists() and valid_pemfile() and not_expired() + if not valid: + if do_raise: + raise Exception('missing cert') + else: + return False + return True + + def is_certificate_exists(self, certfile=None): + if certfile is None: + certfile = self._get_client_cert_path() + return os.path.isfile(certfile) + + def is_cert_not_expired(self): + return True + # XXX TODO + # waiting on #507. If we're not using PyOpenSSL or anything alike + # we will have to roll our own x509 parsing to extract time info. + + def is_valid_pemfile(self, cert_s=None): + """ + checks that the passed string + is a valid pem certificate + @param cert_s: string containing pem content + @type cert_s: string + @rtype: bool + """ + if cert_s is None: + certfile = self._get_client_cert_path() + with open(certfile) as cf: + cert_s = cf.read() + try: + # XXX get a real cert validation + # so far this is only checking begin/end + # delimiters :) + ssl.PEM_cert_to_DER_cert(cert_s) + except: + # XXX raise proper exception + raise + return True + + def _get_client_cert_uri(self): + return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER) + + def _get_client_cert_path(self): + # MVS+ : get provider path + return eipspecs.client_cert_path() + + def is_cert_still_valid(self): + raise NotImplementedError + + def write_cert(self, pemfile_content, to=None): + with open(to, 'w') as cert_f: + cert_f.write(pemfile_content) class EIPConfigChecker(object): """ - Several tests needed + Several checks needed to ensure a EIPConnection can be sucessfully established. use run_all to run all checks. diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py index 1e629203..09fdaabf 100644 --- a/src/leap/eip/tests/test_checks.py +++ b/src/leap/eip/tests/test_checks.py @@ -1,3 +1,4 @@ +from BaseHTTPServer import BaseHTTPRequestHandler import copy import json try: @@ -5,6 +6,7 @@ try: except ImportError: import unittest import os +import urlparse from mock import patch, Mock @@ -18,6 +20,17 @@ from leap.eip import specs as eipspecs from leap.eip import exceptions as eipexceptions from leap.eip.tests import data as testdata from leap.testing.basetest import BaseLeapTest +from leap.testing.https_server import BaseHTTPSServerTestCase +from leap.testing.https_server import where as where_cert + + +class NoLogRequestHandler: + def log_message(self, *args): + # don't write log msg to stderr + pass + + def read(self, n=None): + return '' class EIPCheckTest(BaseLeapTest): @@ -157,5 +170,163 @@ class EIPCheckTest(BaseLeapTest): sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON) checker.check_complete_eip_config(config=sampleconfig) + +class ProviderCertCheckerTest(BaseLeapTest): + + __name__ = "provider_cert_checker_tests" + + def setUp(self): + pass + + def tearDown(self): + pass + + # test methods are there, and can be called from run_all + + def test_checker_should_implement_check_methods(self): + checker = eipchecks.ProviderCertChecker() + + # For MVS+ + self.assertTrue(hasattr(checker, "download_ca_cert"), + "missing meth") + self.assertTrue(hasattr(checker, "download_ca_signature"), + "missing meth") + self.assertTrue(hasattr(checker, "get_ca_signatures"), "missing meth") + self.assertTrue(hasattr(checker, "is_there_trust_path"), + "missing meth") + + # For MVS + self.assertTrue(hasattr(checker, "is_there_provider_ca"), + "missing meth") + self.assertTrue(hasattr(checker, "is_https_working"), "missing meth") + self.assertTrue(hasattr(checker, "check_new_cert_needed"), + "missing meth") + + def test_checker_should_actually_call_all_tests(self): + checker = eipchecks.ProviderCertChecker() + + mc = Mock() + checker.run_all(checker=mc) + # XXX MVS+ + #self.assertTrue(mc.download_ca_cert.called, "not called") + #self.assertTrue(mc.download_ca_signature.called, "not called") + #self.assertTrue(mc.get_ca_signatures.called, "not called") + #self.assertTrue(mc.is_there_trust_path.called, "not called") + + # For MVS + self.assertTrue(mc.is_there_provider_ca.called, "not called") + self.assertTrue(mc.is_https_working.called, + "not called") + self.assertTrue(mc.check_new_cert_needed.called, + "not called") + + # test individual check methods + + def test_is_there_provider_ca(self): + checker = eipchecks.ProviderCertChecker() + self.assertTrue( + checker.is_there_provider_ca()) + + +class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase): + class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler): + responses = { + '/': ['OK', ''], + '/client.cert': [ + # XXX get sample cert + '-----BEGIN CERTIFICATE-----', + '-----END CERTIFICATE-----'], + '/badclient.cert': [ + 'BADCERT']} + + def do_GET(self): + path = urlparse.urlparse(self.path) + message = '\n'.join(self.responses.get( + path.path, None)) + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + def test_is_https_working(self): + fetcher = requests + uri = "https://%s/" % (self.get_server()) + # bare requests call. this should just pass (if there is + # an https service there). + fetcher.get(uri, verify=False) + checker = eipchecks.ProviderCertChecker(fetcher=fetcher) + self.assertTrue(checker.is_https_working(uri=uri, verify=False)) + + # for local debugs, when in doubt + #self.assertTrue(checker.is_https_working(uri="https://github.com", + #verify=True)) + + # for the two checks below, I know they fail because no ca + # cert is passed to them, and I know that's the error that + # requests return with our implementation. + # We're receiving this because our + # server is dying prematurely when the handshake is interrupted on the + # client side. + # Since we have access to the server, we could check that + # the error raised has been: + # SSL23_READ_BYTES: alert bad certificate + with self.assertRaises(requests.exceptions.SSLError) as exc: + fetcher.get(uri, verify=True) + self.assertTrue( + "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message) + with self.assertRaises(requests.exceptions.SSLError) as exc: + checker.is_https_working(uri=uri, verify=True) + self.assertTrue( + "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message) + + # get cacert from testing.https_server + cacert = where_cert('cacert.pem') + fetcher.get(uri, verify=cacert) + self.assertTrue(checker.is_https_working(uri=uri, verify=cacert)) + + # same, but get cacert from leap.custom + # XXX TODO! + + def test_download_new_client_cert(self): + uri = "https://%s/client.cert" % (self.get_server()) + cacert = where_cert('cacert.pem') + checker = eipchecks.ProviderCertChecker() + self.assertTrue(checker.download_new_client_cert( + uri=uri, verify=cacert)) + + # now download a malformed cert + uri = "https://%s/badclient.cert" % (self.get_server()) + cacert = where_cert('cacert.pem') + checker = eipchecks.ProviderCertChecker() + with self.assertRaises(ValueError): + self.assertTrue(checker.download_new_client_cert( + uri=uri, verify=cacert)) + + # did we write cert to its path? + clientcertfile = eipspecs.client_cert_path() + self.assertTrue(os.path.isfile(clientcertfile)) + certfile = eipspecs.client_cert_path() + with open(certfile, 'r') as cf: + certcontent = cf.read() + self.assertEqual(certcontent, + '\n'.join( + self.request_handler.responses['/client.cert'])) + os.remove(clientcertfile) + + def test_is_cert_valid(self): + checker = eipchecks.ProviderCertChecker() + # TODO: better exception catching + with self.assertRaises(Exception) as exc: + self.assertFalse(checker.is_cert_valid()) + exc.message = "missing cert" + + def test_check_new_cert_needed(self): + # check: missing cert + checker = eipchecks.ProviderCertChecker() + self.assertTrue(checker.check_new_cert_needed(skip_download=True)) + # TODO check: malformed cert + # TODO check: expired cert + # TODO check: pass test server uri instead of skip + + if __name__ == "__main__": unittest.main() |