summaryrefslogtreecommitdiff
path: root/src/leap/eip
diff options
context:
space:
mode:
authorkali <kali@leap.se>2012-09-04 01:09:53 +0900
committerkali <kali@leap.se>2012-09-04 01:10:45 +0900
commit8b3ad186e947ad962252c5d4b47a52ccb5514d98 (patch)
treedfa70cab48f1ecb85e7fdc0489cbb48ff03e4f66 /src/leap/eip
parent6c4012fc128c5af1b75cf33eef00590cf0e82438 (diff)
parent37d7e272b7f8a649034a0cf60f6c4a1424bf767a (diff)
Merge branch 'feature/provider-cert-check' into develop
For #501: write base.checks.ProviderCertChecks (in eip.checks, though.) Basic functionality is there, merging to tip. Might have to reopen to implement actual cert ts check.
Diffstat (limited to 'src/leap/eip')
-rw-r--r--src/leap/eip/checks.py154
-rw-r--r--src/leap/eip/tests/test_checks.py171
2 files changed, 322 insertions, 3 deletions
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index c6a7ca72..51a7e219 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -1,5 +1,6 @@
-import json
+#import json
import logging
+import ssl
import os
logging.basicConfig()
@@ -13,6 +14,7 @@ from leap.base import providers
from leap.eip import config as eipconfig
from leap.eip import constants as eipconstants
from leap.eip import exceptions as eipexceptions
+from leap.eip import specs as eipspecs
"""
EIPConfigChecker
@@ -40,12 +42,158 @@ class LeapNetworkChecker(object):
class ProviderCertChecker(object):
- pass
+ """
+ Several checks needed for getting
+ client certs and checking tls connection
+ with provider.
+ """
+ def __init__(self, fetcher=requests):
+ self.fetcher = fetcher
+ self.cacert = None
+
+ def run_all(self, checker=None, skip_download=False):
+ if not checker:
+ checker = self
+
+ # For MVS+
+ # checker.download_ca_cert()
+ # checker.download_ca_signature()
+ # checker.get_ca_signatures()
+ # checker.is_there_trust_path()
+
+ # For MVS
+ checker.is_there_provider_ca()
+ checker.is_https_working()
+ checker.check_new_cert_needed()
+ #checker.download_new_client_cert()
+
+ def download_ca_cert(self):
+ # MVS+
+ raise NotImplementedError
+
+ def download_ca_signature(self):
+ # MVS+
+ raise NotImplementedError
+
+ def get_ca_signatures(self):
+ # MVS+
+ raise NotImplementedError
+
+ def is_there_trust_path(self):
+ # MVS+
+ raise NotImplementedError
+
+ def is_there_provider_ca(self):
+ # XXX fake it till you make it! :P
+ return True
+
+ # enable this when we have
+ # a custom "branded" bundle
+ # certs package.
+ try:
+ from leap.custom import certs
+ except ImportError:
+ raise
+ self.cacert = certs.where('cacert.pem')
+
+ def is_https_working(self, uri=None, verify=True):
+ # XXX raise InsecureURI or something better
+ assert uri.startswith('https')
+ if verify is True and self.cacert is not None:
+ verify = self.cacert
+ self.fetcher.get(uri, verify=verify)
+ return True
+
+ def check_new_cert_needed(self, skip_download=False):
+ if not self.is_cert_valid(do_raise=False):
+ self.download_new_client_cert(skip_download=skip_download)
+ return True
+ return False
+
+ def download_new_client_cert(self, uri=None, verify=True,
+ skip_download=False):
+ if skip_download:
+ return True
+ if uri is None:
+ uri = self._get_client_cert_uri()
+ # XXX raise InsecureURI or something better
+ assert uri.startswith('https')
+ if verify is True and self.cacert is not None:
+ verify = self.cacert
+ req = self.fetcher.get(uri, verify=verify)
+ pemfile_content = req.content
+ self.is_valid_pemfile(pemfile_content)
+ cert_path = self._get_client_cert_path()
+ self.write_cert(pemfile_content, to=cert_path)
+ return True
+
+ def is_cert_valid(self, cert_path=None, do_raise=True):
+ exists = lambda: self.is_certificate_exists()
+ valid_pemfile = lambda: self.is_valid_pemfile()
+ not_expired = lambda: self.is_cert_not_expired()
+ print 'exists?', exists
+ print 'valid', valid_pemfile
+ print 'not expired', not_expired
+
+ valid = exists() and valid_pemfile() and not_expired()
+ if not valid:
+ if do_raise:
+ raise Exception('missing cert')
+ else:
+ return False
+ return True
+
+ def is_certificate_exists(self, certfile=None):
+ if certfile is None:
+ certfile = self._get_client_cert_path()
+ return os.path.isfile(certfile)
+
+ def is_cert_not_expired(self):
+ return True
+ # XXX TODO
+ # waiting on #507. If we're not using PyOpenSSL or anything alike
+ # we will have to roll our own x509 parsing to extract time info.
+
+ def is_valid_pemfile(self, cert_s=None):
+ """
+ checks that the passed string
+ is a valid pem certificate
+ @param cert_s: string containing pem content
+ @type cert_s: string
+ @rtype: bool
+ """
+ if cert_s is None:
+ certfile = self._get_client_cert_path()
+ with open(certfile) as cf:
+ cert_s = cf.read()
+ try:
+ # XXX get a real cert validation
+ # so far this is only checking begin/end
+ # delimiters :)
+ ssl.PEM_cert_to_DER_cert(cert_s)
+ except:
+ # XXX raise proper exception
+ raise
+ return True
+
+ def _get_client_cert_uri(self):
+ return "https://%s/cert/get" % (baseconstants.DEFAULT_TEST_PROVIDER)
+
+ def _get_client_cert_path(self):
+ # MVS+ : get provider path
+ return eipspecs.client_cert_path()
+
+ def is_cert_still_valid(self):
+ raise NotImplementedError
+
+ def write_cert(self, pemfile_content, to=None):
+ with open(to, 'w') as cert_f:
+ cert_f.write(pemfile_content)
class EIPConfigChecker(object):
"""
- Several tests needed
+ Several checks needed
to ensure a EIPConnection
can be sucessfully established.
use run_all to run all checks.
diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py
index 1e629203..09fdaabf 100644
--- a/src/leap/eip/tests/test_checks.py
+++ b/src/leap/eip/tests/test_checks.py
@@ -1,3 +1,4 @@
+from BaseHTTPServer import BaseHTTPRequestHandler
import copy
import json
try:
@@ -5,6 +6,7 @@ try:
except ImportError:
import unittest
import os
+import urlparse
from mock import patch, Mock
@@ -18,6 +20,17 @@ from leap.eip import specs as eipspecs
from leap.eip import exceptions as eipexceptions
from leap.eip.tests import data as testdata
from leap.testing.basetest import BaseLeapTest
+from leap.testing.https_server import BaseHTTPSServerTestCase
+from leap.testing.https_server import where as where_cert
+
+
+class NoLogRequestHandler:
+ def log_message(self, *args):
+ # don't write log msg to stderr
+ pass
+
+ def read(self, n=None):
+ return ''
class EIPCheckTest(BaseLeapTest):
@@ -157,5 +170,163 @@ class EIPCheckTest(BaseLeapTest):
sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON)
checker.check_complete_eip_config(config=sampleconfig)
+
+class ProviderCertCheckerTest(BaseLeapTest):
+
+ __name__ = "provider_cert_checker_tests"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ # test methods are there, and can be called from run_all
+
+ def test_checker_should_implement_check_methods(self):
+ checker = eipchecks.ProviderCertChecker()
+
+ # For MVS+
+ self.assertTrue(hasattr(checker, "download_ca_cert"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "download_ca_signature"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "get_ca_signatures"), "missing meth")
+ self.assertTrue(hasattr(checker, "is_there_trust_path"),
+ "missing meth")
+
+ # For MVS
+ self.assertTrue(hasattr(checker, "is_there_provider_ca"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "is_https_working"), "missing meth")
+ self.assertTrue(hasattr(checker, "check_new_cert_needed"),
+ "missing meth")
+
+ def test_checker_should_actually_call_all_tests(self):
+ checker = eipchecks.ProviderCertChecker()
+
+ mc = Mock()
+ checker.run_all(checker=mc)
+ # XXX MVS+
+ #self.assertTrue(mc.download_ca_cert.called, "not called")
+ #self.assertTrue(mc.download_ca_signature.called, "not called")
+ #self.assertTrue(mc.get_ca_signatures.called, "not called")
+ #self.assertTrue(mc.is_there_trust_path.called, "not called")
+
+ # For MVS
+ self.assertTrue(mc.is_there_provider_ca.called, "not called")
+ self.assertTrue(mc.is_https_working.called,
+ "not called")
+ self.assertTrue(mc.check_new_cert_needed.called,
+ "not called")
+
+ # test individual check methods
+
+ def test_is_there_provider_ca(self):
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(
+ checker.is_there_provider_ca())
+
+
+class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase):
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ responses = {
+ '/': ['OK', ''],
+ '/client.cert': [
+ # XXX get sample cert
+ '-----BEGIN CERTIFICATE-----',
+ '-----END CERTIFICATE-----'],
+ '/badclient.cert': [
+ 'BADCERT']}
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def test_is_https_working(self):
+ fetcher = requests
+ uri = "https://%s/" % (self.get_server())
+ # bare requests call. this should just pass (if there is
+ # an https service there).
+ fetcher.get(uri, verify=False)
+ checker = eipchecks.ProviderCertChecker(fetcher=fetcher)
+ self.assertTrue(checker.is_https_working(uri=uri, verify=False))
+
+ # for local debugs, when in doubt
+ #self.assertTrue(checker.is_https_working(uri="https://github.com",
+ #verify=True))
+
+ # for the two checks below, I know they fail because no ca
+ # cert is passed to them, and I know that's the error that
+ # requests return with our implementation.
+ # We're receiving this because our
+ # server is dying prematurely when the handshake is interrupted on the
+ # client side.
+ # Since we have access to the server, we could check that
+ # the error raised has been:
+ # SSL23_READ_BYTES: alert bad certificate
+ with self.assertRaises(requests.exceptions.SSLError) as exc:
+ fetcher.get(uri, verify=True)
+ self.assertTrue(
+ "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)
+ with self.assertRaises(requests.exceptions.SSLError) as exc:
+ checker.is_https_working(uri=uri, verify=True)
+ self.assertTrue(
+ "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)
+
+ # get cacert from testing.https_server
+ cacert = where_cert('cacert.pem')
+ fetcher.get(uri, verify=cacert)
+ self.assertTrue(checker.is_https_working(uri=uri, verify=cacert))
+
+ # same, but get cacert from leap.custom
+ # XXX TODO!
+
+ def test_download_new_client_cert(self):
+ uri = "https://%s/client.cert" % (self.get_server())
+ cacert = where_cert('cacert.pem')
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(checker.download_new_client_cert(
+ uri=uri, verify=cacert))
+
+ # now download a malformed cert
+ uri = "https://%s/badclient.cert" % (self.get_server())
+ cacert = where_cert('cacert.pem')
+ checker = eipchecks.ProviderCertChecker()
+ with self.assertRaises(ValueError):
+ self.assertTrue(checker.download_new_client_cert(
+ uri=uri, verify=cacert))
+
+ # did we write cert to its path?
+ clientcertfile = eipspecs.client_cert_path()
+ self.assertTrue(os.path.isfile(clientcertfile))
+ certfile = eipspecs.client_cert_path()
+ with open(certfile, 'r') as cf:
+ certcontent = cf.read()
+ self.assertEqual(certcontent,
+ '\n'.join(
+ self.request_handler.responses['/client.cert']))
+ os.remove(clientcertfile)
+
+ def test_is_cert_valid(self):
+ checker = eipchecks.ProviderCertChecker()
+ # TODO: better exception catching
+ with self.assertRaises(Exception) as exc:
+ self.assertFalse(checker.is_cert_valid())
+ exc.message = "missing cert"
+
+ def test_check_new_cert_needed(self):
+ # check: missing cert
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(checker.check_new_cert_needed(skip_download=True))
+ # TODO check: malformed cert
+ # TODO check: expired cert
+ # TODO check: pass test server uri instead of skip
+
+
if __name__ == "__main__":
unittest.main()