summaryrefslogtreecommitdiff
path: root/src/leap/eip/tests/test_checks.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/eip/tests/test_checks.py')
-rw-r--r--src/leap/eip/tests/test_checks.py171
1 files changed, 171 insertions, 0 deletions
diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py
index 1e629203..09fdaabf 100644
--- a/src/leap/eip/tests/test_checks.py
+++ b/src/leap/eip/tests/test_checks.py
@@ -1,3 +1,4 @@
+from BaseHTTPServer import BaseHTTPRequestHandler
import copy
import json
try:
@@ -5,6 +6,7 @@ try:
except ImportError:
import unittest
import os
+import urlparse
from mock import patch, Mock
@@ -18,6 +20,17 @@ from leap.eip import specs as eipspecs
from leap.eip import exceptions as eipexceptions
from leap.eip.tests import data as testdata
from leap.testing.basetest import BaseLeapTest
+from leap.testing.https_server import BaseHTTPSServerTestCase
+from leap.testing.https_server import where as where_cert
+
+
+class NoLogRequestHandler:
+ def log_message(self, *args):
+ # don't write log msg to stderr
+ pass
+
+ def read(self, n=None):
+ return ''
class EIPCheckTest(BaseLeapTest):
@@ -157,5 +170,163 @@ class EIPCheckTest(BaseLeapTest):
sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON)
checker.check_complete_eip_config(config=sampleconfig)
+
+class ProviderCertCheckerTest(BaseLeapTest):
+
+ __name__ = "provider_cert_checker_tests"
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ # test methods are there, and can be called from run_all
+
+ def test_checker_should_implement_check_methods(self):
+ checker = eipchecks.ProviderCertChecker()
+
+ # For MVS+
+ self.assertTrue(hasattr(checker, "download_ca_cert"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "download_ca_signature"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "get_ca_signatures"), "missing meth")
+ self.assertTrue(hasattr(checker, "is_there_trust_path"),
+ "missing meth")
+
+ # For MVS
+ self.assertTrue(hasattr(checker, "is_there_provider_ca"),
+ "missing meth")
+ self.assertTrue(hasattr(checker, "is_https_working"), "missing meth")
+ self.assertTrue(hasattr(checker, "check_new_cert_needed"),
+ "missing meth")
+
+ def test_checker_should_actually_call_all_tests(self):
+ checker = eipchecks.ProviderCertChecker()
+
+ mc = Mock()
+ checker.run_all(checker=mc)
+ # XXX MVS+
+ #self.assertTrue(mc.download_ca_cert.called, "not called")
+ #self.assertTrue(mc.download_ca_signature.called, "not called")
+ #self.assertTrue(mc.get_ca_signatures.called, "not called")
+ #self.assertTrue(mc.is_there_trust_path.called, "not called")
+
+ # For MVS
+ self.assertTrue(mc.is_there_provider_ca.called, "not called")
+ self.assertTrue(mc.is_https_working.called,
+ "not called")
+ self.assertTrue(mc.check_new_cert_needed.called,
+ "not called")
+
+ # test individual check methods
+
+ def test_is_there_provider_ca(self):
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(
+ checker.is_there_provider_ca())
+
+
+class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase):
+ class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
+ responses = {
+ '/': ['OK', ''],
+ '/client.cert': [
+ # XXX get sample cert
+ '-----BEGIN CERTIFICATE-----',
+ '-----END CERTIFICATE-----'],
+ '/badclient.cert': [
+ 'BADCERT']}
+
+ def do_GET(self):
+ path = urlparse.urlparse(self.path)
+ message = '\n'.join(self.responses.get(
+ path.path, None))
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(message)
+
+ def test_is_https_working(self):
+ fetcher = requests
+ uri = "https://%s/" % (self.get_server())
+ # bare requests call. this should just pass (if there is
+ # an https service there).
+ fetcher.get(uri, verify=False)
+ checker = eipchecks.ProviderCertChecker(fetcher=fetcher)
+ self.assertTrue(checker.is_https_working(uri=uri, verify=False))
+
+ # for local debugs, when in doubt
+ #self.assertTrue(checker.is_https_working(uri="https://github.com",
+ #verify=True))
+
+ # for the two checks below, I know they fail because no ca
+ # cert is passed to them, and I know that's the error that
+ # requests return with our implementation.
+ # We're receiving this because our
+ # server is dying prematurely when the handshake is interrupted on the
+ # client side.
+ # Since we have access to the server, we could check that
+ # the error raised has been:
+ # SSL23_READ_BYTES: alert bad certificate
+ with self.assertRaises(requests.exceptions.SSLError) as exc:
+ fetcher.get(uri, verify=True)
+ self.assertTrue(
+ "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)
+ with self.assertRaises(requests.exceptions.SSLError) as exc:
+ checker.is_https_working(uri=uri, verify=True)
+ self.assertTrue(
+ "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)
+
+ # get cacert from testing.https_server
+ cacert = where_cert('cacert.pem')
+ fetcher.get(uri, verify=cacert)
+ self.assertTrue(checker.is_https_working(uri=uri, verify=cacert))
+
+ # same, but get cacert from leap.custom
+ # XXX TODO!
+
+ def test_download_new_client_cert(self):
+ uri = "https://%s/client.cert" % (self.get_server())
+ cacert = where_cert('cacert.pem')
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(checker.download_new_client_cert(
+ uri=uri, verify=cacert))
+
+ # now download a malformed cert
+ uri = "https://%s/badclient.cert" % (self.get_server())
+ cacert = where_cert('cacert.pem')
+ checker = eipchecks.ProviderCertChecker()
+ with self.assertRaises(ValueError):
+ self.assertTrue(checker.download_new_client_cert(
+ uri=uri, verify=cacert))
+
+ # did we write cert to its path?
+ clientcertfile = eipspecs.client_cert_path()
+ self.assertTrue(os.path.isfile(clientcertfile))
+ certfile = eipspecs.client_cert_path()
+ with open(certfile, 'r') as cf:
+ certcontent = cf.read()
+ self.assertEqual(certcontent,
+ '\n'.join(
+ self.request_handler.responses['/client.cert']))
+ os.remove(clientcertfile)
+
+ def test_is_cert_valid(self):
+ checker = eipchecks.ProviderCertChecker()
+ # TODO: better exception catching
+ with self.assertRaises(Exception) as exc:
+ self.assertFalse(checker.is_cert_valid())
+ exc.message = "missing cert"
+
+ def test_check_new_cert_needed(self):
+ # check: missing cert
+ checker = eipchecks.ProviderCertChecker()
+ self.assertTrue(checker.check_new_cert_needed(skip_download=True))
+ # TODO check: malformed cert
+ # TODO check: expired cert
+ # TODO check: pass test server uri instead of skip
+
+
if __name__ == "__main__":
unittest.main()