summaryrefslogtreecommitdiff
path: root/src/leap/eip/tests/test_checks.py
blob: 09fdaabf8c5dc629488abbdf276d5b7446bc0189 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
from BaseHTTPServer import BaseHTTPRequestHandler
import copy
import json
try:
    import unittest2 as unittest
except ImportError:
    import unittest
import os
import urlparse

from mock import patch, Mock

import requests

from leap.base import config as baseconfig
from leap.base.constants import (DEFAULT_PROVIDER_DEFINITION,
                                 DEFINITION_EXPECTED_PATH)
from leap.eip import checks as eipchecks
from leap.eip import specs as eipspecs
from leap.eip import exceptions as eipexceptions
from leap.eip.tests import data as testdata
from leap.testing.basetest import BaseLeapTest
from leap.testing.https_server import BaseHTTPSServerTestCase
from leap.testing.https_server import where as where_cert


class NoLogRequestHandler:
    def log_message(self, *args):
        # don't write log msg to stderr
        pass

    def read(self, n=None):
        return ''


class EIPCheckTest(BaseLeapTest):

    __name__ = "eip_check_tests"

    def setUp(self):
        pass

    def tearDown(self):
        pass

    # test methods are there, and can be called from run_all

    def test_checker_should_implement_check_methods(self):
        checker = eipchecks.EIPConfigChecker()

        self.assertTrue(hasattr(checker, "check_default_eipconfig"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "check_is_there_default_provider"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "fetch_definition"), "missing meth")
        self.assertTrue(hasattr(checker, "fetch_eip_service_config"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "check_complete_eip_config"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "ping_gateway"), "missing meth")

    def test_checker_should_actually_call_all_tests(self):
        checker = eipchecks.EIPConfigChecker()

        mc = Mock()
        checker.run_all(checker=mc)
        self.assertTrue(mc.check_default_eipconfig.called, "not called")
        self.assertTrue(mc.check_is_there_default_provider.called,
                        "not called")
        self.assertTrue(mc.fetch_definition.called,
                        "not called")
        self.assertTrue(mc.fetch_eip_service_config.called,
                        "not called")
        self.assertTrue(mc.check_complete_eip_config.called,
                        "not called")
        #self.assertTrue(mc.ping_gateway.called,
                        #"not called")

    # test individual check methods

    def test_check_default_eipconfig(self):
        checker = eipchecks.EIPConfigChecker()
        # no eip config (empty home)
        eipconfig_path = checker.eipconfig.filename
        self.assertFalse(os.path.isfile(eipconfig_path))
        checker.check_default_eipconfig()
        # we've written one, so it should be there.
        self.assertTrue(os.path.isfile(eipconfig_path))
        with open(eipconfig_path, 'rb') as fp:
            deserialized = json.load(fp)

        # force re-evaluation of the paths
        # small workaround for evaluating home dirs correctly
        EIP_SAMPLE_JSON = copy.copy(testdata.EIP_SAMPLE_JSON)
        EIP_SAMPLE_JSON['openvpn_client_certificate'] = \
            eipspecs.client_cert_path()
        EIP_SAMPLE_JSON['openvpn_ca_certificate'] = \
            eipspecs.provider_ca_path()
        self.assertEqual(deserialized, EIP_SAMPLE_JSON)

        # TODO: shold ALSO run validation methods.

    def test_check_is_there_default_provider(self):
        checker = eipchecks.EIPConfigChecker()
        # we do dump a sample eip config, but lacking a
        # default provider entry.
        # This error will be possible catched in a different
        # place, when JSONConfig does validation of required fields.

        # passing direct config
        with self.assertRaises(eipexceptions.EIPMissingDefaultProvider):
            checker.check_is_there_default_provider(config={})

        # ok. now, messing with real files...
        # blank out default_provider
        sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON)
        sampleconfig['provider'] = None
        eipcfg_path = checker.eipconfig.filename
        with open(eipcfg_path, 'w') as fp:
            json.dump(sampleconfig, fp)
        with self.assertRaises(eipexceptions.EIPMissingDefaultProvider):
            checker.eipconfig.load(fromfile=eipcfg_path)
            checker.check_is_there_default_provider()

        sampleconfig = testdata.EIP_SAMPLE_JSON
        #eipcfg_path = checker._get_default_eipconfig_path()
        with open(eipcfg_path, 'w') as fp:
            json.dump(sampleconfig, fp)
        checker.eipconfig.load()
        self.assertTrue(checker.check_is_there_default_provider())

    def test_fetch_definition(self):
        with patch.object(requests, "get") as mocked_get:
            mocked_get.return_value.status_code = 200
            mocked_get.return_value.json = DEFAULT_PROVIDER_DEFINITION
            checker = eipchecks.EIPConfigChecker(fetcher=requests)
            sampleconfig = testdata.EIP_SAMPLE_JSON
            checker.fetch_definition(config=sampleconfig)

        fn = os.path.join(baseconfig.get_default_provider_path(),
                          DEFINITION_EXPECTED_PATH)
        with open(fn, 'r') as fp:
            deserialized = json.load(fp)
        self.assertEqual(DEFAULT_PROVIDER_DEFINITION, deserialized)

        # XXX TODO check for ConnectionError, HTTPError, InvalidUrl
        # (and proper EIPExceptions are raised).
        # Look at base.test_config.

    def test_fetch_eip_service_config(self):
        with patch.object(requests, "get") as mocked_get:
            mocked_get.return_value.status_code = 200
            mocked_get.return_value.json = testdata.EIP_SAMPLE_SERVICE
            checker = eipchecks.EIPConfigChecker(fetcher=requests)
            sampleconfig = testdata.EIP_SAMPLE_JSON
            checker.fetch_eip_service_config(config=sampleconfig)

    def test_check_complete_eip_config(self):
        checker = eipchecks.EIPConfigChecker()
        with self.assertRaises(eipexceptions.EIPConfigurationError):
            sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON)
            sampleconfig['provider'] = None
            checker.check_complete_eip_config(config=sampleconfig)
        with self.assertRaises(eipexceptions.EIPConfigurationError):
            sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON)
            del sampleconfig['provider']
            checker.check_complete_eip_config(config=sampleconfig)

        # normal case
        sampleconfig = copy.copy(testdata.EIP_SAMPLE_JSON)
        checker.check_complete_eip_config(config=sampleconfig)


class ProviderCertCheckerTest(BaseLeapTest):

    __name__ = "provider_cert_checker_tests"

    def setUp(self):
        pass

    def tearDown(self):
        pass

    # test methods are there, and can be called from run_all

    def test_checker_should_implement_check_methods(self):
        checker = eipchecks.ProviderCertChecker()

        # For MVS+
        self.assertTrue(hasattr(checker, "download_ca_cert"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "download_ca_signature"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "get_ca_signatures"), "missing meth")
        self.assertTrue(hasattr(checker, "is_there_trust_path"),
                        "missing meth")

        # For MVS
        self.assertTrue(hasattr(checker, "is_there_provider_ca"),
                        "missing meth")
        self.assertTrue(hasattr(checker, "is_https_working"), "missing meth")
        self.assertTrue(hasattr(checker, "check_new_cert_needed"),
                        "missing meth")

    def test_checker_should_actually_call_all_tests(self):
        checker = eipchecks.ProviderCertChecker()

        mc = Mock()
        checker.run_all(checker=mc)
        # XXX MVS+
        #self.assertTrue(mc.download_ca_cert.called, "not called")
        #self.assertTrue(mc.download_ca_signature.called, "not called")
        #self.assertTrue(mc.get_ca_signatures.called, "not called")
        #self.assertTrue(mc.is_there_trust_path.called, "not called")

        # For MVS
        self.assertTrue(mc.is_there_provider_ca.called, "not called")
        self.assertTrue(mc.is_https_working.called,
                        "not called")
        self.assertTrue(mc.check_new_cert_needed.called,
                        "not called")

    # test individual check methods

    def test_is_there_provider_ca(self):
        checker = eipchecks.ProviderCertChecker()
        self.assertTrue(
            checker.is_there_provider_ca())


class ProviderCertCheckerHTTPSTests(BaseHTTPSServerTestCase):
    class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
        responses = {
            '/': ['OK', ''],
            '/client.cert': [
                # XXX get sample cert
                '-----BEGIN CERTIFICATE-----',
                '-----END CERTIFICATE-----'],
            '/badclient.cert': [
                'BADCERT']}

        def do_GET(self):
            path = urlparse.urlparse(self.path)
            message = '\n'.join(self.responses.get(
                path.path, None))
            self.send_response(200)
            self.end_headers()
            self.wfile.write(message)

    def test_is_https_working(self):
        fetcher = requests
        uri = "https://%s/" % (self.get_server())
        # bare requests call. this should just pass (if there is
        # an https service there).
        fetcher.get(uri, verify=False)
        checker = eipchecks.ProviderCertChecker(fetcher=fetcher)
        self.assertTrue(checker.is_https_working(uri=uri, verify=False))

        # for local debugs, when in doubt
        #self.assertTrue(checker.is_https_working(uri="https://github.com",
                        #verify=True))

        # for the two checks below, I know they fail because no ca
        # cert is passed to them, and I know that's the error that
        # requests return with our implementation.
        # We're receiving this because our
        # server is dying prematurely when the handshake is interrupted on the
        # client side.
        # Since we have access to the server, we could check that
        # the error raised has been:
        # SSL23_READ_BYTES: alert bad certificate
        with self.assertRaises(requests.exceptions.SSLError) as exc:
            fetcher.get(uri, verify=True)
            self.assertTrue(
                "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)
        with self.assertRaises(requests.exceptions.SSLError) as exc:
            checker.is_https_working(uri=uri, verify=True)
            self.assertTrue(
                "SSL23_GET_SERVER_HELLO:unknown protocol" in exc.message)

        # get cacert from testing.https_server
        cacert = where_cert('cacert.pem')
        fetcher.get(uri, verify=cacert)
        self.assertTrue(checker.is_https_working(uri=uri, verify=cacert))

        # same, but get cacert from leap.custom
        # XXX TODO!

    def test_download_new_client_cert(self):
        uri = "https://%s/client.cert" % (self.get_server())
        cacert = where_cert('cacert.pem')
        checker = eipchecks.ProviderCertChecker()
        self.assertTrue(checker.download_new_client_cert(
                        uri=uri, verify=cacert))

        # now download a malformed cert
        uri = "https://%s/badclient.cert" % (self.get_server())
        cacert = where_cert('cacert.pem')
        checker = eipchecks.ProviderCertChecker()
        with self.assertRaises(ValueError):
            self.assertTrue(checker.download_new_client_cert(
                            uri=uri, verify=cacert))

        # did we write cert to its path?
        clientcertfile = eipspecs.client_cert_path()
        self.assertTrue(os.path.isfile(clientcertfile))
        certfile = eipspecs.client_cert_path()
        with open(certfile, 'r') as cf:
            certcontent = cf.read()
        self.assertEqual(certcontent,
                         '\n'.join(
                             self.request_handler.responses['/client.cert']))
        os.remove(clientcertfile)

    def test_is_cert_valid(self):
        checker = eipchecks.ProviderCertChecker()
        # TODO: better exception catching
        with self.assertRaises(Exception) as exc:
            self.assertFalse(checker.is_cert_valid())
            exc.message = "missing cert"

    def test_check_new_cert_needed(self):
        # check: missing cert
        checker = eipchecks.ProviderCertChecker()
        self.assertTrue(checker.check_new_cert_needed(skip_download=True))
        # TODO check: malformed cert
        # TODO check: expired cert
        # TODO check: pass test server uri instead of skip


if __name__ == "__main__":
    unittest.main()