summaryrefslogtreecommitdiff
path: root/src/leap/bitmask/provider
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/bitmask/provider')
-rw-r--r--src/leap/bitmask/provider/__init__.py34
-rw-r--r--src/leap/bitmask/provider/providerbootstrapper.py370
-rw-r--r--src/leap/bitmask/provider/tests/__init__.py0
-rw-r--r--src/leap/bitmask/provider/tests/test_providerbootstrapper.py547
4 files changed, 951 insertions, 0 deletions
diff --git a/src/leap/bitmask/provider/__init__.py b/src/leap/bitmask/provider/__init__.py
index e69de29b..53587d65 100644
--- a/src/leap/bitmask/provider/__init__.py
+++ b/src/leap/bitmask/provider/__init__.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# __init.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Module initialization for leap.bitmask.provider
+"""
+import os
+from leap.common.check import leap_assert
+
+
+def get_provider_path(domain):
+ """
+ Returns relative path for provider config.
+
+ :param domain: the domain to which this providerconfig belongs to.
+ :type domain: str
+ :returns: the path
+ :rtype: str
+ """
+ leap_assert(domain is not None, "get_provider_path: We need a domain")
+ return os.path.join("leap", "providers", domain, "provider.json")
diff --git a/src/leap/bitmask/provider/providerbootstrapper.py b/src/leap/bitmask/provider/providerbootstrapper.py
new file mode 100644
index 00000000..1b5947e1
--- /dev/null
+++ b/src/leap/bitmask/provider/providerbootstrapper.py
@@ -0,0 +1,370 @@
+# -*- coding: utf-8 -*-
+# providerbootstrapper.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Provider bootstrapping
+"""
+import logging
+import socket
+import os
+
+import requests
+
+from PySide import QtCore
+
+from leap.bitmask.config.providerconfig import ProviderConfig, MissingCACert
+from leap.bitmask.util.request_helpers import get_content
+from leap.bitmask import util
+from leap.bitmask.util.constants import REQUEST_TIMEOUT
+from leap.bitmask.services.abstractbootstrapper import AbstractBootstrapper
+from leap.bitmask.provider.supportedapis import SupportedAPIs
+from leap.common import ca_bundle
+from leap.common.certs import get_digest
+from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p
+from leap.common.check import leap_assert, leap_assert_type, leap_check
+
+logger = logging.getLogger(__name__)
+
+
+class UnsupportedProviderAPI(Exception):
+ """
+ Raised when attempting to use a provider with an incompatible API.
+ """
+ pass
+
+
+class WrongFingerprint(Exception):
+ """
+ Raised when a fingerprint comparison does not match.
+ """
+ pass
+
+
+class ProviderBootstrapper(AbstractBootstrapper):
+ """
+ Given a provider URL performs a series of checks and emits signals
+ after they are passed.
+ If a check fails, the subsequent checks are not executed
+ """
+
+ # All dicts returned are of the form
+ # {"passed": bool, "error": str}
+ name_resolution = QtCore.Signal(dict)
+ https_connection = QtCore.Signal(dict)
+ download_provider_info = QtCore.Signal(dict)
+
+ download_ca_cert = QtCore.Signal(dict)
+ check_ca_fingerprint = QtCore.Signal(dict)
+ check_api_certificate = QtCore.Signal(dict)
+
+ def __init__(self, bypass_checks=False):
+ """
+ Constructor for provider bootstrapper object
+
+ :param bypass_checks: Set to true if the app should bypass
+ first round of checks for CA certificates at bootstrap
+ :type bypass_checks: bool
+ """
+ AbstractBootstrapper.__init__(self, bypass_checks)
+
+ self._domain = None
+ self._provider_config = None
+ self._download_if_needed = False
+
+ @property
+ def verify(self):
+ """
+ Verify parameter for requests.
+
+ :returns: either False, if checks are skipped, or the
+ path to the ca bundle.
+ :rtype: bool or str
+ """
+ if self._bypass_checks:
+ verify = False
+ else:
+ verify = ca_bundle.where()
+ return verify
+
+ def _check_name_resolution(self):
+ """
+ Checks that the name resolution for the provider name works
+ """
+ leap_assert(self._domain, "Cannot check DNS without a domain")
+ logger.debug("Checking name resolution for %s" % (self._domain))
+
+ # We don't skip this check, since it's basic for the whole
+ # system to work
+ # err --- but we can do it after a failure, to diagnose what went
+ # wrong. Right now we're just adding connection overhead. -- kali
+ socket.gethostbyname(self._domain)
+
+ def _check_https(self, *args):
+ """
+ Checks that https is working and that the provided certificate
+ checks out
+ """
+ leap_assert(self._domain, "Cannot check HTTPS without a domain")
+ logger.debug("Checking https for %s" % (self._domain))
+
+ # We don't skip this check, since it's basic for the whole
+ # system to work.
+ # err --- but we can do it after a failure, to diagnose what went
+ # wrong. Right now we're just adding connection overhead. -- kali
+
+ try:
+ res = self._session.get("https://%s" % (self._domain,),
+ verify=self.verify,
+ timeout=REQUEST_TIMEOUT)
+ res.raise_for_status()
+ except requests.exceptions.SSLError as exc:
+ logger.exception(exc)
+ self._err_msg = self.tr("Provider certificate could "
+ "not be verified")
+ raise
+ except Exception as exc:
+ # XXX careful!. The error might be also a SSL handshake
+ # timeout error, in which case we should retry a couple of times
+ # more, for cases where the ssl server gives high latencies.
+ logger.exception(exc)
+ self._err_msg = self.tr("Provider does not support HTTPS")
+ raise
+
+ def _download_provider_info(self, *args):
+ """
+ Downloads the provider.json defition
+ """
+ leap_assert(self._domain,
+ "Cannot download provider info without a domain")
+ logger.debug("Downloading provider info for %s" % (self._domain))
+
+ # --------------------------------------------------------------
+ # TODO factor out with the download routines in services.
+ # Watch out! We're handling the verify paramenter differently here.
+
+ headers = {}
+ provider_json = os.path.join(util.get_path_prefix(),
+ "leap",
+ "providers",
+ self._domain, "provider.json")
+ mtime = get_mtime(provider_json)
+
+ if self._download_if_needed and mtime:
+ headers['if-modified-since'] = mtime
+
+ uri = "https://%s/%s" % (self._domain, "provider.json")
+ verify = self.verify
+
+ if mtime: # the provider.json exists
+ # So, we're getting it from the api.* and checking against
+ # the provider ca.
+ try:
+ provider_config = ProviderConfig()
+ provider_config.load(provider_json)
+ uri = provider_config.get_api_uri() + '/provider.json'
+ verify = provider_config.get_ca_cert_path()
+ except MissingCACert:
+ # no ca? then download from main domain again.
+ pass
+
+ logger.debug("Requesting for provider.json... "
+ "uri: {0}, verify: {1}, headers: {2}".format(
+ uri, verify, headers))
+ res = self._session.get(uri, verify=verify,
+ headers=headers, timeout=REQUEST_TIMEOUT)
+ res.raise_for_status()
+ logger.debug("Request status code: {0}".format(res.status_code))
+
+ # Not modified
+ if res.status_code == 304:
+ logger.debug("Provider definition has not been modified")
+ # --------------------------------------------------------------
+ # end refactor, more or less...
+ # XXX Watch out, have to check the supported api yet.
+ else:
+ provider_definition, mtime = get_content(res)
+
+ provider_config = ProviderConfig()
+ provider_config.load(data=provider_definition, mtime=mtime)
+ provider_config.save(["leap",
+ "providers",
+ self._domain,
+ "provider.json"])
+
+ api_version = provider_config.get_api_version()
+ if SupportedAPIs.supports(api_version):
+ logger.debug("Provider definition has been modified")
+ else:
+ api_supported = ', '.join(SupportedAPIs.SUPPORTED_APIS)
+ error = ('Unsupported provider API version. '
+ 'Supported versions are: {0}. '
+ 'Found: {1}.').format(api_supported, api_version)
+
+ logger.error(error)
+ raise UnsupportedProviderAPI(error)
+
+ def run_provider_select_checks(self, domain, download_if_needed=False):
+ """
+ Populates the check queue.
+
+ :param domain: domain to check
+ :type domain: str
+
+ :param download_if_needed: if True, makes the checks do not
+ overwrite already downloaded data
+ :type download_if_needed: bool
+ """
+ leap_assert(domain and len(domain) > 0, "We need a domain!")
+
+ self._domain = ProviderConfig.sanitize_path_component(domain)
+ self._download_if_needed = download_if_needed
+
+ cb_chain = [
+ (self._check_name_resolution, self.name_resolution),
+ (self._check_https, self.https_connection),
+ (self._download_provider_info, self.download_provider_info)
+ ]
+
+ return self.addCallbackChain(cb_chain)
+
+ def _should_proceed_cert(self):
+ """
+ Returns False if the certificate already exists for the given
+ provider. True otherwise
+
+ :rtype: bool
+ """
+ leap_assert(self._provider_config, "We need a provider config!")
+
+ if not self._download_if_needed:
+ return True
+
+ return not os.path.exists(self._provider_config
+ .get_ca_cert_path(about_to_download=True))
+
+ def _download_ca_cert(self, *args):
+ """
+ Downloads the CA cert that is going to be used for the api URL
+ """
+ # XXX maybe we can skip this step if
+ # we have a fresh one.
+ leap_assert(self._provider_config, "Cannot download the ca cert "
+ "without a provider config!")
+
+ logger.debug("Downloading ca cert for %s at %s" %
+ (self._domain, self._provider_config.get_ca_cert_uri()))
+
+ if not self._should_proceed_cert():
+ check_and_fix_urw_only(
+ self._provider_config
+ .get_ca_cert_path(about_to_download=True))
+ return
+
+ res = self._session.get(self._provider_config.get_ca_cert_uri(),
+ verify=self.verify,
+ timeout=REQUEST_TIMEOUT)
+ res.raise_for_status()
+
+ cert_path = self._provider_config.get_ca_cert_path(
+ about_to_download=True)
+ cert_dir = os.path.dirname(cert_path)
+ mkdir_p(cert_dir)
+ with open(cert_path, "w") as f:
+ f.write(res.content)
+
+ check_and_fix_urw_only(cert_path)
+
+ def _check_ca_fingerprint(self, *args):
+ """
+ Checks the CA cert fingerprint against the one provided in the
+ json definition
+ """
+ leap_assert(self._provider_config, "Cannot check the ca cert "
+ "without a provider config!")
+
+ logger.debug("Checking ca fingerprint for %s and cert %s" %
+ (self._domain,
+ self._provider_config.get_ca_cert_path()))
+
+ if not self._should_proceed_cert():
+ return
+
+ parts = self._provider_config.get_ca_cert_fingerprint().split(":")
+
+ error_msg = "Wrong fingerprint format"
+ leap_check(len(parts) == 2, error_msg, WrongFingerprint)
+
+ method = parts[0].strip()
+ fingerprint = parts[1].strip()
+ cert_data = None
+ with open(self._provider_config.get_ca_cert_path()) as f:
+ cert_data = f.read()
+
+ leap_assert(len(cert_data) > 0, "Could not read certificate data")
+ digest = get_digest(cert_data, method)
+
+ error_msg = "Downloaded certificate has a different fingerprint!"
+ leap_check(digest == fingerprint, error_msg, WrongFingerprint)
+
+ def _check_api_certificate(self, *args):
+ """
+ Tries to make an API call with the downloaded cert and checks
+ if it validates against it
+ """
+ leap_assert(self._provider_config, "Cannot check the ca cert "
+ "without a provider config!")
+
+ logger.debug("Checking api certificate for %s and cert %s" %
+ (self._provider_config.get_api_uri(),
+ self._provider_config.get_ca_cert_path()))
+
+ if not self._should_proceed_cert():
+ return
+
+ test_uri = "%s/%s/cert" % (self._provider_config.get_api_uri(),
+ self._provider_config.get_api_version())
+ res = self._session.get(test_uri,
+ verify=self._provider_config
+ .get_ca_cert_path(),
+ timeout=REQUEST_TIMEOUT)
+ res.raise_for_status()
+
+ def run_provider_setup_checks(self,
+ provider_config,
+ download_if_needed=False):
+ """
+ Starts the checks needed for a new provider setup.
+
+ :param provider_config: Provider configuration
+ :type provider_config: ProviderConfig
+
+ :param download_if_needed: if True, makes the checks do not
+ overwrite already downloaded data.
+ :type download_if_needed: bool
+ """
+ leap_assert(provider_config, "We need a provider config!")
+ leap_assert_type(provider_config, ProviderConfig)
+
+ self._provider_config = provider_config
+ self._download_if_needed = download_if_needed
+
+ cb_chain = [
+ (self._download_ca_cert, self.download_ca_cert),
+ (self._check_ca_fingerprint, self.check_ca_fingerprint),
+ (self._check_api_certificate, self.check_api_certificate)
+ ]
+
+ return self.addCallbackChain(cb_chain)
diff --git a/src/leap/bitmask/provider/tests/__init__.py b/src/leap/bitmask/provider/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/leap/bitmask/provider/tests/__init__.py
diff --git a/src/leap/bitmask/provider/tests/test_providerbootstrapper.py b/src/leap/bitmask/provider/tests/test_providerbootstrapper.py
new file mode 100644
index 00000000..88a4ff0b
--- /dev/null
+++ b/src/leap/bitmask/provider/tests/test_providerbootstrapper.py
@@ -0,0 +1,547 @@
+# -*- coding: utf-8 -*-
+# test_providerbootstrapper.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for the Provider Boostrapper checks
+
+These will be whitebox tests since we want to make sure the private
+implementation is checking what we expect.
+"""
+import os
+import mock
+import socket
+import stat
+import tempfile
+import time
+import requests
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from nose.twistedtools import deferred, reactor
+from twisted.internet import threads
+from requests.models import Response
+
+from leap.bitmask.config.providerconfig import ProviderConfig
+from leap.bitmask.crypto.tests import fake_provider
+from leap.bitmask.provider.providerbootstrapper import ProviderBootstrapper
+from leap.bitmask.provider.providerbootstrapper import UnsupportedProviderAPI
+from leap.bitmask.provider.providerbootstrapper import WrongFingerprint
+from leap.bitmask.provider.supportedapis import SupportedAPIs
+from leap.bitmask import util
+from leap.common.files import mkdir_p
+from leap.common.testing.https_server import where
+from leap.common.testing.basetest import BaseLeapTest
+
+
+class ProviderBootstrapperTest(BaseLeapTest):
+ def setUp(self):
+ self.pb = ProviderBootstrapper()
+
+ def tearDown(self):
+ pass
+
+ def test_name_resolution_check(self):
+ # Something highly likely to success
+ self.pb._domain = "google.com"
+ self.pb._check_name_resolution()
+ # Something highly likely to fail
+ self.pb._domain = "uquhqweuihowquie.abc.def"
+
+ # In python 2.7.4 raises socket.error
+ # In python 2.7.5 raises socket.gaierror
+ with self.assertRaises((socket.gaierror, socket.error)):
+ self.pb._check_name_resolution()
+
+ @deferred()
+ def test_run_provider_select_checks(self):
+ self.pb._check_name_resolution = mock.MagicMock()
+ self.pb._check_https = mock.MagicMock()
+ self.pb._download_provider_info = mock.MagicMock()
+
+ d = self.pb.run_provider_select_checks("somedomain")
+
+ def check(*args):
+ self.pb._check_name_resolution.assert_called_once_with()
+ self.pb._check_https.assert_called_once_with(None)
+ self.pb._download_provider_info.assert_called_once_with(None)
+ d.addCallback(check)
+ return d
+
+ @deferred()
+ def test_run_provider_setup_checks(self):
+ self.pb._download_ca_cert = mock.MagicMock()
+ self.pb._check_ca_fingerprint = mock.MagicMock()
+ self.pb._check_api_certificate = mock.MagicMock()
+
+ d = self.pb.run_provider_setup_checks(ProviderConfig())
+
+ def check(*args):
+ self.pb._download_ca_cert.assert_called_once_with()
+ self.pb._check_ca_fingerprint.assert_called_once_with(None)
+ self.pb._check_api_certificate.assert_called_once_with(None)
+ d.addCallback(check)
+ return d
+
+ def test_should_proceed_cert(self):
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=where("cacert.pem"))
+
+ self.pb._download_if_needed = False
+ self.assertTrue(self.pb._should_proceed_cert())
+
+ self.pb._download_if_needed = True
+ self.assertFalse(self.pb._should_proceed_cert())
+
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=where("somefilethatdoesntexist.pem"))
+ self.assertTrue(self.pb._should_proceed_cert())
+
+ def _check_download_ca_cert(self, should_proceed):
+ """
+ Helper to check different paths easily for the download ca
+ cert check
+
+ :param should_proceed: sets the _should_proceed_cert in the
+ provider bootstrapper being tested
+ :type should_proceed: bool
+
+ :returns: The contents of the certificate, the expected
+ content depending on should_proceed, and the mode of
+ the file to be checked by the caller
+ :rtype: tuple of str, str, int
+ """
+ old_content = "NOT THE NEW CERT"
+ new_content = "NEW CERT"
+ new_cert_path = os.path.join(tempfile.mkdtemp(),
+ "mynewcert.pem")
+
+ with open(new_cert_path, "w") as c:
+ c.write(old_content)
+
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=new_cert_path)
+ self.pb._domain = "somedomain"
+
+ self.pb._should_proceed_cert = mock.MagicMock(
+ return_value=should_proceed)
+
+ read = None
+ content_to_check = None
+ mode = None
+
+ with mock.patch('requests.models.Response.content',
+ new_callable=mock.PropertyMock) as \
+ content:
+ content.return_value = new_content
+ response_obj = Response()
+ response_obj.raise_for_status = mock.MagicMock()
+
+ self.pb._session.get = mock.MagicMock(return_value=response_obj)
+ self.pb._download_ca_cert()
+ with open(new_cert_path, "r") as nc:
+ read = nc.read()
+ if should_proceed:
+ content_to_check = new_content
+ else:
+ content_to_check = old_content
+ mode = stat.S_IMODE(os.stat(new_cert_path).st_mode)
+
+ os.unlink(new_cert_path)
+ return read, content_to_check, mode
+
+ def test_download_ca_cert_no_saving(self):
+ read, expected_read, mode = self._check_download_ca_cert(False)
+ self.assertEqual(read, expected_read)
+ self.assertEqual(mode, int("600", 8))
+
+ def test_download_ca_cert_saving(self):
+ read, expected_read, mode = self._check_download_ca_cert(True)
+ self.assertEqual(read, expected_read)
+ self.assertEqual(mode, int("600", 8))
+
+ def test_check_ca_fingerprint_skips(self):
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
+ return_value="")
+ self.pb._domain = "somedomain"
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=False)
+
+ self.pb._check_ca_fingerprint()
+ self.assertFalse(self.pb._provider_config.
+ get_ca_cert_fingerprint.called)
+
+ def test_check_ca_cert_fingerprint_raises_bad_format(self):
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
+ return_value="wrongfprformat!!")
+ self.pb._domain = "somedomain"
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ with self.assertRaises(WrongFingerprint):
+ self.pb._check_ca_fingerprint()
+
+ # This two hashes different in the last byte, but that's good enough
+ # for the tests
+ KNOWN_BAD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034efe" \
+ "7dd1b910062ca323eb4da5c7f"
+ KNOWN_GOOD_HASH = "SHA256: 0f17c033115f6b76ff67871872303ff65034ef" \
+ "e7dd1b910062ca323eb4da5c7e"
+ KNOWN_GOOD_CERT = """
+-----BEGIN CERTIFICATE-----
+MIIFbzCCA1egAwIBAgIBATANBgkqhkiG9w0BAQ0FADBKMRgwFgYDVQQDDA9CaXRt
+YXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNVBAsME2h0dHBzOi8v
+Yml0bWFzay5uZXQwHhcNMTIxMTA2MDAwMDAwWhcNMjIxMTA2MDAwMDAwWjBKMRgw
+FgYDVQQDDA9CaXRtYXNrIFJvb3QgQ0ExEDAOBgNVBAoMB0JpdG1hc2sxHDAaBgNV
+BAsME2h0dHBzOi8vYml0bWFzay5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAw
+ggIKAoICAQC1eV4YvayaU+maJbWrD4OHo3d7S1BtDlcvkIRS1Fw3iYDjsyDkZxai
+dHp4EUasfNQ+EVtXUvtk6170EmLco6Elg8SJBQ27trE6nielPRPCfX3fQzETRfvB
+7tNvGw4Jn2YKiYoMD79kkjgyZjkJ2r/bEHUSevmR09BRp86syHZerdNGpXYhcQ84
+CA1+V+603GFIHnrP+uQDdssW93rgDNYu+exT+Wj6STfnUkugyjmPRPjL7wh0tzy+
+znCeLl4xiV3g9sjPnc7r2EQKd5uaTe3j71sDPF92KRk0SSUndREz+B1+Dbe/RGk4
+MEqGFuOzrtsgEhPIX0hplhb0Tgz/rtug+yTT7oJjBa3u20AAOQ38/M99EfdeJvc4
+lPFF1XBBLh6X9UKF72an2NuANiX6XPySnJgZ7nZ09RiYZqVwu/qt3DfvLfhboq+0
+bQvLUPXrVDr70onv5UDjpmEA/cLmaIqqrduuTkFZOym65/PfAPvpGnt7crQj/Ibl
+DEDYZQmP7AS+6zBjoOzNjUGE5r40zWAR1RSi7zliXTu+yfsjXUIhUAWmYR6J3KxB
+lfsiHBQ+8dn9kC3YrUexWoOqBiqJOAJzZh5Y1tqgzfh+2nmHSB2dsQRs7rDRRlyy
+YMbkpzL9ZsOUO2eTP1mmar6YjCN+rggYjRrX71K2SpBG6b1zZxOG+wIDAQABo2Aw
+XjAdBgNVHQ4EFgQUuYGDLL2sswnYpHHvProt1JU+D48wDgYDVR0PAQH/BAQDAgIE
+MAwGA1UdEwQFMAMBAf8wHwYDVR0jBBgwFoAUuYGDLL2sswnYpHHvProt1JU+D48w
+DQYJKoZIhvcNAQENBQADggIBADeG67vaFcbITGpi51264kHPYPEWaXUa5XYbtmBl
+cXYyB6hY5hv/YNuVGJ1gWsDmdeXEyj0j2icGQjYdHRfwhrbEri+h1EZOm1cSBDuY
+k/P5+ctHyOXx8IE79DBsZ6IL61UKIaKhqZBfLGYcWu17DVV6+LT+AKtHhOrv3TSj
+RnAcKnCbKqXLhUPXpK0eTjPYS2zQGQGIhIy9sQXVXJJJsGrPgMxna1Xw2JikBOCG
+htD/JKwt6xBmNwktH0GI/LVtVgSp82Clbn9C4eZN9E5YbVYjLkIEDhpByeC71QhX
+EIQ0ZR56bFuJA/CwValBqV/G9gscTPQqd+iETp8yrFpAVHOW+YzSFbxjTEkBte1J
+aF0vmbqdMAWLk+LEFPQRptZh0B88igtx6tV5oVd+p5IVRM49poLhuPNJGPvMj99l
+mlZ4+AeRUnbOOeAEuvpLJbel4rhwFzmUiGoeTVoPZyMevWcVFq6BMkS+jRR2w0jK
+G6b0v5XDHlcFYPOgUrtsOBFJVwbutLvxdk6q37kIFnWCd8L3kmES5q4wjyFK47Co
+Ja8zlx64jmMZPg/t3wWqkZgXZ14qnbyG5/lGsj5CwVtfDljrhN0oCWK1FZaUmW3d
+69db12/g4f6phldhxiWuGC/W6fCW5kre7nmhshcltqAJJuU47iX+DarBFiIj816e
+yV8e
+-----END CERTIFICATE-----
+"""
+
+ def _prepare_provider_config_with(self, cert_path, cert_hash):
+ """
+ Mocks the provider config to give the cert_path and cert_hash
+ specified
+
+ :param cert_path: path for the certificate
+ :type cert_path: str
+ :param cert_hash: hash for the certificate as it would appear
+ in the provider config json
+ :type cert_hash: str
+ """
+ self.pb._provider_config = mock.Mock()
+ self.pb._provider_config.get_ca_cert_fingerprint = mock.MagicMock(
+ return_value=cert_hash)
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=cert_path)
+ self.pb._domain = "somedomain"
+
+ def test_check_ca_fingerprint_checksout(self):
+ cert_path = os.path.join(tempfile.mkdtemp(),
+ "mynewcert.pem")
+
+ with open(cert_path, "w") as c:
+ c.write(self.KNOWN_GOOD_CERT)
+
+ self._prepare_provider_config_with(cert_path, self.KNOWN_GOOD_HASH)
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ self.pb._check_ca_fingerprint()
+
+ os.unlink(cert_path)
+
+ def test_check_ca_fingerprint_fails(self):
+ cert_path = os.path.join(tempfile.mkdtemp(),
+ "mynewcert.pem")
+
+ with open(cert_path, "w") as c:
+ c.write(self.KNOWN_GOOD_CERT)
+
+ self._prepare_provider_config_with(cert_path, self.KNOWN_BAD_HASH)
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ with self.assertRaises(WrongFingerprint):
+ self.pb._check_ca_fingerprint()
+
+ os.unlink(cert_path)
+
+
+###############################################################################
+# Tests with a fake provider #
+###############################################################################
+
+class ProviderBootstrapperActiveTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ factory = fake_provider.get_provider_factory()
+ http = reactor.listenTCP(8002, factory)
+ https = reactor.listenSSL(
+ 0, factory,
+ fake_provider.OpenSSLServerContextFactory())
+ get_port = lambda p: p.getHost().port
+ cls.http_port = get_port(http)
+ cls.https_port = get_port(https)
+
+ def setUp(self):
+ self.pb = ProviderBootstrapper()
+
+ # At certain points we are going to be replacing these methods
+ # directly in ProviderConfig to be able to catch calls from
+ # new ProviderConfig objects inside the methods tested. We
+ # need to save the old implementation and restore it in
+ # tearDown so we are sure everything is as expected for each
+ # test. If we do it inside each specific test, a failure in
+ # the test will leave the implementation with the mock.
+ self.old_gpp = util.get_path_prefix
+
+ self.old_load = ProviderConfig.load
+ self.old_save = ProviderConfig.save
+ self.old_api_version = ProviderConfig.get_api_version
+ self.old_api_uri = ProviderConfig.get_api_uri
+
+ def tearDown(self):
+ util.get_path_prefix = self.old_gpp
+ ProviderConfig.load = self.old_load
+ ProviderConfig.save = self.old_save
+ ProviderConfig.get_api_version = self.old_api_version
+ ProviderConfig.get_api_uri = self.old_api_uri
+
+ def test_check_https_succeeds(self):
+ # XXX: Need a proper CA signed cert to test this
+ pass
+
+ @deferred()
+ def test_check_https_fails(self):
+ self.pb._domain = "localhost:%s" % (self.https_port,)
+
+ def check(*args):
+ with self.assertRaises(requests.exceptions.SSLError):
+ self.pb._check_https()
+ return threads.deferToThread(check)
+
+ @deferred()
+ def test_second_check_https_fails(self):
+ self.pb._domain = "localhost:1234"
+
+ def check(*args):
+ with self.assertRaises(Exception):
+ self.pb._check_https()
+ return threads.deferToThread(check)
+
+ @deferred()
+ def test_check_https_succeeds_if_danger(self):
+ self.pb._domain = "localhost:%s" % (self.https_port,)
+ self.pb._bypass_checks = True
+
+ def check(*args):
+ self.pb._check_https()
+
+ return threads.deferToThread(check)
+
+ def _setup_provider_config_with(self, api, path_prefix):
+ """
+ Sets up the ProviderConfig with mocks for the path prefix, the
+ api returned and load/save methods.
+ It modifies ProviderConfig directly instead of an object
+ because the object used is created in the method itself and we
+ cannot control that.
+
+ :param api: API to return
+ :type api: str
+ :param path_prefix: path prefix to be used when calculating
+ paths
+ :type path_prefix: str
+ """
+ util.get_path_prefix = mock.MagicMock(return_value=path_prefix)
+ ProviderConfig.get_api_version = mock.MagicMock(
+ return_value=api)
+ ProviderConfig.get_api_uri = mock.MagicMock(
+ return_value="https://localhost:%s" % (self.https_port,))
+ ProviderConfig.load = mock.MagicMock()
+ ProviderConfig.save = mock.MagicMock()
+
+ def _setup_providerbootstrapper(self, ifneeded):
+ """
+ Sets the provider bootstrapper's domain to
+ localhost:https_port, sets it to bypass https checks and sets
+ the download if needed based on the ifneeded value.
+
+ :param ifneeded: Value for _download_if_needed
+ :type ifneeded: bool
+ """
+ self.pb._domain = "localhost:%s" % (self.https_port,)
+ self.pb._bypass_checks = True
+ self.pb._download_if_needed = ifneeded
+
+ def _produce_dummy_provider_json(self):
+ """
+ Creates a dummy provider json on disk in order to test
+ behaviour around it (download if newer online, etc)
+
+ :returns: the provider.json path used
+ :rtype: str
+ """
+ provider_dir = os.path.join(util.get_path_prefix(),
+ "leap", "providers",
+ self.pb._domain)
+ mkdir_p(provider_dir)
+ provider_path = os.path.join(provider_dir,
+ "provider.json")
+
+ with open(provider_path, "w") as p:
+ p.write("A")
+ return provider_path
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_domain',
+ lambda x: where('testdomain.com'))
+ def test_download_provider_info_new_provider(self):
+ self._setup_provider_config_with("1", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(True)
+
+ self.pb._download_provider_info()
+ self.assertTrue(ProviderConfig.save.called)
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
+ lambda x: where('cacert.pem'))
+ def test_download_provider_info_not_modified(self):
+ self._setup_provider_config_with("1", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(True)
+ provider_path = self._produce_dummy_provider_json()
+
+ # set mtime to something really new
+ os.utime(provider_path, (-1, time.time()))
+
+ self.pb._download_provider_info()
+ # we check that it doesn't save the provider
+ # config, because it's new enough
+ self.assertFalse(ProviderConfig.save.called)
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_domain',
+ lambda x: where('testdomain.com'))
+ def test_download_provider_info_not_modified_and_no_cacert(self):
+ self._setup_provider_config_with("1", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(True)
+ provider_path = self._produce_dummy_provider_json()
+
+ # set mtime to something really new
+ os.utime(provider_path, (-1, time.time()))
+
+ self.pb._download_provider_info()
+ # we check that it doesn't save the provider
+ # config, because it's new enough
+ self.assertFalse(ProviderConfig.save.called)
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
+ lambda x: where('cacert.pem'))
+ def test_download_provider_info_modified(self):
+ self._setup_provider_config_with("1", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(True)
+ provider_path = self._produce_dummy_provider_json()
+
+ # set mtime to something really old
+ os.utime(provider_path, (-1, 100))
+
+ self.pb._download_provider_info()
+ self.assertTrue(ProviderConfig.load.called)
+ self.assertTrue(ProviderConfig.save.called)
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
+ lambda x: where('cacert.pem'))
+ def test_download_provider_info_unsupported_api_raises(self):
+ self._setup_provider_config_with("9999999", tempfile.mkdtemp())
+ self._setup_providerbootstrapper(False)
+ self._produce_dummy_provider_json()
+
+ with self.assertRaises(UnsupportedProviderAPI):
+ self.pb._download_provider_info()
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
+ lambda x: where('cacert.pem'))
+ def test_download_provider_info_unsupported_api(self):
+ self._setup_provider_config_with(SupportedAPIs.SUPPORTED_APIS[0],
+ tempfile.mkdtemp())
+ self._setup_providerbootstrapper(False)
+ self._produce_dummy_provider_json()
+
+ self.pb._download_provider_info()
+
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_api_uri',
+ lambda x: 'api.uri')
+ @mock.patch(
+ 'leap.bitmask.config.providerconfig.ProviderConfig.get_ca_cert_path',
+ lambda x: '/cert/path')
+ def test_check_api_certificate_skips(self):
+ self.pb._provider_config = ProviderConfig()
+ self.pb._session.get = mock.MagicMock(return_value=Response())
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=False)
+ self.pb._check_api_certificate()
+ self.assertFalse(self.pb._session.get.called)
+
+ @deferred()
+ def test_check_api_certificate_fails(self):
+ self.pb._provider_config = ProviderConfig()
+ self.pb._provider_config.get_api_uri = mock.MagicMock(
+ return_value="https://localhost:%s" % (self.https_port,))
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=os.path.join(
+ os.path.split(__file__)[0],
+ "wrongcert.pem"))
+ self.pb._provider_config.get_api_version = mock.MagicMock(
+ return_value="1")
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ def check(*args):
+ with self.assertRaises(requests.exceptions.SSLError):
+ self.pb._check_api_certificate()
+ d = threads.deferToThread(check)
+ return d
+
+ @deferred()
+ def test_check_api_certificate_succeeds(self):
+ self.pb._provider_config = ProviderConfig()
+ self.pb._provider_config.get_api_uri = mock.MagicMock(
+ return_value="https://localhost:%s" % (self.https_port,))
+ self.pb._provider_config.get_ca_cert_path = mock.MagicMock(
+ return_value=where('cacert.pem'))
+ self.pb._provider_config.get_api_version = mock.MagicMock(
+ return_value="1")
+
+ self.pb._should_proceed_cert = mock.MagicMock(return_value=True)
+
+ def check(*args):
+ self.pb._check_api_certificate()
+ d = threads.deferToThread(check)
+ return d